In [1]:
from endstate_correction.analysis import (
    plot_overlap_for_equilibrium_free_energy,
    plot_results_for_equilibrium_free_energy,
)
from endstate_correction.equ import calculate_u_kn
import pathlib
import endstate_correction
from endstate_correction.system import create_charmm_system
from openmm.app import CharmmParameterSet, CharmmPsfFile, CharmmCrdFile
import glob, pickle
import numpy as np
import os
from openmm import unit
import mdtraj as md
from endstate_correction.utils import convert_pickle_to_dcd_file
from endstate_correction.constant import kBT
from pymbar import mbar
from endstate_correction.protocol import Results, Protocol, perform_endstate_correction
from endstate_correction.analysis import plot_endstate_correction_results
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# bootstrap metric
def bootstrap_exp(fct, values):
    assert callable(fct) == True
    bootstrapped_metric = []
    # bootstrap metric to generate test distribution
    for _ in range(1000):
        indices = np.random.choice(range(0, len(values)), size=len(values), replace=True)
        selection = np.take(values, indices)
        r = fct(selection)['Delta_f']
        bootstrapped_metric.append(r)
    
    # define 90% CI
    alpha = 10.0
    lower_p = alpha / 2.0
    # get value at or near percentile (take a look at the definition of percentile if 
    # you have less than 100 values to make sure you understand what is happening)
    lower = np.percentile(bootstrapped_metric, lower_p)
    upper_p = (100 - alpha) + (alpha / 2.0)
    upper = np.percentile(bootstrapped_metric, upper_p)
    # calculate true mean
    mean = fct(values)['Delta_f']

    return mean, lower, upper

def bootstrap_bar(fct, fw_f, rv_f):
    assert callable(fct) == True
    bootstrapped_metric = []
    # bootstrap metric to generate test distribution
    for _ in range(1000):
        indices_fw = np.random.choice(range(0, len(fw_f)), size=len(fw_f), replace=True)
        selection_fw = np.take(fw_f, indices_fw)
        indices_rv = np.random.choice(range(0, len(rv_f)), size=len(rv_f), replace=True)
        selection_rv = np.take(rv_f, indices_rv)
        r = fct(selection_fw, selection_rv)['Delta_f']
        bootstrapped_metric.append(r)
    
    # define 90% CI
    alpha = 10.0
    lower_p = alpha / 2.0
    # get value at or near percentile (take a look at the definition of percentile if 
    # you have less than 100 values to make sure you understand what is happening)
    lower = np.percentile(bootstrapped_metric, lower_p)
    upper_p = (100 - alpha) + (alpha / 2.0)
    upper = np.percentile(bootstrapped_metric, upper_p)
    # calculate true mean
    mean = fct(fw_f, rv_f)['Delta_f']

    return mean, lower, upper

