In [None]:
from typing import Sequence, Optional
import numpy as np
import jax 
import jax.numpy as jnp
from flax import linen as nn 
from flax.training import train_state 
from flax.linen.initializers import zeros as nn_zeros
import optax 
import pymbar 
import sys 

import pickle

from flax.linen.initializers import lecun_normal

import libs.tool_box as TB
import jax_amber

default_kernel_init = lecun_normal()
    ### This function initialize the WEIGHTS of neural network.
    #   generate random weights with values drawn from a Guassian distribution
    #   with mean = 0 and standard deviation of 1/sqrt(n) 
    ### where n is the number of input neurons.

RT = jnp.float32(8.3144621E-3 * 300.0) 
beta = jnp.float32(1.0)/RT 
nm2ang = jnp.float32(10.0) #conversion nanometers -> angstroms
ang2nm = jnp.float32(0.1) #conversion angstroms -> nanometers


# Load setting files

In [None]:
fname_json = 'data/F18/input_test.json'
with open(fname_json) as f:
    json_data = json.load(f)

# Initialize Data

In [None]:
### main training of the model
#   read from an input json file
###
fout = open (json_data['fname_log'], 'w', 1)


# x = training data
# tx = testing data
x_A, tx_A = TB.get_trajectory (json_data['fname_prmtop'],
                      json_data['fname_dcd_A'],
                      json_data['nsamp'])
x_B, tx_B = TB.get_trajectory (json_data['fname_prmtop'],
                      json_data['fname_dcd_B'],
                      json_data['nsamp'])


# Combine data into a tuple set
inputs = (x_A, x_B)

# Get number of configurations in the training data
nconf = x_A.shape[0]

# Extract fixed atom indices and reference positions for the restraint
fixed_atoms = jnp.array (json_data['fixed']['atoms']) - 1
R0_A = jnp.array (json_data['fixed']['R0_A'])
R0_B = jnp.array (json_data['fixed']['R0_B'])
#dF0  = jnp.float32 (json_data['fixed']['dF0']) ### previous testing line
kval = jnp.float32 (json_data['fixed']['kval'])
dR0_AB = R0_B - R0_A
d_lam = json_data['d_lambda']

# Extract the index of the atom to which the restraint is applied
fixed_iatom = fixed_atoms[-1]

### Uses jax_amber.py file to calculate energy function
ener_funs = jax_amber.get_amber_energy_funs (json_data['fname_prmtop'],
                                            fixed_iatom,
                                            kval)
ener_nHO_fun, ener_wHO_fun, ener_bond_fun = ener_funs 

# Compute the reference energies for each system and each energy component
enr_bnd_A0, enr_nHO_A0, enr_wHO_A0 = TB.get_energy_values (x_A, ener_funs, R0_A)
enr_bnd_B0, enr_nHO_B0, enr_wHO_B0 = TB.get_energy_values (x_B, ener_funs, R0_B)
###(TESTING)
_, _, enr_wHO_A0_test = TB.get_energy_values (tx_A, ener_funs, R0_A)
_, _, enr_wHO_B0_test = TB.get_energy_values (tx_B, ener_funs, R0_B)
###

# Calculate energy difference, average fixed Z-coordinates and bond energy
dE0 = (enr_wHO_B0-enr_wHO_A0).mean()
Z_A = x_A[:,fixed_iatom,2].mean()
Z_B = x_B[:,fixed_iatom,2].mean()
enr_bnd_A0 = enr_bnd_A0.mean()
enr_bnd_B0 = enr_bnd_B0.mean()

# Print output to log file
print (' Fixed_Z:           {:12.6f} {:12.6f}'.format(Z_A, Z_B), file=fout)
print (' <U_wHO0>(kJ/mol):  {:12.6f} {:12.6f}'.format (enr_wHO_A0.mean(), enr_wHO_B0.mean()), file=fout)
print (' <U_nHO0>(kJ/mol):  {:12.6f} {:12.6f}'.format (enr_nHO_A0.mean(), enr_nHO_B0.mean()), file=fout)
print ('<dU>[w/no](kJ/mol): {:12.6f} {:12.6f}'.format( dE0, (enr_nHO_B0-enr_nHO_A0).mean() ), file=fout)
print ('<enr_bond>(kJ/mol): {:12.6f} {:12.6f}'.format(enr_bnd_A0, enr_bnd_B0), file=fout)

# Save energy values for reference and testing
ener_ref0 = (enr_nHO_A0, enr_nHO_B0), \
        (enr_wHO_A0, enr_wHO_B0, enr_bnd_A0, enr_bnd_B0 )
ener_wHO_ref0_test = (enr_wHO_A0_test, enr_wHO_B0_test, enr_bnd_A0, enr_bnd_B0 )

# Set optimization parameters   
lr = json_data['optax']['learning_rate']
total_steps = json_data['optax']['total_steps']
alpha = json_data['optax']['alpha']
scheduler = optax.cosine_decay_schedule (lr, 
                                        decay_steps=total_steps,
                                        alpha=alpha)
opt_method = optax.adam (learning_rate=scheduler)

rng = jax.random.PRNGKey(0)
rng, key = jax.random.split (rng)

