In [185]:
from typing import Dict, Optional, Any
import os
import pickle
import collections

import mdtraj as md
import numpy as np
import scipy.stats
import pyemma
import pandas as pd
import lovelyplots
import matplotlib.pyplot as plt
plt.style.use("ipynb")

import matplotlib as mpl
mpl.rcParams['axes.formatter.useoffset'] = False
mpl.rcParams['axes.formatter.limits'] = (-10000, 10000)  # Controls range before scientific notation is used

import load_trajectory
import utils as analysis_utils
import pyemma_helper

from jamun import utils

### Paths

Load the results for the corresponding experiment, trajectories and reference trajectories.

In [None]:
# results_dir = "/homefs/home/daigavaa/jamun/analysis_results"
results_dir = "/data/bucket/kleinhej/jamun-analysis-new/"

print(f"Results directory: {results_dir}")

In [187]:
experiment = "Our_2AA"
traj_name = "JAMUN"
ref_traj_name = "JAMUNReference_2AA"

# experiment = "Our_5AA"
# traj_name = "JAMUN"
# ref_traj_name = "JAMUNReference_5AA"

# experiment = "MDGen_4AA_newer"
# traj_name = "JAMUN"
# ref_traj_name = "MDGenReference"

# experiment = "Timewarp_4AA"
# traj_name = "JAMUN"
# ref_traj_name = "TimewarpReference"

# experiment = "Timewarp_2AA"
# traj_name = "JAMUN"
# ref_traj_name = "TimewarpReference"

In [None]:
# output_dir = os.path.join(results_dir, "plots", experiment, traj_name, f"ref={ref_traj_name}")
output_dir = os.path.join("/data/bucket/kleinhej/jamun-plots", experiment, traj_name, f"ref={ref_traj_name}")
os.makedirs(output_dir, exist_ok=True)
print(f"Plots will be saved to {output_dir}")

### Load All Trajectories

In [None]:
results_path = os.path.join(results_dir, experiment, traj_name, f"ref={ref_traj_name}")
same_sampling_time_path =  os.path.join(results_dir, experiment, traj_name, f"ref={ref_traj_name}_same_sampling_time")
print(f"Searching for results in {results_path}")

In [None]:
def load_results_path(results_path: str) -> pd.DataFrame:
    """Loads the results as a pandas DataFrame."""
    results = []
    for results_file in sorted(os.listdir(results_path)):
        peptide, ext = os.path.splitext(results_file)
        if ext != ".pkl":
            continue

        with open(os.path.join(results_path, results_file), "rb") as f:
            all_results = pickle.load(f)

        results.append({
            "traj": traj_name,
            "ref_traj": ref_traj_name,
            "peptide": peptide,
            "results": all_results["results"],
            "args": all_results["args"],
        })
    return pd.DataFrame(results)


results_df = load_results_path(results_path)
if os.path.exists(same_sampling_time_path):
    same_sampling_results_df = load_results_path(same_sampling_time_path)

    # Join the two DataFrames based on the peptide column
    results_df = pd.merge(results_df, same_sampling_results_df, on=["peptide", "traj", "ref_traj"], suffixes=("", "_same_sampling_time"))

results_df

In [None]:
# Filter based on peptide names.
peptides = ["AP", "EW", "HK", "LY", "MY", "SD", "FF", "VV"]
peptides = ["_".join([utils.convert_to_three_letter_code(aa) for aa in peptide]) for peptide in peptides]
print(peptides)

results_df = results_df[results_df["peptide"].isin(peptides)]
results_df

In [None]:
# Pick some rows randomly.
sampled_results_df = results_df.sample(n=5, random_state=42)
sampled_results_df = sampled_results_df.reset_index(drop=True)
sampled_results_df

### Ramachandran Plots

In [193]:
def plot_ramachandran_contour(results: Dict[str, Any], dihedral_index: int, ax: Optional[plt.Axes] = None) -> plt.Axes:
    """Plots the Ramachandran contour plot of a trajectory."""

    if ax is None:
        _, ax = plt.subplots(figsize=(10, 10))

    pmf, xedges, yedges = results["pmf"], results["xedges"], results["yedges"]
    im = ax.contourf(xedges[:-1], yedges[:-1], pmf[dihedral_index], cmap="viridis", levels=50)
    contour = ax.contour(xedges[:-1], yedges[:-1], pmf[dihedral_index], colors="white", linestyles="solid", levels=30, linewidths=0.25)

    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("$\phi$")
    ax.set_ylabel("$\psi$")

    tick_eps = 0.1
    ticks = [-np.pi + tick_eps, -np.pi / 2, 0, np.pi / 2, np.pi - tick_eps]
    tick_labels = ["$-\pi$", "$-\pi/2$", "$0$", "$\pi/2$", "$\pi$"]
    ax.set_xticks(ticks, tick_labels)
    ax.set_yticks(ticks, tick_labels)

    return ax


