In [1]:
import os
import sys
import pickle
import numpy as np
import pandas as pd
from tqdm import trange

import matplotlib.pyplot as plt

from Bio.PDB import PDBParser, Select, PDBIO
import mdtraj

PROJECT_PATH = os.path.abspath("..")
sys.path.insert(0, PROJECT_PATH)
import analysis as A

## Input paths and settings

In [2]:
NB_PATH = os.path.join(PROJECT_PATH, "lag21")

WT_PATH = os.path.join(NB_PATH, "mdsims", "wild_type")
MUT_PATH = os.path.join(NB_PATH, "mdsims", "L69F")

WT_REFERENCE_PDB_FILE = os.path.join(NB_PATH, "rosetta_conformations/wild_type/conf_0.pdb")
MUT_REFERENCE_PDB_FILE = os.path.join(NB_PATH, "rosetta_conformations/L69F/conf_0.pdb")

WT_REF_TRJ = mdtraj.load(WT_REFERENCE_PDB_FILE)
MUT_REF_TRJ = mdtraj.load(MUT_REFERENCE_PDB_FILE)

RUNS = [0, 1, 2, 3, 4, 5, 6, 7]
TEMPS = np.array([290, 300, 317, 336, 355, 398, 421])

## Output paths and settings

In [3]:
OUTPUT_PATH = os.path.join(NB_PATH, "results")
if not os.path.isdir(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH, exist_ok=True)

N_FRAMES = 4000
NBLOCKS = 4
NBINS = 15
EPS = 1E-9

## Sequence region boundaries (Frameworks and CDRs)
This is a relaxed crystal structure of Lag21, and has 0.74 A backbone RMSD from the original. Also, the residue id for the first residue in the crystal structure is 4. So, all residue ids here will need to be offset by 4.

In [4]:
seq_wt, regions_wt = A.get_sequence_region_boundaries(WT_REFERENCE_PDB_FILE, chain="A")
for k, v in regions_wt.items():
    subseq = seq_wt[v[0]:v[1]]
    res_start, res_stop = v[0] + 4, v[1] + 4
    print(k, res_start, res_stop, subseq)

fr1 5 30 QVQLVESGGGLVQAGGSLRLSCAAS
fr2 35 52 MAWFRQAPGMEREFVGG
fr3 60 98 YYADFVKGRLTVDRDNVKNTVDLQMNSLKPEDTAVYYC
fr4 115 124 WGQGTQVTV
cdr1 30 35 GPTGA
cdr2 52 60 ISGSETDT
cdr3 98 115 AARRRVTLFTSRADYDF


## Parse trajectories at all temperatures
Trajectories at each temperature will have 1.6 micro-seconds total simulation time, comes from 1E8 iterations with a timestep of 2 fs, repeated over 8 independent instances (each from a different starting conformation of the CDR3 loop). Coordinates were recorded every 1E4 iterations (i.e 20 ps).

We will retain the last 640 ns as production, which comes to the last 4000 frames for each independent run.

In [5]:
def _read_all_traj(base_path, ref_pdb, temps):
    trajs = {t: None for t in temps}
    for i in trange(len(temps)):
        t = temps[i]
        traj_pkl_file = os.path.join(base_path, f"{t}K", f"traj_{t}K.pkl")
        
        if not os.path.isfile(traj_pkl_file):
            traj_files = [os.path.join(base_path, str(t) + "K", f"run_{i}", "nb.nvt.dcd") for i in RUNS]
            trj = A.read_trajectory(traj_files, ref_pdb, N_FRAMES)
            with open(traj_pkl_file, "wb") as of:
                pickle.dump(trj, of)
        
        trajs[t] = traj_pkl_file
    return trajs

def _get_trj(traj_pkl_dict, temp):
    with open(traj_pkl_dict[temp], "rb") as of:
        trj = pickle.load(of)
    return trj

Wild-type trajectories

In [6]:
trajs_wt = _read_all_traj(WT_PATH, WT_REFERENCE_PDB_FILE, temps=TEMPS)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 161.10it/s]


L69F mutation trajectories

In [7]:
trajs_mut = _read_all_traj(MUT_PATH, MUT_REFERENCE_PDB_FILE, temps=TEMPS)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 141.04it/s]


## Backbone RMSDs per residue

In [8]:
def block_avg_rmsd(samples, nblocks):
    nframes = samples.shape[0]
    ndim = samples.shape[1]
    blocksize = int(nframes / nblocks)
    
    samples_block = np.zeros([nblocks, ndim])
    for i in range(nblocks):
        start = i*blocksize
        stop = (i+1)*blocksize
        if i == nblocks-1:
            stop = nframes
        samples_block[i] = np.mean(samples[start:stop], axis=0)
    
    mean = np.mean(samples_block, axis=0)
    err = np.std(samples_block, axis=0)
    return mean, err

