In [1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
from flax.linen.initializers import zeros as nn_zeros
import optax
import pymbar
import sys
import jax_amber_tanh_align as jax_amber
import pickle

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import MultipleLocator
from jax.scipy.stats.multivariate_normal import logpdf

import json

from flax.linen.initializers import lecun_normal

In [2]:
default_kernel_init = lecun_normal()

RT = jnp.float32(8.3144621E-3 * 300.0)
beta = jnp.float32(1.0)/RT
nm2ang = jnp.float32(10.0)
ang2nm = jnp.float32(0.1)

def get_energy_values(x, ener_funs, R0):
    ener_nHO_fun, ener_wHO_fun, ener_bond_fun = ener_funs
    enr_bnd = jax.vmap(ener_bond_fun)(x)
    enr_nHO = jax.vmap(ener_nHO_fun)(x)
    enr_wHO = jax.vmap(ener_wHO_fun, in_axes=(0, None))(x, R0)
    return enr_bnd, enr_nHO, enr_wHO

def get_trajectory (fname_prmtop, fname_dcd, nsamp):
    import mdtraj as md

    c = md.load (fname_dcd, top=fname_prmtop)
    c = c.superpose(c)
    crds = jnp.array (c.xyz)
    return crds[-nsamp:], crds[:-nsamp] # in nm unit

class AfflineCoupling(nn.Module):
    input_size: int
    i_dim: int
    hidden_layers: int
    hidden_dim: int

    @nn.compact
    def __call__(self, inputs, reverse=False):

        fixed_mask = jnp.ones((self.input_size), dtype=jnp.int32).reshape(-1, 3)
        fixed_mask = fixed_mask.at[:, self.i_dim].set(0)
        moved_mask = jnp.int32(1) - fixed_mask
        moved_mask = moved_mask.reshape(1, -1)
        fixed_mask = fixed_mask.reshape(1, -1)
        y = inputs * fixed_mask

        for _ in range(self.hidden_layers):
            y = nn.relu(nn.Dense(features=self.hidden_dim, kernel_init=default_kernel_init)(y))
            #y = nn.leaky_relu(nn.Dense(features=self.hidden_dim, kernel_init=default_kernel_init)(y))
            #y = nn.swish(nn.Dense(features=self.hidden_dim, kernel_init=default_kernel_init)(y))

        log_scale = nn.Dense(features=self.input_size, kernel_init=nn_zeros)(y)
        shift = nn.Dense(features=self.input_size, kernel_init=nn_zeros)(y)
        shift = shift * moved_mask
        log_scale = log_scale * moved_mask

        if reverse:
            log_scale = -log_scale
            outputs = (inputs - shift) * jnp.exp(log_scale)
        else:
            outputs = inputs * jnp.exp(log_scale) + shift

        return outputs, log_scale


class realNVP3(nn.Module):
    input_size: int
    hidden_layers: int
    hidden_dim: int

    def setup(self):

        self.af_x = AfflineCoupling(self.input_size, i_dim=0,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)
        self.af_y = AfflineCoupling(self.input_size, i_dim=1,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)
        self.af_z = AfflineCoupling(self.input_size, i_dim=2,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)

    @nn.compact
    def __call__(self, inputs, reverse=False):
        n_conf, n_atoms, n_dim = inputs.shape

        outputs = inputs.reshape(n_conf, -1)
        if reverse:
            outputs, log_J_z = self.af_z(outputs, reverse)
            outputs, log_J_y = self.af_y(outputs, reverse)
            outputs, log_J_x = self.af_x(outputs, reverse)
        else:
            outputs, log_J_x = self.af_x(outputs)
            outputs, log_J_y = self.af_y(outputs)
            outputs, log_J_z = self.af_z(outputs)

        return outputs.reshape(n_conf, n_atoms, n_dim), \
            (log_J_x + log_J_y + log_J_z).sum(axis=-1)

    
class NNflows(nn.Module):
    input_size: int
    hidden_layers: int
    hidden_dim: int

    def setup(self):

        self.af_x = AfflineCoupling(self.input_size, i_dim=0,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)
        self.af_y = AfflineCoupling(self.input_size, i_dim=1,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)
        self.af_z = AfflineCoupling(self.input_size, i_dim=2,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)

        self.af_x2 = AfflineCoupling(self.input_size, i_dim=0,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)
        self.af_y2 = AfflineCoupling(self.input_size, i_dim=1,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)
        self.af_z2 = AfflineCoupling(self.input_size, i_dim=2,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)

        self.af_x3 = AfflineCoupling(self.input_size, i_dim=0,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)
        self.af_y3 = AfflineCoupling(self.input_size, i_dim=1,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)
        self.af_z3 = AfflineCoupling(self.input_size, i_dim=2,
                                    hidden_layers=self.hidden_layers,
                                    hidden_dim=self.hidden_dim)

        #self.blocks = [[self.af_x, self.af_ym self.af_z], [self.af_x2, self.af_y2, self.af_z2], [self.af_x3, self.af_y3, self.af_z3]

    @nn.compact
    def __call__(self, inputs, reverse=False):
        n_conf, n_atoms, n_dim = inputs.shape

        outputs = inputs.reshape(n_conf, -1)
        if reverse:
            outputs, log_J_z3 = self.af_z3(outputs, reverse)
            outputs, log_J_y3 = self.af_y3(outputs, reverse)
            outputs, log_J_x3 = self.af_x3(outputs, reverse)

            outputs, log_J_z2 = self.af_z2(outputs, reverse)
            outputs, log_J_y2 = self.af_y2(outputs, reverse)
            outputs, log_J_x2 = self.af_x2(outputs, reverse)

            outputs, log_J_z = self.af_z(outputs, reverse)
            outputs, log_J_y = self.af_y(outputs, reverse)
            outputs, log_J_x = self.af_x(outputs, reverse)
        else:
            outputs, log_J_x = self.af_x(outputs)
            outputs, log_J_y = self.af_y(outputs)
            outputs, log_J_z = self.af_z(outputs)

            outputs, log_J_x2 = self.af_x2(outputs)
            outputs, log_J_y2 = self.af_y2(outputs)
            outputs, log_J_z2 = self.af_z2(outputs)

            outputs, log_J_x3 = self.af_x3(outputs)
            outputs, log_J_y3 = self.af_y3(outputs)
            outputs, log_J_z3 = self.af_z3(outputs)

        return outputs.reshape(n_conf, n_atoms, n_dim), \
            (log_J_x + log_J_y + log_J_z + log_J_x2 + log_J_y2 + log_J_z2 + log_J_x3 + log_J_y3 + log_J_z3).sum(axis=-1)
    
    
    
    
def rmsf(x):
    n, m, k = x.shape

    # Calculate the average positions over all frames
    avg_positions = jnp.mean(x, axis=0)

    # Calculate the deviations for each atom at each frame
    rmsf = jnp.sqrt(jnp.sum((x - avg_positions[None, :, :])**2)/m)

    # Calculate the RMSF for each atom by taking the root mean square of the deviations over all frames


    return rmsf


def get_gaussian_energy(x, x_A_flat_mean, x_A_flat_cov, factor):
    x_flat = x.reshape([x.shape[0], x.shape[1] * x.shape[2]])

    enr_gaus = -logpdf(x_flat, x_A_flat_mean, x_A_flat_cov/factor)

    return enr_gaus

def main(x_A, tx_A, R0_A, result = 'A.txt'):
    with open(result, 'w') as file:
        file.write('start\n')
    
    rng = jax.random.PRNGKey(0)
    rng, x_key, tx_key = jax.random.split(rng, num=3)

    x_A_flat = x_A.reshape([x_A.shape[0], x_A.shape[1] * x_A.shape[2]])
    x_A_mean = jnp.mean(x_A, axis=0)
    x_A_flat_mean = jnp.mean(x_A_flat, axis=0)
    x_A_flat_cov_oring = jnp.cov(x_A_flat, rowvar=False)
    x_A_flat_cov = x_A_flat_cov_oring + 0.1 * jnp.diag(x_A_flat_cov_oring).min() * jnp.eye(x_A_flat_cov_oring.shape[0])

    fixed_atoms = jnp.array(json_data['fixed']['atoms']) - 1
    #R0_A = jnp.array(json_data['fixed']['R0_A'])
    kval = jnp.float32(json_data['fixed']['kval'])

    nconf = x_A.shape[0]

    input_size = x_A.shape[1] * 3
    hidden_dim = json_data['realNVP']['hidden_dim']
    hidden_layers = json_data['realNVP']['hidden_layers']

    model = NNflows(input_size=input_size,
                     hidden_layers=hidden_layers,
                     hidden_dim=hidden_dim)
    ener_funs = jax_amber.get_amber_energy_funs(json_data['fname_prmtop'],
                                                fixed_atoms[-1],
                                                kval)
    _, ener_wHO_fun, ener_bond_fun = ener_funs

    _, enr_bnd_A0, enr_wHO_A0 = get_energy_values(x_A, ener_funs, R0_A)

    _, tenr_bnd_A0, tenr_wHO_A0 = get_energy_values(tx_A, ener_funs, R0_A)

    lr = json_data['optax']['learning_rate']
    total_steps = json_data['optax']['total_steps']
    alpha = json_data['optax']['alpha']
    scheduler = optax.cosine_decay_schedule(lr,
                                            decay_steps=total_steps,
                                            alpha=alpha)

    opt_method = optax.chain(
        optax.clip(1.0),
        optax.adam(learning_rate=scheduler)
    )

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=model.init(rng, x_A)['params'],
        tx=opt_method
    )

    fixed_R0 = (R0_A)

    loss_hist = []
    loss_oring_hist = []
    loss_oring_f_hist = []
    loss_f_hist = []
    loss_b_hist = []
    loss_oring_b_hist = []

    tm_A_hist = []
    m_A_hist = []
    tm_B_hist = []

    switch_gear = []
    iterations = 0

    state_hist_full = []

    for multip in jnp.arange(30):
        switch_gear.append(iterations)

        factor = 5 * (0.9 ** multip)
        
        #factor = 1
        # if multip == 8:
        # factor = 1

        #if multip == 11:
            #factor = 1
        # factor = 1

        print(factor)

        x_B_flat = jax.random.multivariate_normal(x_key, x_A_flat_mean, x_A_flat_cov / factor, (x_A.shape[0],),
                                                  method='svd')
        tx_B_flat = jax.random.multivariate_normal(tx_key, x_A_flat_mean, x_A_flat_cov / factor, (tx_A.shape[0],),
                                                   method='svd')
        # x_B_flat = jax.random.multivariate_normal(x_key, x_A_flat_mean, jnp.eye(x_A_flat_mean.shape[0]) / factor, [x_A.shape[0]])
        # tx_B_flat = jax.random.multivariate_normal(tx_key, x_A_flat_mean, jnp.eye(x_A_flat_mean.shape[0]) / factor, [x_A.shape[0]])

        x_B = x_B_flat.reshape(x_A.shape)
        tx_B = tx_B_flat.reshape(tx_A.shape)

        def get_gaussian_energy(x):
            x_flat = x.reshape([x.shape[0], x.shape[1] * x.shape[2]])

            enr_gaus = -logpdf(x_flat, x_A_flat_mean, x_A_flat_cov / factor)

            return enr_gaus

        enr_B0 = get_gaussian_energy(x_B)
        tenr_B0 = get_gaussian_energy(tx_B)

        ener_ref0 = \
            (enr_wHO_A0, enr_B0, enr_bnd_A0)

        tener_ref0 = \
            (tenr_wHO_A0, tenr_B0, tenr_bnd_A0)

        def loss_value(ener_wHO_fn, ener_bond_fn, enr0_wHO, m_B, log_J_F, m_A, log_J_R, fixed_R0):
            enr_wHO_A0, enr_B0, _ = enr0_wHO
            R0_A = fixed_R0

            enr_A = jax.vmap(ener_wHO_fn, in_axes=(0, None))(m_A, R0_A)

            # m_B_flat = m_B.reshape([m_B.shape[0], m_B.shape[1] * m_B.shape[2]])
            # enr_B = logpdf(m_B_flat, mean=mean, cov=cov)
            enr_B = get_gaussian_energy(m_B)

            # enr_bnd_A = jax.vmap(ener_bond_fn) (m_A)

            loss_F = beta * (enr_B - enr_wHO_A0) - log_J_F
            loss_R = beta * (enr_A - enr_B0) - log_J_R

            loss = loss_F.mean() + loss_R.mean()

            return loss, loss_F, loss_R

        @jax.jit
        def train_step(state, inputs, ener_wHO_ref0, fixed_R0):
            def loss_fn(params, apply_fn):
                x_A, x_B = inputs

                m_B, log_J_F = apply_fn({'params': params}, x_A)
                m_A, log_J_R = apply_fn({'params': params}, x_B, reverse=True)

                loss, loss_f, loss_b = loss_value(ener_wHO_fun, ener_bond_fun, ener_wHO_ref0,
                                        m_B, log_J_F, m_A, log_J_R, fixed_R0)

                return loss

            grads = jax.grad(loss_fn)(state.params, state.apply_fn)

            return state.apply_gradients(grads=grads)

        state_hist = []
        for epoch in range(50000):

            iterations += 1
            # for ist0 in range(0, nconf, 200):
            #    ied0 = ist0 + 200
            #    ied0 = jnp.where(ied0 < nconf, ied0, nconf)
            #    batch = (x_A[ist0:ied0], x_B[ist0:ied0])
            #    ener_wHO_ref0 = (enr_wHO_A0[ist0:ied0], enr_B0[ist0:ied0],
            #                     enr_bnd_A0)

            #    state = train_step(state, batch, ener_wHO_ref0, fixed_R0)
            #rng, x_key= jax.random.split(rng, num=2)
            #choice = jax.random.choice(x_key, 8000, shape = [500])
            #batch = (x_A[choice], x_B[choice])
            
            #ener_wHO_ref0 = (enr_wHO_A0[choice], enr_B0[choice],
            #                     enr_bnd_A0)
            batch = (x_A, x_B)
            ener_wHO_ref0 = (enr_wHO_A0, enr_B0, enr_bnd_A0)
            state = train_step(state, batch, ener_wHO_ref0, fixed_R0)

            if (epoch) % 500 == 0:
                #state_hist.append(state.params)

                tm_B, tlog_J_F = state.apply_fn({'params': state.params}, tx_A)

                tm_A, tlog_J_R = state.apply_fn({'params': state.params}, tx_B, reverse=True)

                loss, loss_f, loss_b = loss_value(ener_wHO_fun, ener_bond_fun, tener_ref0,
                                                  tm_B, tlog_J_F, tm_A, tlog_J_R, fixed_R0)
                

                loss_hist.append(loss.item())
                loss_f_hist.append(loss_f.mean().item())
                loss_b_hist.append(loss_b.mean().item())
                
                with open(result, 'a') as file:
                    file.write('factor: ' + str(factor) + ' epoch: '+ str(epoch) + '\n loss: ' + str(loss) + '\n')
                    file.write('forward: '+ str(loss_f.mean()) + ' backward: ' + str(loss_b.mean()) + '\n')

                m_B, log_J_F = state.apply_fn({'params': state.params}, x_A)

                m_A, log_J_R = state.apply_fn({'params': state.params}, x_B, reverse=True)
                
                if epoch>2000:
                
                    last5 = loss_hist[-5:-1]
                    last5_2 = loss_hist[-4:]

                    last5_b = loss_b_hist[-5:-1]
                    last5_b_2 = loss_b_hist[-4:]
                    
                    last5_f = loss_f_hist[-5:-1]
                    last5_f_2 = loss_f_hist[-4:]
                    
                    if loss_hist[-1] >= 10000:
                        state = state.replace (params=test_ckpt['params'],opt_state=test_ckpt['opt_state'])

                        break
                        
                    if loss_f_hist[-1] >= 10000:
                        state = state.replace (params=test_ckpt['params'],opt_state=test_ckpt['opt_state'])
                        break
                        
                    if loss_b_hist[-1] >= 10000:
                        state = state.replace (params=test_ckpt['params'],opt_state=test_ckpt['opt_state'])
                        break

                    if (np.mean(last5_f) + 2*np.abs(np.mean(last5_f))) <= np.mean(last5_f_2):

                        break

                    if epoch > 4000:

                        if np.mean(last5) <= np.mean(last5_2):
                            if np.mean(last5_b) <= np.mean(last5_b_2):
                                break
                                
                test_ckpt = {'params': state.params, 
                            'opt_state':state.opt_state}




                        
                #if epoch%1000 == 0:
                    #print(epoch, loss)
        state_hist.append(state.params)
        state_hist_full.append(state_hist)


    return state_hist_full, [x_A_flat_mean, x_A_flat_cov]


