In [None]:
proteins = [
# 'chignolin',
# 'trpcage', 
# 'bba',
# 'wwdomain',
# 'villin',
# 'ntl9',
# 'bbl',
'proteinb',
# 'homeodomain',
# 'proteing',
# 'a3D',
# 'lambda',
]

# ids = ['opgxy06w', 'mwxf79u9', 'zr3x0ui1', 'k76lu3f8']
ids = ['xqm3fyio']

# models = ['k76lu3f8']
# models = ['zcun2wc3']
mode = 'crystal'
# mode = 'default'
checkpoint = -1

In [None]:
from learnax.registry import Registry
from tqdm import tqdm
from collections import defaultdict
import pickle 
import sys
sys.path.append('..')

from data.fast_folding import FastFoldingDataset


TRAINAX_REGISTRY_PATH = '/mas/projects/molecularmachines/experiments/trainax'

# model_trajs, model_configs, model_metrics = [], [], []

metrics = defaultdict(dict)
ref_metrics = dict()
trajs = defaultdict(dict)

dataset = FastFoldingDataset(proteins=proteins)

for protein in tqdm(proteins):
    ref_metrics[protein] = dataset.get_metrics(protein)

models = []
for model_idx, model in enumerate(ids):
    run = Registry('deepjump', base_path=TRAINAX_REGISTRY_PATH).fetch_run(model)
    cfg = run.get_config()
    models.append(cfg.data.temperatures[-1])

    for protein in tqdm(proteins):
        protein_trajs = dict()

        # protein_metrics[] = 
        trajs[protein][models[-1]] = run.read_all(f'samples/{checkpoint}/{protein}/{mode}/*_traj.pyd')
        # protein_trajs = {file.split('/')[-1]: traj for file, traj in protein_trajs.items()}
        
        # metrics[protein][cfg.data.temperatures[-1]] = run.read(f'samples/{checkpoint}/{protein}/{mode}/tica_metrics.pyd')
        # trajs[protein] = protein_trajs

# ref_metrics = run.read(f'samples/{checkpoint}/{protein}/ref_metrics.pyd')
# ref_metrics = [{k: v} for k in ref_metrics.keys() for v in ref_metrics[k]]
# model_trajs = model_trajs[:3]


