In [None]:
from typing import Dict, Optional, Any
import os
import pickle

import mdtraj as md
import numpy as np
import pyemma
from scipy.spatial import distance
from statsmodels.tsa import stattools
import pandas as pd
import matplotlib.pyplot as plt
import lovelyplots

plt.style.use("ipynb")

import utils as analysis_utils
from jamun import utils

### Environment Variables

These are needed to load reference MD trajectories for comparison.

In [None]:
results_dir = "/homefs/home/daigavaa/jamun/analysis_results"
print(f"results_dir: {results_dir}")

Specify which peptides to analyze:

In [3]:
filter_codes = ["FAFG"]

### Load All Trajectories

In [None]:
trajs_with_reference = [
    # [traj_name, ref_traj_name]
    # ["JAMUN", "5AA_JAMUNReference"],
    ["JAMUN", "TimewarpReference"],
    # ["TimewarpReference", "TimewarpReference"],
]

results = []
for traj_name, ref_traj_name in trajs_with_reference:
    results_path = os.path.join(results_dir, traj_name, f"ref={ref_traj_name}")
    print(f"Searching in {results_path}")

    peptide_results = {}
    for results_file in sorted(os.listdir(results_path)):
        peptide, ext = os.path.splitext(results_file)
        if ext != ".pkl":
            continue

        with open(os.path.join(results_path, results_file), "rb") as f:
            all_results = pickle.load(f)
    
        results.append({
            "traj": traj_name,
            "ref_traj": ref_traj_name,
            "peptide": peptide,
            "results": all_results["results"],
            "args": all_results["args"],
        })

results_df = pd.DataFrame(results)
results_df


In [6]:
# import mdtraj as md

# traj = md.load(
#     "/data/bucket/vanib/5AA/ALA_MET_GLU_TYR_ALA/equilNVT_0.xtc",
#     top="/data/bucket/vanib/5AA/ALA_MET_GLU_TYR_ALA/restrainedNVT_0.pdb",
# )
# nbins = 50
# _, phi = md.compute_phi(traj)
# _, psi = md.compute_psi(traj)
# num_dihedrals = phi.shape[1]
# pmf = np.zeros((num_dihedrals, nbins - 1, nbins - 1))
# xedges = np.linspace(-np.pi, np.pi, nbins)
# yedges = np.linspace(-np.pi, np.pi, nbins)

# for dihedral_index in range(num_dihedrals):
#     H, _, _ = np.histogram2d(
#         phi[:, dihedral_index], psi[:, dihedral_index],
#         bins=np.linspace(-np.pi, np.pi, nbins)
#     )
#     pmf[dihedral_index] = -np.log(H.T) + np.max(np.log(H.T))

#     fig, ax = plt.subplots()
#     im = ax.contourf(xedges[:-1], yedges[:-1], pmf[dihedral_index], cmap="viridis", levels=50)
#     contour = ax.contour(xedges[:-1], yedges[:-1], pmf[dihedral_index], colors="white", linestyles="solid", levels=30, linewidths=0.25)

#     ax.set_aspect("equal", adjustable="box")
#     ax.set_xlabel("$\phi$")
#     ax.set_ylabel("$\psi$")

#     tick_eps = 0.1
#     ticks = [-np.pi + tick_eps, -np.pi / 2, 0, np.pi / 2, np.pi - tick_eps]
#     tick_labels = ["$-\pi$", "$-\pi/2$", "$0$", "$\pi/2$", "$\pi$"]
#     ax.set_xticks(ticks, tick_labels)
#     ax.set_yticks(ticks, tick_labels)

#     plt.show()


### Ramachandran Plots

In [7]:
def plot_ramachandran_contour(results: Dict[str, Any], dihedral_index: int, ax: Optional[plt.Axes] = None) -> plt.Axes:
    """Plots the Ramachandran contour plot of a trajectory."""

    if ax is None:
        _, ax = plt.subplots(figsize=(10, 10))

    pmf, xedges, yedges = results["pmf"], results["xedges"], results["yedges"]
    im = ax.contourf(xedges[:-1], yedges[:-1], pmf[dihedral_index], cmap="viridis", levels=50)
    contour = ax.contour(xedges[:-1], yedges[:-1], pmf[dihedral_index], colors="white", linestyles="solid", levels=30, linewidths=0.25)

    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("$\phi$")
    ax.set_ylabel("$\psi$")

    tick_eps = 0.1
    ticks = [-np.pi + tick_eps, -np.pi / 2, 0, np.pi / 2, np.pi - tick_eps]
    tick_labels = ["$-\pi$", "$-\pi/2$", "$0$", "$\pi/2$", "$\pi$"]
    ax.set_xticks(ticks, tick_labels)
    ax.set_yticks(ticks, tick_labels)

    return ax