def plot_dist(r, system_name, axs):
        import seaborn as sns
        from pymbar import bar
        from pymbar.other_estimators import exp
        import matplotlib.pyplot as plt
        import numpy as np

        tmp = np.append(r.dE_mm_to_qml, r.dE_qml_to_mm * -1)
        min_F = abs(min(tmp))
        w_F = (abs(r.dE_mm_to_qml) - min_F) * -1
        w_R = r.dE_qml_to_mm - min_F
        # draw histogramms and results for FEP and NEQ
        # start with FEP
        #w_F, w_R = r.dE_mm_to_qml, r.dE_qml_to_mm
        rexp_f, rexp_f_lower, rexp_f_upper = bootstrap_exp(exp, w_F)
        rexp_r, rexp_r_lower, rexp_r_upper = bootstrap_exp(exp, w_R)

        bar_bi, bar_bi_lower, bar_bi_upper = bootstrap_bar(bar, w_F, w_R)
        #fig, axs = plt.subplots(1, 2, figsize=(16.0, 8.0), dpi=600)

        axs[0].set_title(f'FEP - {system_name}', fontsize=15)
        ax1 = sns.kdeplot(w_F, label='f_forw', ax=axs[0], fill=True, alpha=.5,)
        sns.rugplot(w_F, ax=ax1, lw=1, alpha=.1)
        ax2 = sns.kdeplot(w_R*-1, label='f_rev', ax=axs[0], fill=True, alpha=.5,)
        sns.rugplot(w_R*-1, ax=ax2, lw=1, alpha=.1)

        # bidirectional estimate
        axs[0].axvline(x = bar_bi,
                color = 'red',lw=5,ls='--',
                label = r'$\Delta G_{bid}$')

        # EXP forward
        axs[0].axvline(x = rexp_f, 
                color = 'purple', lw=5,ls=':',
                label = r'$\Delta G_{forw}$', alpha=0.5)

        # EXP reverse
        axs[0].axvline(x = rexp_r * -1, 
                color = 'green', lw=5,ls=':', alpha=0.5,
                label = r'$\Delta G_{rev}$')
        # equ dG
        equ_r = (r.equ_mbar.compute_free_energy_differences()["Delta_f"][0][-1] + min_F)
        equ_dDG = (r.equ_mbar.compute_free_energy_differences()["dDelta_f"][0][-1])
        axs[0].axvline(x=equ_r, 
                color = 'black',lw=5,ls='-', alpha=0.5,
                label = r'$\Delta G_{equ}$')


        textstr_fep = f'''$\Delta G_{{forw}}$ = {rexp_f:.2f} [95% CI: {rexp_f_lower:.2f}; {rexp_f_upper:.2f}]
$\Delta G_{{rev}}$ = {rexp_r*-1:.2f} [95% CI: {rexp_r_upper*-1:.2f}; {rexp_r_lower*-1:.2f}]
$\Delta G_{{bid}}$ = {bar_bi:.2f} [95% CI: {bar_bi_lower:.2f}; {bar_bi_upper:.2f}]
$\Delta G_{{equ}}$ = {equ_r:.2f} $\pm$ {equ_dDG*-1:.2f}'''



        ############################################################
        # now NEQ
        #tmp = np.append(r.W_mm_to_qml, r.W_qml_to_mm * -1)
        #min_F = abs(min(tmp))
        w_F, w_R = r.W_mm_to_qml, r.W_qml_to_mm
        w_F = (abs(r.W_mm_to_qml) - min_F) * -1
        w_R = r.W_qml_to_mm - min_F
        rexp_f, rexp_f_lower, rexp_f_upper = bootstrap_exp(exp, w_F)
        rexp_r, rexp_r_lower, rexp_r_upper = bootstrap_exp(exp, w_R)

        bar_bi, bar_bi_lower, bar_bi_upper = bootstrap_bar(bar, w_F, w_R)
        axs[1].set_title(f'NEQ - {system_name}', fontsize=15)

        ax1 = sns.kdeplot(w_F, ax=axs[1], fill=True, alpha=.5,)
        sns.rugplot(w_F, ax=ax1, lw=1, alpha=.1)
        ax2 = sns.kdeplot(w_R*-1, ax=axs[1], fill=True, alpha=.5,)
        sns.rugplot(w_R*-1, ax=ax2, lw=1, alpha=.1)

        # bidirectional estimate
        axs[1].axvline(x = bar_bi,
                color = 'red',lw=5,ls='--', alpha=.5,
                label = r'$\Delta G_{bid}$')
        # EXP forward
        axs[1].axvline(x = rexp_f, 
                color = 'purple', lw=5,ls=':', alpha=.5,
                label = '$\Delta G_{forw}$')
        # EXP rev
        axs[1].axvline(x = rexp_r*-1, 
                color = 'green',lw=5,ls=':', alpha=.5,
                label = '$\Delta G_{rev}$')
        # equ dG
        axs[1].axvline(equ_r, 
                color = 'black',lw=5,ls='-', alpha=.5,
                label = r'$\Delta G_{equ}$')


        textstr_neq = f'''$\Delta G_{{forw}}$ = {rexp_f:.2f} [95% CI: {rexp_f_lower:.2f}; {rexp_f_upper:.2f}] 
$\Delta G_{{rev}}$ = {rexp_r*-1:.2f} [95% CI: {rexp_r_upper*-1:.2f}; {rexp_r_lower*-1:.2f}]
$\Delta G_{{bid}}$ = {bar_bi:.2f} [95% CI: {bar_bi_lower:.2f}; {bar_bi_upper:.2f}]
$\Delta G_{{equ}}$ = {equ_r:.2f} $\pm$ {equ_dDG*-1:.2f}'''


        # these are matplotlib.patch.Patch properties
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)

        # place a text box in upper left in axes coords
        axs[0].text(0.05, 0.95, textstr_fep, transform=axs[0].transAxes, fontsize=15,
                verticalalignment='top', bbox=props,horizontalalignment='left')
        axs[1].text(0.05, 0.95, textstr_neq, transform=axs[1].transAxes, fontsize=15,
                verticalalignment='top', bbox=props)

        axs[0].legend(loc='upper right',fontsize=14)
        axs[1].legend(loc='upper right',fontsize=14)
        
