In [1]:
import os
import numpy as np
import MDAnalysis as mda
from MDAnalysis.analysis.base import AnalysisFromFunction
from MDAnalysis.coordinates.memory import MemoryReader
from MDAnalysis.analysis.distances import distance_array
from collections import Counter

In [2]:
import time
from tqdm.auto import tqdm

In [3]:
bPythonScriptExport=False

In [4]:
if bPythonScriptExport:
    import argparse
    parser = argparse.ArgumentParser(description='Analyse a system trajectory and match solvent molecules to cluster sites, '
                                                 'using sets of residues within X Angs as the site definitons, and '
                                                 'occupancy rates to rank water additions.',
                                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--top', type=str, dest='topologyFile', default='seg.psf',
                        help='The name of the protein structure file.')
    parser.add_argument('--trj', type=str, dest='trajectoryFile', default='sum.xtc',
                        help='The name of the trajectory file.')
    parser.add_argument('--clust', type=str, dest='clustDefFile', default='cluster_definitions.txt',
                        help='The name of the cluster definitions file. Can be fed from calculate_contact_persistence.py. '
                             'Although only resids are currently being read in, the expected file format is:\n'
                             '0 Intersection: <segid:resid1> <segid:resid2> ...\n'
                             '0 Union: <segid:resid1> <segid:resid2> ...\n'                       
                             '1 Intersection: <segid:resid1> <segid:resid2> ...\n'
                             '1 Union: <segid:resid1> <segid:resid2> ...\n'
                             '...\n')
    parser.add_argument('-n', '--n_clust', type=int, dest='numCluster', default=None,
                        help='Consider only the top N clusters read in the file. If not given, compute for all clusters read.')    
    parser.add_argument('-d', type=float, dest='distCutoff', default=3.5,
                        help='Maximum separation for pairs to be considered in contact. Needs to match the value defined for clustering.')
    parser.add_argument('--out', type=str, dest='outputPrefix', default='clustered',
                        help='A prefix for all output files generated.')
    parser.add_argument('--sel_site', type=str, dest='selectionTextSite', default='name OW and resname SOL',
                        help='An MDAnalysis selection text that refers to the atoms by which clustering was conducted. '
                       'Examples: "name CL" and "name OW and resname SOL".')    
    parser.add_argument('--sel_host', type=str, dest='selectionTextHost', default='protein',
                        help='An MDAnalysis selection text that chooses the host solute by which clustering was conducted.')    
    parser.add_argument('--sel_solvent', type=str, dest='selectionTextSolvent', default='resname SOL',
                        help='(Output) An MDAnalysis selection text that chooses the solvents to create a universe for.')
    parser.add_argument('--sel_system', type=str, dest='selectionTextSystem', default=None,
                        help='(Output) An MDAnalysis selection text that chooses atoms from the host system to include iin the output.'
                        'If not given, then only the cluster solvent atoms are written.')
    parser.add_argument('--in_mem', action='store_true', dest='bInMemory',
                        help='Asks MDAnlaysis to load the entire trajectory into memory.')
    args = parser.parse_args()
    
    #ratioCutoff  = args.ratioCutoff
    topFile      = args.topologyFile
    trjFile      = args.trajectoryFile
    clustDefFile = args.clustDefFile
    nClusts      = args.numCluster
    distCutoff   = args.distCutoff
    
    outputPrefix = args.outputPrefix

    seltxtSite   = args.selectionTextSite
    seltxtHost   = args.selectionTextHost
    seltxtSolv   = args.selectionTextSolvent
    seltxtSyst   = args.selectionTextSystem
    bInMemory    = args.bInMemory    
    MDAbackend   = 'serial'

In [16]:
if not bPythonScriptExport:
    %cd ..
    allele='wt' ; temperature='310K' ; repl=5
    workDir = './trajectories/%s/%s/%i' % (allele, temperature, repl)
    topFile      = os.path.join(workDir, 'seg.psf')
    trjFile      = os.path.join(workDir, 'sum.xtc')
    clustDefFile = 'Stable_Solvent_Clustering.cluster_definitions_d3.5_r0.50.txt'
    distCutoff   = 3.5
    
    #outputPrefix = os.path.join(workDir, 'clustered')
    outputPrefix = os.path.join(workDir, 'clustered_d3.5_r0.50')
    
    seltxtSite   = "name OW"
    seltxtHost   = "protein"
    seltxtSolv   = "resname SOL"
    seltxtSyst   = "(protein or resname ATP MG) and not name MN? 1MN? 2MN? 1MC? 2MC?"
    bInMemory    = True
    nClusts      = 33
    MDAbackend   = 'openMP'