def format_for_plot(peptide: str) -> str:
    """Formats the peptide name for plotting."""
    if peptide.startswith("uncapped_"):
        peptide = peptide[len("uncapped_"):]
    if "_" in peptide:
        return peptide.replace("_", "-")
    return "".join([utils.convert_to_one_letter_code(aa) for aa in peptide])

In [None]:
if os.path.exists(same_sampling_time_path):
    # "internal" for psi_2 - phi_2, psi_3 - phi_3, etc.
    # "all" for psi_1 - phi_2, psi_2 - phi_3, etc.
    pmf_type = "all"

    if experiment == "Our_2AA":
        num_dihedrals = 1
    elif "2AA" in experiment:
        num_dihedrals = 0
    elif "4AA" in experiment:
        num_dihedrals = 2
    elif "5AA" in experiment:
        num_dihedrals = 3

    if pmf_type == "all":
        num_dihedrals += 1

    if num_dihedrals % 2 == 0:
        label_offset = -0.5
    else:
        label_offset = 0.5

    fig, axs = plt.subplots(len(sampled_results_df), 3 * num_dihedrals, figsize=(12 * num_dihedrals, 4 * len(sampled_results_df)), squeeze=False)
    for i, row in sampled_results_df.iterrows():
        peptide = row["peptide"]

        for j in range(num_dihedrals):
            plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj"][f"pmf_{pmf_type}"], j, axs[i, j])
            plot_ramachandran_contour(row["results"]["PMFs"]["traj"][f"pmf_{pmf_type}"], j, axs[i, j + num_dihedrals])
            plot_ramachandran_contour(row["results_same_sampling_time"]["PMFs"]["ref_traj"][f"pmf_{pmf_type}"], j, axs[i, j + 2 * num_dihedrals])

        # Add labels.
        axs[0, num_dihedrals // 2].text(
            label_offset,
            1.1,
            "Reference",
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, num_dihedrals // 2].transAxes,
            fontsize=12,
        )
        axs[0, num_dihedrals // 2 + num_dihedrals].text(
            label_offset,
            1.1,
            "JAMUN",
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, num_dihedrals // 2 + num_dihedrals].transAxes,
            fontsize=12,
        )
        axs[0, num_dihedrals // 2 + 2 * num_dihedrals].text(
            label_offset,
            1.1,
            "Reference (Benchmark)",
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, num_dihedrals // 2 + 2 * num_dihedrals].transAxes,
            fontsize=12,
        )
        axs[i, -1].text(
            1.1,
            0.5,
            format_for_plot(peptide),
            rotation=90,
            verticalalignment="center",
            horizontalalignment="center",
            transform=axs[i, -1].transAxes,
        )

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "ramachandran_contours.pdf"), dpi=300)
    plt.show()

In [None]:
if not os.path.exists(same_sampling_time_path):
    # "internal" for psi_2 - phi_2, psi_3 - phi_3, etc.
    # "all" for psi_1 - phi_2, psi_2 - phi_3, etc.
    pmf_type = "all"

    if experiment == "Our_2AA":
        num_dihedrals = 1
    elif "2AA" in experiment:
        num_dihedrals = 0
    elif "4AA" in experiment:
        num_dihedrals = 2
    elif "5AA" in experiment:
        num_dihedrals = 3

    if pmf_type == "all":
        num_dihedrals += 1

    if num_dihedrals % 2 == 0:
        label_offset = -0.5
    else:
        label_offset = 0.5

    fig, axs = plt.subplots(len(sampled_results_df), 2 * num_dihedrals, figsize=(8 * num_dihedrals, 4 * len(sampled_results_df)), squeeze=False)
    for i, row in sampled_results_df.iterrows():
        peptide = row["peptide"]

        for j in range(num_dihedrals):
            plot_ramachandran_contour(row["results"]["PMFs"]["ref_traj"][f"pmf_{pmf_type}"], j, axs[i, j])
            plot_ramachandran_contour(row["results"]["PMFs"]["traj"][f"pmf_{pmf_type}"], j, axs[i, j + num_dihedrals])

        # Add labels.
        axs[0, num_dihedrals // 2].text(
            label_offset,
            1.1,
            "Reference",
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, num_dihedrals // 2].transAxes,
            fontsize=12,
        )
        axs[0, num_dihedrals // 2 + num_dihedrals].text(
            label_offset,
            1.1,
            "JAMUN",
            horizontalalignment="center",
            verticalalignment="center",
            transform=axs[0, num_dihedrals // 2 + num_dihedrals].transAxes,
            fontsize=12,
        )
        axs[i, -1].text(
            1.1,
            0.5,
            format_for_plot(peptide),
            rotation=90,
            verticalalignment="center",
            horizontalalignment="center",
            transform=axs[i, -1].transAxes,
        )

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "ramachandran_contours.pdf"), dpi=300)
    plt.show()

