In [1]:
from typing import  Dict, Any
import os

import mdtraj as md
import wandb
import pyemma
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import distance 
from statsmodels.tsa import stattools

In [2]:
def get_wandb_run_config(wandb_run_path: str) -> Dict[str, Any]:
    """Get the wandb run config."""
    run = wandb.Api().run(wandb_run_path)
    print(f"Loading checkpoint corresponding to wandb run {run.name} at {run.url}")
    return run.config["cfg"]

In [3]:
wandb_sample_run_path = "prescient-design/wjpp/fmralmpb"

In [4]:
cfg = get_wandb_run_config(wandb_sample_run_path)
cfg

True


Loading checkpoint corresponding to wandb run lemon-butterfly-1481 at https://genentech.wandb.io/prescient-design/wjpp/runs/fmralmpb


{'m': 1,
 'seed': 42,
 'delta': 0.06,
 'model': {'_target_': 'wjpp.model.e3NoiseConditionedScoreModel.load_from_checkpoint',
  'checkpoint_path': None},
 'paths': {'run_path': '/data/bucket/wjpp//outputs/sample/dev/runs/dc712bf33e7377ae36204cac',
  'data_path': '/data/bucket/wjpp//data',
  'root_path': '/data/bucket/wjpp/',
  'run_path_s3': 's3://prescient-data-dev/kleinhej/wjpp/outputs/sample/dev/runs/dc712bf33e7377ae36204cac',
  'root_path_s3': 's3://prescient-data-dev/kleinhej/wjpp'},
 'sigma': 0.06,
 'logger': {'wandb': {'tags': ['867192', 'sample'],
   'group': 'sampling_tetra',
   'entity': None,
   'offline': False,
   'project': 'wjpp',
   '_target_': 'lightning.pytorch.loggers.WandbLogger'}},
 'run_key': 'dc712bf33e7377ae36204cac',
 'sampler': {'devices': 1,
  '_target_': 'wjpp.sampling.SamplerMD',
  '_convert_': 'partial',
  'precision': '32-true'},
 'trainer': {'devices': 1, 'num_nodes': 1, 'limit_train_batches': 1},
 'friction': 1,
 'root_dir': '.',
 'callbacks': {'timing':

In [5]:
peptide = "AMIG"

In [6]:
figures_dir = f"mdgen_analysis/{peptide}"
os.makedirs(figures_dir, exist_ok=True)

In [7]:
output_dir = cfg["callbacks"]["save_trajectory"]["output_dir"] + f"/{peptide}"
ref_pdb = f"{output_dir}/true_samples/pdb/0.pdb"
ref_dcd = f"{output_dir}/true_samples/dcd/0.dcd"
pred_pdb = f"{output_dir}/predicted_samples/pdb/joined.pdb"
pred_dcd = f"{output_dir}/predicted_samples/dcd/joined.dcd"

print(pred_pdb)

/data/bucket/wjpp//outputs/sample/dev/runs/dc712bf33e7377ae36204cac/saved//AMIG/predicted_samples/pdb/joined.pdb


In [8]:
def get_featurized_traj(pdb_file: str, dcd_file: str, cossin: bool = False):
    feat = pyemma.coordinates.featurizer(pdb_file)
    feat.add_backbone_torsions(cossin=cossin)
    feat.add_sidechain_torsions(cossin=cossin)
    traj = pyemma.coordinates.load(dcd_file, features=feat)
    return feat, traj


In [9]:
ref_feat, ref_traj = get_featurized_traj(ref_pdb, ref_dcd)
feat, traj = get_featurized_traj(pred_pdb, pred_dcd)

  indices = np.vstack(valid.values())


In [10]:
traj = traj[:len(ref_traj)]

In [11]:
ref_feat.describe()

['PHI 0 MET 2',
 'PSI 0 ALA 1',
 'PHI 0 ILE 3',
 'PSI 0 MET 2',
 'PHI 0 GLY 4',
 'PSI 0 ILE 3',
 'CHI1 0 MET 2',
 'CHI1 0 ILE 3',
 'CHI2 0 MET 2',
 'CHI2 0 ILE 3',
 'CHI3 0 MET 2']

In [12]:
assert ref_feat.describe() == feat.describe()

In [13]:
ref_traj.shape, traj.shape

((50000, 11), (50000, 11))

In [None]:
for dihedral in range(3):
    pyemma.plots.plot_density(traj[:, 2 * dihedral], traj[:, 2 * dihedral + 1])
    if dihedral == 1:
        plt.title("JAMUN", fontsize="xx-large")
    plt.savefig(f"{figures_dir}/density_plot_{dihedral}.png", dpi=500)
    
    pyemma.plots.plot_density(ref_traj[:, 2 * dihedral], ref_traj[:, 2 * dihedral + 1])
    if dihedral == 1:
        plt.title("Reference MD", fontsize="xx-large")
    plt.savefig(f"{figures_dir}/density_plot_ref_{dihedral}.png", dpi=500)

In [None]:
for dihedral in range(3):
    pyemma.plots.plot_free_energy(traj[:, 2 * dihedral], traj[:, 2 * dihedral + 1], cmap="inferno")
    if dihedral == 1:
        plt.title("JAMUN", fontsize="xx-large")
    plt.savefig(f"{figures_dir}/free_energy_plot_{dihedral}.png", dpi=500)
    
    pyemma.plots.plot_free_energy(ref_traj[:, 2 * dihedral], ref_traj[:, 2 * dihedral + 1], cmap="inferno")
    if dihedral == 1:
        plt.title("Reference MD", fontsize="xx-large")
    plt.savefig(f"{figures_dir}/free_energy_plot_ref_{dihedral}.png", dpi=500)

In [None]:
pyemma.plots.plot_feature_histograms(traj, feature_labels=feat)
plt.title("JAMUN", fontsize="xx-large")
plt.tight_layout()
plt.savefig(f"{figures_dir}/feature_histograms.png", dpi=500)

pyemma.plots.plot_feature_histograms(ref_traj, feature_labels=ref_feat)
plt.title("Reference MD", fontsize="xx-large")
plt.tight_layout()
plt.savefig(f"{figures_dir}/feature_histograms_ref.png", dpi=500)

In [None]:
for i, feat in enumerate(ref_feat.describe()):
    ref_p = np.histogram(ref_traj[:,i], range=(-np.pi, np.pi), bins=100)[0]
    traj_p = np.histogram(traj[:,i], range=(-np.pi, np.pi), bins=100)[0]
    print(feat, "JSD:", distance.jensenshannon(ref_p, traj_p))


In [None]:
tica = pyemma.coordinates.tica(traj, lag=1000, kinetic_map=True)
ref_tica = tica.transform(ref_traj)
traj_tica = tica.transform(traj)

In [None]:
plt.plot(np.abs(tica.eigenvalues), marker='o')
plt.xlabel("index")
plt.ylabel("|eigenvalue|")
plt.title("Absolute values of TICA eigenvalues")
plt.show()

In [None]:
tica_0_min = min(ref_tica[:,0].min(), traj_tica[:,0].min())
tica_0_max = max(ref_tica[:,0].max(), traj_tica[:,0].max())

tica_1_min = min(ref_tica[:,1].min(), traj_tica[:,1].min())
tica_1_max = max(ref_tica[:,1].max(), traj_tica[:,1].max())

ref_p = np.histogram(ref_tica[:,0], range=(tica_0_min, tica_0_max), bins=100)[0]
traj_p = np.histogram(traj_tica[:,0], range=(tica_0_min, tica_0_max), bins=100)[0]
print("TICA-0 JS", distance.jensenshannon(ref_p, traj_p))

ref_p = np.histogram2d(*ref_tica[:,:2].T, range=((tica_0_min, tica_0_max),(tica_1_min, tica_1_max)), bins=50)[0]
traj_p = np.histogram2d(*traj_tica[:,:2].T, range=((tica_0_min, tica_0_max),(tica_1_min, tica_1_max)), bins=50)[0]
print("TICA-0,1 JSD", distance.jensenshannon(ref_p.flatten(), traj_p.flatten()))


In [None]:
pyemma.plots.plot_free_energy(traj_tica[:,0], traj_tica[:,1], cmap="plasma")
plt.title("JAMUN", fontsize="xx-large")
plt.savefig(f"{figures_dir}/tica_01.png", dpi=500)

pyemma.plots.plot_free_energy(ref_tica[:,0], ref_tica[:,1], cmap="plasma")
plt.title("Reference MD", fontsize="xx-large")
plt.savefig(f"{figures_dir}/tica_01_ref.png", dpi=500)

In [None]:
nlag = 1000
ref_autocorr = stattools.acovf(ref_tica[:,0], nlag=nlag, adjusted=True, demean=False)
traj_autocorr = stattools.acovf(traj_tica[:,0], nlag=nlag, adjusted=True, demean=False)

plt.plot(ref_autocorr, label="Reference MD")
plt.plot(traj_autocorr, label="JAMUN")
plt.title(f"TICA-0 Autocorrelation")
plt.xlabel("Lag")
plt.ylabel("Autocorrelation")
plt.legend()
plt.show()