/home/zharmad/projects/cftr


In [6]:
def process_list_items( ll ):
    for l in ll:
        l[0].sort()
        for item in l[0]:
            if item in l[1]:
                l[1].remove(item)
        l[1].sort()
    
#0 Intersection: ND1:573 ND1:603 ND1:464 ND1:465 ND1:572
#0 Union: ND1:573 ND1:603 ND1:464 ND1:465 ND1:572 ND2:1348 ND1:493
#1 Intersection: ND1:549 ND2:1251 ND1:550
#1 Union: ND2:1370 ND1:549 ND1:551 ND2:1371 ND1:553 ND1:548 ND2:1251 ND1:550 ND2:1291
#2 Intersection: LAS:28 TD2:1032 TD2:1036 TD2:1033 LAS:24
#2 Union: TD2:1037 TD2:1032 TD2:1036 LAS:27 LAS:24 LAS:28 TD2:1034 TD2:1033 TD2:1035 LAS:25                                                                    3
#...
def read_cluster_definitions(fileName):
    out=[]
    items=[]    
    with open(fileName,'r') as fp:
        for line in fp:
            l=line.split()
            if len(l)<2:
                continue
            elif len(l)==2:
                items.append([])
            else:
                temp=[]
                for i in range(2,len(l)):
                    temp.append( int(l[i].split(':')[-1]) )
                items.append(temp)
            if len(items)==2:
                out.append(items)
                items=[]
    return out

In [7]:
def parse_list(l, pref="", suff="", sep=" "):
    out=""
    for item in l:
        out+="%s%s%s" % (pref, item, suff)
        if item != l[-1]:
            out+=sep
    return out

def write_cluster_selection_text(l, selSite="name OW", selSolute="protein", d=5.0):
    out=[]
    for i in l:
        s1=parse_list(i[0], pref="(around %g (resid " % d, suff=" and %s))" % selSolute, sep=" and ")
        if s1 !="":
            s1+=" and "
        s2 = parse_list(i[1], pref="", sep=" ")
        s2b = "(around %g (resid %s and protein))" % (d, s2)
        out.append("%s and %s %s" % (selSite, s1, s2b))
    return out

In [8]:
def load_subuniverse_inmem(u, selectionText):
    a = u.select_atoms(selectionText)
    return mda.Merge(a).load_new(AnalysisFromFunction(lambda ag: ag.positions.copy(), a).run().results, format=MemoryReader)

In [9]:
selectionResids = read_cluster_definitions(clustDefFile)
process_list_items(selectionResids) 
stringSelectionResids = [ [str(s[0]).strip('[]'),str(s[1]).strip('[]')] for s in selectionResids ]
x = write_cluster_selection_text(selectionResids, selSite=seltxtSite, selSolute=seltxtHost, d=distCutoff)
print("= = cluster definition file processed with distance cutoff %g Angs." % distCutoff)
print("   ...selection text for cluster 0:", x[0])

if nClusts is None:
    nClusts = len(selectionResids)

= = cluster definition file processed with distance cutoff 3.5 Angs.
   ...selection text for cluster 0: name OW and (around 3.5 (resid 549 and protein)) and (around 3.5 (resid 1251 and protein)) and  (around 3.5 (resid 550 1291 and protein))


In [10]:
bWriteOnlyWaters=True
if seltxtSyst is not None:
    bWriteOnlyWaters=False

In [11]:
bInMemory=True

In [12]:
u = mda.Universe(topFile,trjFile, in_memory=False)
nFrames=u.trajectory.n_frames
if bInMemory:
    print("= = Loading the trajectories into memory for subselections %s and %s ..." % (seltxtHost, seltxtSite) )
    # The in_memory version loads just the protein and solvent atoms for computation of distance matrices
    uSite  = load_subuniverse_inmem(u, seltxtSite)
    uHost  = load_subuniverse_inmem(u, seltxtHost)



In [13]:
atomSelSite=u.select_atoms(seltxtSite)
nSites = atomSelSite.n_atoms
dictGlobalIndex = { a:b for a,b in enumerate(atomSelSite.indices) }
drevGlobalIndex = { b:a for a,b in dictGlobalIndex.items()}