def format_for_plot(peptide: str) -> str:
    """Formats the peptide name for plotting."""
    return "".join([utils.convert_to_one_letter_code(aa) for aa in peptide])

In [None]:
n_dihedrals = 3
if n_dihedrals % 2 == 0:
    label_offset = 0.0
else:
    label_offset = 0.5

JAMUN_results_df = results_df[(results_df["traj"] == "JAMUN")]

fig, axs = plt.subplots(len(JAMUN_results_df), 2 * n_dihedrals, figsize=(32, 8), squeeze=False)
for i, row in JAMUN_results_df.iterrows():
    peptide = row["peptide"]

    for j in range(n_dihedrals):
        plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj_pmf_internal"], j, axs[i, j])
        plot_ramachandran_contour(row["results"]["PMFs"]["traj_pmf_internal"], j, axs[i, j + n_dihedrals])

    # Add labels.
    axs[0, n_dihedrals // 2].text(
        label_offset,
        1.1,
        "Reference",
        horizontalalignment="center",
        verticalalignment="center",
        transform=axs[0, n_dihedrals // 2].transAxes,
    )
    axs[0, n_dihedrals // 2 + n_dihedrals].text(
        label_offset,
        1.1,
        "JAMUN",
        horizontalalignment="center",
        verticalalignment="center",
        transform=axs[0, n_dihedrals // 2 + n_dihedrals].transAxes,
    )
    axs[i, -1].text(
        1.1,
        0.5,
        format_for_plot(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

plt.tight_layout()
plt.show()

In [None]:
topology = md.load("/data/bucket/kleinhej/jamun-runs/outputs/sample/dev/runs/52710300c5520ca8503d7d47/sampler/IDRL/topology.pdb")
print(md.geometry.indices_phi(topology), md.geometry.indices_psi(topology))

feat = pyemma.coordinates.featurizer(topology)
feat = pyemma.coordinates.featurizer("/data/bucket/kleinhej/jamun-runs/outputs/sample/dev/runs/52710300c5520ca8503d7d47/sampler/IDRL/topology.pdb")
feat.add_backbone_torsions(cossin=False)
feat.describe()

# Sort features by resSeq.

### Feature Histograms

In [None]:
JAMUN_results_df = results_df[(results_df["traj"] == "JAMUN")]

fig, axs = plt.subplots(nrows=len(JAMUN_results_df), ncols=2, figsize=(14, 8), squeeze=False)
for i, row in JAMUN_results_df.iterrows():
    peptide = row["peptide"]

    ref_results = row["results"]["featurization"]["ref_traj"]
    pyemma.plots.plot_feature_histograms(
        ref_results["traj_featurized"],
        feature_labels=ref_results["feats"].describe(),
        ax=axs[i, 0]
    )

    traj_results = row["results"]["featurization"]["traj"]
    pyemma.plots.plot_feature_histograms(
        traj_results["traj_featurized"],
        feature_labels=traj_results["feats"].describe(),
        ax=axs[i, 1]
    )

    axs[i, -1].text(
        1.1,
        0.5,
        format_for_plot(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

axs[0, 0].set_title("Reference")
axs[0, 1].set_title("JAMUN")
plt.tight_layout()

In [None]:
JAMUN_results_df = results_df[(results_df["traj"] == "JAMUN")]

fig, axs = plt.subplots(nrows=len(JAMUN_results_df), ncols=2, figsize=(14, 8), squeeze=False)
for i, row in JAMUN_results_df.iterrows():
    peptide = row["peptide"]

    ref_results = row["results"]["featurization"]["ref_traj"]
    pyemma.plots.plot_feature_histograms(
        ref_results["traj_featurized_dists"],
        feature_labels=ref_results["feats_dists"].describe(),
        ax=axs[i, 0]
    )

    traj_results = row["results"]["featurization"]["traj"]
    pyemma.plots.plot_feature_histograms(
        traj_results["traj_featurized_dists"],
        feature_labels=traj_results["feats_dists"].describe(),
        ax=axs[i, 1]
    )

    axs[i, -1].text(
        1.1,
        0.5,
        format_for_plot(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

axs[0, 0].set_title("Reference")
axs[0, 1].set_title("JAMUN")
plt.tight_layout()

In [None]:
JAMUN_results_df = results_df[(results_df["traj"] == "JAMUN")]

for i, row in JAMUN_results_df.iterrows():
    peptide = row["peptide"]
    print(f"Peptide: {peptide}")
    for feat in ["backbone", "sidechain", "all_torsions"]:
        print(feat, "JSD:", row["results"]["JSD_stats"][feat])

### TICA Analysis

In [12]:
JAMUN_results_df = results_df[(results_df["traj"] == "JAMUN")]

fig, axs = plt.subplots(nrows=len(JAMUN_results_df), ncols=2, figsize=(12, 6), squeeze=False)
for i, row in JAMUN_results_df.iterrows():
    peptide = row["peptide"]

    # Compute stats.
    print(row["results"]["TICA_stats"])

    # Plot free energy.
    ref_tica = row["results"]["TICA"]["ref_tica"]
    pyemma.plots.plot_free_energy(ref_tica[:, 0], ref_tica[:, 1], cmap="plasma", ax=axs[i, 0])
    axs[i, 0].set_title("Reference")

    traj_tica = row["results"]["TICA"]["traj_tica"]
    pyemma.plots.plot_free_energy(traj_tica[:, 0], traj_tica[:, 1], cmap="plasma", ax=axs[i, 1])
    axs[i, 1].set_title("JAMUN")

    # Set the same limits for both plots.
    axs[i, 1].set_xlim(axs[i, 0].get_xlim())
    axs[i, 1].set_ylim(axs[i, 0].get_ylim())

    axs[i, -1].text(
        1.4,
        0.5,
        format_for_plot(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

plt.tight_layout()

In [13]:
JAMUN_results_df = results_df[(results_df["traj"] == "JAMUN")]
fig, axs = plt.subplots(nrows=1, ncols=len(JAMUN_results_df), figsize=(16, 6), squeeze=False)

for i, row in JAMUN_results_df.iterrows():
    peptide = row["peptide"]
    ref_autocorr = row["results"]["autocorrelation_stats"]["ref_autocorr"]
    traj_autocorr = row["results"]["autocorrelation_stats"]["traj_autocorr"]
    axs[0, i].plot(ref_autocorr, label="Reference")
    axs[0, i].plot(traj_autocorr, label="JAMUN")
    axs[0, i].set_xlabel("Lag")
    axs[0, i].set_ylabel("Autocorrelation")
    axs[0, i].grid(linestyle='--')
    axs[0, i].text(
        0.5,
        1.05,
        format_for_plot(peptide),
        rotation=0,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[0, i].transAxes,
        fontsize="x-large",
    )

# Place legend outside plot.
plt.suptitle(f"TICA-0 Autocorrelation", fontsize="xx-large")
plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
plt.ticklabel_format(useOffset=False)
plt.tight_layout()
plt.show()

In [14]:
JAMUN_results_df = results_df[(results_df["traj"] == "JAMUN")]

for i, row in JAMUN_results_df.iterrows():
    print(f"Peptide: {row['peptide']}")
    if i == 0:
        Timewarp_label = None
        JAMUN_label = None
    else:
        Timewarp_label = "Reference"
        JAMUN_label = "JAMUN"

    steps = list(row["results"]["JSD_stats_against_time"].keys())
    js_divs = list(row["results"]["JSD_stats_against_time"].values())
    print(steps, js_divs)
    progress = steps / np.max(steps)
    # plt.plot(progress, Timewarp_js_divs[peptide], color="C0", label=Timewarp_label)
    plt.plot(progress, js_divs, color="C1", label=JAMUN_label)

plt.title("Jenson-Shannon Divergences")
plt.xlabel("Fraction of Trajectory Progress", fontsize=12)
plt.ylabel("Jenson-Shannon Divergence", fontsize=12)
plt.ticklabel_format(useOffset=False, style="plain")
plt.legend(fontsize=10)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.show()