### Feature Histograms

In [None]:
fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(14, 4 * len(sampled_results_df)), squeeze=False)
for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]

    feats = row["results"]["featurization"]
    histograms = row["results"]["feature_histograms"]

    pyemma_helper.plot_feature_histograms(
        histograms["ref_traj"]["torsions"]["histograms"],
        histograms["ref_traj"]["torsions"]["edges"],
        feature_labels=feats["ref_traj"]["feats"]["torsions"].describe(),
        ax=axs[i, 0]
    )

    pyemma_helper.plot_feature_histograms(
        histograms["traj"]["torsions"]["histograms"],
        histograms["traj"]["torsions"]["edges"],    
        feature_labels=feats["traj"]["feats"]["torsions"].describe(),
        ax=axs[i, 1]
    )

    axs[i, -1].text(
        1.1,
        0.5,
        format_for_plot(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

axs[0, 0].set_title("Reference")
axs[0, 1].set_title("JAMUN")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "feature_histograms.pdf"), dpi=300)
plt.show()

In [None]:
fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(14, 4 * len(sampled_results_df)), squeeze=False)
for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]

    feats = row["results"]["featurization"]
    histograms = row["results"]["feature_histograms"]

    pyemma_helper.plot_feature_histograms(
        histograms["ref_traj"]["distances"]["histograms"][:10],
        histograms["ref_traj"]["distances"]["edges"][:10],
        feature_labels=feats["ref_traj"]["feats"]["distances"].describe()[:10],
        ax=axs[i, 0]
    )

    pyemma_helper.plot_feature_histograms(
        histograms["traj"]["distances"]["histograms"][:10],
        histograms["traj"]["distances"]["edges"][:10],    
        feature_labels=feats["traj"]["feats"]["distances"].describe()[:10],
        ax=axs[i, 1]
    )

    axs[i, -1].text(
        1.1,
        0.5,
        format_for_plot(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

axs[0, 0].set_title("Reference")
axs[0, 1].set_title("JAMUN")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "distance_histograms.pdf"), dpi=300)
plt.show()

In [None]:
all_JSDs = collections.defaultdict(list)
for i, row in results_df.iterrows():
    peptide = row["peptide"]
    for feat in ["backbone_torsions", "sidechain_torsions", "all_torsions"]:
        if "JSD_torsion_stats" in row["results"]:
            key = "JSD_torsion_stats"
        else:
            key = "JSD_stats"
        all_JSDs[feat].append(row["results"][key][feat])

for feat in all_JSDs:
    print(feat, "mean JSD:", np.mean(all_JSDs[feat]))

# Save as a text file.
with open(os.path.join(output_dir, "JSDs.txt"), "w") as f:
    for feat in all_JSDs:
        f.write(f"Mean JSD {feat}: {np.mean(all_JSDs[feat])}\n")


### TICA Analysis