In [14]:
if bInMemory:
    indexSelectionResids = []
    for selClust in selectionResids:
        # Each residue in the intersect selection needs its own list of indices for slicing.
        # The whole residue in the union selection is added in a single set of incides for slicing.
        temp = [ uHost.select_atoms("resid %i" % r).indices for r in selClust[0] ]
        if len(selClust[1])>0:
            temp.append( uHost.select_atoms("resid %s" % ' '.join(map(str,selClust[1])) ).indices )
        indexSelectionResids.append( temp )
else:
    listAtomSels = []
    for i in range(nClusts):
        listAtomSels.append( u.select_atoms(x[i], updating=True, periodic=False) )

In [17]:
"""
The matBoolIndices holds the information about which water ID is within the cutoffdistance from the protein residue atoms.
One can access the fram and index information via np.where( matBoolIndices[frame,cluster,waterID] )
"""
matBoolIndices=np.full( (nFrames, nClusts, nSites), False)
matTemps=[]
for i in range(nClusts):
    nSel = len(indexSelectionResids[i])
    matTemps.append( np.full( (nSel, nSites), False) )

if bInMemory:
    print("= = Entering lengthy distance matrix computation to determine water site occupancies...")
    print("    ...Starting at: %s" % time.ctime() )        
    for f in tqdm(range(nFrames)):
        for c in range(nClusts):
            nSel = len(indexSelectionResids[c])
            #matTemp = np.full( (nSel, uSite.n_atoms), False)
            matTemps[c][:] = False
            # = = Get site atoms within X of any host atoms in sub-selection, e.g. any atom of protein resid X
            for j in range(nSel):
                posHost=uHost.trajectory.coordinate_array[f,indexSelectionResids[c][j]]
                distMat = mda.analysis.distances.distance_array(posHost,
                                                                uSite.trajectory.coordinate_array[f],
                                                                backend=MDAbackend)
                matTemps[c][j] = np.any( distMat<distCutoff, axis=0 )
            # = = Filter only site atoms that are within the cutoff or all selections.
            matBoolIndices[f,c] = np.all( matTemps[c], axis=0 )
    matTemps = None    
else:
    print("= = Entering lengthy distance matrix computation to determine water site occupancies...")
    print("    ...Starting at: %s" % time.ctime() )
    for f in tqdm(range(nFrames)):
        u.trajectory[f]
        for c in range(nClusts):
            for i in listAtomSels[c].indices:
                matBoolIndices[f,c,drevGlobalIndex[i]]=True
    print("    ...Ending at: %s" % time.ctime() )    

= = Entering lengthy distance matrix computation to determine water site occupancies...
    ...Starting at: Wed Aug  4 13:57:10 2021


  0%|          | 0/2001 [00:00<?, ?it/s]

In [18]:
listCounters=[]
for i in range(nClusts):
    c =  Counter( np.where(matBoolIndices[:,i])[1] )
    print( "Cluster %i, 5 most common waters:" % i, c.most_common(5) )
    listCounters.append( c )

Cluster 0, 5 most common waters: [(11146, 1956), (12148, 1)]
Cluster 1, 5 most common waters: [(11110, 1995), (40904, 1)]
Cluster 2, 5 most common waters: [(12466, 1656), (40691, 344), (40584, 4), (19784, 3), (3616, 2)]
Cluster 3, 5 most common waters: [(34432, 953), (25042, 196), (25942, 140), (28434, 96), (9005, 81)]
Cluster 4, 5 most common waters: [(5548, 1867), (43938, 1)]
Cluster 5, 5 most common waters: [(25139, 2001), (34432, 1289), (32950, 276), (29062, 170), (35661, 144)]
Cluster 6, 5 most common waters: [(25139, 1879), (14894, 2), (26986, 2), (31849, 2), (5875, 2)]
Cluster 7, 5 most common waters: [(40256, 533), (34253, 304), (18161, 212), (17045, 186), (42059, 183)]
Cluster 8, 5 most common waters: [(34432, 1057), (25942, 333), (25937, 328), (33519, 256), (25042, 195)]
Cluster 9, 5 most common waters: [(18240, 1986), (15493, 53), (23732, 38), (30218, 33), (10955, 32)]
Cluster 10, 5 most common waters: [(370, 636), (43534, 499), (14888, 275), (18145, 257), (15794, 135)]
Clus

