In [None]:
%ls

In [1]:
#!/usr/bin/python

from __future__ import print_function
import glob, itertools, os, subprocess, re
import sys, time, tqdm, itertools, socket
import mdtraj as md
import msmbuilder.utils
import numpy as np
import itertools
from itertools import groupby, count
import matplotlib
# Script: matplotlib.use('Agg')  | Notebook: %matplotlib inline
%matplotlib inline
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from msmbuilder.cluster import KCenters, KMedoids
from msmbuilder.decomposition import tICA
from msmbuilder.featurizer import AtomPairsFeaturizer
from msmbuilder.msm import ContinuousTimeMSM, implied_timescales, MarkovStateModel
from operator import itemgetter
from sklearn.externals import joblib
from sklearn.pipeline import Pipeline

'''This script has lots of functionality and is based on analyzing Gromacs trajectories. A list of trajectory
   files is given as trajectory_files, as well as a general structure file. Other structure files should contain
   the same name as the corresponding trajectory file, e.g. traj_001.trr traj_001.gro.
'''

# TODO: add interactivity
# ADD loader function
# do we need the pkl files in calculating components?
# add threading for parallel feature selections
# add metric for comparing overlaps given feature selections
# make sure cluster centers are saving properly
# should i plot traj end point as well?
# correct movie and eigenvector analysis functions
# add autocorrelation function

  
def calculate_distances():
    
    """Calculate pair-wise distance features and save as .npy files. Index selection is done within the function.
       Feature labels are returned to match tICA components back to the features that make them up.
    """
    
    print("\nCalculating distances...")
    for i in range(len(trajectory_files)): # For each trajectory file
        ions,indices = [],[]
        if '.gro' in trajectory_files[i]:
            traj = md.load(trajectory_files[i]) # Load trajectory as single frame
            print("\nLoaded " + trajectory_files[i] + " as a single frame")
        else:
            try: # Try to load with structure file matching trajectory name
                structure = trajectory_files[i].split('.')[0] + '.gro'
                traj = md.load(trajectory_files[i], top=structure)
                print("Loaded " + trajectory_files[i] + " with top: " + structure)
            except Exception as e: # Else try loading trajectory with general structure file
                traj = md.load(trajectory_files[i], top=structure_file)
                print("Loaded " + trajectory_files[i] + " with top: " + structure_file)

###################
            
#      INDEX SELECTION

        bk_carbon = [ a.index for a in traj.topology.atoms if a.name in ['C','CA'] and a.residue.name not in ['TFL']]
        bk_oxygen = [ a.index for a in traj.topology.atoms if a.name == 'O' and a.residue.name not in ['TFL']]
        ions = [ a.index for a in traj.topology.atoms if a.name in ['NA','K'] ]
        all_oxygen = [ a.index for a in traj.topology.atoms if a.element.symbol == 'O' and a.residue.name not in ['TFL','MOE']]
        all_carbon = [ a.index for a in traj.topology.atoms if a.element.symbol == 'C' and a.residue.name not in ['TFL']]

        indices += bk_oxygen
#        indices += all_carbon
#        indices += ions

###################

        # Transform indices into distances and save
        pairs = list(itertools.combinations(indices, 2))
        feature_labels = [[[str(traj.topology.atom(j[0]).residue.index) +
                            traj.topology.atom(j[0]).residue.name,traj.topology.atom(j[0]).name],
                          [str(traj.topology.atom(j[1]).residue.index) + 
                           traj.topology.atom(j[1]).residue.name,traj.topology.atom(j[1]).name]] for j in pairs]
        
        features = AtomPairsFeaturizer(pairs)
        transformed_data = features.fit_transform(traj)
        
        for j in range(len(transformed_data)):
            transformed_data[j] = transformed_data[j][0]
            
        print("Saved %d pair-wise distance features over %d frames.\n" %(len(pairs),len(transformed_data)))
        np.save(project_title + '/' + 'distance_out_' + str(i).zfill(3) + '.npy', transformed_data)
        
    return feature_labels


def get_bound_frames():
    
    '''Returns a list of lists (for each trajectory) of lists (for each binding event) of frame numbers that
       contain binding events based on the bound_cutoff parameter. If no binding events are found, that
       trajectory is reported as None. Binding events are characterized by indices assigned below.
    '''
    
    # Initiate lists
    all_bound_frames,all_binding_distances,_averages = [],[],[]
    
    # Load trajectory data
    for i in range(len(trajectory_files)): # For each trajectory
        try: # Try loading with general structure file
            traj = md.load(trajectory_files[i], top=structure_file)
        except Exception as e: # Else load with structure file matching trajectory name
            structure = trajectory_files[i].split('.')[0] + '.gro'
            traj = md.load(trajectory_files[i], top=structure)

            
        # Get indices target (backbone oxygens) and ligand (cations)   
        target = [ a.index for a in traj.topology.atoms if a.name == 'O' and a.residue.name in ['PR1','MOE'] and a.residue.index % 2 == 1]
        ligand = [ a.index for a in traj.topology.atoms if a.name in ['NA','K'] ]
            
        if ligand: # If this simulation contains ligands
            bound_frames,binding_indices = [],[]
            print("Checking bound frames in " + trajectory_files[i])
            for j in target: # Determine distance pairs
                for k in ligand:
                    binding_indices.append([j,k])
                    
            # Compute distances
            binding_distances = md.compute_distances(traj,binding_indices)
            for j in range(len(binding_distances)):
                for k in range(len(binding_distances[j])):
                    if binding_distances[j][k] < bound_cutoff:
                        bound_frames.append(j)
                        
            # Transforms distances to an average of target index to each ligand index.
            # This is used to calculate the distance of each ion from the center of the macrocycle
            # and can be deleted for non-similar systems.
            
            for j in range(len(binding_distances)): # For each frame
                averages = []
                for k in range(len(ligand)): # For each ligand index (4 cations)
                    distances_per_ligand = []
                    for m in range(len(binding_distances[0])): # For each binding distance {ligand(4) * target(6) = 24}
                        if m % len(ligand) == k: # Separate binding frames per ligand index (ion)
                            distances_per_ligand.append(binding_distances[j][m])
                    averages.append(np.average(distances_per_ligand)) # Calculate average binding distance
                _averages.append(averages)
            binding_distances = _averages
                        
            # Separate bound frames into lists of binding events
            all_bound_frames.append([list(g) for f, g in groupby(list(set(bound_frames)), key=lambda a,b=count(): a-next(b))])
            length = max([len(i) for i in all_bound_frames[-1]])
            all_binding_distances.append(binding_distances)
            
        else: # If no ligands are present
            all_bound_frames.append(None)
            all_binding_distances.append(None)
            
    return all_bound_frames, all_binding_distances
     

