In [None]:
# This was originally done on the cluster
import numpy as np
import matplotlib.pyplot as plt

from conf.conf_analysis import xyzt
from utils.atomselect import select_resids_str, combine_from_intervals
from plot.plot_utilities import hist1d, edgeformat, savefig
from database.query import traj2grotop, traj_group, get_protdef

In [None]:
traj_ids = traj_group(3)

# RMSD from average structure

In [None]:
# Get domain intervals
protein_domain_definitions = get_protdef(protein_id=1)
tmd1 = protein_domain_definitions['TMD1']
tmd2 = protein_domain_definitions['TMD2']
tm_helices = [protein_domain_definitions[f'TM{n}'] for n in range(1,12+1)]

In [None]:
def rmsd(traj_xyzt, ref_xyz):
    print(traj_xyzt.shape)
    assert len(traj_xyzt.shape) == 3
    assert len(ref_xyz.shape) == 2
    assert traj_xyzt.shape[1] == ref_xyz.shape[0]
    
    return np.sqrt(np.sum((traj_xyzt - ref_xyz)**2, axis=(1,2)) / len(ref_xyz))

def ca_displacement(traj_xyzt, ref_xyz):
    print(traj_xyzt.shape)
    assert len(traj_xyzt.shape) == 3
    assert len(ref_xyz.shape) == 2
    assert traj_xyzt.shape[1] == ref_xyz.shape[0]
    
    displacements = np.sqrt(np.sum((traj_xyzt - ref_xyz)**2, axis=2))
    average_displacement = np.mean(displacements, axis=0)
    std_displacement = np.std(displacements, axis=0, ddof=1)
    
    return average_displacement, std_displacement

# To reference structure
def ca_rmsf(traj_xyzt, ref_xyz):
    print(traj_xyzt.shape)
    assert len(traj_xyzt.shape) == 3
    assert len(ref_xyz.shape) == 2
    assert traj_xyzt.shape[1] == ref_xyz.shape[0]
    
    rmsf = np.sqrt(np.mean(np.sum((traj_xyzt - ref_xyz)**2, axis=2), axis=0))
    
    return rmsf

class rmsd2average_structure(xyzt):
    def __init__(self, traj_ids, selection_string, **kwargs):
        super().__init__(traj_ids=traj_ids, **kwargs)
        self.selection_string = selection_string
        
        self.load_refs()        
        self.open_nc()
        
        # Get CA atom coordinates from all trajectories
        xyz_collect = []
        for t in self.traj_ids:
            if t % 10 == 0:
                print(t)
            ca_index = self.refs[traj2grotop(t)].top.select(self.selection_string)
            xyz = self.getcoords(t, ca_index, df=False)
            xyz_collect.append(xyz)
        xyz_collect = np.vstack(xyz_collect)
        
        self.average_xyz = np.mean(xyz_collect, axis=0)
        self.rmsd = rmsd(xyz_collect, self.average_xyz)
        
class rmsd2custom_structure(xyzt):
    def __init__(self, traj_ids, selection_string, ref_xyz, **kwargs):
        super().__init__(traj_ids=traj_ids, **kwargs)
        self.selection_string = selection_string
        
        self.load_refs()        
        self.open_nc()
        
        # Get CA atom coordinates from all trajectories
        xyz_collect = []
        for t in self.traj_ids:
            if t % 10 == 0:
                print(t)
            ca_index = self.refs[traj2grotop(t)].top.select(self.selection_string)
            xyz = self.getcoords(t, ca_index, df=False)
            xyz_collect.append(xyz)
        xyz_collect = np.vstack(xyz_collect)
        
        self.ref_xyz = ref_xyz
        self.rmsd = rmsd(xyz_collect, self.ref_xyz)

In [None]:
# TODO: add checks for whether the selection string is mdtraj or MDAnalysis
selection_string = "name CA"
selection_string += f" and {select_resids_str(combine_from_intervals(tmd1, tmd2), 'mdtraj')}"
rmsd_tmdca = rmsd2average_structure(traj_ids, selection_string)

In [None]:
np.save("average_tmca.npy", rmsd_tmdca.average_xyz.data)

In [None]:
# TODO: add checks for whether it's mdtraj or MDAnalysis
selection_string = "name CA"
rmsd_allca = rmsd2average_structure(traj_ids, selection_string)

In [None]:
np.save("average_allca.npy", rmsd_allca.average_xyz.data)

In [None]:
# TODO: add checks for whether it's mdtraj or MDAnalysis
selection_string = "name CA"
selection_string += f" and {select_resids_str(combine_from_intervals(*tm_helices), 'mdtraj')}"
rmsd_tmhelixca = rmsd2average_structure(traj_ids, selection_string)

In [None]:
np.save("average_tmhelixca.npy", rmsd_tmhelixca.average_xyz.data)

In [None]:
fig, axs = plt.subplots()
edgeformat(axs)

hist1d(rmsd_allca.rmsd, range=[0,6], bins=60).plot(axs, label=f"all ({round(np.mean(rmsd_allca.rmsd), 1)} "+r"$\mathrm{\AA}$)")
print(np.mean(rmsd_allca.rmsd), np.max(rmsd_allca.rmsd), np.min(rmsd_allca.rmsd))
hist1d(rmsd_tmdca.rmsd, range=[0,6], bins=60).plot(axs, label=f"TMDs ({round(np.mean(rmsd_tmdca.rmsd), 1)} "+r"$\mathrm{\AA}$)")
print(np.mean(rmsd_tmdca.rmsd), np.max(rmsd_tmdca.rmsd), np.min(rmsd_tmdca.rmsd))
hist1d(rmsd_tmhelixca.rmsd, range=[0,6], bins=60).plot(axs, label=f"TM helices ({round(np.mean(rmsd_tmhelixca.rmsd), 1)} "+r"$\mathrm{\AA}$)")
print(np.mean(rmsd_tmhelixca.rmsd), np.max(rmsd_tmhelixca.rmsd), np.min(rmsd_tmhelixca.rmsd))

axs.legend()
axs.set_xlim(0,5)
axs.set_ylim(0,2.5)
axs.grid(True, ls='--')

axs.set_xlabel(r"C$\alpha$ RMSD [$\mathrm{\AA}$]", fontsize=14)
axs.set_ylabel(r"Prob. Density [A.U.]", fontsize=14)