In [1]:
from crossflow.tasks import SubprocessTask
from crossflow.clients import Client
from dask_jobqueue import SLURMCluster

In [2]:
cluster = SLURMCluster(cores=1, 
                       job_cpu=1,
                       processes=1,
                       memory='256GB',
                       queue='standard',
                       job_directives_skip=['--mem', '-n '],
                       interface='hsn0',
                       job_extra_directives=['--nodes=1',
                           '--reservation=e280-workshop_1018917',
                           '--qos=reservation', 
                           '--tasks-per-node=128'],
                       python='python',
                       account='e280-workshop',
                       walltime="06:00:00",
                       shebang="#!/bin/bash --login",
                       local_directory='$PWD',
                       job_script_prologue=['module load gromacs',
                                  'export OMP_NUM_THREADS=1',
                                  'source /work/e280/e280-workshop/yuyang/myenv/bin/activate'])

Perhaps you already have a cluster running?
Hosting the HTTP server on port 38457 instead


In [3]:
client = Client(cluster)
cluster.scale(1) # start a single worker

In [4]:
print(cluster)

SLURMCluster(1605d29c, 'tcp://10.253.46.22:42921', workers=1, threads=1, memory=222.00 GiB)


## From the 1ake endpoint

In [5]:
grompp = SubprocessTask('gmx grompp -f x.mdp -c x.gro -p x.top -o system.tpr -maxwarn 1')
grompp.set_inputs(['x.mdp', 'x.gro', 'x.top'])
grompp.set_outputs(['system.tpr'])

mdrun = SubprocessTask('srun --distribution=block:block --hint=nomultithread gmx_mpi mdrun  -deffnm system')
mdrun.set_inputs(['system.tpr'])
mdrun.set_outputs(['system.log', 'system.trr'])

makewhole = SubprocessTask('echo 0 | gmx trjconv -f broken.trr -s system.tpr -o whole.trr -pbc whole')
makewhole.set_inputs(['broken.trr', 'system.tpr'])
makewhole.set_outputs(['whole.trr'])

In [6]:
tprA = client.submit(grompp, 'nvt.mdp', '1ake_em.gro', '1ake.top')
logA, trajA = client.submit(mdrun, tprA)
wholeA = client.submit(makewhole, trajA, tprA)
print(logA)
print(trajA)
print(wholeA)

<Future: pending, key: lambda-37a2bc2894235edeca226540c7ed24fc>
<Future: pending, key: lambda-e7e1cf50eedac23424d2583dd636d547>
<Future: pending, key: run-684c0391-e829-4b33-b3e8-bd167abe25e0>


In [8]:
logA.result().save('1ake_test.log')
wholeA.result().save('1ake_test.trr')

'1ake_test.trr'

In [9]:
from crossflow.filehandling import FileHandler
fh = FileHandler()
startcrdsA = fh.load('1ake_em.gro')
topA = fh.load('1ake.top')
mdp = fh.load('nvt.mdp')

In [12]:
from distributed import wait, as_completed
n_reps = 10
startcrdsAs = [startcrdsA] * n_reps # replicate the starting structure
tprAs = client.map(grompp, mdp, startcrdsAs, topA)
logAs, trajAs = client.map(mdrun, tprAs)
wholeAs = client.map(makewhole, trajAs)
for future in as_completed(tprAs + trajAs + wholeAs):
    if future in tprAs:
        print(f'grompp job {tprAs.index(future)} completed')
    elif future in trajAs:
        print(f'mdrun job {trajAs.index(future)} completed')
    else:
        print(f'make_whole job {wholeAs.index(future)} completed')

grompp job 2 completed
grompp job 8 completed
grompp job 1 completed
grompp job 0 completed
grompp job 9 completed
grompp job 3 completed
mdrun job 2 completed
make_whole job 2 completed
grompp job 5 completed
mdrun job 1 completed
mdrun job 9 completed
make_whole job 1 completed
make_whole job 9 completed
mdrun job 8 completed
make_whole job 8 completed
grompp job 6 completed
grompp job 7 completed
grompp job 4 completed
mdrun job 6 completed
mdrun job 7 completed
make_whole job 6 completed
mdrun job 0 completed
mdrun job 3 completed
make_whole job 7 completed
mdrun job 4 completed
make_whole job 4 completed
mdrun job 5 completed
make_whole job 0 completed
make_whole job 3 completed
make_whole job 5 completed


In [11]:
cluster.scale(5)

## From the 4ake endpoint

In [13]:
startcrdsB = fh.load('4ake_em.gro')
topB = fh.load('4ake.top')

cluster.scale(4) # Reset the size of the cluster to a moderate value
tprB = client.submit(grompp, mdp, startcrdsB, topB)
logB, trajB = client.submit(mdrun, tprB)
wholeB = client.submit(makewhole, trajB, tprB)

In [14]:
import mdtraj as mdt

ensembleA = mdt.load(wholeA.result(), top=startcrdsA)
ensembleB = mdt.load(wholeB.result(), top=startcrdsB)

In [15]:
import numpy as np
from mdplus.pca import PCA
from scipy.spatial.distance import cdist