def get_dihedral_indices(traj,structure_file,angle):
    
    '''Returns a list of indices and a list of all possible cis/trans isomers given a trajectory file,
       a structure file, and a particular dihedral angle. Choices for dihedral angles include omega, phi,
       and psi. The indices used are based on the nmer and cyclic parameters.'''
    
    # Change nmer to range of residue numbers to use, so it's more applicable to proteins.
    
    indices = []
    for i in range(nmer):
        indices.append([])
        if angle == 'omega':
            if i != nmer-1:
                for k in ['CA','C']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i:
                            indices[i].append(a.index)
                for k in ['N','CA']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i+1:
                            indices[i].append(a.index)
            elif i == nmer-1 and cyclic:
                for k in ['CA','C']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i:
                            indices[i].append(a.index)
                for k in ['N','CA']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == 0:
                            indices[i].append(a.index)
        elif angle == 'phi':
            if i != nmer-1:
                for k in ['C','N']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i:
                            indices[i].append(a.index)
                for k in ['CA','C']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i+1:
                            indices[i].append(a.index)
            elif i == nmer-1 and cyclic:
                for k in ['C','N']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i:
                            indices[i].append(a.index)
                for k in ['CA','C']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == 0:
                            indices[i].append(a.index)
        elif angle == 'psi':
            if i != nmer-1:
                for k in ['N','CA']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i:
                            indices[i].append(a.index)
                for k in ['C','N']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i+1:
                            indices[i].append(a.index)
            elif i == nmer-1 and cyclic:
                for k in ['N','CA']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == i:
                            indices[i].append(a.index)
                for k in ['C','N']:
                    for a in traj.topology.atoms:
                        if a.name == k and a.residue.index == 0:
                            indices[i].append(a.index)


    # Populate a list of all possible cis/trans conformations.
    choices = [''.join(i) for i in itertools.product(['c','t'], repeat = nmer)]
    return indices, choices


def calculate_dihedrals():
    
    """Calculate dihedral features and save as .npy files. Residue selection is done using the nmer and cyclic
       parameters. Feature labels are returned to match tICA components back to the features that make them up.
    """
    
    print("\nCalculating Dihedrals...")
          
    for i in tqdm.tqdm_notebook(range(len(trajectory_files))):
         
        # Load trajectory and get dihedral indices
        try: # Try loading with general structure file
            traj = md.load(trajectory_files[i], top=structure_file)
            indices,choices = get_dihedral_indices(traj,structure_file,angle)
            print("\nLoaded " + trajectory_files[i] + " with top: " + structure_file)
          
        except Exception as e: # Else load with structure file matching trajectory name
            indices,choices = get_dihedral_indices(traj,structure,angle)
            print("\nLoaded " + trajectory_files[i] + " with top: " + structure)

        # Compute Dihedrals
        dihedral_quartets = np.asarray(indices)
        thetas = md.compute_dihedrals(traj, dihedral_quartets)
        
        # Stack sines and cosines to preserve the 2D nature of ~*~AnGlEs~*~
        sin_cos_dihedrals = np.hstack([np.sin(thetas), np.cos(thetas)])
        # Save dihedral features as .npy
        np.save(project_title + '/' + 'dihedral_out_' + str(i).zfill(3) + '.npy', sin_cos_dihedrals)
        
        print("Saving %d dihedral features for %s" %(len(thetas),trajectory_files[i]))

        feature_labels = [[[str(traj.topology.atom(j[0]).residue.index) +
                            traj.topology.atom(j[0]).residue.name,traj.topology.atom(j[0]).name],
                          [str(traj.topology.atom(j[1]).residue.index) + 
                           traj.topology.atom(j[1]).residue.name,traj.topology.atom(j[1]).name],
                          [str(traj.topology.atom(j[2]).residue.index) + 
                           traj.topology.atom(j[2]).residue.name,traj.topology.atom(j[2]).name],
                          [str(traj.topology.atom(j[3]).residue.index) + 
                           traj.topology.atom(j[3]).residue.name,traj.topology.atom(j[3]).name]] for j in indices]
        
        feature_labels = np.hstack([['sin' + str(i) for i in feature_labels],['cos' + str(i) for i in feature_labels]])
        
    return feature_labels

          
def color_by_cistrans(traj,structure_file,angle):
          
    '''This function is particularly specialized and returns an array of color strings denoting the number of
       trans angles in each frame. It can be generalized by changing options to reflect the number of dihedrals
       and to change how nmer works.'''
    
    # Load each trajectory
    print(traj,structure_file)
    traj = md.load(traj,top=structure_file)
    # Determine dihedral indices
    indices = get_dihedral_indices(traj,structure_file,angle)[0]
    # Each color represents a number of trans angles (0-6)
    options = ['pink','red','orange','green','blue','purple','black']
    # Compute dihedrals
    dihedrals = md.compute_dihedrals(traj, np.asarray(indices))
    results = []
          
    for i in range(len(dihedrals)): # For each frame
        dihedral_string = []
        for j in range(len(dihedrals[i])): # For each angle
            # Determine whether each angle is cis or trans.
            if dihedrals[i][j] >= -1.57 and dihedrals[i][j] < 1.57:
                dihedral_string.append("c")
            elif dihedrals[i][j] < -1.57 or dihedrals[i][j] >= 1.57:
                dihedral_string.append("t")

        results.append(options[dihedral_string.count('t')])
                
    return results


    