In [19]:
from bokeh.io import output_notebook, output_file, show
from bokeh.plotting import figure
from bokeh.layouts import row, column
import bokeh.models as bokehModels
import bokeh.palettes as bokehPalettes

def map_palette(p, v, vMin, vMax):
    x=int((v-vMin)/(vMax-vMin)*(len(p)-1))
    return p[x]

def create_JS_update_visibility(listGlyphs, div="", divText=None):
    return bokehModels.CustomJS(args=dict(gs=listGlyphs, d = div, dt=divText),
             code="""
             //var s = this.item;
             var s = this.value_throttled;
             for (var i = 0; i < gs.length; i++) {
                 if ( i == s ) {
                     gs[i].visible = true
                 } else if ( gs[i].visible ) {
                     gs[i].visible = false
                 }
             }
             if ( d != "" ) {
                 d.text = '<strong>Core resids:</strong> ' + dt[s][0] + '<br><strong>Peripheral resids:</strong> ' + dt[s][1]
             }
             """
            )

def export_site_occupancy_as_bokehHTML(outputFile,
                                       listXs, listYs, listColours,
                                       listClusterDefs=None,
                                       titleText='Water occupancy',
                                       bNoteBook=False,
                                       palette=bokehPalettes.Viridis256
                                      ):
    nPlots = len(listXs)
            
    fig = figure(title=titleText, plot_width=800, plot_height=400,
                 x_axis_label='Trajectory frame', y_axis_label='solvent index')
    listSources = []
    listGlyphs   = []
    for i in range(nPlots):       
        source = bokehModels.ColumnDataSource({'x': listXs[i], 'y' : listYs[i],
                                               'color': listColours[i],
                                              })
        glyph = fig.scatter(x = 'x', y = 'y', size=2, color='color', source = source)          
        listSources.append(source)
        glyph.visible=(i==0)
        listGlyphs.append(glyph)

    colourMapper = bokehModels.mappers.LinearColorMapper(palette=palette, low=0.0, high=1.0)
    colourBar = bokehModels.ColorBar(title='Occupancy rate',
                                     color_mapper=colourMapper, label_standoff=12)
    fig.add_layout(colourBar, 'above')

    if listClusterDefs is None:
        div = bokehModels.Div(width=600, text=trjFile)
    else:
        div = bokehModels.Div(width=600, text='<strong>Core resids:</strong> %s<br><strong>Peripheral resids:</strong> %s' % \
                              (listClusterDefs[0][0], listClusterDefs[0][1]))
        
    sliderWidget = bokehModels.Slider(width=200,title='Show cluster...', start=0, end=nClusts-1, step=1, value=0)
    sliderWidget.js_on_change('value_throttled',
                              create_JS_update_visibility(listGlyphs, div, listClusterDefs))

    output_file(outputFile)
    if not bPythonScriptExport:
        output_notebook()
    show(column(row(sliderWidget,div),fig))            

In [22]:
listAtomColours=[]
for i in range(nClusts):
    #c=listDictAtomData[i]['counter']
    c=listCounters[i]
    l=[]
    for index in np.where(matBoolIndices[:,i])[1]:
         l.append( map_palette( bokehPalettes.Viridis256, c[index], 0, nFrames )  )
    listAtomColours.append( l )

export_site_occupancy_as_bokehHTML( outputPrefix+'_site_occupancy_all.html',
                                   [ np.where(matBoolIndices[:,i])[0] for i in range(nClusts) ],
                                   [ np.where(matBoolIndices[:,i])[1] for i in range(nClusts) ],
                                   [ listAtomColours[i] for i in range(nClusts) ],
                                   listClusterDefs=stringSelectionResids,
                                   titleText='Waters satisfying consensus cluster residues sets, using distance cutoff %g Angs.' % distCutoff,
                                   bNoteBook=not bPythonScriptExport,
                                   palette=bokehPalettes.Viridis256
                                   )

## Populate the assignment of a water molecule to every site

In [23]:
def count_duplicate_assignments(arr, emptyID=0, bList=False):
    d=[]
    for j in range(arr.shape[1]):
        check = arr[np.nonzero(arr[...,j]-emptyID)[0],j]
        if len(np.unique(check))<len(check):
            d.append(j)
    if bList:
        return len(d), d
    else:
        return len(d)