In [9]:
resids = [4+i for i in range(len(seq_wt))]
columns = ["resid"]
for t in TEMPS:
    columns.extend([f"wt_{t}K_mean", f"mut_{t}K_mean", f"wt_{t}K_err", f"mut_{t}K_err"])

df_rmsd = pd.DataFrame([[i] + [0.0]*4*len(TEMPS) for i in resids], columns=columns)

print("Calculating per residue RMSDs at:")
for i in trange(len(TEMPS)):
    t = TEMPS[i]
    
    trj_wt = _get_trj(trajs_wt, t)
    rmsd_wt = A.get_per_residue_backbone_RMSD(trj_wt, WT_REF_TRJ)
    mean_rmsd_wt, err_rmsd_wt = block_avg_rmsd(rmsd_wt, NBLOCKS)
    df_rmsd[f"wt_{t}K_mean"] =  10.0 * mean_rmsd_wt # convert rmsds from nm to A when writing to dataframe
    df_rmsd[f"wt_{t}K_err"] = 10.0 * err_rmsd_wt
    
    trj_mut = _get_trj(trajs_mut, t)
    rmsd_mut = A.get_per_residue_backbone_RMSD(trj_mut, MUT_REF_TRJ)
    mean_rmsd_mut, err_rmsd_mut = block_avg_rmsd(rmsd_mut, NBLOCKS)
    df_rmsd[f"mut_{t}K_mean"] =  10.0 * mean_rmsd_mut
    df_rmsd[f"mut_{t}K_err"] = 10.0 * err_rmsd_mut
    
rmsd_out_file = os.path.join(OUTPUT_PATH, "per_residue_rmsd.csv")
df_rmsd.to_csv(rmsd_out_file, index=False)

display(df_rmsd)

Calculating per residue RMSDs at:


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:07<00:00,  9.66s/it]


Unnamed: 0,resid,wt_290K_mean,mut_290K_mean,wt_290K_err,mut_290K_err,wt_300K_mean,mut_300K_mean,wt_300K_err,mut_300K_err,wt_317K_mean,...,wt_355K_err,mut_355K_err,wt_398K_mean,mut_398K_mean,wt_398K_err,mut_398K_err,wt_421K_mean,mut_421K_mean,wt_421K_err,mut_421K_err
0,4,2.550486,2.031762,0.220964,0.338965,3.391172,2.457265,1.169631,0.646839,3.325254,...,4.683445,1.385060,3.999289,3.593526,1.715562,1.516215,6.366815,3.919510,4.095125,1.367007
1,5,1.609659,1.293696,0.465137,0.140241,2.146561,1.618622,1.098978,0.713923,2.278058,...,4.337177,1.245190,2.387437,2.368671,1.390479,1.414009,4.633117,2.356351,3.798902,0.933760
2,6,0.942044,0.955499,0.370047,0.266895,1.370142,0.923422,0.656214,0.327570,1.531486,...,3.910993,0.871783,1.386737,1.651290,0.815715,1.319211,3.314324,1.445820,3.398749,0.644391
3,7,0.694564,0.861327,0.144459,0.198142,1.054551,0.811308,0.351832,0.118152,1.180823,...,3.393992,0.277340,1.053851,1.447658,0.455263,1.158007,2.649340,0.986486,2.804850,0.165323
4,8,0.542451,0.594904,0.091018,0.069719,0.716406,0.610980,0.227861,0.055712,0.907547,...,2.832493,0.184398,0.833679,1.194849,0.341336,1.020233,2.129507,0.742868,2.262415,0.189873
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
115,119,0.450237,0.307241,0.137731,0.038006,0.501549,0.299127,0.209054,0.005436,0.480464,...,0.488554,0.069866,0.695438,0.504287,0.289829,0.211662,1.455678,0.417838,1.313383,0.023042
116,120,0.408075,0.289988,0.159790,0.016942,0.485963,0.293323,0.179726,0.012424,0.500414,...,0.453076,0.063046,0.650225,0.525054,0.258647,0.231892,1.635888,0.412995,1.527291,0.028601
117,121,0.419610,0.294074,0.060111,0.025879,0.579377,0.284177,0.253260,0.003936,0.572704,...,0.455737,0.050809,0.675418,0.529704,0.188797,0.212932,1.883870,0.427600,1.810751,0.032432
118,122,0.639162,0.449204,0.107446,0.051253,0.812716,0.449389,0.349122,0.027605,0.752288,...,0.457680,0.054889,0.910810,0.610981,0.386799,0.221544,2.272412,0.502728,2.255596,0.050293


## Write normalized residue RMSDs as bfactors into the reference structures

In [10]:
pdb_out_path = os.path.join(OUTPUT_PATH, "heatmap_pdbs")
if not os.path.isdir(pdb_out_path):
    os.makedirs(pdb_out_path, exist_ok=True)