def calculate_tica_components():
          
    '''Load in the features, calculate a given number of tICA components (tica_components) given a
       lagtime (lag_time), and save tICA coordinates and eigenvector data. It then creates and populates
       a list for each desired component, clusters the data, saving normalized populations as populations.dat
       and saving each cluster center as a .pdb. tICA plots are created and saved, and implied timescales are
       calculated, saved, and plotted.
    '''
          
    print("\nCalculating tICA components...")
          
    # Load in feature files
    feature_files = sorted(glob.glob(project_title + '/' + "*out*npy"))
    features = [ np.load(filename) for filename in feature_files]

    # Perform tICA calculation
    tica_coordinates = tICA(lag_time=tica_lagtime,
        n_components=int(tica_components)).fit_transform(features)
    eigen_data = tICA(lag_time=tica_lagtime,
        n_components=int(tica_components)).fit(features)
          
    np.save(project_title + '/' + 'lag_%d_comp_%d.npy' %(tica_lagtime, tica_components), tica_coordinates)
    np.save(project_title + '/' + 'lag_%d_eigen.npy' %tica_lagtime, eigen_data)
          
    # Extract tICA eigenvectors
    eigenvectors = np.transpose(eigen_data.eigenvectors_)
          
    # Initiate and populate an array for each component    
    for i in range(tica_components):
        exec('tica_' + str(i+1) + ' = []')
          
    for i in tqdm.tqdm(range(len(features))):
        for j in range(len(tica_coordinates[i])):
            for k in range(tica_components):
                exec('tica_' + str(k+1) + '.append(tica_coordinates[i][j][k])')
            
    # Perform clustering based on the cluster_method parameter.
    if cluster_method == 'kcenters':
        print("Clustering via KCenters...")
        clusters = KCenters(n_clusters)
    elif cluster_method == 'kmeans':
        print("Clustering via KMeans...")
        clusters = KMeans(n_clusters)
    else:
        sys.exit("Invalid cluster_method. Use kmeans or kcenters.")
     
    # Determine cluster assignment for each frame.      
    sequences = clusters.fit_transform(tica_coordinates)
    np.save(project_title + '/' + 'lag_%d_clusters_%d_sequences.npy' %(tica_lagtime, n_clusters), sequences)
    np.save(project_title + '/' + 'lag_%d_clusters_%d_center.npy' %(tica_lagtime, n_clusters),
        clusters.cluster_centers_)

#    if enspara_msm:
#        from enspara.msm import MSM, builders

        # build the MSM fitter with a lag time of 100 (frames) and
        # using the transpose method
#        msm = MSM(lag_time=100, method=builders.transpose)

        # fit the MSM to your assignments (a numpy ndarray or ragged array)
#        msm.fit(assignments)

