In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pca.pca import prot_pca
from pca.pcaIO import write_pca
from database.query import traj_group, get_protdef
from utils.atomselect import select_domains
from plot.plot_utilities import hist2d, hist1d

# Perform calculation

In [None]:
atomstride = 1

traj_ids = traj_group(3)

domain_dict = get_protdef(1)
domains_measure = [f'TM{n}' for n in np.arange(12)+1]

atomselect_str = select_domains([domain_dict.get(tm) for tm in domains_measure], package='mdtraj') + " and backbone"

In [None]:
main = prot_pca(traj_ids, atomselect=atomselect_str, atomstride=atomstride)

tf = pd.DataFrame()
tf['traj_id'] = np.repeat(main.traj_ids, [main.nframes[t] for t in main.traj_ids])
tf['timestep'] = np.concatenate([np.arange(main.nframes[t]) for t in main.traj_ids])

main.pca_init(tf=tf)
# main.pca_init()
xyzcenter = np.mean(main.pca.xyz_data, axis=0)

# Visualize states on 2PC space

In [None]:
comp1 = 0
comp2 = 1

xrange = [-100, 100]
yrange = [-100, 100]

nbins = 30

fig, axs = plt.subplots()

hist2d(main.pca.pca_output[:,comp1], main.pca.pca_output[:,comp2], range=[xrange, yrange], bins=nbins).hist2d_contour(axs)

plt.xlim(*xrange)
plt.ylim(*yrange)

# Trajectory projected onto PC

In [None]:
#Projection onto PCs
comp = 1
trajsel = [70]

i = 0
for t in trajsel:
    plt.figure()
    plt.plot(main.get_trajpc(t, comp), c='black')
    plt.xlim(-10,1010)
    # plt.ylim(-25,25)
    plt.xlabel('Time [ns]')
    plt.ylabel('Principal Component '+str(comp))
    plt.grid(True)
    plt.title(t)
    i += 1

# Residue SSW

In [None]:
for comp in range(10):
    fig, axs = plt.subplots()
    hist1d(main.pca.pca_output[:,comp], bins=100, range=[80, 100]).plot(axs)

In [None]:
for comp in np.arange(10)+1:
    fig, axs = plt.subplots()
    main.plot_residue_ssw(axs, comp, 'red')
    plt.ylim(0,0.2)

# Explained variance analysis

In [None]:
for i in range(20):
    print('pc_explained: '+str(main.pca.variances[i]))
    print('cumul: '+str(np.cumsum(main.pca.variances)[i]))
    
fig, axs = plt.subplots()
main.pca.plot_explained_variance(axs, 50)

fig, axs = plt.subplots()
main.pca.plot_cumulative_variance(axs, 50)

# Write to h5

In [None]:
write_pca(main, '6msm_tmpc/all_tmpc.stride1.realign.240702', xyz_center=xyzcenter)
# write_pca(main, '6msm_tmpc/all_tmpc.post100', xyz_center=xyzcenter)