# Set RealNVP parameters
input_size = x_A.shape[1]*3
hidden_dim = json_data['realNVP']['hidden_dim']
hidden_layers=json_data['realNVP']['hidden_layers']
mask_fixed = jnp.array(json_data['realNVP']['mask_fixed']) - 1
model = realNVP3(input_size=input_size, 
                 hidden_layers=hidden_layers,
                 hidden_dim=hidden_dim,
                 fixed_atoms=mask_fixed)

# Create the initial training state
state = train_state.TrainState.create (
    apply_fn=model.apply,
    params=model.init (key, x_A)['params'],
    tx=opt_method
)


lam_max = jnp.float32(1.0)
lam = lam_max
test_ckpt = {'params': state.params, 
            'opt_state':state.opt_state}

### Depends on the input json data
#   If the restart_nn option is enabled, 
### Load the neural network from the checkpoint file
if json_data['restart_nn']['run']:
    ckpt = checkpoint_load (json_data['restart_nn']['fname_nn_pkl'])

    state = state.replace (step=state.step, 
                            params=ckpt['params'], 
                            opt_state=ckpt['opt_state'])
    lam = ckpt['lam']



# Training loop

In [None]:
@jax.jit
def train_step (state, inputs, ener_wHO_ref0, fixed_R0):
    # Train a step for the model
    def loss_fn (params, apply_fn):
        # Calculate loss value
        x_A, x_B = inputs

        m_B, log_J_F = apply_fn ({'params':params}, x_A)
        m_A, log_J_R = apply_fn ({'params':params}, x_B, reverse=True)

        loss_wBnd, loss = TB.loss_value (ener_wHO_fun, ener_bond_fun, ener_wHO_ref0,
            m_B, log_J_F, m_A, log_J_R, fixed_R0)

        return loss_wBnd

    grads = jax.grad (loss_fn) (state.params, state.apply_fn)

    return state.apply_gradients (grads=grads)


R_A = R0_B - lam*dR0_AB 
R_B = R0_A + lam*dR0_AB
dE  = lam*dE0
fixed_R0 = (R_A, R_B, dE)

loss_old = 0.0
loss_test_min = 1000.0
loss_test_list = []

for epoch in range (json_data['nepoch']):

    # loop over batches
    for ist0 in range (0,nconf,1000):
        ied0 = ist0 + 1000
        ied0 = jnp.where (ied0 < nconf, ied0, nconf)
        batch = (x_A[ist0:ied0], x_B[ist0:ied0])
        ener_wHO_ref0 = (enr_wHO_A0[ist0:ied0], enr_wHO_B0[ist0:ied0], \
                enr_bnd_A0, enr_bnd_B0)

        state = train_step (state, batch, ener_wHO_ref0, fixed_R0)

    # every 10 epochs print the loss
    if (epoch+1)%10 == 0:
        m_B, log_J_F = state.apply_fn ({'params':state.params}, x_A)
        m_A, log_J_R = state.apply_fn ({'params':state.params}, x_B, reverse=True)

        loss_Wbnd, loss = TB.loss_value (ener_wHO_fun, ener_bond_fun, ener_ref0[1],
                         m_B, log_J_F, m_A, log_J_R, fixed_R0)
        diff = loss_Wbnd - loss_old 
        loss_old = loss_Wbnd

        m_B, log_J_F = state.apply_fn ({'params':state.params}, tx_A)
        m_A, log_J_R = state.apply_fn ({'params':state.params}, tx_B, reverse=True)
        _, loss_test = TB.loss_value (ener_wHO_fun, ener_bond_fun, ener_wHO_ref0_test,
                         m_B, log_J_F, m_A, log_J_R, fixed_R0)
        print ('loss {:8d} {:12.4f} {:12.4f} {:12.4f} {:14.4f}'.format(
            epoch+1, loss, loss_Wbnd, diff, loss_test),file=fout)

        # save the state if the loss on the test set is reduced
        if loss_test < loss_test_min:
            loss_test_min = loss_test 
            test_ckpt = {'params': state.params, 
                        'opt_state':state.opt_state}

        # break if the loss is less than 0
        if loss < jnp.float32(0.0):
            break 

        loss_test_list.append (loss_test)

    # every 200 epochs print progress
    if (epoch+1)%200 == 0:
        TB.print_progress (state, inputs,
                        ener_funs,
                        ener_ref0,
                        fixed_iatom,
                        fixed_R0, fout)
        loss_test = jnp.array (loss_test_list).min()

        # break condition
        if loss_test > loss_test_min + 3.0:
            print ('loss_test_min', loss_test_min, file=fout)
            break
        loss_test_list = [] # clear the loss test list for next epochs


ckpt = {'params': state.params, 'opt_state': state.opt_state, 'lam': lam}
checkpoint_save (json_data['fname_nn_pkl'], ckpt)

# Save the current state of the model with the best test loss so far
checkpoint_save (json_data['fname_nn_test_pkl'], test_ckpt)

# Update the state of the model with the saved state
state = state.replace (step=0,
                        params=test_ckpt['params'],
                        opt_state=test_ckpt['opt_state'])

print ("===SUMMARY===", file=fout)

# Print the progress ~
print_progress (state, inputs,
                        ener_funs,
                        ener_ref0,
                        fixed_iatom,
                        fixed_R0, fout)