#        print(msm.tcounts_)
#        print(msm.tprobs_)
#        print(msm.eq_probs_)
        
    # Determine cluster populations, normalize the counts, and save as percentages for
    # labeling if a cluster contains more than cluster_percentage_cutoff percent of the data.
    # Finally, save normalized counts.
    print("\nDetermining cluster populations...")
    
    if not os.path.exists(project_title + '/cluster_centers'):
        os.makedirs(project_title + '/cluster_centers')
    counts = np.array([len(np.where(np.concatenate(sequences)==i)[0]) for i in range(n_clusters)])
    normalized_counts =  counts/float(counts.sum())
    percentages = [ i*100 for i in normalized_counts ]
    population_labels = [ [i,"%.2f"%percentages[i]] for i in range(len(percentages)) if percentages[i] > cluster_percentage_cutoff ]
    np.savetxt(project_title + '/cluster_centers/populations.dat', normalized_counts)


    # Plot all unique combinations of tICA components
    print("\nPlotting tICA components with cluster centers...")
    all_ticas = list(itertools.permutations(range(1,tica_components+1), 2))
    for j in tqdm.tqdm(range(len(all_ticas))): # For each pair
        if all_ticas[j][0] < all_ticas[j][1]:
            plt.figure(j, figsize=(20,16))
            plt.hexbin(eval("tica_"+str(all_ticas[j][0])), eval("tica_"+str(all_ticas[j][1])), bins='log')
            x_centers = [clusters.cluster_centers_[i][all_ticas[j][0]-1] for i in range(len(clusters.cluster_centers_))]
            y_centers = [clusters.cluster_centers_[i][all_ticas[j][1]-1] for i in range(len(clusters.cluster_centers_))]
            high_pop_x_centers = [ x_centers[i] for i in range(len(x_centers)) if percentages[i] > cluster_percentage_cutoff ]
            high_pop_y_centers = [ y_centers[i] for i in range(len(y_centers)) if percentages[i] > cluster_percentage_cutoff ]
            plt.plot(x_centers, y_centers, color='y', linestyle="", marker="o")
            plt.plot(eval("tica_"+str(all_ticas[j][0])+'[0]'), eval("tica_"+str(all_ticas[j][1])+'[0]'), color='k', marker='*',markersize=24)
            plt.xlabel('tic'+str(all_ticas[j][0]))
            plt.ylabel('tic'+str(all_ticas[j][1]))
            plt.title(project_title)
            # Add labels for high-population cluster centers
            for label, x, y in zip(population_labels, high_pop_x_centers, high_pop_y_centers):
                plt.annotate(
                  label,
                  xy = (x, y), xytext = (-15, 15),
                  textcoords = 'offset points', ha = 'right', va = 'bottom',
                  bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
                  arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
            plt.savefig(project_title + '/' + 'tica_'+str(all_ticas[j][0])+'_'+str(all_ticas[j][1])+'.png')
            plt.close()


    # Calculate and save cluster entropy
    print("\nDetermining cluster entropy")
    cluster_entropy = (-1.0*normalized_counts*np.log(normalized_counts)).sum()
    print(np.shape(cluster_entropy))
    print(cluster_entropy)
#    np.savetxt(project_title + '/' + 'cluster_entropy.dat', cluster_entropy)

          
    # Write out PDBs for each cluster center
    print("Performing cluster analytics and saving center PDBs...\n")

    for i in range(len(features)):
        n_snapshots = len(clusters.distances_[i])
          
        # Determine frames that are cluster centers
        cluster_indices = np.arange(n_snapshots)[ (clusters.distances_[i] < 1e-6) ]
        print('cluster_indices',cluster_indices)
        
        # Determine number of each cluster, correlates to populations.dat
        cluster_labels = sequences[i][cluster_indices]
        print('cluster_labels',cluster_labels)

        print(cluster_labels,cluster_indices)
        # Print information on each cluster center
        if cluster_indices.size != 0 and verbose:
            for j in range(len(cluster_labels)): # for each cluster center found in this trajectory
                print('Cluster center', cluster_labels[j], 'was found in trajectory ' + str(features[i]) + '.')
                print('It is found on frame', cluster_indices[j], 'and has a relative population of %.3f'%normalized_counts[cluster_labels[0]]*100 + '%.')

        # Save each cluster center as a pdb
        for j in range(len(cluster_indices)): # actually saving the snapshots
            try:
                print("Attempting to save cluster center %d from frame %d of %s"%(cluster_labels[j],cluster_indices[j],trajectory_files[i]))
            except:
                print(cluster_labels,cluster_indices,j,trajectory_files,i)
            try: # Catches any other errors
                try: # Catches invalid structure file
                    cluster_traj = md.load_frame(trajectory_files[i], cluster_indices[j], top=structure_file)
                    cluster_traj.save_pdb(project_title + '/cluster_centers/state_%d_%.3f.pdb'%(cluster_labels[j],normalized_counts[j]))
                except Exception as e:
                    structure = trajectory_files[i].split('.')[0] + '.gro'
                    cluster_traj = md.load_frame(trajectory_files[i], cluster_indices[j], top=structure)
                    cluster_traj.save_pdb(project_title + '/cluster_centers/state_%d_%.3f.pdb'%(cluster_labels[j],normalized_counts[j]))
            except Exception as e:
                print(e)
                pass

   # Calculate and save Implied Timescales
    print("\nCalculating Implied Timescales...")
    timescales = implied_timescales(sequences, lagtimes, n_timescales=n_timescales,
        msm=MarkovStateModel(verbose=False))
          
    numpy_timescale_data = project_title + '/' + 'lag_%d_clusters_%d_timescales.npy' %(tica_lagtime, n_clusters)
    np.savetxt(project_title + '/' + 'lagtimes.txt', lagtimes)
    np.save(numpy_timescale_data, timescales)
   
    # Plot Implied Timescales per lagtime
    print("\nPlotting Implied Timescales...")
    for i in tqdm.tqdm(range(n_timescales)):
        plt.figure(42)
        plt.plot(lagtimes * time_step, timescales[:, i] * time_step, 'o-')
        plt.yscale('log')
        plt.xlabel('lagtime (1e-1 ns)') # Is this scale always true?
        plt.ylabel('Implied timescales (1e-1 ns)') # This one?
        plt.title(project_title + ' Implied timescales')
        plt.savefig(project_title + '/' + 'lag_%d_clusters_%d_.png' %(tica_lagtime, n_clusters))

    return eigenvectors

          
def trajectory_trace():
    
    '''This function will load in tICA coordinates for a multiple-trajectory tICA landscape and
       project individual trajectories upon it. It will also search for bound frames, as determined
       by the get_bound_frames function, plotting these frames on top of the trajectory trace.
       The projection is plotted upon the landscape with a stride variable denoted below.'''      
        
    print("\nEvaluating Trajectory Traces...")
    
    # Determine bound frames, if any
    bound_frames = get_bound_frames()[0]
          
    # Load in tICA coordinates
    tica_coordinates = np.load(project_title + '/' + 'lag_%d_comp_%d.npy' %(tica_lagtime, tica_components))
          
    # Populate an array for each tICA component
    for i in range(tica_components):
        exec('tica_' + str(i+1) + ' = []')
    for i in range(len(tica_coordinates)):
        for j in range(len(tica_coordinates[i])):
            for k in range(tica_components):
                exec('tica_' + str(k+1) + '.append(tica_coordinates[i][j][k])')
    
    # Determine stride of projection, and a different color for each trajectory
    stride = int(len(tica_coordinates[0])/1000) + 1
    colors = cm.rainbow(np.linspace(0,1,len(tica_coordinates)))
    # Font-size changed for paper
    matplotlib.rcParams.update({'font.size': 32})
        
    # Create a plot for each desired tICA component combination
    for j in tqdm.tqdm_notebook(range(len(all_ticas))):
        if all_ticas[j][0] < all_ticas[j][1]:
            if not os.path.exists(project_title + "/%d_%d" %(all_ticas[j][0],all_ticas[j][1])):
                os.makedirs(project_title + "/%d_%d" %(all_ticas[j][0],all_ticas[j][1]))
                
            # Create a plot for each trajectory
            for k in range(len(trajectory_files)):
                legend_elements,traj_tic1,traj_tic2 = [],[],[]
                plt.figure(j, figsize=(20,15))
                plt.hexbin(eval("tica_"+str(all_ticas[j][0])), eval("tica_"+str(all_ticas[j][1])), bins='log')
                
                # Populate arrays with tICA coordinates for a particular trajectory
                for l in range(len(tica_coordinates[k])):
                    traj_tic1.append(tica_coordinates[k][l][all_ticas[j][0]-1])
                    traj_tic2.append(tica_coordinates[k][l][all_ticas[j][1]-1])
                
                if len(tica_coordinates[k]) == 1: # If this trajectory represents a single frame, plot it as a star
                    plt.plot(traj_tic1, traj_tic2, color='w', marker='*')
                else: # Otherwise, project it onto the landscape and plot its starting point as a star
                    plt.plot(traj_tic1, traj_tic2, color=colors[k], linestyle="", marker="o")
                    plt.plot(traj_tic1[0], traj_tic2[0], color='k', marker='*')
                print(np.shape(bound_frames), np.shape(bound_frames[k]), len(traj_tic1))
                if bound_frames[k]: # If bound frames exist
                    for m in range(len(bound_frames[k])):
                        try:
                            bound_x_centers = [ traj_tic1[x] for x in bound_frames[k][m] ]
                            bound_y_centers = [ traj_tic2[x] for x in bound_frames[k][m] ]
                        except Exception as e:
                            print(e,k,m,trajectory_files[k])
                        # Project them on top of the trajectory projection in white
                        plt.plot(bound_x_centers[::stride], bound_y_centers[::stride], color='w')
                # Create a legend denoting which trajectory file is plotted in which color
                legend_elements.append(Line2D([0], [0], marker='o', color=colors[k], label=trajectory_files[k],
                  markerfacecolor=colors[k], markersize=8))
                plt.xlabel("tica_"+str(all_ticas[j][0]))
                plt.ylabel("tica_"+str(all_ticas[j][1]))
                plt.legend(handles=legend_elements, loc='best')
                plt.savefig(project_title + "/%d_%d/"%(all_ticas[j][0],all_ticas[j][1]) + "traj_%d_trace.png" %k)
                plt.close()
                


def eigenvalue_analysis():
    
    '''Plots eigenvectors / corresponding features for each tICA component and sorts them and...
   
    '''
    print('Performing Eigenvector Analysis')
    
    tica_coordinates = np.load(project_title + '/' + 'lag_%d_comp_%d.npy' %(tica_lagtime, tica_components))
    
    for i in range(tica_components):
        exec('tica_' + str(i+1) + ' = []')
    for i in range(len(tica_coordinates)): # this takes a while
        for j in range(len(tica_coordinates[i])):
            for k in range(tica_components):
                exec('tica_' + str(k+1) + '.append(tica_coordinates[i][j][k])')
                
    try:
        sorted_eigenvectors = [sorted(x) for x in eigenvectors]
    except NameError:
        eigenvectors = calculate_tica_components()
        sorted_eigenvectors = [sorted(x) for x in eigenvectors]
        
    sorted_eigenindices = [np.argsort(x) for x in eigenvectors]
    
    plt.figure(0)
    for i in range(tica_components):
        plt.subplot(int("42" + str(i+1)))
        plt.bar(range(len(eigenvectors[i])),eigenvectors[i])
    plt.title('Unsorted Eigenvectors')
    plt.tight_layout()
    plt.savefig(project_title + '/' + "unsorted_eigenvectors")
    plt.close()

    plt.figure(1)
    for i in range(tica_components):
        plt.subplot(int("42" + str(i+1)))
        plt.bar(range(len(sorted_eigenvectors[i])),sorted_eigenvectors[i])
    plt.title('Sorted Eigenvectors')
    plt.tight_layout()
    plt.savefig(project_title + '/' + "sorted_eigenvectors")
    plt.close('all')
    
        
    for i in range(len(feature_labels)):
        exec('y_' + str(i+1) + ' = []')
        exec('labels_' + str(i+1) + ' = []')

    for i in range(tica_components):
        for j in range(len(feature_labels)): #list(range(5)) + list(range(-1,-6,-1)):
            exec('y_' + str(i+1) + '.append(sorted_eigenvectors[i][j])')
            exec('labels_' + str(i+1) + '.append(feature_labels[sorted_eigenindices[i][j]])')
        
    for j in range(len(all_ticas)):
        if all_ticas[j][0] < all_ticas[j][1]:
            plt.figure(j+777, figsize=(15,30)) # plotting tica_1, tica_2
            plt.subplot(211)
            plt.hexbin(eval("tica_"+str(all_ticas[j][0])), eval("tica_"+str(all_ticas[j][1])), bins='log') #, cmap=cmaps.viridis
            plt.subplot(413)
            plt.bar(range(len(eval("y_" + str(all_ticas[j][0])))), eval("y_" + str(all_ticas[j][0])))
            plt.xticks(range(len(eval("labels_" + str(all_ticas[j][0])))),eval("labels_" + str(all_ticas[j][0])),rotation=90)
            plt.subplot(414)
            plt.bar(range(len(eval("y_" + str(all_ticas[j][1])))), eval("y_" + str(all_ticas[j][1])))
            plt.xticks(range(len(eval("labels_" + str(all_ticas[j][0])))),eval("labels_" + str(all_ticas[j][1])),rotation=90)
            plt.title('Eigenvalue Analysis for tIC' + str(all_ticas[j][0]) + "/tIC" + str(all_ticas[j][1]))
            plt.tight_layout()
            print()
            plt.savefig(project_title + '/' + "eigen_analysis_" + str(all_ticas[j][0]) + "_" + str(all_ticas[j][1]) + ".png")
            plt.close()
    


def plot_clusters(color_by):
    
    '''This function projects additional information on top of the tICA landscape. If color_by is set to
       cluster, frames are colored by cluster number assignment with stride defined below.
       If color_by is set to angle, frames are colored by the number of trans dihedral angles found, as
       defined by color_by_cistrans.
    '''
          
    print("\nPlotting clusters...")
    
    # Determine bound frames, if any
    bound_frames = get_bound_frames()[0]
    
    # Load tICA coordinates and cluster assignments
    tica_coordinates = np.load(project_title + '/' + 'lag_%d_comp_%d.npy' %(tica_lagtime, tica_components))
    sequences = np.load(project_title + '/' + 'lag_%d_clusters_%d_sequences.npy' %(tica_lagtime, n_clusters))
    
    # Initiate and populate arrays for each component
    for i in range(tica_components):
        exec('tica_' + str(i+1) + ' = []')
    for i in range(len(tica_coordinates)):
        for j in range(len(tica_coordinates[i])):
            for k in range(tica_components):
                exec('tica_' + str(k+1) + '.append(tica_coordinates[i][j][k])')
    
    # Denote stride and a color array for the number of clusters
    stride = 128 #int(len(tica_coordinates[0])/100) + 1
    cluster_colors = cm.rainbow(np.linspace(0,1,n_clusters))
    
    # Create and save plots for each tICA component combination, for each trajectory
    for j in tqdm.tqdm_notebook(range(len(all_ticas))):
        if all_ticas[j][0] < all_ticas[j][1]:
            if not os.path.exists(project_title + "/%d_%d" %(all_ticas[j][0],all_ticas[j][1])):
                os.makedirs(project_title + "/%d_%d" %(all_ticas[j][0],all_ticas[j][1]))
            angle_colors = []
            for k in range(len(trajectory_files)):
                legend_elements = []
                try: # Retrieve color associated with number of trans dihedral angles
                    structure_file = re.sub('.trr','.gro',re.sub('.xtc','.gro',trajectory_files[k]))
                    angle_colors += color_by_cistrans(trajectory_files[k],structure_file=structure_file,angle='omega')
                except Exception as e:
                    angle_colors += color_by_cistrans(trajectory_files[k],structure_file='xtc.gro',angle='omega')

                traj_tic1,traj_tic2 = [],[]
                plt.figure(j, figsize=(8,5))
                # Plot the underlying tICA map
                plt.hexbin(eval("tica_"+str(all_ticas[j][0])), eval("tica_"+str(all_ticas[j][1])), bins='log')
                for l in range(len(tica_coordinates[k])):
                    traj_tic1.append(tica_coordinates[k][l][all_ticas[j][0]-1])
                    traj_tic2.append(tica_coordinates[k][l][all_ticas[j][1]-1])

                # Plot the strided projection
                for l in tqdm.tqdm(range(len(traj_tic1))[::stride]):
                    if color_by == 'cluster':
                        plt.plot(traj_tic1[l], traj_tic2[l], color=cluster_colors[sequences[k][l]], linestyle="", marker="o")
                    if color_by == 'angle':
                        try:
                            plt.plot(traj_tic1[l], traj_tic2[l], color=angle_colors[l], linestyle="", marker="o")
                        except Exception as e:
                            print(l,len(angle_colors), len(traj_tic1))
                if color_by == 'angle':
                    colors = ['pink','red','orange','green','blue','purple','black']
                    for l in range(len(colors)): # Create a legend containing colors corresponding to number of angles
                        legend_elements.append(Line2D([0], [0], marker='o', color=colors[l], label="%d trans"%l,
                            markerfacecolor=colors[l], markersize=8))
                if bound_frames[k]: # If bound frames exist
                    for m in range(len(bound_frames[k])):
                        try:
                            bound_x_centers = [ traj_tic1[x] for x in bound_frames[k][m] ]
                            bound_y_centers = [ traj_tic2[x] for x in bound_frames[k][m] ]
                        except Exception as e:
                            print(e,k,m,trajectory_files[k])
                        # Project them on top of the trajectory projection in white
                        plt.plot(bound_x_centers[::stride], bound_y_centers[::stride], color='w')
                # Plot the trajectory's starting point
                plt.plot(traj_tic1[0],traj_tic2[0],color='w',marker='*',markersize=24)
            plt.xlabel("tica_"+str(all_ticas[j][0]))
            plt.ylabel("tica_"+str(all_ticas[j][1]))
#            plt.title(trajectory_files[k])
            #plt.legend(handles=legend_elements, loc='best')
            plt.savefig(project_title + "/%d_%d/"%(all_ticas[j][0],all_ticas[j][1]) + "all_clusters_%d.png"%k)
            plt.close()
            

def show_transitions():
    
    # Load tICA coordinates and cluster assignments
    n_transitions = 24
    tica_coordinates = np.load(project_title + '/' + 'lag_%d_comp_%d.npy' %(tica_lagtime, tica_components))
    sequences = np.load(project_title + '/' + 'lag_%d_clusters_%d_sequences.npy' %(tica_lagtime, n_clusters))
    centers = np.load(project_title + '/' + 'lag_%d_clusters_%d_center.npy' %(tica_lagtime, n_clusters))
    
    # Initiate and populate arrays for each component
    for i in range(tica_components):
        exec('tica_' + str(i+1) + ' = []')
    for i in range(len(tica_coordinates)):
        for j in range(len(tica_coordinates[i])):
            for k in range(tica_components):
                exec('tica_' + str(k+1) + '.append(-tica_coordinates[i][j][k])')
    
    def largest_indices(arr, n):
        """Returns the n largest indices from a numpy array."""
        flat = arr.flatten()
        indices = np.argpartition(flat, -n)[-n:]
        indices = indices[np.argsort(-flat[indices])]
        return np.unravel_index(indices, arr.shape)
    
    transitions = np.zeros(shape=(n_clusters,n_clusters))
    for i in range(len(sequences)):
        for j in range(len(sequences[i])):
            try:
                transitions[sequences[i][j]][sequences[i][j+1]] += 1
            except IndexError as e:
                pass

    top_transition_indices = largest_indices(transitions,n_transitions)[::-1]
    top_transition_counts = transitions[largest_indices(transitions,n_transitions)]
    normalized_transition_counts = [x/sum(top_transition_counts) for x in top_transition_counts][::-1]
    
    counts = np.array([float(len(np.where(np.concatenate(sequences)==i)[0])) for i in range(n_clusters)])
    normalized_counts =  counts/float(counts.sum())
    percentages = [x*100 for x in normalized_counts]
    cluster_colors = cm.rainbow(np.linspace(0,1,n_clusters))
    
    for j in tqdm.tqdm(range(len(all_ticas))): # For each pair
        if all_ticas[j][0] < all_ticas[j][1]:
            plt.figure(j, figsize=(20,16))
            plt.hexbin(eval("tica_"+str(all_ticas[j][0])), eval("tica_"+str(all_ticas[j][1])), bins='log')
            
            # plot centers, colored by population
            x_centers = [-centers[i][all_ticas[j][0]-1] for i in range(len(centers))]
            y_centers = [-centers[i][all_ticas[j][1]-1] for i in range(len(centers))]
            
            for k in range(len(x_centers)):
                plt.plot(x_centers[k], y_centers[k], linestyle='', marker='*', markersize=256*normalized_counts[k],
                         color=cluster_colors[int(normalized_counts[k]*100.)])
                print("Plotting cluster center %d with size %f"%(k,256.*normalized_counts[k]))
                
            print("\nPlotting Top Cluster Transitions")
            for k in range(len(top_transition_counts)): # plot top non-self transitions
                if top_transition_indices[0][k] != top_transition_indices[1][k]:
                    try:
                        dx = x_centers[top_transition_indices[1][k]] - x_centers[top_transition_indices[0][k]]
                        dy = y_centers[top_transition_indices[1][k]] - y_centers[top_transition_indices[0][k]]
                        plt.arrow(x_centers[top_transition_indices[0][k]],dx,
                                  y_centers[top_transition_indices[0][k]],dy,
                                  color=cluster_colors[k], width=normalized_transition_counts[k],
                                  head_width=3.*normalized_transition_counts[k], length_includes_head=True)
#                        print("\tPlotting transition Between Clusters %d and %d with head %f and width %f"
#                              %(top_transition_indices[0][k],top_transition_indices[1][k],
#                                3.*normalized_transition_counts[k],normalized_transition_counts[k]))
                    except IndexError as e:
                        pass
 #               else:
 #                   print("\tSkipping self-transition of Cluster %d"%top_transition_indices[0][k])
            
            plt.plot(eval("tica_"+str(all_ticas[j][0])+'[0]'), eval("tica_"+str(all_ticas[j][1])+'[0]'),
             color='w', marker='o',markersize=16)
            plt.xlabel('tic'+str(all_ticas[j][0]))
            plt.ylabel('tic'+str(all_ticas[j][1]))
            plt.title(project_title)
            plt.savefig(project_title + '/' + 'tica_'+str(all_ticas[j][0])+'_'+str(all_ticas[j][1])+'_transitions.png')
            plt.close()
    
            



In [2]:
# Creates a directory by the name project_title, to populate with results
project_title = 'Na_ACN'
# nmer and cyclic used for single molecule trajectories, to calculate dihedrals
# and will likely be deprecated soonish.
nmer = 6
cyclic = True
# cut-off for binding distance in nm
bound_cutoff = 0.30
# general structure file and trajectory files denoted below
structure_file = 'xtc.gro'
trajectory_files = sorted(glob.glob('*xtc')) + sorted(glob.glob('*trr')) #sorted(glob.glob("1_a_n*.trr")) #['traj_00.xtc','traj_01.xtc','traj_02.xtc'] + sorted(glob.glob("0*.trr")) + sorted(glob.glob("H030*trr"))
# tICA parameters
tica_lagtime = 100
tica_components = 8
n_clusters = 100
n_timescales = 8
time_step = 1.0 # ns multiplier of timescales and lagtimes in implied timescale plot
lagtimes = np.array([1,2,4,8,16,32,64,128,256,512,1024])
# which features to use
tica_metric = 'distance' # 'distance / dihedral'
if tica_metric in ['dihedral', 'mixed']:
    angle = 'omega' # phi/psi/omega
# cluster method
cluster_method = 'kcenters' # 'kcenters/kmeans'

# select which functions to run (boolean toggle)
calculate_features = True # calculate and save features
calculate_components = True # calculate and save tICA components / eigenvectors/values
enspara_msm = 'False'
multiple_trajectory_analysis = False # create a plot for each trajectory in trajectory_files, 
                                     # showing its projection over the combined tICA space
movie_analysis = False # creates a movie based on the trajectory data and metric specified
eigen_analysis = False # outputs plots of eigenvectors and corresponding features for each component in all_ticas
cluster_plot = True # outputs a plot with colors showing number of trans angles or cluster identity
color_by = 'angle' # angle/cluster
verbose = False # let's you choose to have a bunch of junk printed to stdout (debug)

all_ticas = list(itertools.permutations(range(1,tica_components+1), 2)) # all combinations
all_ticas = [[1,2]]  #[1,3],[2,3],[3,4],[5,6]] # just show analysis for first two components
# analyzing more than one set may generate thousands more images (movie analysis)

cluster_percentage_cutoff = n_clusters/64 # clusters with a relative population less than this
                              # number will not be labeled i.e. 0 : all clusters labeled
    
if socket.gethostname() == 'syzygy':
    debug_prefix = '/media/matt/ext'
else:
    debug_prefix = ''


In [3]:
# This cell runs the analysis given the parameters entered above

#deprecated
#%install_ext https://raw.github.com/cpcloud/ipython-autotime/master/autotime.py
#%load_ext autotime

if not os.path.exists(project_title):
    os.makedirs(project_title)

if calculate_features:
    if tica_metric == 'distance':
        feature_labels = calculate_distances()
    elif tica_metric == 'dihedral':
        feature_labels = calculate_dihedrals()
    else:
        print("Specify either distance or dihedral as the tica_metric")
        sys.exit()
if calculate_components:
    eigenvectors = calculate_tica_components()
if multiple_trajectory_analysis:
    trajectory_trace()
if movie_analysis:
    if tica_metric == 'distance':
        thetas = distance_tica_movie()
    if tica_metric == 'dihedral':
        dihedral_tica_movie()
if eigen_analysis:
    eigenvalue_analysis()
if cluster_plot:
    plot_clusters(color_by=color_by)


Calculating distances...


UnboundLocalError: local variable 'feature_labels' referenced before assignment

# choices = [[['c', 'c', 'c', 'c', 'c', 'c']],
     [['c', 'c', 'c', 'c', 'c', 't'],['c', 'c', 'c', 't', 'c', 'c'],['c', 't', 'c', 'c', 'c', 'c']],
     [['c', 'c', 'c', 'c', 't', 'c'],['c', 'c', 't', 'c', 'c', 'c'],['t', 'c', 'c', 'c', 'c', 'c']],
     [['c', 'c', 'c', 'c', 't', 't'],['c', 'c', 't', 't', 'c', 'c'],['t', 't', 'c', 'c', 'c', 'c']],
     [['c', 'c', 'c', 't', 't', 'c'],['c', 't', 't', 'c', 'c', 'c'],['t', 'c', 'c', 'c', 'c', 't']],
     [['c', 'c', 'c', 't', 'c', 't'],['c', 't', 'c', 't', 'c', 'c'],['c', 't', 'c', 'c', 'c', 't']],
     [['c', 'c', 't', 'c', 't', 'c'],['t', 'c', 'c', 'c', 't', 'c'],['t', 'c', 't', 'c', 'c', 'c']],
     [['c', 'c', 't', 'c', 'c', 't'],['t', 'c', 'c', 't', 'c', 'c']],
     [['c', 't', 'c', 'c', 't', 'c']],
     [['c', 'c', 'c', 't', 't', 't'],['c', 't', 't', 't', 'c', 'c'],['t', 't', 'c', 'c', 'c', 't']],
     [['c', 'c', 't', 't', 't', 'c'],['t', 'c', 'c', 'c', 't', 't'],['t', 't', 't', 'c', 'c', 'c']],
     [['c', 'c', 't', 'c', 't', 't'],['t', 'c', 't', 't', 'c', 'c'],['t', 't', 'c', 'c', 't', 'c']],
     [['c', 't', 'c', 't', 't', 'c'],['c', 't', 't', 'c', 'c', 't'],['t', 'c', 'c', 't', 'c', 't']],
     [['c', 'c', 't', 't', 'c', 't'],['c', 't', 'c', 'c', 't', 't'],['t', 't', 'c', 't', 'c', 'c']],
     [['c', 't', 't', 'c', 't', 'c'],['t', 'c', 'c', 't', 't', 'c'],['t', 'c', 't', 'c', 'c', 't']],
     [['c', 't', 'c', 't', 'c', 't']],
     [['t', 'c', 't', 'c', 't', 'c']],
     [['c', 'c', 't', 't', 't', 't'],['t', 't', 'c', 'c', 't', 't'],['t', 't', 't', 't', 'c', 'c']],
     [['c', 't', 't', 't', 't', 'c'],['t', 'c', 'c', 't', 't', 't'],['t', 't', 't', 'c', 'c', 't']],
     [['c', 't', 'c', 't', 't', 't'],['c', 't', 't', 't', 'c', 't'],['t', 't', 'c', 't', 'c', 't']],
     [['t', 'c', 't', 'c', 't', 't'],['t', 'c', 't', 't', 't', 'c'],['t', 't', 't', 'c', 't', 'c']],
     [['c', 't', 't', 'c', 't', 't']],
     [['t', 'c', 't', 't', 'c', 't'],['t', 't', 'c', 't', 't', 'c']],
     [['c', 't', 't', 't', 't', 't'],['t', 't', 'c', 't', 't', 't'],['t', 't', 't', 't', 'c', 't']],
     [['t', 'c', 't', 't', 't', 't'],['t', 't', 't', 'c', 't', 't'],['t', 't', 't', 't', 't', 'c']],
     [['t', 't', 't', 't', 't', 't']]]