In [24]:
# The complete non-redundant assignment of water molecules to each cluster for each frame.
emptyID=-1
clusterAssignments=np.zeros((nClusts,nFrames), dtype=int)
if emptyID != 0:
    clusterAssignments[:]=emptyID

# List of waters to be ignored as they are adjacent clusters.
listAssigned=[]

# Copy of counter lists to pop and modify as necessary.
cTest=[]
for i in range(nClusts):
    cTest.append( [ (k,v) for k,v in listCounters[i].most_common() ] )
    

### Step 1.
Assign waters with maximium occupancy in each cluster. Order by clusters with higher prevalence among simulations (lowID to highID).

In [25]:
for i in range(nClusts):
    for k, v in cTest[i]:
        cTest[i].pop(0)
        if k in listAssigned:
            print("Water %i repeated in cluster %i" % (k,i))
            continue
        listAssigned.append(k)
        # = = = Assign the frames in where water ID k is found in cluster i
        clusterAssignments[ i, np.where( matBoolIndices[:,i,k] )[0] ] = k
        #for j in range(nFrames):
        #    if k in listDictAtomData[i]['indices'][j]:
        #        clusterAssignments[i,j] = k
        print("Water %i (occ. %i) assigned to cluster %i by maximum occupancy" % (k,v,i))
        break

rAssigned=100*np.count_nonzero(clusterAssignments-emptyID)/(nClusts*nFrames)
print("   ...%.1f%% assigned so far." % rAssigned)

Water 11146 (occ. 1956) assigned to cluster 0 by maximum occupancy
Water 11110 (occ. 1995) assigned to cluster 1 by maximum occupancy
Water 12466 (occ. 1656) assigned to cluster 2 by maximum occupancy
Water 34432 (occ. 953) assigned to cluster 3 by maximum occupancy
Water 5548 (occ. 1867) assigned to cluster 4 by maximum occupancy
Water 25139 (occ. 2001) assigned to cluster 5 by maximum occupancy
Water 25139 repeated in cluster 6
Water 26986 (occ. 2) assigned to cluster 6 by maximum occupancy
Water 40256 (occ. 533) assigned to cluster 7 by maximum occupancy
Water 34432 repeated in cluster 8
Water 25937 (occ. 328) assigned to cluster 8 by maximum occupancy
Water 18240 (occ. 1986) assigned to cluster 9 by maximum occupancy
Water 370 (occ. 636) assigned to cluster 10 by maximum occupancy
Water 22690 (occ. 873) assigned to cluster 11 by maximum occupancy
Water 11110 repeated in cluster 12
Water 11184 (occ. 11) assigned to cluster 12 by maximum occupancy
Water 800 (occ. 1112) assigned to cl

In [26]:
print( "...Debug: frames with duplicate assignments:", count_duplicate_assignments(clusterAssignments, emptyID=-1) )

...Debug: frames with duplicate assignments: 0


### Step 2.
Starting populating the assignment matrix by working through the cluster-specific counters, ordered by highest occupancy.

In [27]:
cTest2 = np.array([ [i,k,v] for i,l in enumerate(cTest) for k,v in l ])
cTest2 = cTest2[ np.flip(np.argsort(cTest2[...,2])) ]

In [28]:
minFrameBreak=0.01*nFrames
#minFrameBreak=1
for cID, k, v in cTest2:
    if v<minFrameBreak:
        break
    if k in listAssigned:
        #print("Water %i repeated in cluster %i" % (k,cID))
        continue
        
    # = = = Assign the frames in where water ID k is found in cluster i
    # clusterAssignments[ cID, np.where( matBoolIndices[:,cID,k] )[0] ] = k 
    # Two checks are performed.
    # - The ==emptyID checks that cluster assignments is currenty empty.
    # - The np.any() checks that water ID k hasn't been assigned to another cluster yet in that same frame.
    indices = np.where(np.logical_and(np.logical_and(matBoolIndices[:,cID,k],
                                                     clusterAssignments[cID] == emptyID),
                                                     np.all(clusterAssignments!=k,axis=0)
                      ))[0]
    clusterAssignments[cID, indices ] = k
    #countOverlap=0    
    #for j in range(nFrames):
    #    if k in clusterAssignments[...,j]:
    #        # = = = Preassigned to another cluster
    #        continue
    #    if k in listDictAtomData[cID]['indices'][j]:
    #        if clusterAssignments[cID,j]==emptyID:
    #            clusterAssignments[cID,j] = k
    #        else:
    #            countOverlap+=1
    #if countOverlap<v:
    #    print("Water %i (n: %i of %i) assigned to cluster %i by sequential fill. " % (k,v-countOverlap,v,cID))
    