def load_equ_samples(
    system_name: str, base:str, nr_of_samples: int = 5_000, nr_of_steps: int = 1_000, run_id:int = 1
)->list:
    """Helper function that loads trajectories from the test data"""

    trajs = []
    path = f"{base}/{system_name}/sampling_charmmff/run{run_id:02d}/"

    for lamb in np.linspace(0, 1, 11):
        file = glob.glob(
            f"{path}/{system_name}_samples_{nr_of_samples}_steps_{nr_of_steps}_lamb_{lamb:.4f}.dcd"
        )
        if len(file) == 2:
            raise RuntimeError("Multiple traj files present. Abort.")
        if len(file) == 0:
            raise RuntimeError(
                "WARNING! Incomplete equ sampling. Proceed with cautions."
            )

        trajs.append(md.open(file[0]).read()[0] * unit.angstrom)
    return trajs

def convert_pickle_files(system_name: str, base:str, path_to_psf:str, path_to_crd:str, nr_of_samples: int = 5_000, nr_of_steps: int = 1_000, run_id:int=1):
    """Helper function that converts pickle to dcd files"""
    path = f"{base}/{system_name}/sampling_charmmff/run{run_id:02d}/"

    for lamb in np.linspace(0, 1, 11):
        file = glob.glob(
            f"{path}/{system_name}_samples_{nr_of_samples}_steps_{nr_of_steps}_lamb_{lamb:.4f}.pickle"
        )
        if file[0]:
            dcd_output_path = f"{path}/{system_name}_samples_{nr_of_samples}_steps_{nr_of_steps}_lamb_{lamb:.4f}.dcd"
            pdb_output_path = f"{path}/{system_name}.pdb"
            convert_pickle_to_dcd_file(file[0], path_to_psf, path_to_crd, dcd_output_path, pdb_output_path)

########################################################
########################################################
# ----------------- vacuum -----------------------------
# get all relevant files
path = pathlib.Path(endstate_correction.__file__).resolve().parent
hipen_testsystem = f"{path}/data/hipen_data"

all_systems = [['ZINC00086442', 'ZINC00079729', "ZINC00077329", "ZINC00087557","ZINC00107550"],["ZINC00107778", "ZINC00123162","ZINC00133435","ZINC00140610","ZINC00164361"],["ZINC00167648","ZINC00169358", "ZINC01036618","ZINC01755198","ZINC01867000"],["ZINC03127671","ZINC04344392","ZINC04363792", "ZINC06568023","ZINC33381936"]]