In [None]:
trajs[proteins[0]][models

In [None]:
trajs

In [None]:
metrics['bbl'][450]['clusters'].keys()

### Plot Markov Models and Timescales

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np 
import deeptime


for protein in proteins:

    ref_tics = np.concatenate(ref_metrics[protein]['feats'], axis=0).reshape(-1, 2)

    fig, axes = plt.subplots(2, 3, figsize=(8, 5))

    H, xbins, ybins = np.histogram2d(ref_tics[:, 0], ref_tics[:, 1], bins=40)
    H = -np.log(H)
    H -= np.min(H)
    k = 32

    # axes[0].imshow(H.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], cmap=plt.cm.Spectral_r, aspect='auto', origin='lower', alpha=0.4)
    centers = ref_metrics[protein]['clusters'][k]['kmeans'].cluster_centers

    
    for i, msms in enumerate([ref_metrics[protein]['clusters'][k]['msms'], metrics[protein][models[-1]]['clusters'][k]['msms']]):
        msm = msms[-1]
        axes[i, 0].contour(H.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], levels=5, alpha=0.2, colors='black', linewidths=0.7, linestyles='solid')

        pi = msm.stationary_distribution
        colors = plt.cm.Spectral(pi / np.max(pi))

        fluxes = msm.transition_matrix * msm.stationary_distribution[:, None]
        fluxes = fluxes.flatten()
        fluxes.sort()
        min_flux = fluxes[7 * len(fluxes) // 10].item()

        # plot clusters and matrix
        deeptime.plots.plot_markov_model(
            msm,
            ax=axes[i, 0],
            pos=centers,
            edge_labels=None,
            edge_scale=0.01,
            state_colors=colors,
            size=0,
            minflux=min_flux,
        )

        for patch in axes[i, 0].patches:
            if hasattr(patch, 'set_width'):
                r = patch.get_radius()
                patch.set_width(np.abs(xbins[1] - xbins[0]) * r * 20)
                patch.set_height(np.abs(ybins[1] - ybins[0]) * r * 20)
                patch.set_alpha(0.7)
            else:
                patch.set_color('darkgray')

        helix_tics = ref_metrics[protein]['helix_tics']
        axes[i, 0].plot(helix_tics[0], helix_tics[1], '^', alpha=1.0, markersize=6, color='black')
        crystal_tics = ref_metrics[protein]['crystal_tics']
        axes[i, 0].plot(crystal_tics[0], crystal_tics[1], 'v', alpha=1.0, markersize=6, color='black')

        print('helix_cluster', ref_metrics[protein]['clusters'][k]['helix_cluster'])
        # plot transition matrix
        # print(msm.transition_matrix.shape)
        print(msm.__dict__.keys())
        axes[i, 1].imshow(msm.transition_matrix, vmin=0, vmax=1, cmap='Greys')

        fig.suptitle(protein)

        # plot timescales
        from deeptime.util.validation import implied_timescales
        timescales = implied_timescales(msms)
        deeptime.plots.plot_implied_timescales(timescales, ax=axes[i, 2], n_its=3)
    # axes[2]
    # deeptime.util.validation.ImpliedTimescales()

    plt.show()

In [None]:
import numpy as np
import pandas as pd 

def kl_div(hist1: np.ndarray, hist2: np.ndarray, reduce=True) -> float:
    hist2 = np.where(hist2 == 0, 1, hist2)
    div = hist1 * np.nan_to_num(np.log(hist1 / hist2))
    div = np.where(hist1 == 0, 0, div)
    if reduce: return np.sum(div)
    else: return div

def js_div(hist1, hist2):
    m = 0.5 * (hist1 + hist2)
    return 0.5 * kl_div(hist1, m) + 0.5 * kl_div(hist2, m)

# MAKE A TABLE

from deeptime.util.validation import implied_timescales

columns = ['π (JS)', 'P (JS)', 'T1 (ns)', 'T1 error', 'MFPT (ns)', 'MFPT Error', 'ΔG', 'ΔG Error', 'Lifetime (ns)', 'Lifetime Error']


def compute_fe_diff(msm):
    crystal_fe = - np.log(msm.stationary_distribution[crystal_cluster[0]] / (1 - msm.stationary_distribution[crystal_cluster[0]]))
    helix_fe = - np.log(msm.stationary_distribution[helix_cluster[0]] / (1 - msm.stationary_distribution[helix_cluster[0]]))
    return helix_fe - crystal_fe 


dfs = []

for protein in proteins:

    # ref_tics = np.concatenate(ref_metrics[protein]['feats'], axis=0).reshape(-1, 2)
    # H, xbins, ybins = np.histogram2d(ref_tics[:, 0], ref_tics[:, 1], bins=40)
    # H = -np.log(H)
    # H -= np.min(H)
    k = 32

    df = pd.DataFrame(columns=columns, index=['reference'] + models)

    ref_msms = ref_metrics[protein]['clusters'][k]['msms']
    pi_ref = ref_msms[-1].stationary_distribution
    P_ref = ref_msms[-1].transition_matrix

    timescales_ref = implied_timescales(ref_msms).timescales_for_process(0)[-1]
    helix_cluster = ref_metrics[protein]['clusters'][k]['helix_cluster']
    crystal_cluster = ref_metrics[protein]['clusters'][k]['crystal_cluster']
    reference_mftp = ref_msms[-1].mfpt(helix_cluster, crystal_cluster)
    ref_lifetime = (-ref_msms[-1].lagtime / np.log(np.diag(ref_msms[-1].transition_matrix)))[crystal_cluster[0]]
    ref_fe = compute_fe_diff(ref_msms[-1])

    df.loc['reference'] = [None, None, timescales_ref, None, reference_mftp, None, ref_fe, None, ref_lifetime, None]

    for model_idx, model in enumerate(models):
        # compute JS for stationary distribution
        model_msms = metrics[protein][model]['clusters'][k]['msms']
        ps = []
    
        pi = metrics[protein][model]['clusters'][k]['msms'][-1].stationary_distribution

        P = metrics[protein][model]['clusters'][k]['msms'][-1].transition_matrix

        pi_js = js_div(pi, pi_ref)
        P_js = np.mean([js_div(P_, P_ref_) for P_, P_ref_ in zip(P, P_ref)])
        ps.extend([pi_js, P_js])

        timescales = implied_timescales(model_msms).timescales_for_process(0)[-1]

        ps.append(timescales)
        ps.append(np.abs(timescales - timescales_ref))

        model_mftp = model_msms[-1].mfpt(helix_cluster, crystal_cluster)
        ps.append(model_mftp)

        ps.append(np.abs(reference_mftp - model_mftp))
        model_fe = compute_fe_diff(model_msms[-1])
        ps.append(model_fe)
        ps.append(np.abs(model_fe - ref_fe))

        model_lifetime = (-model_msms[-1].lagtime / np.log(np.diag(model_msms[-1].transition_matrix)))[crystal_cluster[0]]

        ps.append(model_lifetime)
        ps.append(np.abs(ref_lifetime - model_lifetime))

        df.loc[model] = ps
    
    cols_to_highlight = ['π (JS)', 'P (JS)', 'T1 error', 'MFPT Error', 'ΔG Error', 'Lifetime Error']

    df.index.name = protein

    display(
        df.style.highlight_min(subset=cols_to_highlight, props='font-weight: bold;').format(precision=2)
    )

    dfs.append(df)

# make a mean df
dfs = pd.concat(dfs)
df = dfs.groupby(level=0).mean()
# drop the columns T1 MFPT
df = df.drop(columns=['T1 (ns)', 'MFPT (ns)', 'Lifetime (ns)', 'ΔG'])
# and reference
df = df.drop(index=['reference'])

display(
    df.style.highlight_min(subset=cols_to_highlight, props='font-weight: bold;').format(precision=2)
)



In [None]:
import matplotlib.pyplot as plt

for col in df.columns:
    plt.figure(figsize=(4, 2))
    plt.plot(df.index, df[col], label=col)
    plt.title(col)    
    plt.show()

### Plot TICA Free Energy

In [None]:
import matplotlib.pyplot as plt
import numpy as np 

N_BINS = 40
N_COUNTOURS = 15

for protein in proteins:
    ref_tics = np.concatenate(ref_metrics[protein]['feats'], axis=0)
    ref_weights = ref_metrics[protein]['clusters'][128]['weights']

    fig, axes = plt.subplots(1, len(models) + 1, figsize=(3 * (len(models) + 1), 3.7), dpi=80)

    H, xbins, ybins = np.histogram2d(ref_tics[:, 0], ref_tics[:, 1], bins=N_BINS, weights=ref_weights)
    
    H = -np.log(H)
    H -= np.min(H)

    axes[0].imshow(H.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], cmap=plt.cm.turbo, aspect='auto', origin='lower')
    axes[0].contour(H.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], levels=N_COUNTOURS, colors='black', linewidths=0.3)

    helix_tics = ref_metrics[protein]['helix_tics']
    axes[0].plot(helix_tics[0], helix_tics[1], '^', alpha=1.0, markersize=8, color='black')
    crystal_tics = ref_metrics[protein]['crystal_tics']
    axes[0].plot(crystal_tics[0], crystal_tics[1], 'v', alpha=1.0, markersize=8, color='black')


    axes[0].set_title('Reference')
    axes[0].set_xlabel('TIC1')
    axes[0].set_ylabel('TIC2')

    for idx, model in enumerate(models):
        model_tics = np.concatenate(metrics[protein][model]['tics'][:2], axis=0)
        # model_weights = np.concatenate(metrics[protein][model]['clusters'][16]['weights'], axis=0)

        H, _, _ = np.histogram2d(model_tics[:, 0], model_tics[:, 1], bins=N_BINS, range=[[xbins.min(), xbins.max()], [ybins.min(), ybins.max()]]) # weights=model_weights)

        H = -np.log(H)
        H -= np.min(H)

        axes[idx + 1].imshow(H.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], cmap=plt.cm.turbo, aspect='auto', origin='lower')
        axes[idx + 1].contour(H.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], levels=N_COUNTOURS, colors='black', linewidths=0.3)
        
        axes[idx + 1].set_title(model)
        axes[idx + 1].set_xlabel('TIC1')

    fig.suptitle(protein)
    plt.tight_layout()
    plt.show()