In [None]:
fig, axs = plt.subplots(nrows=len(sampled_results_df), ncols=2, figsize=(12, 3 * len(sampled_results_df)), squeeze=False)
for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]
    results = row["results"]["TICA_stats"]["TICA-0,1 histograms"]

    # Plot free energy.
    ref_traj_tica = results["ref_traj"]
    pyemma_helper.plot_free_energy(*ref_traj_tica, cmap="plasma", ax=axs[i, 0])
    axs[i, 0].set_title("Reference")
    axs[i, 0].ticklabel_format(useOffset=False, style="plain")

    traj_tica = results["traj"]
    pyemma_helper.plot_free_energy(*traj_tica, cmap="plasma", ax=axs[i, 1])
    axs[i, 1].set_title("JAMUN")
    axs[i, 1].ticklabel_format(useOffset=False, style="plain")

    # Set the same limits for both plots.
    axs[i, 1].set_xlim(axs[i, 0].get_xlim())
    axs[i, 1].set_ylim(axs[i, 0].get_ylim())
    axs[i, -1].text(
        1.4,
        0.5,
        format_for_plot(peptide),
        rotation=90,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[i, -1].transAxes,
    )

plt.suptitle("TICA-0,1 Projections", fontsize="x-large")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "tica_projections.pdf"), dpi=300)
plt.show()

In [None]:

fig, axs = plt.subplots(nrows=1, ncols=len(sampled_results_df), figsize=(8 * len(sampled_results_df), 8), squeeze=False)

for i, row in sampled_results_df.iterrows():
    peptide = row["peptide"]
    results = row["results"]["autocorrelation_stats"]

    ref_autocorr = results["ref_autocorr"]
    traj_autocorr = results["traj_autocorr"]
    axs[0, i].plot(ref_autocorr, label="Reference")
    axs[0, i].plot(traj_autocorr, label="JAMUN")
    axs[0, i].set_xlabel("Lag")
    axs[0, i].set_ylabel("Autocorrelation")
    axs[0, i].text(
        0.5,
        1.04,
        format_for_plot(peptide),
        rotation=0,
        verticalalignment="center",
        horizontalalignment="center",
        transform=axs[0, i].transAxes,
        fontsize=20,
    )

# Place legend outside plot.
plt.suptitle(f"TICA-0 Autocorrelation", fontsize=32)
plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
plt.ticklabel_format(useOffset=False)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "tica_autocorrelation.pdf"), dpi=300)
plt.show()

### JSD Against Time

In [None]:
fig, axs = plt.subplots(ncols=2, sharey=True, sharex=True)

for i, row in sampled_results_df.iterrows():    

    if "JSD_torsion_stats_against_time" in row["results"]:
        key = "JSD_torsion_stats_against_time"
    else:
        key = "JSD_stats_against_time"

    results = row["results"][key]
    ref_traj_results = results["ref_traj"]
    ref_traj_results = {key: result["backbone_torsions"] for key, result in ref_traj_results.items()}
    
    ref_traj_steps = np.array(list(ref_traj_results.keys()))
    ref_traj_js_divs = np.array(list(ref_traj_results.values()))
    ref_traj_progress = ref_traj_steps / np.max(ref_traj_steps)
    
    traj_results = results["traj"]
    traj_results = {key: result["backbone_torsions"] for key, result in traj_results.items()}
    
    traj_steps = np.array(list(traj_results.keys()))
    traj_js_divs = np.array(list(traj_results.values()))
    traj_progress = traj_steps / np.max(traj_steps)

    # Load sampling rates.
    traj_samples_per_sec = load_trajectory.get_sampling_rate(traj_name, peptide, experiment)
    ref_traj_samples_per_sec = load_trajectory.get_sampling_rate(ref_traj_name, peptide, experiment)

    traj_frames = np.max(traj_steps)
    ref_traj_frames = np.max(ref_traj_steps)

    traj_time = traj_samples_per_sec * traj_frames
    ref_traj_time = ref_traj_samples_per_sec * ref_traj_frames

    factor = min(traj_time / ref_traj_time, 1)
    ref_traj_frames_new = int(ref_traj_frames * factor)

    ref_traj_results_new = {key: val for key, val in ref_traj_results.items() if key <= ref_traj_frames_new}
    ref_traj_steps_new = np.array(list(ref_traj_results_new.keys()))
    ref_traj_js_divs_new = np.array(list(ref_traj_results_new.values()))
    ref_traj_progress_new = ref_traj_steps_new / np.max(ref_traj_steps_new)

    if i == len(sampled_results_df) - 1:
        ref_label = "Reference"
        traj_label = "JAMUN"
    else:
        ref_label = None
        traj_label = None

    axs[0].plot(ref_traj_progress_new, ref_traj_js_divs_new, color="C0", label=ref_label)
    axs[0].plot(traj_progress, traj_js_divs, color="C1", label=traj_label)
    axs[0].set_xscale("log")

    axs[1].plot(ref_traj_progress, ref_traj_js_divs, color="C0", label=ref_label)
    axs[1].plot(traj_progress, traj_js_divs, color="C1", label=traj_label)
    axs[1].set_xscale("log")

