In [1]:
# Papermill params
ratio = 0.9          # Train-Test split ratio
attempt = 0        # Number of times to run
width = 512
depth = 2
learning_rate = 1e-2
dropout = 0.05
regularization = 1e-7
epsilon = 1e-7

# Neural network

In this notebook we set up the neural networks with VAMPNet scoring functions and train them for all ensembles with different output sizes and estimate errors by bootstrap aggregation. This notebook can be used with `papermill` to run all cells automatically with given parameters. We first define the imports and useful utility functions.

In [2]:
%run model.py



In [3]:
def statdist(X: np.ndarray) -> np.ndarray:
    """
    Calculate the equilibrium distribution of a transition matrix.
    
    Parameters
    ----------
    X
        Row-stochastic transition matrix
    
    Returns
    -------
    mu
        Stationary distribution, i.e. the left
        eigenvector associated with eigenvalue 1.
    
    """
    ev, evec = eig(X, left=True, right=False)
    mu = evec.T[ev.argmax()]
    mu /= mu.sum()
    return mu

In [4]:
def test_split(data: MaybeListType[np.ndarray], lag: int, p: float=0.1, mask: Optional[np.ndarray]=None):
    data = make_list(data)
    lengths = np.array([len(d) for d in data])
    nframes = lengths.sum()
    
    inds = np.empty((nframes, 4), dtype=np.int)
    inds[:, 0] = np.repeat(np.arange(len(data), dtype=np.int), lengths)
    inds[:, 1] = np.concatenate([np.arange(n) for n in lengths])
    inds[:, 2] = np.arange(nframes, dtype=np.int)
    inds[:, 3] = np.zeros_like(inds[:, 0]) if mask is None else mask
    inds = unflatten(inds, lengths=[lengths])
    
    # Local (frame) shuffling
    shuf_traj_inds = [np.random.choice(
        d[:, 1], size=d.shape[FRAMES], replace=False) for d in inds]
    
    # Sort out too short trajectories, split out lagged part
    n_pairs = 0
    xt, xttau = [], []
    for i, traj in enumerate(inds):
        n_points = traj.shape[FRAMES]

        # We'll just skip super short trajectories for now
        shuf_traj_inds[i] = shuf_traj_inds[i][shuf_traj_inds[i] < (n_points - lag)]
        if n_points <= lag:
            continue
                
        n_pairs += n_points - lag
        xt.append(traj[:n_points - lag][shuf_traj_inds[i]])
        xttau.append(traj[lag:n_points][shuf_traj_inds[i]])
        
    # Shuffle externally
    shuf_full_inds = np.random.choice(
        np.arange(n_pairs, dtype=np.int), size=n_pairs, replace=False)
    xt_shuf = np.vstack(xt)[shuf_full_inds]
    xttau_shuf = np.vstack(xttau)[shuf_full_inds]
    
    # These are the entries for the test set
    n_frames_test = int(xt_shuf.shape[FRAMES] * p)
    inds_t = xt_shuf[:n_frames_test]
    inds_ttau = xttau_shuf[:n_frames_test]
    data_flat = np.vstack(data)
    test_xt, test_xttau = data_flat[inds_t[:, 2]], data_flat[inds_ttau[:, 2]]
    
    # Mask out unwanted frames with NaNs
    mask_pair = xt_shuf[:, 3] | xttau_shuf[:, 3]
    test_xt[mask_pair[:n_frames_test]] = np.nan
    test_xttau[mask_pair[:n_frames_test]] = np.nan
    
    # We can't just remove our test frame pairs, as the training set
    # would then be out of sync! So we replace the test samples with
    # NaNs instead, we can check for those later in the DataGenerator.
    data_flat[np.union1d(inds_t[:, 2], inds_ttau[:, 2])] = np.nan
    data_train_valid = unflatten(data_flat, lengths=[lengths])
        
    return data_train_valid, (test_xt, test_xttau)

## Data
### Trajectories
Trajectories were acquired in multiple rounds of 1024 simulations each at 278 K in the $NVT$ ensemble yielding approximately 300 µs per ensemble. Postprocessing involved removing water, subsampling to 250 ps timesteps, and making molecules whole.

In [5]:
sim_names = ("apo", "holo", "control")
top, trajs = {}, {}
trajs = {k: sorted(glob("trajectories/{0}/r?/traj*.xtc".format(k))) for k in sim_names}
top = {k: "trajectories/{0}/topol.gro".format(k) for k in sim_names}
KBT = 2.311420 # 278 K
nres = 42
traj_rounds = {
    "apo": [1024, 1023, 1024, 1024, 1024],
    "holo": [1023, 1024, 32],
    "control": [1024, 1023]
}

We use minimum distances as features for the neural network:

In [6]:
allpairs = np.asarray(list(itertools.combinations(range(nres), 2)))
inpcon = {}
for k in sim_names:
    feat = pe.coordinates.featurizer(top[k])
    feat.add_residue_mindist(residue_pairs=allpairs)
    inpcon[k] = pe.coordinates.source(trajs[k], feat)

