# Overall Strategy
1. First filter the molecules by their proximal distance to the protein.
2. For the molecules with contacts over the bound, filter for the protein residues identities that they are in contact with.
3. Sort through this second list of lists for those that also remain over the bound.
4. Renumber of resid of this module according to the chain and resid of the protein

In [None]:
import MDAnalysis as mda
from MDAnalysis.analysis import distances as mdaDist
import numpy as np
from collections import Counter
import os

In [None]:
# = = = Time-keeping functions. TQDM is a simple progress bar that works in both Jupyter Notebook and terminals.
import time
from tqdm.auto import tqdm

In [None]:
bPythonScriptExport=False

In [None]:
# = = = Plotting results
if bPythonScriptExport:
    import matplotlib
    matplotlib.use('pdf')
else:
    %matplotlib inline
    #%matplotlib notebook
from matplotlib import pyplot as plt
import networkx as nx
from networkx.drawing.nx_pylab import draw_networkx

In [None]:
if bPythonScriptExport:
    import argparse
    parser = argparse.ArgumentParser(description='Identify stable protein and water/chloride contact pairs in a trajectory',
                                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--psf', type=str, dest='psfFile', default='seg.psf',
                        help='The name of the protein structure file.')
    parser.add_argument('--trj', type=str, dest='trjFile', default='sum.xtc',
                        help='The name of the trajectory file.')
    parser.add_argument('--out', type=str, dest='outputPrefix', default='contactPairs',
                        help='A prefix for all output files generated.')
    parser.add_argument('--sel', type=str, dest='selectionText', default='name OW and resname SOL',
                        help='An MDAnalysis selection text that chooses one atom of each solvent residue to compute contacts from.'
                       'Examples: "name CL" and "name OW and resname SOL". Solvent residue indices are often non-unique.')    
    parser.add_argument('-r', type=float, dest='ratioCutoff', default=0.75,
                        help='Minimum fraction of simulation frames pairs must be in contact.')
    parser.add_argument('-k', type=int, dest='numKeepCutoff', default=None,
                        help='(WIP) Alternative to the above, keep N waters with the most persistent contacts, regardless of cutoff.')
    parser.add_argument('-d', type=float, dest='distCutoff', default=5.0,
                        help='Maximum separation for pairs to be considered in contact.')
    parser.add_argument('--in_mem', action='store_true', dest='bInMemory',
                        help='Asks MDAnlaysis to load the entire trajectory into memory.')
    args = parser.parse_args()
    
    ratioCutoff  = args.ratioCutoff
    distCutoff   = args.distCutoff
    outputPrefix = args.outputPrefix
    psfFile      = args.psfFile
    trjFile      = args.trjFile
    bInMemory    = args.bInMemory
    solvSelText  = args.selectionText

In [None]:
if not bPythonScriptExport:
    ratioCutoff=0.25 ; distCutoff = 3.5
    filePathPrefix='./trajectories'
    allele='wt' ; temperature='310K' ; replicate=1
    sourceFolder='%s/%s/%s/%i' % (filePathPrefix, allele, temperature, replicate)
    psfFile='%s/seg.psf' % sourceFolder
    trjFile='%s/sum.xtc' % sourceFolder
    bInMemory    = True
    solvSelText  = "name OW and resname SOL"
    #outputPrefix=sourceFolder+'contactPairs_SOL'
    outputPrefix=os.path.join(sourceFolder,'contactPairs_SOL_%g_%g' % ( ratioCutoff, distCutoff ) )
    print(outputPrefix)
    #solvSelText  = "name CL"
    #outputPrefix=sourceFolder+'contactPairs_CL'    

In [None]:
def filter_unique_residues( atomSel ):
    #return np.unique(atomSel.resids)
    listExtendednameString = []
    for x,y,z in zip(atomSel.segids, atomSel.resnames, atomSel.resids):
        listExtendednameString.append( x+":"+y+":"+str(z) )
    return np.unique( listExtendednameString )

In [None]:
def count_indices_of_selection_over_trajectory(u, selectionText):
    l=[]
    atomSel = u.select_atoms(selectionText, updating=True)
    print("= = (Time accounting) Looking for contacts using selection text: %s" % selectionText )
    print("    ...Starting at: %s" % time.ctime() )
    for f in tqdm(range(u.trajectory.n_frames)):
        u.trajectory[f]
        l.extend(atomSel.indices)
    return Counter(l)