rAssigned=100*np.count_nonzero(clusterAssignments-emptyID)/(nClusts*nFrames)
print("   ...%.1f%% assigned so far." % rAssigned)

   ...64.0% assigned so far.


In [31]:
print( "...Debug: frames with duplicate assignments:", count_duplicate_assignments(clusterAssignments, emptyID=-1) )

cA=0 ; cB=0 ; inds=np.where(clusterAssignments==emptyID)
for i,j in zip(inds[0],inds[1]):
    if listCounters[i][j] == 0:
        cA+=1
    else:
        cB+=1
print (np.count_nonzero(clusterAssignments-emptyID), cA, cB, nClusts*nFrames)

...Debug: frames with duplicate assignments: 0
42236 23212 585 66033


### Step 3.
After this, extend empty assignments by looking at existing populated entries on either side. Now ignoring whether the assignment is already taken.

In [32]:
def return_discrete_ranges(arr):
    if len(arr)==0:
        return [],[]
    left=[] ; right=[]
    l=arr[0] ; r=arr[0]
    for v in arr[1:]:
        if v==r+1:
            r+=1
            continue
        left.append(l) ; right.append(r)
        l=v ; r=v
    left.append(l) ; right.append(r)
    return left, right

# = = = Fill in voids from existing water IDs, based on adjacency.
for i in range(nClusts):
    l,r = return_discrete_ranges( np.where(clusterAssignments[i]==emptyID)[0] )
    for a,b in zip(l,r):        
        if a==0:
            if b==nFrames-1:
                # Completely empty
                continue
            else:
                clusterAssignments[i,a:b+1]=clusterAssignments[i,b+1]
        elif b==nFrames-1:
            clusterAssignments[i,a:b+1]=clusterAssignments[i,a-1]
        else:
            m=int(0.5*(a+b))
            clusterAssignments[i,m:b+1]=clusterAssignments[i,b+1]            
            clusterAssignments[i,a:m]=clusterAssignments[i,a-1]   

In [33]:
cA=0 ; cB=0 ; inds=np.where(clusterAssignments==emptyID)
for i,j in zip(inds[0],inds[1]):
    if listCounters[i][j] == 0:
        cA+=1
    else:
        cB+=1
print (np.count_nonzero(clusterAssignments), cA, cB, nClusts*nFrames)

print( "...Debug: frames with duplicate assignments:", count_duplicate_assignments(clusterAssignments, emptyID=-1) )

66033 0 0 66033
...Debug: frames with duplicate assignments: 871


### Plot the final assignment, taking note of duplicate assignemnts due to the final step.

In [34]:
colours = np.chararray(clusterAssignments.shape, itemsize=7)
for j in range(clusterAssignments.shape[1]):
    for i in range(clusterAssignments.shape[0]):
        if clusterAssignments[i,j]==emptyID:
            colours[i,j]='#4422FF'
            continue
        if np.count_nonzero(clusterAssignments[:,j]==clusterAssignments[i,j])>1:
            colours[i,j]='#FF4422'
        else:
            colours[i,j]='#000000'
colours = colours.astype(str)

In [35]:
export_site_occupancy_as_bokehHTML(outputPrefix+'_site_occupancy_filled.html',
                                   np.tile( np.arange(nFrames), (nClusts,1) ),
                                   clusterAssignments,
                                   colours,
                                   listClusterDefs=stringSelectionResids,
                                   titleText='Final Waters assigned to cluster, using distance cutoff %g Angs.' % distCutoff,
                                   bNoteBook=not bPythonScriptExport,
                                   palette=bokehPalettes.Viridis256
                                   )

### Convert the solvent indices back to global indices.

In [36]:
k = np.array(list(dictGlobalIndex.keys()))
v = np.array(list(dictGlobalIndex.values()))
mapping_ar = np.zeros(k.max()+1,dtype=v.dtype) #k,v from approach #1
mapping_ar[k] = v
clusterAssignmentsOut = mapping_ar[clusterAssignments]