def get_BAR(x_A, state_list, parm, R0_A, multip, inner = -1):
    factor = 5 * (0.95 ** multip)

    input_size = x_A.shape[1] * 3
    hidden_dim = json_data['realNVP']['hidden_dim']
    hidden_layers = json_data['realNVP']['hidden_layers']

    model = realNVP3(input_size=input_size,
                     hidden_layers=hidden_layers,
                     hidden_dim=hidden_dim)

    opt_method = optax.chain(
        optax.clip(1.0),
        optax.adam(learning_rate=1e-5)
    )

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=model.init(rng, x_A)['params'],
        tx=opt_method
    )
    a = state_list[multip][inner]

    fixed_atoms = jnp.array(json_data['fixed']['atoms']) - 1
    kval = jnp.float32(json_data['fixed']['kval'])

    ener_funs = jax_amber.get_amber_energy_funs(json_data['fname_prmtop'],
                                                fixed_atoms,
                                                kval)

    _, ener_wHO_fun, ener_bond_fun = ener_funs

    x_B_flat = jax.random.multivariate_normal(tx_key, parm[0], parm[1] / factor, (x_A.shape[0],), method='svd')
    x_B = x_B_flat.reshape(x_A.shape)

    enr_bnd_A0, enr_wHO_A0 = get_energy_values(tx_A, ener_funs, R0_A)
    enr_wHO_B0 = get_gaussian_energy(x_B, parm[0], parm[1], factor)

    m_B, log_J_F = state.apply_fn({'params': a}, x_A)
    m_A, log_J_R = state.apply_fn({'params': a}, x_B, reverse=True)

    enr_bond_A, enr_wHO_A = get_energy_values(m_A, ener_funs, R0_A)
    enr_wHO_B = get_gaussian_energy(m_B, parm[0], parm[1], factor)

    dU_F = enr_wHO_B - enr_wHO_A0
    dU_R = enr_wHO_A - enr_wHO_B0
    phi_F = beta * dU_F - log_J_F
    phi_R = beta * dU_R - log_J_R

    f_BAR_wHO = pymbar.bar(phi_F, phi_R,
                           relative_tolerance=1.0e-5,
                           verbose=False,
                           compute_uncertainty=False)

    return RT*f_BAR_wHO['Delta_f']