### Plot TICA Trajectories

In [None]:
import matplotlib.pyplot as plt
import numpy as np 

import warnings

import matplotlib.pyplot as plt
import numpy as np

from matplotlib.collections import LineCollection
def colored_line(x, y, c, ax, **lc_kwargs):
    if "array" in lc_kwargs:
        warnings.warn('The provided "array" keyword argument will be overridden')
    default_kwargs = {"capstyle": "butt"}
    default_kwargs.update(lc_kwargs)
    x = np.asarray(x)
    y = np.asarray(y)
    x_midpts = np.hstack((x[0], 0.5 * (x[1:] + x[:-1]), x[-1]))
    y_midpts = np.hstack((y[0], 0.5 * (y[1:] + y[:-1]), y[-1]))
    coord_start = np.column_stack((x_midpts[:-1], y_midpts[:-1]))[:, np.newaxis, :]
    coord_mid = np.column_stack((x, y))[:, np.newaxis, :]
    coord_end = np.column_stack((x_midpts[1:], y_midpts[1:]))[:, np.newaxis, :]
    segments = np.concatenate((coord_start, coord_mid, coord_end), axis=1)
    lc = LineCollection(segments, **default_kwargs)
    lc.set_array(c)  # set the colors of each segment
    return ax.add_collection(lc)


