In [None]:
import pyemma
pyemma.__version__
import random as rndm

In [None]:
import matplotlib as mpltlib
import matplotlib.pylab as plt
import numpy as np
#import nglview as nv
%pylab inline
import mdtraj
import numpy as np

import pyemma.util.contexts

In [None]:
import pyemma.coordinates as coor
import pyemma.msm as msm
import pyemma.plots as mplt

In [None]:
indir = '.'
topfile =  'hdim-oct.inpcrd.pdb'
traj_list = []
for filename in os.listdir(indir):
    if filename.endswith('.nc'):
        traj_list.append(os.path.join(indir,filename))


In [None]:
atom_mol = 145
num_mol = 2

topology = mdtraj.load(topfile).topology

In [None]:
feat = coor.featurizer(topfile)
feat.add_backbone_torsions(deg=True)
inp = coor.load(traj_list, features=feat)

In [None]:
def score_cv(data, dim, lag, number_of_splits=10, validation_fraction=0.5):
    
    """Compute a cross-validated VAMP2 score.
# 
    We randomly split the list of independent trajectories into
    a training and a validation set, compute the VAMP2 score,
    and repeat this process several times.

    Parameters
    ----------
    data : list of numpy.ndarrays
        The input data.
    dim : int
        Number of processes to score; equivalent to the dimension
        after projecting the data with VAMP2.
    lag : int
        Lag time for the VAMP2 scoring.
    number_of_splits : int, optional, default=10
        How often do we repeat the splitting and score calculation.
    validation_fraction : int, optional, default=0.75
        Fraction of trajectories which should go into the validation
        set during a split.
    """

    
    nval = int(len(data) * validation_fraction)
    scores = np.zeros(number_of_splits)
    for n in range(number_of_splits):
        ival = np.random.choice(len(data), size=nval, replace=False)
        vamp = pyemma.coordinates.vamp([d for i, d in enumerate(data) if i not in ival], lag=lag, dim=dim)
        scores[n] = vamp.score([d for i, d in enumerate(data) if i in ival])
    return scores



In [None]:


lags = [100,250,375,500,750,1000]
dims = [2,4,6,8,10]

fig, ax = plt.subplots()
for i, lag in enumerate(lags):
    print('Calculation for Lag Time (in MD steps):' + str(lag))
    scores_ = np.array([score_cv(inp, dim, lag)
                        for dim in dims])
    scores = np.mean(scores_, axis=1)
    errors = np.std(scores_, axis=1, ddof=1)
    color = 'C{}'.format(i)
    ax.fill_between(dims, scores - errors, scores + errors, alpha=0.3, facecolor=color)
    ax.plot(dims, scores, '--o', color=color, label='lag={:.1f}ns'.format(lag * 0.02))
ax.legend()
ax.set_xlabel('number of dimensions')
ax.set_ylabel('VAMP2 score')
fig.tight_layout()

In [None]:
lag = 250
dim = 4
tica_obj = coor.tica(inp, lag = lag, dim = dim, kinetic_map = True)

In [None]:
Y = tica_obj.get_output()

Y1 = np.concatenate(Y)

In [None]:
figure(figsize=(9,7))
plt.subplot2grid((4,1),(0,0))
plt.plot(Y1[:,0])
plt.ylabel('IC 1')
plt.subplot2grid((4,1),(1,0))
plt.plot(Y1[:,1])
plt.ylabel('IC 2')
plt.subplot2grid((4,1),(2,0))
plt.plot(Y1[:,2])
plt.ylabel('IC 3')
plt.subplot2grid((4,1),(3,0))
plt.plot(Y1[:,3])
plt.ylabel('IC 4')
# plt.subplot2grid((8,1),(4,0))
# plt.plot(Y1[:,0])
# plt.ylabel('IC 5')
# plt.subplot2grid((8,1),(5,0))
# plt.plot(Y1[:,1])
# plt.ylabel('IC 6')
# plt.subplot2grid((8,1),(6,0))
# plt.plot(Y1[:,2])
# plt.ylabel('IC 7')
# plt.subplot2grid((8,1),(7,0))
# plt.plot(Y1[:,3])
# plt.ylabel('IC 8')
plt.xlabel('time (40 micros)')

In [None]:
mplt.plot_feature_histograms(Y1, feature_labels=['IC1','IC2','IC3','IC4'])

In [None]:
tica0 = np.array([])
tica1 = np.array([])
for j in range(len(Y)):
    tica0 = np.concatenate((tica0, Y[j][:,0]))

for j in range(len(Y)):
    tica1 = np.concatenate((tica1, Y[j][:,1]))
    
# histogram data
z,x,y = np.histogram2d(tica0, tica1, bins=200)
extent = (x.min(), x.max(), y.min(), y.max()) # extent of the plot
# compute free energies
F = -np.log(z)
F[F == inf] = -1000
maxval = np.amax(F)
F[F == -1000] = maxval
plt.figure(figsize=(6,5))
plt.contourf(F.T, 50, cmap=plt.cm.afmhot, extent = extent)
plt.colorbar()
plt.ylabel('IC2')
plt.xlabel('IC1')
plt.show()

del(tica0)
del(tica1)