fig.suptitle("JSD on Backbone Torsions")
axs[0].set_xlabel("Fraction of Trajectory Progress", fontsize=12)
axs[0].set_ylabel("Jenson-Shannon Distance", fontsize=12)
axs[0].legend(fontsize=10)
axs[1].legend(fontsize=10)
plt.savefig(os.path.join(output_dir, "jsd_against_time.pdf"), dpi=300)
plt.show()

### MSM Analysis

In [None]:
all_ref_metastable_probs = []
all_traj_metastable_probs = []

for i, row in results_df.iterrows():    
    results = row["results"]["MSM_stats"]
    ref_metastable_probs = results["ref_metastable_probs"]
    traj_metastable_probs = results["traj_metastable_probs"]
    
    all_ref_metastable_probs.append(ref_metastable_probs)
    all_traj_metastable_probs.append(traj_metastable_probs)

all_ref_metastable_probs = np.concatenate(all_ref_metastable_probs)
all_traj_metastable_probs = np.concatenate(all_traj_metastable_probs)

# Scatter plot of probabilities.
plt.scatter(all_ref_metastable_probs, all_traj_metastable_probs, alpha=0.3, edgecolors="none")

# Fit line.
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(
    all_ref_metastable_probs, all_traj_metastable_probs
)

# Create x points for line.
x_line = np.array([-0.5, 1.5])
y_line = slope * x_line + intercept

# Plot the fitted line with dashed style.
plt.plot(x_line, y_line, color='red', linestyle='--')
plt.text(0.45, 0.90, f'R² = {r_value**2:.3f}', transform=plt.gca().transAxes, color='red')

plt.title("Metastable State Probabilities")
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.xlabel("Reference")
plt.ylabel("JAMUN")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "metastable_probs.pdf"), dpi=300)
plt.show()

In [None]:
JSD_msms = []
for i, row in results_df.iterrows():    
    results = row["results"]["MSM_stats"]    
    JSD_msms.append(results["JSD_metastable_probs"])

JSD_msms = np.array(JSD_msms)
print("Mean JSD MSM:", np.mean(JSD_msms))

# Save as a text file.
with open(os.path.join(output_dir, "JSDs.txt"), "a") as f:
    f.write(f"Mean JSD MSM: {np.mean(JSD_msms)}\n")

plt.hist(JSD_msms)
plt.title("Jenson-Shannon Distances of Metastable State Probabilities")
plt.xlabel("JSD")
plt.ylabel("Frequency")
plt.ticklabel_format(useOffset=False, style="plain")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "jsd_metastable_probs.pdf"), dpi=300)
plt.show()

In [None]:
for i, row in results_df.iterrows():    

    if "JSD_MSM_stats_against_time" not in row["results"]:
        continue

    results = row["results"]["JSD_MSM_stats_against_time"]
    ref_traj_results = results["ref_traj"]
    ref_traj_steps = np.array(list(ref_traj_results.keys()))
    ref_traj_js_divs = np.array(list(ref_traj_results.values()))
    ref_traj_progress = ref_traj_steps / np.max(ref_traj_steps)
    
    traj_results = results["traj"]    
    traj_steps = np.array(list(traj_results.keys()))
    traj_js_divs = np.array(list(traj_results.values()))
    traj_progress = traj_steps / np.max(traj_steps)
    
    if i == len(results_df) - 1:
        ref_label = "Reference"
        traj_label = "JAMUN"
    else:
        ref_label = None
        traj_label = None
    
    plt.plot(ref_traj_progress, ref_traj_js_divs, color="C0", label=ref_label)
    plt.plot(traj_progress, traj_js_divs, color="C1", label=traj_label)

plt.title("Jenson-Shannon Distances on Metastable State Probabilities")
plt.xlabel("Fraction of Trajectory Progress", fontsize=12)
plt.ylabel("Jenson-Shannon Distance", fontsize=12)
plt.ticklabel_format(useOffset=False, style="plain")
plt.legend(fontsize=10)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.savefig(os.path.join(output_dir, "jsd_msms_against_time.pdf"), dpi=300)
plt.show()