In [None]:
def obtain_protein_contacts_of_highest_counts(u, counter, ratioCutoff, distCutoff):
    # = = = Setup relevant dictionaries
    minCounts = ratioCutoff*u.trajectory.n_frames
    dictSelContacts={} ; dictProteinNearby={}
    for x in counter.most_common():
        if x[1]>=minCounts:
            a = u.select_atoms("protein and around %f index %i" % (distCutoff, x[0]), updating=True)
            dictSelContacts[x[0]] = a
            dictProteinNearby[x[0]] = []
        else:
            break

    print("= = (Time accounting) Looking for protein contacts of indices with highest occupancy")
    print("    ...Starting at: %s" % time.ctime() )
    for f in tqdm(range(u.trajectory.n_frames)):
        u.trajectory[f]
        for i in dictSelContacts.keys():
            listUniqueResidues = filter_unique_residues( dictSelContacts[i] )
            dictProteinNearby[i].extend( listUniqueResidues )
        
    # = = = Convert list of protein resid occurrences to python Counter
    for i in dictProteinNearby.keys():
        dictProteinNearby[i] = Counter( dictProteinNearby[i] )            
    
    return dictProteinNearby

In [None]:
def convert_contacts_to_edgelist(u, dictCounters, ratioCutoff=0.0):
    edgeList=[]
    minCounts = ratioCutoff*u.trajectory.n_frames
    for i in dictCounters.keys():
        for p in dictCounters[i].most_common():
            if p[1]>=minCounts:
                nameA = "Index_%i" % i
                edgeList.append( (nameA, p[0], p[1]/u.trajectory.n_frames) )
            else:
                break
    return edgeList

In [None]:
def print_edgelist(fileName, edgeList):
    fp = open(fileName,'w')
    for a,b,c in edgeList:
        print(a,b,c, file=fp)
    fp.close()
    
def print_graph(fileName, edgeList):
    delta=0.1 ; nCols=3
    G = nx.Graph(name='Contact pairs')    
    for x in edgeList:
        G.add_edge(x[0],x[1],weight=x[2])
    subGraphs = [G.subgraph(c).copy() for c in nx.connected_components(G)]
    nSub=len(subGraphs) ; nRows=int(np.ceil(nSub/nCols))
    if nSub < nCols:
        fig = plt.figure(figsize=(8, 6))
        wEdges = [G.edges[x]['weight'] for x in G.edges() ] 
        draw_networkx(G, font_size=9, node_color='white',
                      edge_color=wEdges, style='dashed')
        axThis = fig.get_axes()[0]
        xLim = axThis.get_xlim() ; axThis.set_xlim( xLim[0]-delta, xLim[1]+delta )
        yLim = axThis.get_ylim() ; axThis.set_ylim( yLim[0]-delta, yLim[1]+delta )
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(fileName)
        return   
    
    fig, axes = plt.subplots(nRows,nCols, figsize=(4*nCols, 4*nRows) )
    for i in range(nSub):
        ii = int(i/nCols) ; jj = i % nCols
        axThis = axes[ii,jj]
        wEdges = [subGraphs[i].edges[x]['weight'] for x in subGraphs[i].edges() ] 
        draw_networkx(subGraphs[i], ax=axThis, font_size=9, node_color='white',
                      edge_color=wEdges, style='dashed')
        xLim = axThis.get_xlim() ; axThis.set_xlim( xLim[0]-delta*2, xLim[1]+delta*2 )
        yLim = axThis.get_ylim() ; axThis.set_ylim( yLim[0]-delta, yLim[1]+delta )
        axThis.set_axis_off()
    plt.tight_layout()
    plt.savefig(fileName)

In [None]:
u = mda.Universe(psfFile, trjFile, in_memory=bInMemory)

In [None]:
u.trajectory.coordinate_array.shape

In [None]:
counterNearbyWaters = count_indices_of_selection_over_trajectory(u,
    "%s and around %g protein" % (solvSelText, distCutoff) )

In [None]:
vals = np.array([ x[1] for x in counterNearbyWaters.most_common() ])
plt.figure(figsize=(8, 4))
plt.hist(vals, bins=int(u.trajectory.n_frames/20))
plt.axvline(x=u.trajectory.n_frames*ratioCutoff, color='grey', linestyle='--', label='Persistence cutoff ratio.')
plt.yscale('log')
plt.title('Histogram count of all solvents ever found nearby the protein')
plt.xlabel('Number of frames spent within %g Angs. of protein' % distCutoff); plt.ylabel('Count')
plt.savefig(outputPrefix+'_histogram.pdf')

In [None]:
dictProteinsNearbyWater = obtain_protein_contacts_of_highest_counts(u, counterNearbyWaters, ratioCutoff, distCutoff)

In [None]:
edgeListWater = convert_contacts_to_edgelist(u, dictProteinsNearbyWater, ratioCutoff)

In [None]:
print_edgelist(outputPrefix+'.txt', edgeListWater)

In [None]:
print_graph(outputPrefix+'_graph.pdf', edgeListWater)