min_rmsd = 0.0 # A
max_rmsd = 10.0 # A
    
df_rmsd = pd.read_csv(os.path.join(OUTPUT_PATH, "per_residue_rmsd.csv"))

for i in trange(len(TEMPS)):
    t = TEMPS[i]
    
    # wild type
    rmsds_wt = np.clip(df_rmsd[f"wt_{t}K_mean"].values, min_rmsd, max_rmsd)
    bfs_wt = 100.0 * (rmsds_wt - min_rmsd) / (max_rmsd - min_rmsd)
    
    src_pdb_file = WT_REFERENCE_PDB_FILE
    tar_pdb_file = os.path.join(pdb_out_path, f"wt_{t}K.pdb")
    A.embed_residue_bfactors(src_pdb_file, tar_pdb_file, bfs_wt, chain="A")
    
    # mutant
    rmsds_mut = np.clip(df_rmsd[f"mut_{t}K_mean"].values, min_rmsd, max_rmsd)
    bfs_mut = 100.0 * (rmsds_mut - min_rmsd) / (max_rmsd - min_rmsd)
    src_pdb_file = MUT_REFERENCE_PDB_FILE
    tar_pdb_file = os.path.join(pdb_out_path, f"mut_{t}K.pdb")
    A.embed_residue_bfactors(src_pdb_file, tar_pdb_file, bfs_mut, chain="A")

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  4.08it/s]


## Inter-residue mutual information of correlations in backbone dihedrals

Compute total deviations in backbone dihedrals across residues as, (written here for the $i^{\mathrm{th}}$ residue):
$$\Delta_{\varphi\psi}^i = \sqrt{ \left( \varphi^i-\varphi_{\mathrm{ref}}^i \right)^2 + 
                                  \left( \psi^i-\psi_{\mathrm{ref}}^i \right)^2}$$

In [11]:
dphipsi_wt, dphipsi_mut = dict(), dict()

for i in trange(len(TEMPS)):
    t = TEMPS[i]
    trj_wt = _get_trj(trajs_wt, t)
    trj_mut = _get_trj(trajs_mut, t)
    dphipsi_wt[t] = A.get_phipsi_deviations(trj_wt, WT_REF_TRJ)
    dphipsi_mut[t] = A.get_phipsi_deviations(trj_mut, MUT_REF_TRJ)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:14<00:00, 10.67s/it]


In [12]:
min_dphipsi = min([np.min(v) for k,v in dphipsi_wt.items()])
max_dphipsi = max([np.max(v) for k,v in dphipsi_wt.items()])

Compute all mutual information based generalized correlations

In [13]:
correlation_map_out_path = os.path.join(OUTPUT_PATH, "correlation_maps")
if not os.path.isdir(correlation_map_out_path):
    os.makedirs(correlation_map_out_path, exist_ok=True)

nframes = dphipsi_wt[300].shape[0]
nres = dphipsi_wt[300].shape[1]

blocksize = int(nframes/NBLOCKS)
    
for i in trange(len(TEMPS)):
    t = TEMPS[i]
    MI_wt_blocks = np.zeros([NBLOCKS, nres, nres])
    MI_mut_blocks = np.zeros([NBLOCKS, nres, nres])
    
    for n in range(NBLOCKS):
        start = n*blocksize
        stop = (n+1)*blocksize
        if n == NBLOCKS-1:
            stop = nframes
            
        # compute mutual information
        I_wt, C_wt = A.get_mutual_information(
            x=dphipsi_wt[t][start:stop], 
            nbins=NBINS,
            xmin=min_dphipsi,
            xmax=max_dphipsi
        )
        MI_wt_blocks[n] = C_wt
        
    
        I_mut, C_mut = A.get_mutual_information(
            x=dphipsi_mut[t][start:stop], 
            nbins=NBINS,
            xmin=min_dphipsi,
            xmax=max_dphipsi
        )
        MI_mut_blocks[n] = C_mut
    
    # average and save to file
    # report average error
    MI_wt_mean = np.mean(MI_wt_blocks, axis=0)
    MI_wt_err = np.std(MI_wt_blocks, axis=0)
    fn_wt = os.path.join(correlation_map_out_path, f"wt_{t}K.npz")
    np.savez_compressed(fn_wt, C=MI_wt_mean, Cerr=MI_wt_err)
    
    MI_mut_mean = np.mean(MI_mut_blocks, axis=0)
    MI_mut_err = np.std(MI_mut_blocks, axis=0)
    fn_mut = os.path.join(correlation_map_out_path, f"mut_{t}K.npz")
    np.savez_compressed(fn_mut, C=MI_mut_mean, Cerr=MI_mut_err)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:50<00:00,  7.19s/it]