In [3]:
title_addon = 'h,d,n=128,3,3,lesslr,0.9'

In [4]:
fname_json = 'F18_20_align/input_test.json'

with open(fname_json) as f:
    json_data = json.load(f)

fout = open(json_data['fname_log'], 'w', 1)

R0_A = jnp.array(json_data['fixed']['R0_A'])
R0_B = jnp.array(json_data['fixed']['R0_B'])

rng = jax.random.PRNGKey(0)
rng, x_key, tx_key = jax.random.split(rng, num=3)



In [5]:
x_A, tx_A = get_trajectory(json_data['fname_prmtop'],
                              json_data['fname_dcd_A'],
                              8000)

state_A, parm_A = main(x_A, tx_A, R0_A, f'A{title_addon}.txt')


with open(f'A{title_addon}.pkl', 'wb') as file:
      
    # A new file will be created
    pickle.dump([state_A, parm_A], file)
    
del state_A


In [None]:
x_B, tx_B = get_trajectory(json_data['fname_prmtop'],
                              json_data['fname_dcd_B'],
                              8000)

state_B, parm_B = main(x_B, tx_B, R0_B, f'B{title_addon}.txt')
    
    
with open(f'B{title_addon}.pkl', 'wb') as file:
    pickle.dump([state_B, parm_B], file)
    