for batch_idx, list_of_systems in enumerate(all_systems):
    fig, axs = plt.subplots(len(list_of_systems), 2, figsize=(15.0, 13.0), dpi=600)
    for idx, system_name in enumerate(list_of_systems):
        print(idx)
        path_to_psf = f"{hipen_testsystem}/{system_name}/{system_name}.psf"
        path_to_crd = f"{hipen_testsystem}/{system_name}/{system_name}.crd"
        psf = CharmmPsfFile(path_to_psf)
        coord = CharmmCrdFile(path_to_crd)
        params = CharmmParameterSet(
            f"{hipen_testsystem}/top_all36_cgenff.rtf",
            f"{hipen_testsystem}/par_all36_cgenff.prm",
            f"{hipen_testsystem}/{system_name}/{system_name}.str",
        )
        # define region that should be treated with the qml
        chains = list(psf.topology.chains())
        ml_atoms = [atom.index for atom in chains[0].atoms()]
        # create openmm simulation system
        sim = create_charmm_system(
            psf=psf, parameters=params, env="vacuum", ml_atoms=ml_atoms
        )

        # convert pickle files if necessary
        print('converting pickle file ...')
        try:
            convert_pickle_files(system_name, '/data/shared/projects/endstate_rew', path_to_psf, path_to_crd)
        except OSError:
            print('Pickle files already converted ...')
        # load all trajectories
        trajs = load_equ_samples(system_name, '/data/shared/projects/endstate_rew')
        # calculate input for MBAR
        N_k, u_kn = calculate_u_kn(
            trajs=trajs,
            every_nth_frame=10,
            sim=sim,
        )

        # calculate equilibrium free energy difference
        print('#-----------------------------------------#')
        m = mbar.MBAR(u_kn, N_k)
        print(f'Free energy difference: {m.compute_free_energy_differences()["Delta_f"][0][-1]}')
        print(f'Free energy uncertainty: {m.compute_free_energy_differences()["dDelta_f"][0][-1]}')
        print('#-----------------------------------------#')

        # perform FEP
        fep_file_name = f'/data/shared/projects/endstate_rew/{system_name}/switching_charmmff/{system_name}_FEP_results_5000_switches.pickle'
        if os.path.isfile(fep_file_name):
            r = pickle.load(open(fep_file_name, 'rb'))
        else:
            # load all trajectories
            mm_trajs = []
            qml_trajs = []
            for i in [1,2,3]:
                # convert pickle files if necessary
                try:
                    convert_pickle_files(system_name, '/data/shared/projects/endstate_rew', path_to_psf, path_to_crd, run_id = i)
                except OSError:
                    print('Pickle files already converted ...')

                trajs = load_equ_samples(system_name, '/data/shared/projects/endstate_rew', run_id = i)
                mm_trajs.extend(trajs[0])
                qml_trajs.extend(trajs[-1])

            fep_protocol = Protocol(
                method="FEP",
                direction="bidirectional",
                sim=sim,
                trajectories=[mm_trajs, qml_trajs],
                nr_of_switches=5_000,
            )
            r = perform_endstate_correction(fep_protocol)
            pickle.dump(r, open(fep_file_name, 'wb'))
        # load NEQ results
        r.W_mm_to_qml = pickle.load(open(f'/data/shared/projects/endstate_rew/{system_name}/switching_charmmff/{system_name}_neq_ws_from_mm_to_qml_500_5001.pickle', 'rb')) /kBT
        r.W_qml_to_mm = pickle.load(open(f'/data/shared/projects/endstate_rew/{system_name}/switching_charmmff/{system_name}_neq_ws_from_qml_to_mm_500_5001.pickle', 'rb')) /kBT
        # put mbar in results object
        r.equ_mbar = mbar.MBAR(u_kn, N_k)

        #plot_endstate_correction_results(
        #    system_name, r, f"{system_name}_results.png"
        #)
        plot_dist(r, system_name, axs[idx])
        
    plt.tight_layout()
    plt.savefig(f'{batch_idx}_batch_dist.png')
    plt.show()