for protein in proteins:
    ref_tics = np.concatenate(ref_metrics[protein]['feats'], axis=0)

    H, xbins, ybins = np.histogram2d(ref_tics[:, 0], ref_tics[:, 1], bins=40)
    H = -np.log(H)
    H -= np.min(H)

    fig, axes = plt.subplots(1, 1, figsize=(3, 3))

    # for name, tics in metrics[protein]['tics'].items():
        # if 'ref' in name:
            # continue
    axes.contour(H.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], levels=10, colors='black', linewidths=0.3)


    model_tics = [tics for name, tics in metrics[protein]['tics'].items() if 'ref' not in name]

    for tics in model_tics:
        colored_line(
            tics[:, 0], tics[:, 1],  ax=axes, linewidth=0.5, alpha=0.3, c=np.arange(len(tics)), cmap=plt.cm.coolwarm_r,
        )

    helix_tics = ref_metrics[protein]['helix_tics']
    axes.plot(helix_tics[0], helix_tics[1], '^', alpha=1.0, markersize=6, color='black')
    crystal_tics = ref_metrics[protein]['crystal_tics']
    axes.plot(crystal_tics[0], crystal_tics[1], 'v', alpha=1.0, markersize=6, color='black')

    # axes[1].imshow(H.T, extent=[xbins.min(), xbins.max(), ybins.min(), ybins.max()], cmap=plt.cm.Spectral_r, aspect='auto', origin='lower')
                    

    fig.suptitle(protein)

    plt.show()


### Plot Trajectory

In [None]:
from colour import Color
from tqdm import tqdm

from biotite.structure.io.pdb import PDBFile
from biotite.structure import stack

def traj_to_pdb(traj, downsample=1, tail=10):
    models = ""
    traj = traj[::downsample] + [traj[-1]] * tail
    for i, p in enumerate(tqdm(traj)):
        models += f"MODEL {i + 1}\n"
        models += p.to_pdb_str()
        models += "\nENDMDL\n"
    return models