del state_B

In [None]:
A_energy = []
B_energy = []
factor_list_A = []
factor_list_B = []
for idx in range(len(state_A)):
    A_temp = []
    B_temp = []
    for idx2 in range(len(state_A[idx])):
        A_temp.append(get_BAR(tx_A, state_A, parm_A, R0_A, idx, idx2))
        factor_list_A.append(10 * (0.8 ** idx))
    A_energy.append(A_temp)
    for idx3 in range(len(state_B[idx])):
        B_temp.append(get_BAR(tx_B, state_B, parm_B, R0_B, idx, idx3))
        factor_list_B.append(10 * (0.8 ** idx))
    B_energy.append(B_temp)

In [None]:
factor_list = [10 * (0.8 ** idx) for idx in range(12)]
A_final = [sublist[-1] for sublist in A_energy]
B_final = [sublist[-1] for sublist in B_energy]
diff = jnp.array(A_final) - jnp.array(B_final)
plt.plot(A_final, marker = 'o')
plt.xlabel('iteration*100')
plt.ylabel('-LBAR')
plt.gca().xaxis.set_major_locator(MultipleLocator(base=1))
#plt.yscale('log')
plt.title('absolute free energy with different factor for A')

In [None]:
factor_list = [10 * (0.8 ** idx) for idx in range(12)]
A_final = [sublist[-1] for sublist in A_energy]
B_final = [sublist[-1] for sublist in B_energy]
diff = jnp.array(A_final) - jnp.array(B_final)
plt.plot(B_final, marker = 'o')
plt.xlabel('factor')
plt.ylabel('-LBAR')
plt.gca().xaxis.set_major_locator(MultipleLocator(base=1))
#plt.yscale('log')
plt.title('absolute free energy with different factor for A')