0
Generating charmm system in vacuum




nnpops CUDA
platform='CUDA'
env='vacuum'
ml_atoms=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
/data/shared/software/python_env/anaconda3/envs/rew/lib/python3.9/site-packages/torchani/resources/
converting pickle file ...
Pickle files already converted ...dcdplugin) Could not open file '/data/shared/projects/endstate_rew/ZINC00086442/sampling_charmmff/run01//ZINC00086442_samples_5000_steps_1000_lamb_0.0000.dcd' for writing

Number of samples loaded: 4400


100%|██████████| 4400/4400 [00:17<00:00, 245.24it/s]
100%|██████████| 4400/4400 [00:14<00:00, 312.02it/s]
100%|██████████| 4400/4400 [00:13<00:00, 332.75it/s]
100%|██████████| 4400/4400 [00:13<00:00, 335.39it/s]
100%|██████████| 4400/4400 [00:13<00:00, 331.71it/s]
100%|██████████| 4400/4400 [00:12<00:00, 339.45it/s]
100%|██████████| 4400/4400 [00:12<00:00, 347.24it/s]
100%|██████████| 4400/4400 [00:12<00:00, 339.97it/s]
100%|██████████| 4400/4400 [00:13<00:00, 335.49it/s]
100%|██████████| 4400/4400 [00:13<00:00, 330.51it/s]
100%|██████████| 4400/4400 [00:14<00:00, 299.25it/s]


#-----------------------------------------#




Free energy difference: -656584.5315802457
Free energy uncertainty: 0.10930165820912931
#-----------------------------------------#




1
Generating charmm system in vacuum
nnpops CUDA
platform='CUDA'
env='vacuum'
ml_atoms=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
/data/shared/software/python_env/anaconda3/envs/rew/lib/python3.9/site-packages/torchani/resources/
converting pickle file ...
Pickle files already converted ...
dcdplugin) Could not open file '/data/shared/projects/endstate_rew/ZINC00079729/sampling_charmmff/run01//ZINC00079729_samples_5000_steps_1000_lamb_0.0000.dcd' for writing
Number of samples loaded: 4400


100%|██████████| 4400/4400 [00:13<00:00, 323.48it/s]
100%|██████████| 4400/4400 [00:13<00:00, 334.92it/s]
100%|██████████| 4400/4400 [00:11<00:00, 372.13it/s]
100%|██████████| 4400/4400 [00:11<00:00, 386.57it/s]
100%|██████████| 4400/4400 [00:12<00:00, 362.20it/s]
100%|██████████| 4400/4400 [00:12<00:00, 340.83it/s]
100%|██████████| 4400/4400 [00:12<00:00, 347.02it/s]
100%|██████████| 4400/4400 [00:12<00:00, 346.32it/s]
100%|██████████| 4400/4400 [00:12<00:00, 346.66it/s]
100%|██████████| 4400/4400 [00:12<00:00, 346.90it/s]
100%|██████████| 4400/4400 [00:12<00:00, 340.81it/s]


#-----------------------------------------#
Free energy difference: -2105811.309851527
Free energy uncertainty: 0.09195730712750406
#-----------------------------------------#




2
Generating charmm system in vacuum
nnpops CUDA
platform='CUDA'
env='vacuum'
ml_atoms=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
/data/shared/software/python_env/anaconda3/envs/rew/lib/python3.9/site-packages/torchani/resources/
converting pickle file ...
Number of samples loaded: 4400