In [None]:
show_transitions()

In [None]:
# Let's figure this vamp shit out
import pyemma
print(sorted(glob.glob('0/*npy')))
data = [np.load(i) for i in sorted(glob.glob('0/*out*npy'))]
cluster = pyemma.coordinates.cluster_kmeans(data, k=100, max_iter=512)

In [None]:
#its = pyemma.msm.its(cluster.dtrajs, lags=[1, 2, 4,8,16,32,64,128,256,512,1024,2048,4196,8392], nits=3, errors='bayes')
#pyemma.plots.plot_implied_timescales(its, ylog=False)
clusters = [64, 128, 256, 512]
fig, axes = plt.subplots(2, len(clusters), figsize=(12, 6))
for i, k in enumerate(clusters):
    cluster = pyemma.coordinates.cluster_kmeans(data, k=k, max_iter=128, stride=10)
    pyemma.plots.plot_implied_timescales(
        pyemma.msm.its(cluster.dtrajs, lags=[1,2,4,8,16,32,64,128], nits=4, errors='bayes'),
        ax=axes[1, i], units='ns')
    axes[1, i].set_ylim(1, 150)
fig.tight_layout()

In [None]:
VAMP_score = pyemma.msm.estimators.score(data)

In [5]:
from enspara.msm import MSM, builders
from enspara.msm import MSM, builders

# build the MSM fitter with a lag time of 100 (frames) and
# using the transpose method
msm = MSM(lag_time=100, method=builders.transpose)

# fit the MSM to your assignments (a numpy ndarray or ragged array)
msm.fit(assignments)

print(msm.tcounts_)
print(msm.tprobs_)
print(msm.eq_probs_)

In [1]:
import nglview as ng
import mdtraj as md
import ipywidgets
v = ng.demo()

In [2]:
v

In [4]:
v._ngl_version

''