In [None]:
factor_list = [10 * (0.9 ** idx) for idx in range(25)]
A_final = [sublist[-1] for sublist in A_energy]
B_final = [sublist[-1] for sublist in B_energy]
diff = jnp.array(A_final) - jnp.array(B_final)
plt.plot(factor_list,diff, marker = 'o')
plt.xlabel('factor')
plt.ylabel('-LBAR')
plt.xlim(0,10)
plt.gca().xaxis.set_major_locator(MultipleLocator(base=1))
#plt.yscale('log')
plt.title('absolute free energy with different factor for A')

In [None]:
factor_list = [10 * (0.8 ** idx) for idx in range(12)]
A_final = [sublist[-1] for sublist in A_energy]
B_final = [sublist[-1] for sublist in B_energy]
diff = jnp.array(A_final) - jnp.array(B_final)
plt.plot(-jnp.array(A_energy[-4]), marker = 'o', label = 'A')
plt.xlabel('iteration*100')
plt.ylabel('-LBAR')
plt.yscale('log')
factor_list = [10 * (0.8 ** idx) for idx in range(12)]
A_final = [sublist[-1] for sublist in A_energy]
B_final = [sublist[-1] for sublist in B_energy]
diff = jnp.array(A_final) - jnp.array(B_final)
plt.plot(-jnp.array(B_energy[-4]), marker = 'o', label = 'B')
plt.xlabel('iteration*100')
plt.ylabel('-LBAR')
plt.yscale('log')
#plt.title('absolute free energy with factor 1 for B')
plt.title('absolute free energy with factor 1 for A and B')
plt.legend()