100%|██████████| 4400/4400 [00:12<00:00, 362.84it/s]
100%|██████████| 4400/4400 [00:13<00:00, 326.59it/s]
100%|██████████| 4400/4400 [00:12<00:00, 348.68it/s]
100%|██████████| 4400/4400 [00:12<00:00, 340.01it/s]
100%|██████████| 4400/4400 [00:12<00:00, 352.02it/s]
100%|██████████| 4400/4400 [00:12<00:00, 353.75it/s]
100%|██████████| 4400/4400 [00:12<00:00, 346.07it/s]
100%|██████████| 4400/4400 [00:12<00:00, 353.53it/s]
100%|██████████| 4400/4400 [00:12<00:00, 364.27it/s]
100%|██████████| 4400/4400 [00:11<00:00, 381.29it/s]
100%|██████████| 4400/4400 [00:12<00:00, 360.19it/s]


#-----------------------------------------#
Free energy difference: -940543.7199070829
Free energy uncertainty: 0.07795397242121609
#-----------------------------------------#




3
Generating charmm system in vacuum
nnpops CUDA
platform='CUDA'
env='vacuum'
ml_atoms=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
/data/shared/software/python_env/anaconda3/envs/rew/lib/python3.9/site-packages/torchani/resources/
converting pickle file ...
Pickle files already converted ...
dcdplugin) Could not open file '/data/shared/projects/endstate_rew/ZINC00087557/sampling_charmmff/run01//ZINC00087557_samples_5000_steps_1000_lamb_0.0000.dcd' for writing
Number of samples loaded: 4400


100%|██████████| 4400/4400 [00:15<00:00, 284.67it/s]
100%|██████████| 4400/4400 [00:14<00:00, 301.81it/s]
100%|██████████| 4400/4400 [00:15<00:00, 289.70it/s]
100%|██████████| 4400/4400 [00:15<00:00, 283.80it/s]
100%|██████████| 4400/4400 [00:15<00:00, 284.50it/s]
100%|██████████| 4400/4400 [00:15<00:00, 281.40it/s]
100%|██████████| 4400/4400 [00:14<00:00, 298.35it/s]
100%|██████████| 4400/4400 [00:15<00:00, 279.67it/s]
100%|██████████| 4400/4400 [00:14<00:00, 299.76it/s]
100%|██████████| 4400/4400 [00:14<00:00, 311.60it/s]
100%|██████████| 4400/4400 [00:14<00:00, 309.38it/s]


#-----------------------------------------#
Free energy difference: -839346.6945415696
Free energy uncertainty: 0.15244159798833487
#-----------------------------------------#




4
Generating charmm system in vacuum
nnpops CUDA
platform='CUDA'
env='vacuum'
ml_atoms=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
/data/shared/software/python_env/anaconda3/envs/rew/lib/python3.9/site-packages/torchani/resources/
converting pickle file ...
Pickle files already converted ...
dcdplugin) Could not open file '/data/shared/projects/endstate_rew/ZINC00107550/sampling_charmmff/run01//ZINC00107550_samples_5000_steps_1000_lamb_0.0000.dcd' for writing
Number of samples loaded: 4400


100%|██████████| 4400/4400 [00:13<00:00, 330.42it/s]
100%|██████████| 4400/4400 [00:12<00:00, 350.60it/s]
100%|██████████| 4400/4400 [00:12<00:00, 344.22it/s]
100%|██████████| 4400/4400 [00:13<00:00, 335.24it/s]
100%|██████████| 4400/4400 [00:13<00:00, 337.23it/s]
100%|██████████| 4400/4400 [00:13<00:00, 322.38it/s]
100%|██████████| 4400/4400 [00:13<00:00, 317.96it/s]
100%|██████████| 4400/4400 [00:13<00:00, 335.14it/s]
100%|██████████| 4400/4400 [00:12<00:00, 349.09it/s]
100%|██████████| 4400/4400 [00:12<00:00, 355.28it/s]
100%|██████████| 4400/4400 [00:12<00:00, 350.77it/s]


#-----------------------------------------#
Free energy difference: -560516.4238925535
Free energy uncertainty: 0.09373930924631345
#-----------------------------------------#