HBox(children=(HBox(children=(Label(value='Obtaining file info'),), layout=Layout(max_width='35%', min_width='…



HBox(children=(HBox(children=(Label(value='Obtaining file info'),), layout=Layout(max_width='35%', min_width='…



HBox(children=(HBox(children=(Label(value='Obtaining file info'),), layout=Layout(max_width='35%', min_width='…



HBox(children=(HBox(children=(Label(value='Obtaining file info'),), layout=Layout(max_width='35%', min_width='…



In [7]:
lengths, nframes = {}, {}
for i, k in enumerate(sim_names):
    lengths[k] = sort_lengths(inpcon[k].trajectory_lengths(), traj_rounds[k])
    nframes[k] = inpcon[k].trajectory_lengths().sum()

In [8]:
print("\t\t" + "\t\t".join(sim_names))
print("\n".join((
    "Trajs: \t\t" + "\t\t".join("{0}".format(len(trajs[k])) for k in sim_names),
    "Frames: \t" + "\t\t".join("{0}".format(nframes[k]) for k in sim_names),
    "Time: \t\t" + "\t".join("{0:5.3f} µs".format(inpcon[k].trajectory_lengths().sum() * 0.00025)
                           for k in sim_names)
)))

		apo		holo		control		phen
Trajs: 		5119		2079		2047		2048
Frames: 	1259172		1225868		1114503		1236792
Time: 		314.793 µs	306.467 µs	278.626 µs	309.198 µs


## VAMPNet
VAMPNet[1] is composed of two lobes, one reading the system features $\mathbf{x}$ at a timepoint $t$ and the other after some lag time $\tau$. In this case the network reads all minimum inter-residue distances (780 values) and sends them through 5 layers with 256 nodes each. The final layer uses between 2 and 8 *softmax* outputs to yield a state assignment vector $\chi: \mathbb{R}^m \to \Delta^{n}$ where $\Delta^{n} = \{ s \in \mathbb{R}^n \mid 0 \le s_i \le 1, \sum_i^n s_i = 1 \}$ representing the probability of a state assignment. One lobe thus transforms a system state into a state occupation probability. We can also view this value as a kind of reverse ambiguity, i.e. how sure the network is that the system is part of a certain cluster. These outputs are then used as the input for the VAMP scoring function. We use the new enhanced version with physical constraints[2], particularly the ones for positive entries and reversibility.

[1] Mardt, A., Pasquali, L., Wu, H. & Noé, F. VAMPnets for deep learning of molecular kinetics. Nat Comms 1–11 (2017). doi:10.1038/s41467-017-02388-1

[2] Mardt, A., Pasquali, L., Noé, F. & Wu, H. Deep learning Markov and Koopman models with physical constraints. arXiv:1912.07392 [physics] (2019).

### Data preparation
We use minimum residue distances as input ($\frac{N(N-1)}{2}$ values, where $N$ is the number of residues) and first normalize the data:

In [9]:
for k in sim_names:
    filename = "intermediate/mindist-all-{0}.npy".format(k)
    if not os.path.exists(filename):
        print("No mindist file for {0} ensemble, calculating from scratch...".format(k))
        con = np.vstack(inpcon[k].get_output())
        np.save(filename, con)

In [10]:
idx = np.triu_indices(nres, k=1)
mat = np.zeros((nres, nres), dtype=np.int)
full_flat, full_data = {}, {}
for k in sim_names:
    raw = np.load("intermediate/mindist-all-{0}.npy".format(k))
    mat[idx] = np.arange(raw.shape[1])
    redinds = mat[np.triu_indices_from(mat, k=3)]
    full_flat[k] = ((raw - raw.mean(axis=0)) / raw.std(axis=0))[:, redinds]
    full_data[k] = unflatten(full_flat[k], lengths[k])

### Neural network hyperparameters
To allow for a larger hyperparameter search space, we use the self-normalizing neural network approach by Klambauer *et al.* [2], thus using SELU units, `AlphaDropout` and normalized `LeCun` weight initialization. The other hyperparameters are defined at the beginning of this notebook.

[2] Klambauer, G., Unterthiner, T., Mayr, A. & Hochreiter, S. Self-Normalizing Neural Networks. arXiv.org cs.LG, (2017).

In [11]:
activation = "selu"                 # NN activation function
init = "lecun_normal"               # NN weight initialization
lag = 20                            # Lag time
n_epoch = 100                       # Max. number of epochs
n_epoch_s = 10000                   # Max. number of epochs for S optimization
n_batch = 10000                     # Training batch size
n_dims = full_data[k][0].shape[1]   # Input dimension
nres = 42                           # Number of residues
epsilon = 1e-7                      # Floating point noise
dt = 0.25                           # Trajectory timestep in ns
steps = 6                           # CK test steps
bs_frames = 1000000                 # Number of frames in the bootstrap sample
n_tries = 1                         # Number of training attempts for each model, we pick the best scoring one

outsizes = np.array([4, 2, 3, 5, 6])
lags = np.array([1, 2, 5, 10, 20, 50, 100])

### Bound ensemble
We filter out unbound frames to avoid issues with potential non-equilibrium conditions, as we see essentially no unbinding events.

In [12]:
inter_pairs = list(itertools.product(range(nres), (nres,)))
intercon = {}
for k in ("holo", "control"):
    feat = pe.coordinates.featurizer(top[k])
    feat.add_residue_mindist(residue_pairs=inter_pairs)
    intercon[k] = pe.coordinates.source(trajs[k], feat)

HBox(children=(HBox(children=(Label(value='Obtaining file info'),), layout=Layout(max_width='35%', min_width='…



HBox(children=(HBox(children=(Label(value='Obtaining file info'),), layout=Layout(max_width='35%', min_width='…



HBox(children=(HBox(children=(Label(value='Obtaining file info'),), layout=Layout(max_width='35%', min_width='…



In [13]:
inter_con = {}
for k in ("holo", "control"):
    filename = "intermediate/mindist-inter-all-{0}.npy".format(k)
    if not os.path.exists(filename):
        print("No mindist file for {0} ensemble, calculating from scratch...".format(k))
        inter_con[k] = np.vstack(intercon[k].get_output())
        np.save(filename, inter_con)
    else:
        inter_con[k] = np.load(filename)

In [40]:
cutoff = 0.5  # Compromise given the above plot
masks = {k: (inter_con[k] < cutoff).any(axis=1) for k in ("holo",)}
masks.update({k: np.full(full_flat[k].shape[0], True) for k in ("apo", "control")})

In [45]:
input_data, input_flat, test_data = {}, {}, {}
for k in ("holo", "apo", "control"):
    filename = "intermediate/input-mask-{0}.npz".format(k)
    if not os.path.exists(filename):
        print("No input file for {0} ensemble, resplitting...".format(k))
        input_data[k], test_data[k] = test_split(full_data[k], lag=lag, mask=~masks[k])
        input_flat[k] = np.vstack(input_data[k])
        input_flat[k][~masks[k]] = np.nan
        np.savez(filename, data=input_flat[k], test_t=test_data[k][0], test_ttau=test_data[k][1])
    else:
        print("Reading existing input file for {0} ensemble...".format(k))
        raw = np.load(filename)
        input_flat[k], test_data[k] = raw["data"], (raw["test_t"], raw["test_ttau"])
        input_data[k] = unflatten(input_flat[k], lengths=lengths[k])

Reading existing input file for holo ensemble...
Reading existing input file for phen ensemble...
Reading existing input file for apo ensemble...
Reading existing input file for control ensemble...


### Run
We run the training several times with different train/test splits to get an error estimate, this is referred to as bootstrap aggregating (*bagging*).

In [19]:
with h5py.File("intermediate/data.hdf5", "a") as write:
    for k in ("holo",):
        
        # Create HDF5 groups
        ens = write.require_group(k)
        att = ens.require_group(str(attempt))
        
        # Generate or read previously generated
        index_file = "models/model-idx-{0}-{1}.hdf5".format(k, attempt)
        if os.path.exists(index_file):
            generator = DataGenerator.from_state(input_data[k], index_file)
        else:
            generator = DataGenerator(input_data[k], dt=dt)
            generator.save(index_file)
        
        for n in outsizes:
            print("Training {0} n={1} i={2}...".format(k, n, attempt + 1))
            out = att.require_group(str(n))
            
            tests = test_data[k], np.zeros((test_data[k][0].shape[0], 2 * n))
            koops, scores = [], np.empty(n_tries)
            for i in range(n_tries):
                koop = KoopmanModel(n=n, network_lag=lag, verbose=1, nnargs=dict(
                    width=width, depth=depth, learning_rate=learning_rate,
                    regularization=regularization, dropout=dropout,
                    batchnorm=False, lr_factor=2e-2))
                koop.fit(generator)
                scores[i] = koop.score(tests)
                koops.append(koop)
                
            koop = koops[scores.argmax()]
            koop.save("models/model-ve-{0}-{1}-{2}.hdf5".format(k, n, attempt))
            print("Estimating Koopman operator...")
            ko = out.require_dataset("k", shape=(n, n), dtype="float32")
            ko[:] = koop.estimate_koopman(lag=50)
            print("Estimating mu...")
            mu = out.require_dataset("mu", shape=(koop.data.n_train,), dtype="float32")
            mu[:] = koop.mu
            print("Estimating implied timescales...")
            its = out.require_dataset("its", shape=(n - 1, len(lags)), dtype="float32")
            its[:] = koop.its(lags)
            print("Performing CK-test...")
            cke = out.require_dataset("cke", shape=(n, n, steps), dtype="float32")
            ckp = out.require_dataset("ckp", shape=(n, n, steps), dtype="float32")
            cke[:], ckp[:] = koop.cktest(steps)
            print("Estimating chi...")
            bootstrap = out.require_dataset("bootstrap", shape=(koop.data.n_train, 2 * n), dtype="float32")
            bootstrap[:] = koop.transform(koop.data.trains[0])
            full = out.require_dataset("full", shape=(nframes[k], 2 * n), dtype="float32")
            full[:] = koop.transform(generator.data_flat)
            del koop, koops
            gc.collect()

Training holo n=3 i=1...


KeyboardInterrupt: 