def rmsd2(ensembleA, ensembleB, sel):
    """
    Calculate approximate 2D RMSD matrix via PCA

    Args:

        ensembleA (mdtraj trajectory): ensemble A of nA frames
        ensembleB (mdtraj trajectory): ensemble B of nB frames
        sel (string): mdtraj selection specifier for RMSD calculation

    Returns:

        rmsd2d (numpy array[nA, nB]): RMSD matrix
    """
    idxA = ensembleA.topology.select(sel)
    idxB = ensembleB.topology.select(sel)
    x = np.concatenate([ensembleA.xyz[:,idxA], ensembleB.xyz[:, idxB]])
    p = PCA()
    scores = p.fit_transform(x)
    d = cdist(scores[:len(ensembleA)], scores[len(ensembleA):])
    return d / np.sqrt(len(idxA))

In [16]:
def get_pair_distances(ensembleA, ensembleB, selection='name CA', max_pairs=10):
    '''
    Calculate the RMSD distances between all snapshots in each of two ensembles and return the closest

    Args:
       ensembleA (mdtraj trajectory): first ensemble
       ensembleB (mdtraj trajectory): second ensemble
       selection (string): mdtraj selection for RMSD calculation
       max_pairs (int): number of closest pairs to return

    Returns:
       pairlist: sorted dictionary with tuple of snaphot indices as keys, RMSDs as values
    '''
    pairdist = {}
    d = rmsd2(ensembleA, ensembleB, selection)
    for i in range(ensembleA.n_frames):
        for j in range(ensembleB.n_frames):
            key = (i, j)
            pairdist[key] = d[i, j]
        
    # sort by increasing RMSD:
    pairdist = {k:v for k, v in sorted(pairdist.items(), key=lambda i: i[1])[:max_pairs]}
    return pairdist

In [17]:
# Now run it:
closest_pairs = get_pair_distances(ensembleA, ensembleB, max_pairs=5)
print(closest_pairs)

{(0, 8): 0.7100329136073691, (0, 9): 0.7110713569172359, (0, 7): 0.7130101411327233, (0, 0): 0.7151266022313763, (0, 6): 0.716788622193025}


## The complete workflow

In [18]:
max_cycles = 10 # Maximum number of workflow iterations
min_rmsd = 0.2 # Target minimum RMSD between structures from each ensemble
max_pairs = 5 # Number of shortest inter-ensemble pairs to take forward to next iteration
cluster.scale(max_pairs) # Scale the SLURMCluster up to max_pairs workers

for icycle in range(max_cycles):
    print(f'Starting cycle {icycle}...')
    shortestA = [k[0] for k in list(closest_pairs.keys())] # indices of chosen structures from A
    shortestB = [k[1] for k in list(closest_pairs.keys())] # ditto for B
    tprAs = client.map(grompp, mdp, [ensembleA[i] for i in shortestA], topA) # run grompp jobs in parallel
    tprBs = client.map(grompp, mdp, [ensembleB[i] for i in shortestB], topB) # ditto
    logAs, trajAs = client.map(mdrun, tprAs) # run MD jobs in parallel
    logBs, trajBs = client.map(mdrun, tprBs) # ditto
    wholeAs = client.map(makewhole, trajAs, tprAs) # remove PBC artifacts
    wholeBs = client.map(makewhole, trajBs, tprBs)
    wait(wholeBs) # Block here until all jobs are done
    print('MD runs finished, finding closest pairs...')
    ensembleA += mdt.load([t.result() for t in wholeAs], top='1ake_em.gro') # Add to ensembles
    ensembleB += mdt.load([t.result() for t in wholeBs], top='4ake_em.gro')
    closest_pairs = get_pair_distances(ensembleA, ensembleB, max_pairs=max_pairs)
    print(f'Cycle {icycle}: closest pairs = {closest_pairs}')
    if list(closest_pairs.values())[0] < min_rmsd: # The two ends of the path have "met" - stop.
        break
        
print('Path search completed')
cluster.scale(0) # Scale the cluster back to having no workers at all.
ensembleA.save('ensembleA.xtc') # Save the ensembles in Gromacs xtc format for later analysis.
ensembleB.save('ensembleB.xtc')

Starting cycle 0...
MD runs finished, finding closest pairs...
Cycle 0: closest pairs = {(45, 35): 0.6757456151493262, (30, 35): 0.6782170197500583, (32, 35): 0.6784142521007794, (31, 35): 0.6807478663070266, (26, 35): 0.6812686199036484}
Starting cycle 1...
MD runs finished, finding closest pairs...
Cycle 1: closest pairs = {(105, 80): 0.6455366545591223, (106, 80): 0.6482641749946749, (105, 116): 0.6499597005582308, (104, 80): 0.6501932100793748, (103, 80): 0.6504947367456139}
Starting cycle 2...
MD runs finished, finding closest pairs...
Cycle 2: closest pairs = {(105, 148): 0.6375325320816776, (121, 148): 0.6375493175191181, (143, 148): 0.637549317546703, (105, 131): 0.6386808669186296, (143, 131): 0.6386977400689988}
Starting cycle 3...
MD runs finished, finding closest pairs...
Cycle 3: closest pairs = {(218, 219): 0.56417450617941, (219, 219): 0.566036488928158, (215, 219): 0.5665269341855169, (214, 219): 0.5706958725524629, (105, 219): 0.5716329409663421}
Starting cycle 4...
MD