## Prepare the coordinates.
Note: Completely empty asignments will remain empty at this stage. Need to give them default treatment.

In [37]:
def create_water_universe(nSol, listNames=['O', 'H', 'H'], resName='SOL', res0=1, segID='SOL', bTraj=False):
    nAtoms = nSol*3
    resIndices = np.repeat(range(nSol), 3)
    segIndices = [0]*nSol
    sol = mda.Universe.empty(nAtoms, n_residues=nSol, atom_resindex=resIndices, residue_segindex=segIndices, trajectory=bTraj)
    sol.add_TopologyAttr('name', listNames*nSol)
    sol.add_TopologyAttr('type', ['O', 'H', 'H']*nSol)
    sol.add_TopologyAttr('resname', [resName]*nSol)
    sol.add_TopologyAttr('resid', list(range(res0,nSol+res0)))
    sol.add_TopologyAttr('segid', [segID])
    
    bonds = []
    for i in range(0, nAtoms, 3):
        bonds.extend([(i, i+1), (i, i+2)])
    sol.add_TopologyAttr('bonds', bonds)
    return sol

def create_universe_from_selection(u, atomSel):
    """
    Need an In-memory version in order to write-out the trajectory
    """
    from MDAnalysis.analysis.base import AnalysisFromFunction
    xyz = AnalysisFromFunction(lambda ag: ag.positions.copy(),
                                   atomSel).run().results
    u2 = mda.Merge(atomSel)
    u2.load_new(xyz, format=MemoryReader)

def get_water_site_positions(mdaU, oxyIDs, emptyID=0):
    nSol, nFrames = oxyIDs.shape
    posOut=np.zeros( (nFrames,nSol*3,3), dtype=float)
    for j in range(nFrames):
        mdaU.trajectory[j]
        for i in range(nSol):
            oxyID = oxyIDs[i,j]
            if oxyID > emptyID:
                posOut[j,i*3:(i+1)*3,:]=mdaU.select_atoms("index %i %i %i" % (oxyID,oxyID+1,oxyID+2)).positions
    return posOut


In [38]:
# = = = Create water universe from positions.
memReader = MemoryReader(get_water_site_positions(u, clusterAssignmentsOut),
                         dt=u.trajectory.dt)
uWat = create_water_universe(nClusts, listNames=['OW', 'HW1', 'HW2'], res0=0, segID='CRY', bTraj=True)

# = = Directly replace the trajectory reader object. Small memory leak?
uWat.trajectory = memReader

# Use mda.Merge() as needed to write

In [39]:
if bWriteOnlyWaters:
    selWatersOut = uWat.select_atoms("all")
    selWatersOut.universe.trajectory[0]
    selWatersOut.write(outputPrefix+'_solvent.pdb')   
    selWatersOut.write(outputPrefix+'_solvent.xtc', frames='all')    
else:
    selSoluteIn = u.select_atoms(seltxtSyst, updating=True)
    selWatersIn = uWat.select_atoms("all", updating=True)
    uSystemOut   = mda.Merge(selSoluteIn,selWatersIn)
    nAtomsSolute = selSoluteIn.n_atoms
    nAtomsWater  = nClusts*3
  
    dimSystemOut = (nFrames, nAtomsSolute+nAtomsWater, 3 )
    memReaderOutputSystem = MemoryReader(np.zeros( dimSystemOut, dtype=float ), dt=u.trajectory.dt)
    uSystemOut.trajectory = memReaderOutputSystem
    
    # = = = Transfer positions to the in-memory output system from the source system (which is probably from file)
    selSoluteOut = uSystemOut.select_atoms(seltxtSyst, updating=True)
    selWatersOut = uSystemOut.select_atoms(seltxtSolv, updating=True)
    for f in range(uSystemOut.trajectory.n_frames):
        u.trajectory[f] ; uWat.trajectory[f] ; uSystemOut.trajectory[f]
        selSoluteOut.positions = selSoluteIn.positions
        selWatersOut.positions = selWatersIn.positions

    selSystemOut = uSystemOut.select_atoms("all")
    selSystemOut.universe.trajectory[0]
    selSystemOut.write(outputPrefix+'.pdb')
    selSystemOut.write(outputPrefix+'.xtc', frames='all')