import py3Dmol

def plot_py3dmol_traj(view, traj, viewer):
    file = PDBFile()
    file.set_structure(stack(traj))
    file.write('./temp.pdb')
    
    with open('./temp.pdb', 'r') as f:
        models = f.read()

    print(len(models))
    view.addModelsAsFrames(models, 'pdb', viewer=viewer)
    # view.setStyle({}, viewer=viewer)
    view.setStyle({'chain': 'A'}, {'stick': {}, 'cartoon': {'color': 'spectrum'}}, viewer=viewer)
    # view.setStyle({'cartoon': {'color': 'spectrum' }}, viewer=viewer)
    view.setBackgroundColor("rgb(255,255,255)", 0)
    view.animate({'loop': 'forward'}, viewer=viewer)
    return view

prot = proteins[-1]

# get traj with the largest tica displacement

# dists = []
# for name, tics in metrics[protein]['tics'].items():
#     tic_start, tics_end = tics[0], tics[-1]
#     dist = np.linalg.norm(tics_end - tic_start)
#     dists.append((dist, name))
# sorted(dists)

# name = dists[0][1]

sample_trajs = list(trajs[prot][models[-1]].values())[:1]
# sample_trajs = [trajs[protein][name]]

v = py3Dmol.view(   
    viewergrid=(1, len(sample_trajs)),
    # linked=True,
    linked=True,
    width=300 * len(sample_trajs),
    height=300,
)

for i, traj in enumerate(sample_trajs):
    v = plot_py3dmol_traj(v, list(traj)[1:100], viewer=(0, i))
    # break

v.zoomTo()
v.show()
# v.write_html/('test.html')

In [None]:
import matplotlib.pyplot as plt
import numpy as np 
import einops as ein
 
factor = 100

for protein in proteins:
    ref_tics = np.concatenate(ref_metrics[protein]['feats'], axis=0)


    H, xbins, ybins = np.histogram2d(ref_tics[:, 0], ref_tics[:, 1], bins=20)
    H = -np.log(H)
    H -= np.min(H)

    x_max, x_min = xbins.max(), xbins.min()
    y_max, y_min = ybins.max(), ybins.min()

    
    model_tics = np.stack([tics for name, tics in metrics[protein]['tics'].items() if 'ref' not in name], axis=0)
    # slice second dimension
    # model_tics = ein.rearrange(model_tics[:, 1:], 's (f t) c -> f (s t) c', f=factor)

    
    model_tics = np.stack([
        model_tics[:, i:i+factor, :] for i in range(model_tics.shape[1] // factor)
    ])

    model_fes = [
        -np.log(np.histogram2d(
            tics[:, 0], tics[:, 1], bins=20, 
        )[0]) for tics in model_tics
    ]
    model_fes = [ fe - np.min(fe) for fe in model_fes ]

    # data = []

    # fig = go.Figure()
    # for fes in model_fes:
    #     data.append(go.Histogram2d(
    #         x=xbins[:-1], y=ybins[:-1], z=fes.flatten(), 
    #         colorscale='Viridis', 
    #         # zmin=0.0, zmax=0.2,
    #         # opacity=0.5,
    #         # surface_count=12,
    #         )
    #     )

    # fig.add_traces(data)
    # fig.show()

    # print(model_densities.shape)

    import plotly.graph_objects as go
    X, Y, Z = np.meshgrid(xbins[:-1], ybins[:-1], np.arange(len(model_densities)))

    print(model_densities.min(), model_densities.max())

    fig = go.Figure(data=go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=model_densities.flatten(),
        isomin=0.0,
        isomax=0.2,
        opacity=0.1, # needs to be small to see through all surfaces
        surface_count=12, # needs to be a large number for good volume rendering
        # ,
        )
    )
        # title=title=
    
    # fig.show()
    fig.update_layout(
        title=dict(
            text=f'min {model_densities.min()} max {model_densities.max()}'
        ),
    )
    fig.write_html(f'./density_density.html')

    
    # tic_volume = 


### Tabulate 