In [None]:
factor_list = [10 * (0.8 ** idx) for idx in range(12)]
A_final = [sublist[-1] for sublist in A_energy]
B_final = [sublist[-1] for sublist in B_energy]
diff = jnp.array(A_final) - jnp.array(B_final)
plt.plot(-jnp.array(A_energy[0]), marker = 'o', label = 'A')
plt.xlabel('iteration*100')
plt.ylabel('-LBAR')
plt.yscale('log')
plt.title('absolute free energy with factor 10 for A')
factor_list = [10 * (0.8 ** idx) for idx in range(12)]
A_final = [sublist[-1] for sublist in A_energy]
B_final = [sublist[-1] for sublist in B_energy]
diff = jnp.array(A_final) - jnp.array(B_final)
plt.plot(-jnp.array(B_energy[0]), marker = 'o', label = 'B')
plt.xlabel('iteration*100')
plt.ylabel('-LBAR')
plt.yscale('log')
plt.title('absolute free energy with factor 10 for A and B')

In [None]:
flattened_A = [element for sublist in A_energy for element in sublist]
flattened_B = [element for sublist in B_energy for element in sublist]
plt.plot(-jnp.array(flattened_A), marker = 'o', label = 'A')
plt.plot(-jnp.array(flattened_B), marker = 'o', label = 'B')
full_len = 0
#for i in A_energy:
#    full_len += len(i)
#    plt.axvline(full_len, color = 'blue')
full_len = 0
#for i in B_energy:
#    full_len += len(i)
#    plt.axvline(full_len, color = 'orange')

plt.xlabel('iteration*100')
plt.ylabel('-LBAR')
plt.yscale('log')
plt.title('over all absolute free energy for A and B')
plt.legend()