# Analysis

In this notebook we perform most of the analysis and all the plotting.

In [None]:
from collections import defaultdict
import gc
from glob import glob
from multiprocessing import Pool
import itertools
import os
from typing import List, Tuple, Sequence
import warnings

import holoviews as hv
from holoviews import opts, dim
import h5py
import matplotlib.pyplot as plt
from matplotlib import rc, ticker
from matplotlib.colors import ListedColormap
from msmtools.analysis import stationary_distribution, mfpt
from msmtools.flux import tpt
import mdtraj as md
import numpy as np
import pyemma as pe
from scipy.linalg import eig
from scipy.stats import gaussian_kde
import seaborn as sns

# Plot settings
sns.set_palette("husl", 8)
rc("font", **{"family": "Helvetica",
              "sans-serif": ["Helvetica"]})
rc("svg", **{"fonttype": "none"})
colors = sns.color_palette("husl", 8)
hv.extension("matplotlib")

warnings.filterwarnings('ignore')

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Utility functions

The version of Keras we're using unfortunately doesn't have `restore_best_weights` implemented, so I copied this from a newer version.

In [None]:
%run data.py

In [None]:
def unflatten(source: np.ndarray, lengths: List[int]) -> List[np.ndarray]:
    """
    Takes an array and returns a list of arrays.
    
    Parameters
    ----------
    source
        Array to be unflattened.
    lengths
        List of integers giving the length of each subarray.
        Must sum to the length of source.
    
    Returns
    -------
    unflat
        List of arrays.
    
    """
    conv = []
    lp = 0
    for arr in lengths:
        arrconv = []
        for le in arr:
            arrconv.append(source[lp:le + lp])
            lp += le
        conv.append(arrconv)
    ccs = list(itertools.chain(*conv))
    return ccs

In [None]:
def sort_lengths(flatlengths: Sequence[int], shapes: Sequence[int]) -> List[List[int]]:
    """
    Takes a list of lengths and returns a list of lists of lengths.
    
    Parameters
    ----------
    flatlengths
        List of lengths
    shapes
        List of shapes
    
    Returns
    -------
    lengths
        List of lists of lengths
    
    """
    lengths = []
    i = 0
    for n in shapes:
        arr = []
        for _ in range(n):
            arr.append(flatlengths[i])
            i += 1
        lengths.append(arr)
    return lengths

In [None]:
def triu_inverse(x: np.ndarray, n: int) -> np.ndarray:
    """
    Converts flattened upper-triangular matrices into full symmetric matrices.
    
    Parameters
    ----------
    x
        Flattened matrices
    n
        Size of the n * n matrix
    
    Returns
    -------
    mat
        Array of shape (length, n, n)
    
    """
    length = x.shape[0]
    mat = np.zeros((length, n, n))
    a, b = np.triu_indices(n, k=1)
    mat[:, a, b] = x
    mat += mat.swapaxes(1, 2)
    return mat

In [None]:
def statdist(X: np.ndarray) -> np.ndarray:
    """
    Calculate the equilibrium distribution of a transition matrix.
    
    Parameters
    ----------
    X
        Row-stochastic transition matrix
    
    Returns
    -------
    mu
        Stationary distribution, i.e. the left
        eigenvector associated with eigenvalue 1.
    
    """
    ev, evec = eig(X, left=True, right=False)
    mu = evec.T[ev.argmax()]
    mu /= mu.sum()
    return mu

In [None]:
def unique_sorting(rmsd: np.ndarray) -> np.ndarray:
    """
    Sorts a matrix of RMSD values.
    
    Parameters
    ----------
    rmsd
        Array of shape (n, n) with interstate differences.
        This matrix should be acquired by calculating the RMSD
        between a reference decomposition and a trial decomposition.
    
    Returns
    -------
    sorter
        Array of sorted indices
    
    """
    size = rmsd.shape[0]
    
    # -1 is not yet assigned
    sorter = np.full(size, -1, dtype=np.int8)
    sorted_idx = rmsd.argsort(axis=None)
    
    # We walk through the sorted RMSDs from low to high and assign the 2D indices.
    # If one is already assigned, we just jump to the next one, which will be the next-lowest RMSD.
    for i, j in zip(*np.unravel_index(sorted_idx, (size, size))):
        if sorter[i] < 0 and j not in sorter:
            sorter[i] = j
    return sorter

In [None]:
def idx_to_traj(idx: int, lengths: List[int]) -> Tuple[int, int]:
    """
    Given a trajectory index, find the round and trajectory file number.
    
    Parameters
    ----------
    idx
        Trajectory index
    lengths
        Length of each round
    
    Returns
    -------
    round, number
        Simulation round and corresponding trajectory number
    
    """
    lengths = np.array(lengths)
    lcs = lengths.cumsum()
    if idx >= lengths[0]:
        nr = idx - lcs[lcs < idx][-1]
        i = np.arange(len(lengths))[lcs > idx][0]
    else:
        i, nr = 0, idx
    return i, nr

In [None]:
# msmtools doesn't think our matrices are normalized enough,
# so we normalize until we hit the default tolerance requirement.
def renormalize(mat, tol=1e-12, axis=1):
    n = mat.shape[0]
    while abs(np.ones(n) - mat.sum(axis=axis)).max() > tol:
        mat = abs(mat) / abs(mat).sum(axis=axis)
    return mat

In [None]:
def plot_its(its, lags, dt=1.0):
    multi = its.ndim == 3
    nits, nlags = its.shape[-2], its.shape[-1]
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(111)
    
    if multi:
        itsm = its.mean(axis=0)
        cfl, cfu = np.percentile(its, q=(2.5, 97.5), axis=0)
    else:
        itsm = its
    
    ax.semilogy(lags * dt, lags * dt, color="k")
    ax.fill_between(lags * dt, ax.get_ylim()[0] * np.ones(len(lags)),
                    lags * dt, color="k", alpha=0.2)
    for i in range(nits):
        ax.plot(lags * dt, itsm[i], marker="o",
                    linestyle="dashed", linewidth=1.5, color=colors[-(i + 2)])
        ax.plot(lags * dt, itsm[i], marker="o", linewidth=1.5, color=colors[-(i + 2)])
        if multi:
            ax.fill_between(lags * dt, cfl[i], cfu[i],
                            interpolate=True, color=colors[-(i + 2)], alpha=0.2)
    loc = ticker.LogLocator(base=10.0, subs=(0.2, 0.4, 0.6, 0.8), numticks=12)
    ax.set_ylim(1, 100000)
    ax.set_yticks(10 ** np.arange(6))
    ax.yaxis.set_minor_locator(loc)
    ax.yaxis.set_minor_formatter(ticker.NullFormatter())
    ax.set_xlabel(r"$\tau$ [ns]", fontsize=24)
    ax.set_ylabel(r"$t_i$ [ns]", fontsize=24)
    ax.tick_params(labelsize=24)
    sns.despine(ax=ax)
    return fig

In [None]:
def plot_ck(cke, ckp, lag):
    multi = cke.ndim == 4
    n = cke.shape[-2]
    steps = cke.shape[-1]
    
    if multi:
        ckem = cke.mean(axis=0)
        ckpm = ckp.mean(axis=0)
        ckep = np.percentile(cke, q=(2.5, 97.5), axis=0)
        ckpp = np.percentile(ckp, q=(2.5, 97.5), axis=0)
    else:
        ckem = cke
        ckpm = ckp
    
    fig, axes = plt.subplots(n, n, figsize=(4 * n, 4 * n), sharex=True)
    for i in range(n):
        for j in range(n):
            ax = axes[i, j]
            x = np.arange(0, steps * lag, lag)
            if multi:
                ax.errorbar(x, ckpm[i, j], yerr=[ckpm[i, j] - ckpp[0, i, j], ckpp[1, i, j] - ckpm[i, j]],
                            linewidth=2, elinewidth=2)
                ax.fill_between(x, ckep[0, i, j], ckep[1, i, j],
                                alpha=0.2, interpolate=True, color=colors[1])
            else:
                ax.plot(x, ckpm[i, j], linestyle="-", color=colors[0], linewidth=2)
            ax.plot(x, ckem[i, j], linestyle="--", color=colors[1], linewidth=2)
            
            if i == j:
                ax.set_ylim(0.78, 1.02)
                ax.text(0, 0.8, r"{0} $\to$ {1}".format(i, j), fontsize=24, verticalalignment="center")
            else:
                ax.set_ylim(-0.02, 0.22)
                ax.text(0, 0.2, r"{0} $\to$ {1}".format(i, j), fontsize=24, verticalalignment="center")
            ax.set_xticks(np.arange(0, steps * lag, lag), minor=True)
            ax.set_xticks(np.arange(0, steps * lag, 2 * lag))
            ax.set_xticklabels((np.arange(0, steps * lag, 2 * lag) * dt).astype(int))
            ax.tick_params(labelsize=24)
    fig.text(0.5, 0.01 * 1.5 * n, r"$\tau$ [ns]", ha="center", fontsize=24)
    fig.text(0.01 * 1.5 * n, 0.5, r"$P$", va="center", rotation="vertical", fontsize=24)
    fig.subplots_adjust(wspace=0.25)
    return fig

## Data
### Trajectories
Trajectories were acquired in five rounds of 1024 simulations each, totalling 5119 runs (one simulation failed to run) at 278 K in the $NVT$ ensemble. Postprocessing involved removing water, subsampling to 250 ps timesteps, and making molecules whole.

In [None]:
trajs = sorted(glob("trajectories/red/r?/traj*.xtc"))
top = "trajectories/red/topol.gro"
KBT = 2.311420 # 278 K
traj_rounds = [1024, 2047, 3071, 4095, 5119]
nres = 42

# This is only really necessary for the residues in the plots
topo = md.load_topology(top)

We use minimum distances as features for the neural network:

In [None]:
feat = pe.coordinates.featurizer(top)
feat.add_residue_mindist()
inpcon = pe.coordinates.source(trajs, feat)

# Switch for full version:
# lengths = sort_lengths(inpcon.trajectory_lengths(), [1024, 1023, 1024, 1024, 1024])
lengths = [inpcon.trajectory_lengths()]
nframes = inpcon.trajectory_lengths().sum()

In [None]:
print("Trajectories: {0}".format(len(trajs)))
print("Frames: {0}".format(nframes))
print("Time: {0:5.3f} µs".format(inpcon.trajectory_lengths().sum() * 0.00025))

### Experimental
Chemical shifts were acquired at 278 K. We have C$\alpha$, C$\beta$, C', H$\alpha$, HN, and N shifts available. The chemical shifts were backcalculated from the trajectory using Camshift [2], the errors are given by the test-set errors.

[2]	Kohlhoff, K. J., Robustelli, P., Cavalli, A., Salvatella, X. & Vendruscolo, M. Fast and accurate predictions of protein NMR chemical shifts from interatomic distances. J. Am. Chem. Soc. 131, 13894–13895 (2009).

In [None]:
csraw = sorted(glob("trajectories/red/r?cs/cs*.dat"))
cs = [np.loadtxt(f)[:, 1:207] for f in csraw]
csexp = np.loadtxt(csraw[0])[0, 207:]

err = {"NH": 3.01, "HN": 0.56, "HA": 0.28, "CA": 1.3, "CB": 1.36, "CO": 1.38}
nerr = {"CA": 37, "CB": 27, "CO": 37, "HA": 31, "HN": 37, "NH": 37}
experr = np.array([err["CA"]] * 37 + [err["CB"]] * 27 + [err["CO"]] * 37 +
                  [err["HA"]] * 31 + [err["HN"]] * 37 + [err["NH"]] * 37)

## VAMPNet
VAMPNet[1] is composed of two lobes, one reading the system features $\mathbf{x}$ at a timepoint $t$ and the other after some lag time $\tau$. In this case the network reads all minimum inter-residue distances (780 values) and sends them through 5 layers with 256 nodes each. The final layer uses between 2 and 8 *softmax* outputs to yield a state assignment vector $\chi: \mathbb{R}^m \to \Delta^{n}$ where $\Delta^{n} = \{ s \in \mathbb{R}^n \mid 0 \le s_i \le 1, \sum_i^n s_i = 1 \}$ representing the probability of a state assignment. One lobe thus transforms a system state into a state occupation probability. We can also view this value as a kind of reverse ambiguity, i.e. how sure the network is that the system is part of a certain cluster. These outputs are then used as the input for the VAMP scoring function. We use the new enhanced version with physical constraints[2], particularly the ones for positive entries and reversibility.

[1] Mardt, A., Pasquali, L., Wu, H. & Noé, F. VAMPnets for deep learning of molecular kinetics. Nat Comms 1–11 (2017). doi:10.1038/s41467-017-02388-1

[2] Mardt, A., Pasquali, L., Noé, F. & Wu, H. Deep learning Markov and Koopman models with physical constraints. arXiv:1912.07392 [physics] (2019).

### Data preparation
We use minimum residue distances as input ($\frac{N(N-1)}{2}$ values, where $N$ is the number of residues) for the neural network, but remove the 2nd and 3rd off-diagonals:

In [None]:
filename = "intermediate/mindist-780-red.npy"
if os.path.exists(filename):
    print("Loading existing file for ensemble: {0}".format(filename))
    input_flat = np.load(filename)
else:
    print("No mindist file for ensemble, calculating from scratch...")
    input_flat = np.vstack(inpcon.get_output())
    np.save(filename, input_flat)
input_data = unflatten(input_flat, lengths)

We also use the full minimum inter-residue distances for some analysis:

In [None]:
allpairs = np.asarray(list(itertools.combinations(range(nres), 2)))
filename = "intermediate/mindist-all-red.npy"
if os.path.exists(filename):
    print("Loading existing file for ensemble: {0}".format(filename))
    mindist_flat = np.load(filename)
else:
    print("No mindist file for ensemble, calculating from scratch...")
    feat = pe.coordinates.featurizer(top)
    feat.add_residue_mindist(residue_pairs=allpairs)
    inpmindist = pe.coordinates.source(trajs, feat)
    mindist_flat = np.vstack(inpmindist.get_output())
    np.save(filename, mindist_flat)
mindist = unflatten(mindist_flat, lengths)

### Neural network hyperparameters
To allow for a larger hyperparameter search space, we use the self-normalizing neural network approach by Klambauer *et al.* [2], thus using SELU units, `AlphaDropout` and normalized `LeCun` weight initialization. The other hyperparameters are defined at the beginning of this notebook.

[2] Klambauer, G., Unterthiner, T., Mayr, A. & Hochreiter, S. Self-Normalizing Neural Networks. arXiv.org cs.LG, (2017).

In [None]:
lag = 50                         # Lag time
n_dims = input_data[0].shape[1]  # Input dimension
nres = 42                        # Number of residues
dt = 0.25                        # Trajectory timestep in ns
steps = 6                        # CK test steps
bs_frames = 900000               # Number of frames in the bootstrap sample
attempts = 20                    # Number of times to run

outsizes = np.array([2, 3, 4, 5, 6])
lags = np.array([1, 2, 5, 10, 20, 50, 100])

# Comment for full version:
bs_frames = nframes
attempts = 2
outsizes = np.array([4])

# Analysis

## Model validation
We load the previously trained neural network models and calculated implied timescales, Chapman-Kolmogorov test, and Koopman operators.

### State sorting
As every run of the neural network will generate a different ordering of (mostly the same) classes, we need to reorder them to be internally consistent. We do this by calculating the root-mean-square deviation between all average inter-residue minimum distance matrices for the individual states and matching the lowest values. We also make sure the sorting is unique, i.e. all states are accounted for, even if a certain RMSD is lower.

In [None]:
# You do not need to run this if you have the data.hdf5 file from the repo
with h5py.File("intermediate/data.hdf5") as store:
    for n in outsizes:
        # First attempt is the reference dataset
        ref_group = store["red/{0}/{1}".format(0, n)]
        ref = ref_group["full"][:, :n]
        
        # This is the average contact map for every
        # state for the first training attempt
        conwref = (ref / ref.sum(axis=0)).T @ mindist_flat
        conwrefbt = np.broadcast_to(conwref, (n, n, mindist_flat.shape[1])).swapaxes(0, 1)
        ref_group.create_dataset("full_sorted", data=ref)
        ref_group.create_dataset("bootstrap_sorted", data=ref_group["bootstrap"][:, :n])
        
        for i in range(1, attempts):
            print("Processing n={0} i={1}...".format(n, i), end="\r")
            group = store["red/{0}/{1}".format(i, n)]
            pf = group["full"][:, :n]
            
            # Average contact map for each state for attempt `i`
            conw = (pf / pf.sum(axis=0)).T @ mindist_flat
            conwbt = np.broadcast_to(conw, (n, n, mindist_flat.shape[1]))
            
            # Calculate the RMSD to our reference attempt
            rmsd = np.sqrt(((conwbt - conwrefbt) ** 2).sum(axis=-1))
            
            # Get a permutation vector and store it
            sorter = unique_sorting(rmsd)
            group.create_dataset("sorter", data=sorter)
            group.create_dataset("full_sorted", data=pf[:, sorter])
            group.create_dataset("bootstrap_sorted", data=group["bootstrap"][:, :n][:, sorter])

In [None]:
# Create example data
with h5py.File("intermediate/data.hdf5") as store:
    with h5py.File("intermediate/data-mini.hdf5", "w") as write:
        for k in ("red", "ox"):
            ens = write.create_group(k)
            for i in [0, 1]:
                att = ens.create_group(str(i))
                for n in [4]:
                    out = att.create_group(str(n))
                    for key in ("k", "its", "cke", "ckp"):
                        out.create_dataset(key, data=store["{0}/{1}/{2}/{3}".format(k, i, n, key)][:], compression=9)
                    for key in ("mu", "bootstrap", "full", "bootstrap_sorted", "full_sorted", "sorter"):
                        out.create_dataset(key, data=store["{0}/{1}/{2}/{3}".format(k, i, n, key)][:sum(lengths[0][:6])], compression=9)

In [None]:
sorters = {n: np.empty((attempts, n), dtype=int) for n in outsizes}
pfs = {n: np.empty((attempts, nframes, n)) for n in outsizes}
pfsn = {n: np.empty((attempts, nframes, n)) for n in outsizes}
pfs_boot = {n: np.empty((attempts, bs_frames, n)) for n in outsizes}
pfsn_boot = {n: np.empty((attempts, bs_frames, n)) for n in outsizes}
koops = {n: np.empty((attempts, n, n)) for n in outsizes}
pis = {n: np.empty((attempts, n)) for n in outsizes}
with h5py.File("intermediate/data-mini.hdf5") as read:
    store = read["red"]
    for i in range(attempts):
        for n in outsizes:
            sorters[n][i] = store["{0}/{1}/sorter".format(i, n)][:]
            pfs[n][i] = store["{0}/{1}/full_sorted".format(i, n)][:]
            pfsn[n][i] = pfs[n][i] / pfs[n][i].sum(axis=0)
            pfs_boot[n][i] = store["{0}/{1}/bootstrap_sorted".format(i, n)][:]
            pfsn_boot[n][i] = pfs_boot[n][i] / pfs_boot[n][i].sum(axis=0)
            koops[n][i] = store["{0}/{1}/k".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
            pis[n][i] = statdist(koops[n][i])

We calculate an equilibrium weighting vector, analogous to discrete Markov models:

\begin{equation}
w_t = \frac{\langle \chi(\mathbf{x}_t) | \mathbf{\pi} \rangle}{\sum_t^N \langle \chi(\mathbf{x}_t) | \mathbf{\pi} \rangle}
\end{equation}

In [None]:
weights = {}
for n in outsizes:
    weights[n] = np.empty((attempts, nframes))
    for i in range(attempts):
        w = pfs[n][i] @ pis[n][i]
        weights[n][i] = w / w.sum()

#### Global sorting
Because state assignments will not only be different within a choice of number of states, but also globally, we sort the states based on the previous (coarser) state decomposition.

In [None]:
prods = {}
for n in outsizes[1:]:
    prods[n] = np.einsum("jk,jl->kl", pfs[n - 1].mean(axis=0), pfs[n].mean(axis=0))

global_sorter = {2: np.arange(2, dtype=int)}
for n in outsizes[1:]:
    global_sorter[n] = np.full(n, -1, dtype=int)
    mass = prods[n][global_sorter[n - 1]]
    for i in range(n - 1):
        sort = mass[i].argsort()[::-1]
        for j in range(n):
            if sort[j] not in global_sorter[n]:
                global_sorter[n][i] = sort[j]
                break
    global_sorter[n][-1] = [i for i in range(n) if i not in global_sorter[n][:-1]][0]

if outsizes.shape[0] == 1:
    n = outsizes[0]
    global_sorter[n] = np.arange(n, dtype=int)

### Implied timescales
The implied timescales are computed from the eigenvalues $\lambda_i$ of the Koopman operator $\mathbf{K}$ and the selected lag time $\tau$:
$$ t_i = \frac{-\tau}{\log | \lambda_i(\tau) |} $$
We compute the implied timescales for several lag times $\tau$. Ideally, we want to choose $\tau$ so that we're in a regime where the timescales are mostly independent of $\tau$. This is the case for lag times longer than $\tau = 12.5\,\mathrm{ns}$:

In [None]:
with h5py.File("intermediate/data-mini.hdf5") as store:
    for n in outsizes:
        its = np.stack(store["red/{0}/{1}/its".format(i, n)] for i in range(attempts))
        fig = plot_its(its, lags, dt=dt)
        plt.savefig("figs/its-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
        plt.savefig("figs/its-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Chapman-Kolmogorov Test
Now that we have found a Koopman operator $\mathbf{K}$, how do we know if it adequately describes the dynamical system?

There are a number of tests we can perform to ensure we have found a good approximation. One of the more stringent ones is the Chapman Kolmogorov (CK-test). It is based on the assumption that multiple applications of our operator to our system state at time $t$ should be equivalent to estimating the operator at a longer lag time $\tau$:

$$ \mathbf{K}(\tau)^n \approx \mathbf{K}(n\tau) $$

I.e., multiple applications of the operator should produce the same result as using an operator with a multiple of a certain lag time, within an error. This is a more stringent test of model quality than the implied timescales test described above.

In [None]:
with h5py.File("intermediate/data-mini.hdf5") as store:
    for n in outsizes:
        cke = np.stack(store["red/{0}/{1}/cke".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
                       for i in range(attempts))[:, global_sorter[n]][:, :, global_sorter[n]]
        ckp = np.stack(store["red/{0}/{1}/ckp".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
                       for i in range(attempts))[:, global_sorter[n]][:, :, global_sorter[n]]
        fig = plot_ck(cke, ckp, lag=50)
        plt.savefig("figs/ck-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
        plt.savefig("figs/ck-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Experimental backcalculation
We calculate the root-mean-square deviation of backcalculated values (using CamShift) to experimental observables for the whole ensemble:

In [None]:
labels = {
    "CA": ((0, 37), (30, 70)),
    "CB": ((37, 64), (10, 70)),
    "CO": ((64, 101), (165, 185)),
    "HA": ((101, 132), (0, 6)),
    "HN": ((132, 169), (0, 12)),
    "NH": ((169, 206), (90, 140))
}

In [None]:
n = 4
csm = np.einsum("ij,jk->ik", weights[n], np.vstack(cs))
sqdev = (csm - csexp) ** 2
rmsd = {k: np.sqrt(sqdev[:, v[0][0]:v[0][1]].mean(axis=1)) for k, v in labels.items()}
rmsderr = [(k, rmsd[k].mean(axis=0), np.percentile(rmsd[k], (2.5, 97.5))) for k in sorted(labels.keys())]

In [None]:
fig = plt.figure(figsize=(5, 4))
ax = fig.add_subplot(111)
ax.bar(np.arange(6), [d[1] for d in sorted(tuple(err.items()))], capsize=8,
       color=colors[5], alpha=0.5, label="Camshift Error")
ax.bar(np.arange(6), [d[1] for d in rmsderr],
       capsize=8, color=colors[5], label="Simulation Error")
ax.set_xticks(np.arange(6))
ax.set_xticklabels([d[0] for d in rmsderr])
ax.set_ylabel("RMSD [ppm]", fontsize=24)
ax.set_ylim(0, 3.5)
ax.tick_params(axis="x", length=0, pad=10, labelsize=24)
ax.tick_params(axis="y", labelsize=24)
sns.despine(ax=ax)
ax.legend(fontsize=16)
plt.savefig("figs/rmsd-tot.pdf", bbox_inches="tight", transparent=True)
plt.savefig("figs/rmsd-tot.svg", bbox_inches="tight", transparent=True)

### Convergence
We would ideally like to see how converged our ensemble is with respect to the timescales and stationary distribution given by our model. We thus build trial models with different numbers of trajectories:

In [None]:
n = 4
k_conv = np.load("intermediate/k-conv-red-{0}.npy".format(n))

#### Timescales
We can examine how the timescales change when estimating with more trajectories:

In [None]:
its_conv = np.empty((len(traj_rounds), attempts, n - 1))
for j, _ in enumerate(traj_rounds):
    for i in range(attempts):
        ev, _ = eig(k_conv[j, i], left=True, right=False)
        order = np.abs(ev).argsort()[::-1]
        its_conv[j, i] = -50 * dt / np.log(np.real(ev[order][1:]))
its_conv_m = its_conv.mean(axis=1)
its_conv_p = np.percentile(its_conv, (2.5, 97.5), axis=1)

In [None]:
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(111)
x = np.array(traj_rounds)
for i in range(n - 1):
    ax.semilogy(x, its_conv_m[:, i], color=colors[i + 4], linewidth=2, marker="o")
    ax.fill_between(x, its_conv_p[0, :, i], its_conv_p[1, :, i],
                    interpolate=True, color=colors[i + 4], alpha=0.2)

loc = ticker.LogLocator(base=10.0, subs=(0.2, 0.4, 0.6, 0.8), numticks=12)
ax.set_ylim(1, 100000)
ax.set_yticks(10 ** np.arange(6))
ax.yaxis.set_minor_locator(loc)
ax.yaxis.set_minor_formatter(ticker.NullFormatter())
ax.set_xticks(x)
ax.tick_params(labelsize=16)
ax.set_xlabel("# Trajectories", fontsize=24)
ax.set_ylabel("$t$", fontsize=24)
sns.despine(ax=ax)
plt.savefig("figs/its-conv-coarse-{0}.pdf".format(n), transparent=True, bbox_inches="tight")
plt.savefig("figs/its-conv-coarse-{0}.svg".format(n), transparent=True, bbox_inches="tight")

### Metainference ensemble
#### Radius of gyration
We calculate the radius of gyration $R_\mathrm{gyr}$ of a previously performed metainference metadynamics simulation [3] to compare to our kinetic ensemble:

[3] Heller, G. T. et al. Small molecule sequestration of amyloid-β as a drug discovery strategy for Alzheimer’s disease. bioRxiv 729392 (2019) doi:10.1101/729392.

In [None]:
feat = pe.coordinates.featurizer(top)
feat.add_custom_func(lambda t: md.compute_rg(t).astype(np.float32).reshape(-1, 1), 1)
inp = pe.coordinates.source(traj_mi, feat)
gyrmi = np.vstack(inp.get_output())

In [None]:
gyrs_mi, _ = np.histogram(
    gyrmi.flatten(), bins=nbins,
    range=(xmin, xmax), weights=weights_mi)

### State decomposition
By looking at how states split when choosing finer state decompositions, we can indirectly get a feeling for the spectral gaps, and also see the consistency of the neural network.

In [None]:
prods = {}
for n in outsizes[1:]:
    prods[n] = np.einsum("jk,jl->kl", pfs[n - 1].mean(axis=0)[:, global_sorter[n - 1]],
                         pfs[n].mean(axis=0)[:, global_sorter[n]])

# This is just formatting for holoviews
prods_format = {}
for n in outsizes[1:]:
    prods_format[n] = {}
    for u in range(n - 1):
        for v in range(n):
            prods_format[n][(u, v)] = prods[n][u, v]
            
prods_format = {n: {(int("{0}{1}".format(n - 1, u)),
                     int("{0}{1}".format(n, v))): val
                    for (u, v), val in prods_format[n].items()}
                for n in outsizes[1:]}

sankeydata = sorted(list(itertools.chain(
    *[[(u, v, val) for (u, v), val in prods_format[n].items()]
      for n in outsizes[1:]])))

In [None]:
cmap = ListedColormap(list(itertools.chain(*[sns.color_palette("husl", 8)[:n] for n in outsizes])))
sank = hv.Sankey(sankeydata)
sank = sank.options(node_width=30, cmap=cmap, edge_alpha=0.35, node_size=0.1, edge_linewidth=0, node_linewidth=0)
fig = hv.render(sank)

ax = fig.axes[0]
for t in ax.texts:
    text = t.get_text()
    t.set_text(text[1])
    t.set_fontsize(16)
fig.set_dpi(600)
fig.savefig("figs/sankey-ab-6.pdf")
fig

## Structure
### TICA
Time-lagged independent component analysis is a special case of Koopman operator estimation using a linear projection [4]. We solve the following generalized eigenvalue problem:

$$ \mathbf{C}_{01}v = \lambda \mathbf{C}_{00} v $$

The eigenvectors encode the slowest dynamics of the system, and we use them as a convenient visualization technique.

[4]	Pérez-Hernández, G., Paul, F., Giorgino, T., De Fabritiis, G. & Noé, F. Identification of slow molecular order parameters for Markov model construction. The Journal of Chemical Physics 139, 015102–14 (2013).

In [None]:
ticacon = pe.coordinates.tica(mindist, lag=4, dim=-1, kinetic_map=True)
ticscon = ticacon.get_output()
ycon = np.vstack(ticscon)

print("tIC Dimensions: {0}".format(ycon.shape[1]))
print("Required dimensions for 90 %: {0}".format(ticacon.cumvar[ticacon.cumvar < 0.9].shape[0]))

#### Free energy surface
We also show the free energy surface projected onto the two slowest tICs in the form of a kernel density estimate:

In [None]:
kernel = gaussian_kde(ycon[::10, :2].T)
xmin, ymin, *_ = ycon.min(axis=0)
xmax, ymax, *_ = ycon.max(axis=0)
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
posi = np.vstack((X.ravel(), Y.ravel()))
Z = kernel(posi).reshape(X.shape)
mat = np.rot90(Z.copy())
mat[mat < 0.01] = np.nan

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
cmap = "plasma"
F = -KBT * np.log(Z)
F -= F.min()
ax.contourf(X, Y, F, np.arange(0.0, 10, 1), cmap=cmap)
ax.contour(X, Y, F, np.arange(0.0, 10, 1), cmap=cmap, linewidth=10)
ax.tick_params(labelsize=24)
ax.set_xlabel(r"tIC 1", fontsize=24, labelpad=10)
ax.set_ylabel(r"tIC 2", fontsize=24, labelpad=10)
sns.despine(ax=ax)

### Contact maps
We calculate contact maps from inter-residue minimum distance matrices, using the bootstrapped state assignments:

In [None]:
cutoff = 0.8
inds = unflatten(np.arange(nframes).reshape(-1, 1), lengths)
contacts = {n: np.empty((attempts, n, nres, nres)) for n in outsizes}
for i in range(attempts):
    generator = DataGenerator.from_state(inds, "models/model-idx-red-{0}.hdf5".format(i))
    idx = generator(n=2, lag=20).trains[0].flatten().astype(int)
    for n in outsizes:
        print("Processing n={0} i={1}...".format(n, i), end="\r")
        con = (np.einsum("jk,jl->kl", pfs[n][i, idx], (mindist_flat[idx] < cutoff)) /
               pfs[n][i, idx].sum(axis=0).reshape(-1, 1))
        contacts[n][i] = np.asarray([triu_inverse(con[j], nres)[0] for j in range(n)])

We choose a more stringent cutoff of 0.45 nm to compare to Meng's ensemble:

In [None]:
cutoff = 0.45
n = 4
inds = unflatten(np.arange(nframes).reshape(-1, 1), lengths)
enscon = np.empty((attempts, nres, nres))
for i in range(attempts):
    generator = DataGenerator.from_state(inds, "models/model-idx-red-{0}.hdf5".format(i))
    idx = generator(n=2, lag=20).trains[0].flatten().astype(int)
    print("Processing i={0}...".format(i), end="\r")
    con = np.einsum("j,jk->k", weights[n][i, idx] / weights[n][i, idx].sum(), (mindist_flat[idx] < cutoff))
    enscon[i] = triu_inverse(con, nres)[0]

In [None]:
n = 4
fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
for ax, dat, title in zip(axes.flatten(),
                          [enscon.mean(axis=0), enscon.std(axis=0)],
                          ["$P$", "$\sigma(P)$"]):
    cmap = sns.light_palette(colors[4], as_cmap=True)
    im = ax.matshow(dat, vmin=0.0, vmax=1.0, interpolation="nearest", cmap=cmap)
    ax.xaxis.tick_bottom()
    ax.set_xticks(np.arange(0, 50, 10))
    ax.set_xticks(np.arange(0, 42, 1), minor=True)
    ax.set_yticks(np.arange(0, 50, 10))
    ax.set_yticks(np.arange(0, 42, 1), minor=True)
    ax.set_xticklabels(np.arange(0, 50, 10))
    ax.set_xlim(0, 41)
    ax.set_ylim(41, 0)
    ax.tick_params(labelsize=24)

In [None]:
print("ASP1 - ALA42 contact ({0:4.2f} nm) probability: {1:4.4f} +/- {2:4.4f}".format(
    cutoff, enscon.mean(axis=0)[0, 41], enscon.std(axis=0)[0, 41]))

### Secondary structure
Secondary structure assignment with DSSP:

In [None]:
# Custom func makes this a lot easier
dssptable = str.maketrans("HBEGITS ", "01234567")
def dssp_enc(traj):
    raw = md.compute_dssp(traj.atom_slice(range(627)), simplified=False)
    return np.char.translate(raw, table=dssptable).astype(np.float32)

feat = pe.coordinates.featurizer(top)
feat.add_custom_func(dssp_enc, dim=42)
inp = pe.coordinates.source(trajs, feat)
dssp = np.vstack(inp.get_output()).astype(np.int32)

# One-hot encoding
nvals = dssp.max() + 1
dsspoh = np.eye(nvals, dtype=np.int32)[dssp]

# We could use the simplified DSSP scheme, but this gives us a bit more flexibility
dssplow = np.empty((nframes, nres, 4))
dssplow[:, :, 0] = dsspoh[:, :, [0, 3, 4]].sum(axis=-1)
dssplow[:, :, 1] = dsspoh[:, :, [1, 2]].sum(axis=-1)
dssplow[:, :, 2] = dsspoh[:, :, [5, 6]].sum(axis=-1)
dssplow[:, :, 3] = dsspoh[:, :, 7]

In [None]:
inds = unflatten(np.arange(nframes).reshape(-1, 1), lengths)
sec = {n: np.empty((attempts, n, nres, 4)) for n in outsizes}
for i in range(attempts):
    generator = DataGenerator.from_state(inds, "models/model-idx-red-{0}.hdf5".format(i))
    idx = generator(n=2, lag=20).trains[0].flatten().astype(int)
    for n in outsizes:
        print("Processing n={0} i={1}...".format(n, i), end="\r")
        sec[n][i] = (np.einsum("ij,ikl->jkl", pfs[n][i, idx], dssplow[idx]) /
                     pfs[n][i, idx].sum(axis=0).reshape(-1, 1, 1))

secm = {n: sec[n].mean(axis=0)[global_sorter[n]] for n in outsizes}
secs = {n: np.percentile(sec[n], q=(2.5, 97.5), axis=0)[:, global_sorter[n]] for n in outsizes}

In [None]:
for n in outsizes:
    fig, axes = plt.subplots(4, n, figsize=(4 * n, 7.5), sharex=True,
                             gridspec_kw={"height_ratios": [5, 5, 15, 1], "width_ratios": [3] * n})
    for i, axs in enumerate(axes.T):
        for j, ax in enumerate(axs[:2]):
            ax.bar(np.arange(nres), secm[n][i, :, j], yerr=[secm[n][i, :, j] - secs[n][0, i, :, j],
                                                            secs[n][1, i, :, j] - secm[n][i, :, j]],
                   capsize=2, color=colors[i])
            sns.despine(ax=ax)
            ax.set_ylim(0, 1)
            if i == 0:
                ax.set_ylabel([r"$P(\alpha)$", r"$P(\beta)$"][j], fontsize=24)
            else:
                ax.set_yticklabels([])
            ax.set_yticks([0, 0.5, 1.0])
            ax.set_xticks(np.arange(0, 50, 10))
            ax.set_xticks(np.arange(0, 42, 1), minor=True)
            ax.set_xticklabels([])
            ax.tick_params(labelsize=24)
            ax.tick_params(axis="x", length=0, labelsize=12)

        ax = axs[2]
        dat = contacts[n].mean(axis=0)[global_sorter[n]][i]
        cmap = sns.light_palette(colors[i], as_cmap=True)
        im = ax.matshow(dat, vmin=0.0, vmax=1.0, interpolation="nearest", cmap=cmap)
        ax.xaxis.tick_bottom()
        ax.set_xticks(np.arange(0, 50, 10))
        ax.set_xticks(np.arange(0, 42, 1), minor=True)
        ax.set_yticks(np.arange(0, 50, 10))
        ax.set_yticks(np.arange(0, 42, 1), minor=True)
        ax.set_xticklabels(np.arange(0, 50, 10))
        if i != 0:
            ax.set_yticklabels([])
        ax.set_xlim(0, 41)
        ax.set_ylim(41, 0)
        ax.tick_params(labelsize=24)

        cax = axs[3]
        shared = cax.get_shared_x_axes()
        for a in shared.get_siblings(cax):
            shared.remove(a)
        fig.colorbar(im, cax=cax, fraction=0.046, pad=0.04, orientation="horizontal")
        cax.xaxis.set_ticks(np.arange(0, 1.5, 0.5))
        cax.set_xticklabels(np.arange(0, 1.5, 0.5))
        cax.tick_params(labelsize=16)
        ax.set_xticks(np.arange(0, 50, 10))
        ax.set_xticks(np.arange(0, 42, 1), minor=True)
        ax.set_yticks(np.arange(0, 50, 10))
        ax.set_yticks(np.arange(0, 42, 1), minor=True)
        ax.set_xticklabels(np.arange(0, 50, 10))

        cax.xaxis.set_ticks(np.arange(0, 1.5, 0.5))
        cax.set_xticklabels(np.arange(0, 1.5, 0.5))
        cax.tick_params(labelsize=16)
    plt.savefig("figs/structure-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/structure-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Radius of gyration
We can calculate the radius of gyration $R_\mathrm{gyr}$ for the whole ensemble:

In [None]:
feat = pe.coordinates.featurizer(top)
feat.add_custom_func(lambda t: md.compute_rg(t).astype(np.float32).reshape(-1, 1), 1)
inp = pe.coordinates.source(trajs, feat)
gyr = np.vstack(inp.get_output())

In [None]:
nbins = 100
xmin, xmax = gyr.min(), gyr.max()
inds = unflatten(np.arange(nframes).reshape(-1, 1), lengths)
gyrs = {n: np.empty((attempts, nbins)) for n in outsizes}
for i in range(attempts):
    generator = DataGenerator.from_state(inds, "models/model-idx-red-{0}.hdf5".format(i))
    idx = generator(n=2, lag=20).trains[0].flatten().astype(int)
    for n in outsizes:
        print("Processing n={0} i={1}...".format(n, i), end="\r")
        gyrs[n][i], edges = np.histogram(
            gyr.flatten()[idx], bins=nbins,
            range=(xmin, xmax), weights=weights[n][i, idx] / weights[n][i, idx].sum())

gyrm = {n: gyrs[n].mean(axis=0) for n in outsizes}
gyrs = {n: np.percentile(gyrs[n], q=(2.5, 97.5), axis=0) for n in outsizes}

In [None]:
n = 4
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(111)
x = 0.5 * (edges[:-1] + edges[1:])
ax.plot(x, gyrm[n], linewidth=2, label="MSM", color=colors[4])
ax.fill_between(x, gyrm[n], alpha=0.3, color=colors[4])

ax.plot(x, gyrs_mi, linewidth=2, label="MI", color=colors[6])
ax.fill_between(x, gyrs_mi, alpha=0.3, color=colors[6])

ax.set_xlim(0.9, 2.1)
ax.set_xlabel(r"$R_\mathrm{gyr}$ [nm]", fontsize=24)
ax.set_ylabel(r"Density", fontsize=24)
ax.tick_params(labelsize=24)
ax.legend(fontsize=16)
sns.despine(ax=ax)
plt.savefig("figs/gyr.pdf", bbox_inches="tight", transparent=True)
plt.savefig("figs/gyr.svg", bbox_inches="tight", transparent=True)

In [None]:
inds = unflatten(np.arange(nframes).reshape(-1, 1), lengths)
gyrs_av = {n: np.empty((attempts, n)) for n in outsizes}
for i in range(attempts):
    generator = DataGenerator.from_state(inds, "models/model-idx-red-{0}.hdf5".format(i))
    idx = generator(n=2, lag=20).trains[0].flatten().astype(int)
    for n in outsizes:
        print("Processing n={0} i={1}...".format(n, i), end="\r")
        gyrs_av[n][i] = (pfs[n][i, idx].T @ gyr.flatten()[idx]) / pfs[n][i, idx].sum(axis=0)

gyr_avm = {n: gyrs_av[n].mean(axis=0) for n in outsizes}
gyr_avs = {n: np.percentile(gyrs_av[n], q=(2.5, 97.5), axis=0) for n in outsizes}

In [None]:
for n in outsizes:
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n), gyr_avm[n][global_sorter[n]],
           yerr=[gyr_avm[n][global_sorter[n]] - gyr_avs[n][0][global_sorter[n]],
                 gyr_avs[n][1][global_sorter[n]] - gyr_avm[n][global_sorter[n]]],
           color=colors, capsize=8)
    for i in range(n):
        ax.text(i, gyr_avs[n][1][global_sorter[n]][i] + 0.02, "{:.2f}".format(gyr_avm[n][global_sorter[n]][i]),
                fontsize=20, ha="center", va="center")
    ax.set_ylim(1, 1.3)
    ax.set_xticks(np.arange(n))
    ax.set_ylabel(r"$R_\mathrm{gyr}$ [nm]", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/gyr-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/gyr-{0}.svg".format(n), bbox_inches="tight", transparent=True)

## Kinetics
### Koopman operators
We can look at the Koopman operators and their errors directly:

In [None]:
n = 4
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, mat, title in zip(axes, (koops[n].mean(axis=0)[global_sorter[n]][:, global_sorter[n]],
                                 koops[n].std(axis=0)[global_sorter[n]][:, global_sorter[n]]),
                          ("$P$", r"$\sigma(P)$")):
    ax.matshow(mat, vmin=0.0, vmax=0.02, interpolation="none", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.4f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)

### Mean first passage times
We calculate mean first-passage times from our Koopman matrix $\mathbf{K}(\tau)$:

In [None]:
mfpts, rates = {}, {}
for n in outsizes:
    mfpts[n] = np.zeros((attempts, n, n))
    rates[n] = np.zeros((attempts, n, n))
    for i in range(attempts):
        for u in range(n):
            for v in range(n):
                if u == v:
                    continue
                koop = renormalize(koops[n][i])
                f = tpt(koop, [u], [v])
                rates[n][i, u, v] = f.rate
                mfpts[n][i, u, v] = f.mfpt * 50 * dt * 0.001

In [None]:
n = 4
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, mat, title in zip(axes, (mfpts[n].mean(axis=0)[global_sorter[n]][:, global_sorter[n]],
                                 mfpts[n].std(axis=0)[global_sorter[n]][:, global_sorter[n]]),
                          (r"$\mathrm{MFPT}$ [µs]", r"$\sigma(\mathrm{MFPT})$")):
    im = ax.matshow(mat, vmin=0.0, vmax=60, interpolation="nearest", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.2f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)
plt.savefig("figs/mfpt-red-tmp-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
plt.savefig("figs/mfpt-red-tmp-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Transition rates
The transition rates are the inverse of the mean first-passage times (MFPTs):

In [None]:
n = 4
perms = 1e6 * rates[n] / (50 * dt)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, mat, title in zip(axes, (perms.mean(axis=0)[global_sorter[n]][:, global_sorter[n]],
                                 perms.std(axis=0)[global_sorter[n]][:, global_sorter[n]]),
                          [r"$k_{ij}$ [1/ms]", "$\sigma(k_{ij})$ [1/ms]"]):
    ax.matshow(mat, vmin=0.0, vmax=1000, interpolation="none", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.2f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)

### Lifetimes
The lifetimes are dependent on the probability of transition into the same state, they are given by:

$$ \overline{t}_i = \frac{-\tau}{\log K_{ii}} $$

where $K_{ii}$ are the diagonal elements of the transition matrix and $\tau$ is the model lagtime.

In [None]:
lifetimes = {}
for n in outsizes:
    lt = np.empty((attempts, n))
    for i in range(attempts):
        lt[i] = -lag * dt / np.log(np.diag(koops[n][i])) * 1e-3
    lifetimes[n] = lt.mean(axis=0)[global_sorter[n]], *(np.percentile(lt, q=(2.5, 97.5), axis=0)[:, global_sorter[n]])

In [None]:
ylimit = 6
for n in outsizes:
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n), lifetimes[n][0], yerr=[lifetimes[n][0] - lifetimes[n][1],
                                                lifetimes[n][2] - lifetimes[n][0]],
           color=colors, capsize=8)
    for i in range(n):
        if lifetimes[n][2][i] < ylimit:
            ax.text(i, lifetimes[n][2][i] + 0.5, "{:.2f}".format(lifetimes[n][0][i]),
                    fontsize=20, ha="center", va="center")
    ax.set_ylim(0, ylimit)
    ax.set_xticks(np.arange(n))
    ax.set_ylabel(r"$\overline{t}_i$ [µs]", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/lifetime-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/lifetime-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Timescales
We just plot the relaxation timescales here for clarity:

In [None]:
timescales = {}
with h5py.File("intermediate/data.hdf5") as read:
    store = read["red"]
    for n in outsizes:
        its = np.stack(store["{0}/{1}/its".format(i, n)] for i in range(attempts))[:, ::-1] * 1e-3
        timescales[n] = (its[:, :, -2].mean(axis=0), *(np.percentile(its[:, :, -2], q=(2.5, 97.5), axis=0)))

In [None]:
for n in outsizes:
    fig = plt.figure(figsize=(n - 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n - 1), timescales[n][0], yerr=[timescales[n][0] - timescales[n][1],
                                                     timescales[n][2] - timescales[n][0]],
           color=colors[4:], capsize=8)
    for i in range(n - 1):
        ax.text(i, timescales[n][2][i] + 0.2, "{:.2f}".format(timescales[n][0][i]),
                fontsize=20, ha="center", va="center")
    ax.set_ylim(0, 3)
    ax.set_xticks(np.arange(n - 1))
    ax.set_ylabel(r"$t_i$ [µs]", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/timescales-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/timescales-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Equilibrium distribution
We can look at the equilibrium distributions $\mathbf{\pi}$:

In [None]:
for n in outsizes:
    pm = pis[n].mean(axis=0)[global_sorter[n]]
    pv = np.percentile(pis[n], q=(2.5, 97.5), axis=0)[:, global_sorter[n]]
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n), pm, yerr=[pm - pv[0], pv[1] - pm], color=colors, capsize=8)
    for i in range(n):
        ax.text(i, pv[1, i] + 0.05, "{:.2f}".format(pm[i]), fontsize=20, ha="center", va="center")
    ax.set_ylim(0, 1)
    ax.set_xticks(np.arange(n))
    ax.set_ylabel("Probability", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/pops-{0}.pdf".format(n), transparent=True, bbox_inches="tight")
    plt.savefig("figs/pops-{0}.svg".format(n), transparent=True, bbox_inches="tight")

In [None]:
n = 4
pm = pis[n].mean(axis=0)[global_sorter[n]]
pv = np.percentile(pis[n], q=(2.5, 97.5), axis=0)[:, global_sorter[n]]
for i in range(n):
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    cols = [(0.8, 0.8, 0.8)] * n
    cols[i] = colors[i]
    ax.bar(np.arange(n), pm, yerr=[pm - pv[0], pv[1] - pm], color=cols, capsize=8)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.0, 0.5, 1.0])
    ax.set_xticks(np.arange(n))
    ax.set_ylabel("P", fontsize=32, labelpad=10)
    ax.tick_params(labelsize=32)
    ax.tick_params(axis="x", length=0, pad=10)
    ax.text(0.7, 0.9, "{0}: {1:2.0f} %".format(i, pm[i] * 100), fontsize=32)
    sns.despine(ax=ax)
    plt.savefig("figs/pops-fine-{0}-{1}.pdf".format(n, i), transparent=True, bbox_inches="tight")
    plt.savefig("figs/pops-fine-{0}-{1}.svg".format(n, i), transparent=True, bbox_inches="tight")

### Entropy
We can get some idea of the "entropy" of each state by calculating the information entropy $S_i = -\sum_t^N \chi_i(\mathbf{x}_t) \log_2(\chi_i(\mathbf{x}_t)) $. In some sense, this encodes ambiguity in the state assignment, or how "wide" the state is:

In [None]:
ents = {}
for i, n in enumerate(outsizes):
    ent = -np.nansum(pfsn_boot[n] * np.log2(pfsn_boot[n]) / np.log2(pfsn_boot[n].shape[1]), axis=1)
    ents[n] = np.array([ent.mean(axis=0)[global_sorter[n]],
                        *(np.percentile(ent, (2.5, 97.5), axis=0))[:, global_sorter[n]]])

In [None]:
for n in outsizes:
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n), ents[n][0], yerr=[ents[n][0] - ents[n][1], ents[n][2] - ents[n][0]],
           color=colors, capsize=8)
    for i in range(n):
        ax.text(i, ents[n][2, i] + 0.05, "{:.2f}".format(ents[n][0, i]), fontsize=20, ha="center", va="center")
    ax.set_ylim(0, 1)
    ax.set_xticks(np.arange(n))
    ax.set_ylabel("Normalized entropy", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)

### Graph
We will now look at the model in the classic graph format. A good way of projecting the states is on the space of the two slowest time-lagged independent components (tICs), as they separate the states very well.

In [None]:
pos = {}
for n in outsizes:
    state_tic = np.einsum("ijk,jl->ikl", pfsn[n], ycon[:, :2])
    pos[n] = state_tic.mean(axis=0)[global_sorter[n]]

We mask out transition probabilities below a certain threshold, and define the crispness as $\mathscr{c}_i := S_i^{-1}$, i.e. a crisper state is less ambiguous in it's state assignments.

In [None]:
n = 4

minflux = 3e-4
ps = np.empty((attempts, n, n))
for i in range(attempts):
    p = koops[n][i][global_sorter[n]][:, global_sorter[n]].copy()
    u, v = np.where((np.diag(pis[n][i][global_sorter[n]]) @ p) < minflux)
#     p[u, v] = 0.0
    ps[i, :, :] = p

crisp = 1 / ents[n][0]
crisp /= crisp.max()
posi = pos[n]

In [None]:
# This is just to get the arrow thickness proportional to the flux,
# when using external software like Illustrator...
psm = ps.mean(axis=0)
pmin, pmax = psm[psm > 0].min(), psm[psm < 0.9].max()
psm[(psm == 0.0) | (psm > 0.9)] = np.nan
psm -= pmin
psm /= pmax
psm * 3 + 1

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
fig, posi = pe.plots.plot_network(
    ps.mean(axis=0), pos=posi, state_sizes=pis[n].mean(axis=0)[global_sorter[n]], show_frame=True, ax=ax, arrow_curvature=1.2)
plt.close();

figpadding = 0.2
xmin = posi[:, 0].min()
xmax = posi[:, 0].max()
Dx = xmax - xmin
xmin -= Dx * figpadding
xmax += Dx * figpadding
Dx *= 1 + figpadding
ymin = posi[:, 0].min()
ymax = posi[:, 0].max()
Dy = ymax - ymin
ymin -= Dy * figpadding
ymax += Dy * figpadding
Dy *= 1 + figpadding
sizes = min(Dx, Dy) ** 2 * 0.5 * pis[n].mean(axis=0)[global_sorter[n]] / (pis[n].mean(axis=0).max() * n)
crispness = min(Dx, Dy) ** 2 * 0.5 * crisp

# We need to redraw these to be able to show the crispness
ax.artists = []
for i in range(n):
    ax.add_artist(plt.Circle(posi[i], radius=0.5 * np.sqrt(0.5 * sizes[i]),
                             facecolor=colors[i], edgecolor=colors[i], alpha=0.5))
    ax.add_artist(plt.Circle(posi[i], radius=0.1 * np.sqrt(0.5 * crispness[i]), facecolor=colors[i]))
ax.tick_params(labelsize=24)
ax.set_xticks(np.arange(-2, 3, 1))
ax.set_yticks(np.arange(-2, 3, 1))
ax.set_xticks(np.arange(-2, 2.1, 0.1), minor=True)
ax.set_yticks(np.arange(-2, 2.1, 0.1), minor=True)
ax.set_xlim(-2.1, 2.1)
ax.set_ylim(-2.1, 2.1)
ax.set_xlabel(r"tIC 0", fontsize=24, labelpad=10)
ax.set_ylabel(r"tIC 1", fontsize=24, labelpad=10)
sns.despine(ax=ax)
fig.savefig("figs/graph-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
fig.savefig("figs/graph-{0}.svg".format(n), bbox_inches="tight", transparent=True)
fig

## Comparison to unconstrained model
By comparing the unconstrained model, estimated directly from the state assignments $\chi(\mathbf{x}_t)$, with the constrained model, we can get an idea of the impact of constraining reversibility and positive matrix elements.

In [None]:
def estimate_koopman(data: List[np.ndarray], lag: int) -> np.ndarray:
    """
    Estimate the Koopman matrix.
    
    Parameters
    ----------
    data
        List of state vector trajectories
    lag
        Lag time for estimating the matrix
    
    Returns
    -------
    K
        The Koopman matrix
    
    """
    cl = pe.coordinates.covariance_lagged(
            data=data, lag=lag, weights="empirical",
            reversible=True, bessel=True)
    return np.linalg.pinv(cl.C00_) @ cl.C0t_

In [None]:
ukoops = {}
for n in outsizes:
    ukoops[n] = np.empty((attempts, n, n))
    for i in range(attempts):
        ppf = unflatten(pfs[n][i], lengths)
        ukoops[n][i] = estimate_koopman(ppf, lag=50)

### Transition matrix comparison
We compare the transition matrices from both the constrained and unconstrained VAMPNets. We would expect the shorter timescales in the constrained case as the eigenfunction estimation becomes hampered.

In [None]:
n = 4
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, mat, title in zip(axes, (koops[n].mean(axis=0)[global_sorter[n]][:, global_sorter[n]],
                                 ukoops[n].mean(axis=0)[global_sorter[n]][:, global_sorter[n]]),
                          ("Constrained $P$", r"Unconstrained $P$")):
    ax.matshow(mat, vmin=0.0, vmax=0.02, interpolation="nearest", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            col = "white" if i == j else "black"
            ax.text(j, i, "{0:2.4f}".format(mat[i, j]), ha="center", va="center", fontsize=12, color=col)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)
    fig.savefig("figs/k-comp-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    fig.savefig("figs/k-comp-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Timescales
Timescales are a bit easier to interpret than raw matrices:

In [None]:
nlags = lags.shape[0]
mitss = {}
for n in outsizes:
    mitss[n] = np.empty((attempts, n - 1, nlags))
    for i in range(attempts):
        print("Reading {0}/{1}  Size {2}...".format(i, attempts, n), end="\r")
        ppf = unflatten(pfs[n][i], lengths)
        for t, tau in enumerate(lags):
            koop = estimate_koopman(ppf, tau)
            mitss[n][i, :, t] = -tau * dt / np.log(np.sort(np.linalg.eigvals(koop))[::-1])[1:]

In [None]:
timescales_uncon = {}
for n in outsizes:
    _its = mitss[n][:, :, 5] * 1e-3
    t_uncon_m = _its.mean(axis=0)
    t_uncon_p = np.percentile(_its, q=(2.5, 97.5), axis=0)
    timescales_uncon[n] = t_uncon_m, *t_uncon_p

In [None]:
for n in outsizes:
    fig = plt.figure(figsize=(n - 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n - 1), timescales_uncon[n][0], yerr=[timescales_uncon[n][0] - timescales_uncon[n][1],
                                                           timescales_uncon[n][2] - timescales_uncon[n][0]],
           color=colors[4:], capsize=8)
    for i in range(n - 1):
        ax.text(i, timescales_uncon[n][2][i] + 0.2, "{:.2f}".format(timescales_uncon[n][0][i]),
                fontsize=20, ha="center", va="center")
    ax.set_ylim(0, 3)
    ax.set_xticks(np.arange(n - 1))
    ax.set_ylabel(r"$t_i$ [µs]", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/timescales-uncon-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/timescales-uncon-{0}.svg".format(n), bbox_inches="tight", transparent=True)

In [None]:
for n in outsizes:
    fig = plot_its(mitss[n], lags)
    fig.savefig("figs/its-uncon-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    fig.savefig("figs/its-uncon-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Lifetimes
The lifetimes are dependent on the probability of transition into the same state, they are given by:

$$ \overline{t}_i = \frac{-\tau}{\log K_{ii}} $$

where $K_{ii}$ are the diagonal elements of the transition matrix and $\tau$ is the model lagtime.

In [None]:
lifetimes_uncon = {}
for n in outsizes:
    lt = np.empty((attempts, n))
    for i in range(attempts):
        lt[i] = -lag * dt / np.log(np.diag(ukoops[n][i])) * 1e-3
    lifetimes_uncon[n] = lt.mean(axis=0)[global_sorter[n]], *(np.percentile(lt, q=(2.5, 97.5), axis=0)[:, global_sorter[n]])

In [None]:
ylimit = 6
for n in outsizes:
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n), lifetimes_uncon[n][0], yerr=[lifetimes_uncon[n][0] - lifetimes_uncon[n][1],
                                                      lifetimes_uncon[n][2] - lifetimes_uncon[n][0]],
           color=colors, capsize=8)
    for i in range(n):
        if lifetimes_uncon[n][2][i] < ylimit:
            ax.text(i, lifetimes_uncon[n][2][i] + 0.5, "{:.2f}".format(lifetimes_uncon[n][0][i]),
                    fontsize=20, ha="center", va="center")
    ax.set_ylim(0, ylimit)
    ax.set_xticks(np.arange(n))
    ax.set_ylabel(r"$\overline{t}_i$ [µs]", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/lifetime-uncon-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/lifetime-uncon-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Mean first passage times
We calculate mean first-passage times from our Koopman matrix $\mathbf{K}(\tau)$:

In [None]:
mfpts_uncon, rates_uncon = {}, {}
for n in outsizes:
    mfpts_uncon[n] = np.zeros((attempts, n, n))
    rates_uncon[n] = np.zeros((attempts, n, n))
    for i in range(attempts):
        for u in range(n):
            for v in range(n):
                if u == v:
                    continue
                koop = renormalize(ukoops[n][i])
                f = tpt(koop, [u], [v])
                rates_uncon[n][i, u, v] = f.rate
                mfpts_uncon[n][i, u, v] = f.mfpt * 50 * dt * 0.001

In [None]:
n = 4
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, mat, title in zip(axes, (mfpts[n].mean(axis=0)[global_sorter[n]][:, global_sorter[n]],
                                 mfpts_uncon[n].mean(axis=0)[global_sorter[n]][:, global_sorter[n]]),
                          (r"Constr. $\mathrm{MFPT}$ [µs]", r"Unconstr. $\mathrm{MFPT}$ [µs]")):
    ax.matshow(mat, vmin=0.0, vmax=50, interpolation="none", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.2f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)
    plt.savefig("figs/mfpt-comp-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/mfpt-comp-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Trajectories
We can visualize the state occupation over time, at least for a few trajectories:

In [None]:
n = 4
flatlengths = np.fromiter(itertools.chain(*lengths), dtype=np.int32)
borders = flatlengths.cumsum()

t0, t1 = 0, 10000
fig = plt.figure(figsize=(16, 4))
ax = fig.add_subplot(111)
x = np.arange(t1 - t0)
y = pfs[n].mean(axis=0)[t0:t1]
ax.matshow(y.T, interpolation="none", aspect=200, cmap="Blues")
for i, b in enumerate(borders[(t0 < borders) & (borders < t1)]):
    ax.plot(np.repeat(b - t0, 2), [-0.5, 3.5], linewidth=0.5, color="#333333", alpha=0.5)
    ax.text(b - t0 - 100, 4.5, str(i), fontsize=16)

ax.set_xlabel("Time [steps]", fontsize=24, labelpad=35)
ax.set_ylabel("$P$", fontsize=24)
ax.tick_params(labelsize=24)

#### Transitions
We look for trajectories with full state transitions with some simple heuristics. This can be cool for visualisation with VMD, Chimera, etc...

In [None]:
n = 4
upf = unflatten(pfs[n].mean(axis=0)[:, global_sorter[n]], lengths=lengths)
window, frac, minn = 10, 0.8, 5
transitions = defaultdict(list)
for i, t in enumerate(upf):
    instate = (t[:window].max(axis=1) > frac).sum() > minn and (t[-window:].max(axis=1) > frac).sum() > minn
    samestate = t[0].argmax() == t[-1].argmax()
    if instate and not samestate:
        transitions[(t[0].argmax(), t[-1].argmax())].append(i)

Example use:

In [None]:
idx_to_traj(transitions[(1, 0)][0], lengths=[1024, 1023, 1024, 1024, 1024])

### Example structures
We extract representative structures for each state, for example just the ones with highest weight:

In [None]:
flatlengths = np.array(list(itertools.chain(*lengths)))
for n in outsizes:
    spfs = pfs[n].mean(axis=0)[:, global_sorter[n]]
    sids = spfs.argsort(axis=0)[-50:]
    for s in range(n):
        trjind = []
        for ind in sids[:, s]:
            pdiff = flatlengths.cumsum() - ind
            trjidx = np.where(pdiff > 0)[0][0]
            trjnr = pdiff[trjidx]
            trjind.append((trjidx, trjnr - 1))

        frames = md.join(md.load_frame(trajs[trjidx], trjnr, top=top)
                         for i, (trjidx, trjnr) in enumerate(trjind))
        frames.save_pdb("structures/state-{0}-{1}-top.pdb".format(n, s))
        np.savetxt("structures/state-{0}-{1}-top-p.dat".format(n, s), spfs[:, s][sids[:, s]])

or by sampling randomly based on these weights:

In [None]:
nsamples = 50
flatlengths = np.array(list(itertools.chain(*lengths)))
allinds = np.arange(nframes, dtype=np.int64)
for n in outsizes:
    spfs = pfs[n].mean(axis=0)[:, global_sorter[n]]
    for s in range(n):
        p = spfs[:, s] / spfs[:, s].sum()
        sids = np.random.choice(allinds, size=nsamples, replace=False, p=p)
        trjind = []
        for ind in sids:
            pdiff = flatlengths.cumsum() - ind
            trjidx = np.where(pdiff > 0)[0][0]
            trjnr = pdiff[trjidx]
            trjind.append((trjidx, trjnr - 1))

        frames = md.join(md.load_frame(trajs[trjidx], trjnr, top=top)
                         for i, (trjidx, trjnr) in enumerate(trjind))
        frames.save_pdb("structures-alt/state-{0}-{1}-top.pdb".format(n, s))
        np.savetxt("structures-alt/state-{0}-{1}-top-p.dat".format(n, s), spfs[:, s][sids])

# Aβ42-MetSO Analysis

## Data
### Trajectories
Trajectories were acquired in five rounds of 1024 simulations each, totalling 5119 runs (one simulation failed to run) at 278 K in the $NVT$ ensemble. Postprocessing involved removing water, subsampling to 250 ps timesteps, and making molecules whole.

In [None]:
trajs_ox = sorted(glob("trajectories/ox/r?/traj*.xtc"))
top_ox = "trajectories/ox/topol.gro"
traj_rounds_ox = [1024, 2047, 3071]
outsizes = [3, 4]

# This is only really necessary for the residues in the plots
topo_ox = md.load_topology(top_ox)

We use minimum distances as features for the neural network:

In [None]:
feat = pe.coordinates.featurizer(top_ox)
feat.add_residue_mindist()
inpcon_ox = pe.coordinates.source(trajs_ox, feat)

lengths_ox = sort_lengths(inpcon_ox.trajectory_lengths(), [1024, 1023, 1024])
nframes_ox = inpcon_ox.trajectory_lengths().sum()

In [None]:
print("Trajectories: {0}".format(len(trajs_ox)))
print("Frames: {0}".format(nframes_ox))
print("Time: {0:5.3f} µs".format(inpcon_ox.trajectory_lengths().sum() * 0.00025))

### Data preparation
We use minimum residue distances as input ($\frac{N(N-1)}{2}$ values, where $N$ is the number of residues) for the neural network, but remove the 2nd and 3rd off-diagonals:

In [None]:
filename = "intermediate/mindist-780-ox.npy"
if os.path.exists(filename):
    print("Loading existing file for ensemble: {0}".format(filename))
    input_flat_ox = np.load(filename)
else:
    print("No mindist file for ensemble, calculating from scratch...")
    input_flat_ox = np.vstack(inpcon_ox.get_output())
    np.save(filename, input_flat_ox)
input_data_ox = unflatten(input_flat_ox, lengths_ox)

We also use the full minimum inter-residue distances for some analysis:

In [None]:
allpairs = np.asarray(list(itertools.combinations(range(nres), 2)))
filename = "intermediate/mindist-all-ox.npy"
if os.path.exists(filename):
    print("Loading existing file for ensemble: {0}".format(filename))
    mindist_flat_ox = np.load(filename)
else:
    print("No mindist file for ensemble, calculating from scratch...")
    feat = pe.coordinates.featurizer(top_ox)
    feat.add_residue_mindist(residue_pairs=allpairs)
    inpmindist_ox = pe.coordinates.source(trajs_ox, feat)
    mindist_flat_ox = np.vstack(inpmindist_ox.get_output())
    np.save(filename, mindist_flat_ox)
mindist_ox = unflatten(mindist_flat_ox, lengths_ox)

## Model validation
We load the previously trained neural network models and calculate the implied timescales, Chapman-Kolmogorov test, and the Koopman operators. This can take a long time, as the constraint vectors have to be re-estimated for every lag time, so we save the intermediate results.

### State sorting
As every run of the neural network will generate a different ordering of (mostly the same) classes, we need to reorder them to be internally consistent. We do this by calculating the root-mean-square deviation between all average inter-residue minimum distance matrices for the individual states and matching the lowest values. We also make sure the sorting is unique, i.e. all states are accounted for, even if a certain RMSD is lower.

In [None]:
with h5py.File("intermediate/data.hdf5") as store:
    for n in outsizes:
        # First attempt is the reference dataset
        ref_group = store["ox/{0}/{1}".format(0, n)]
        ref = ref_group["full"][:, :n]

        # This is the average contact map for every
        # state for the first training attempt
        conwref = (ref / ref.sum(axis=0)).T @ mindist_flat_ox
        conwrefbt = np.broadcast_to(conwref, (n, n, mindist_flat_ox.shape[1])).swapaxes(0, 1)
        ref_group.create_dataset("sorter", data=np.arange(n, dtype="int8"))
        ref_group.create_dataset("full_sorted", data=ref)
        ref_group.create_dataset("bootstrap_sorted", data=ref_group["bootstrap"][:, :n])

        for i in range(1, attempts):
            print("Processing n={0} i={1}...".format(n, i), end="\r")
            group = store["ox/{0}/{1}".format(i, n)]
            pf = group["full"][:, :n]

            # Average contact map for each state for attempt `i`
            conw = (pf / pf.sum(axis=0)).T @ mindist_flat_ox
            conwbt = np.broadcast_to(conw, (n, n, mindist_flat_ox.shape[1]))

            # Calculate the RMSD to our reference attempt
            rmsd = np.sqrt(((conwbt - conwrefbt) ** 2).sum(axis=-1))

            # Get a permutation vector and store it
            sorter = unique_sorting(rmsd)
            group.create_dataset("sorter", data=sorter)
            group.create_dataset("full_sorted", data=pf[:, sorter])
            group.create_dataset("bootstrap_sorted", data=group["bootstrap"][:, :n][:, sorter])

In [None]:
sorters = {n: np.empty((attempts, n), dtype=int) for n in outsizes}
pfs = {n: np.empty((attempts, nframes_ox, n)) for n in outsizes}
pfsn = {n: np.empty((attempts, nframes_ox, n)) for n in outsizes}
pfs_boot = {n: np.empty((attempts, bs_frames, n)) for n in outsizes}
pfsn_boot = {n: np.empty((attempts, bs_frames, n)) for n in outsizes}
koops = {n: np.empty((attempts, n, n)) for n in outsizes}
pis = {n: np.empty((attempts, n)) for n in outsizes}
with h5py.File("intermediate/data.hdf5") as read:
    store = read["ox"]
    for i in range(attempts):
        for n in outsizes:
            sorters[n][i] = store["{0}/{1}/sorter".format(i, n)][:]
            pfs[n][i] = store["{0}/{1}/full_sorted".format(i, n)][:]
            pfsn[n][i] = pfs[n][i] / pfs[n][i].sum(axis=0)
            pfs_boot[n][i] = store["{0}/{1}/bootstrap_sorted".format(i, n)][:]
            pfsn_boot[n][i] = pfs_boot[n][i] / pfs_boot[n][i].sum(axis=0)
            koops[n][i] = store["{0}/{1}/k".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
            pis[n][i] = statdist(koops[n][i])

We calculate an equilibrium weighting vector, analogous to discrete Markov models:

\begin{equation}
w_t = \frac{\langle \chi(\mathbf{x}_t) | \mathbf{\pi} \rangle}{\sum_t^N \langle \chi(\mathbf{x}_t) | \mathbf{\pi} \rangle}
\end{equation}

In [None]:
weights = {}
for n in outsizes:
    weights[n] = np.empty((attempts, nframes_ox))
    for i in range(attempts):
        w = pfs[n][i] @ pis[n][i]
        weights[n][i] = w / w.sum()

#### Global sorting
Because state assignments will not only be different within a choice of number of states, but also globally, we sort the states based on the previous (coarser) state decomposition.

In [None]:
global_sorter = {n: pis[n].mean(axis=0).argsort()[::-1] for n in outsizes}

### Implied timescales
The implied timescales are computed from the eigenvalues $\lambda_i$ of the Koopman operator $\mathbf{K}$ and the selected lag time $\tau$:
$$ t_i = \frac{-\tau}{\log | \lambda_i(\tau) |} $$
We compute the implied timescales for several lag times $\tau$. Ideally, we want to choose $\tau$ so that we're in a regime where the timescales are mostly independent of $\tau$. This is the case for lag times longer than $\tau = 12.5\,\mathrm{ns}$:

In [None]:
with h5py.File("intermediate/data.hdf5") as store:
    for n in outsizes:
        its = np.stack(store["ox/{0}/{1}/its".format(i, n)] for i in range(attempts))
        fig = plot_its(its, lags, dt=dt)
        plt.savefig("figs/its-ox-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
        plt.savefig("figs/its-ox-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Chapman-Kolmogorov Test
Now that we have found a Koopman operator $\mathbf{K}$, how do we know if it adequately describes the dynamical system?

There are a number of tests we can perform to ensure we have found a good approximation. One of the more stringent ones is the Chapman Kolmogorov (CK-test). It is based on the assumption that multiple applications of our operator to our system state at time $t$ should be equivalent to estimating the operator at a longer lag time $\tau$:

$$ \mathbf{K}(\tau)^n \approx \mathbf{K}(n\tau) $$

I.e., multiple applications of the operator should produce the same result as using an operator with a multiple of a certain lag time, within an error. This is a more stringent test of model quality than the implied timescales test described above.

In [None]:
with h5py.File("intermediate/data.hdf5") as store:
    for n in outsizes:
        cke = np.stack(store["ox/{0}/{1}/cke".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
                       for i in range(attempts))[:, global_sorter[n]][:, :, global_sorter[n]]
        ckp = np.stack(store["ox/{0}/{1}/ckp".format(i, n)][:][sorters[n][i]][:, sorters[n][i]]
                       for i in range(attempts))[:, global_sorter[n]][:, :, global_sorter[n]]
        fig = plot_ck(cke, ckp, lag=50)
        plt.savefig("figs/ck-ox-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
        plt.savefig("figs/ck-ox-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Convergence
We would ideally like to see how converged our ensemble is with respect to the timescales and stationary distribution given by our model. We thus build trial models with different numbers of trajectories:

In [None]:
n = 4
k = "ox"
k_conv = np.load("intermediate/k-conv-{0}-{1}.npy".format(k, n))

#### Timescales
We can examine how the timescales change when estimating with more trajectories:

In [None]:
its_conv = np.empty((len(traj_rounds_ox), attempts, n - 1))
for j, _ in enumerate(traj_rounds_ox):
    for i in range(attempts):
        ev, _ = eig(k_conv[j, i], left=True, right=False)
        order = np.abs(ev).argsort()[::-1]
        its_conv[j, i] = -50 * dt / np.log(np.real(ev[order][1:]))
its_conv_m = its_conv.mean(axis=1)
its_conv_p = np.percentile(its_conv, (2.5, 97.5), axis=1)

In [None]:
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(111)
x = np.array(traj_rounds_ox)
for i in range(n - 1):
    ax.semilogy(x, its_conv_m[:, i], color=colors[i + 4], linewidth=2, marker="o")
    ax.fill_between(x, its_conv_p[0, :, i], its_conv_p[1, :, i],
                    interpolate=True, color=colors[i + 4], alpha=0.2)

loc = ticker.LogLocator(base=10.0, subs=(0.2, 0.4, 0.6, 0.8), numticks=12)
ax.set_ylim(1, 100000)
ax.set_yticks(10 ** np.arange(6))
ax.yaxis.set_minor_locator(loc)
ax.yaxis.set_minor_formatter(ticker.NullFormatter())
ax.set_xticks(x)
ax.tick_params(labelsize=16)
ax.set_xlabel("# Trajectories", fontsize=24)
ax.set_ylabel("$t$", fontsize=24)
sns.despine(ax=ax)
plt.savefig("figs/its-conv-coarse-ox-{0}.pdf".format(n), transparent=True, bbox_inches="tight")
plt.savefig("figs/its-conv-coarse-ox-{0}.svg".format(n), transparent=True, bbox_inches="tight")

## Structure
### TICA
Time-lagged independent component analysis is a special case of Koopman operator estimation using a linear projection [4]. We solve the following generalized eigenvalue problem:

$$ \mathbf{C}_{01}v = \lambda \mathbf{C}_{00} v $$

The eigenvectors encode the slowest dynamics of the system, and we use them as a convenient visualization technique.

[4]	Pérez-Hernández, G., Paul, F., Giorgino, T., De Fabritiis, G. & Noé, F. Identification of slow molecular order parameters for Markov model construction. The Journal of Chemical Physics 139, 015102–14 (2013).

In [None]:
ticacon = pe.coordinates.tica(mindist_ox, lag=4, dim=-1, kinetic_map=True)
ticscon = ticacon.get_output()
ycon = np.vstack(ticscon)

print("tIC Dimensions: {0}".format(ycon.shape[1]))
print("Required dimensions for 90 %: {0}".format(ticacon.cumvar[ticacon.cumvar < 0.9].shape[0]))

#### Free energy surface
We also show the free energy surface projected onto the two slowest tICs in the form of a kernel density estimate:

In [None]:
kernel = gaussian_kde(ycon[::10, :2].T)
xmin, ymin, *_ = ycon.min(axis=0)
xmax, ymax, *_ = ycon.max(axis=0)
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
posi = np.vstack((X.ravel(), Y.ravel()))
Z = kernel(posi).reshape(X.shape)
mat = np.rot90(Z.copy())
mat[mat < 0.01] = np.nan

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
cmap = "plasma"
F = -KBT * np.log(Z)
F -= F.min()
ax.contourf(X, Y, F, np.arange(0.0, 10, 1), cmap=cmap)
ax.contour(X, Y, F, np.arange(0.0, 10, 1), cmap=cmap, linewidth=10)
ax.tick_params(labelsize=24)
ax.set_xlabel(r"tIC 1", fontsize=24, labelpad=10)
ax.set_ylabel(r"tIC 2", fontsize=24, labelpad=10)
sns.despine(ax=ax)

### Contact maps
We calculate contact maps from inter-residue minimum distance matrices, using the bootstrapped state assignments:

In [None]:
cutoff = 0.8
inds = unflatten(np.arange(nframes_ox).reshape(-1, 1), lengths_ox)
contacts = {n: np.empty((attempts, n, nres, nres)) for n in outsizes}
for i in range(attempts):
    generator = DataGenerator.from_state(inds, "models/model-idx-ox-{0}.hdf5".format(i))
    idx = generator(n=2, lag=20).trains[0].flatten().astype(int)
    for n in outsizes:
        print("Processing n={0} i={1}...".format(n, i), end="\r")
        con = (np.einsum("jk,jl->kl", pfs[n][i, idx], (mindist_flat_ox[idx] < cutoff)) /
               pfs[n][i, idx].sum(axis=0).reshape(-1, 1))
        contacts[n][i] = np.asarray([triu_inverse(con[j], nres)[0] for j in range(n)])

### Secondary structure
Secondary structure assignment with DSSP:

In [None]:
# Custom func makes this a lot easier
dssptable = str.maketrans("HBEGITS ", "01234567")
def dssp_enc(traj):
    raw = md.compute_dssp(traj.atom_slice(range(627)), simplified=False)
    return np.char.translate(raw, table=dssptable).astype(np.float32)

feat = pe.coordinates.featurizer(top_ox)
feat.add_custom_func(dssp_enc, dim=42)
inp = pe.coordinates.source(trajs_ox, feat)
dssp = np.vstack(inp.get_output()).astype(np.int32)

# One-hot encoding
nvals = dssp.max() + 1
dsspoh = np.eye(nvals, dtype=np.int32)[dssp]

# We could use the simplified DSSP scheme, but this gives us a bit more flexibility
dssplow = np.empty((nframes_ox, nres, 4))
dssplow[:, :, 0] = dsspoh[:, :, [0, 3, 4]].sum(axis=-1)
dssplow[:, :, 1] = dsspoh[:, :, [1, 2]].sum(axis=-1)
dssplow[:, :, 2] = dsspoh[:, :, [5, 6]].sum(axis=-1)
dssplow[:, :, 3] = dsspoh[:, :, 7]

In [None]:
inds = unflatten(np.arange(nframes_ox).reshape(-1, 1), lengths_ox)
sec = {n: np.empty((attempts, n, nres, 4)) for n in outsizes}
for i in range(attempts):
    generator = DataGenerator.from_state(inds, "models/model-idx-ox-{0}.hdf5".format(i))
    idx = generator(n=2, lag=20).trains[0].flatten().astype(int)
    for n in outsizes:
        print("Processing n={0} i={1}...".format(n, i), end="\r")
        sec[n][i] = (np.einsum("ij,ikl->jkl", pfs[n][i, idx], dssplow[idx]) /
                     pfs[n][i, idx].sum(axis=0).reshape(-1, 1, 1))

secm = {n: sec[n].mean(axis=0)[global_sorter[n]] for n in outsizes}
secs = {n: np.percentile(sec[n], q=(2.5, 97.5), axis=0)[:, global_sorter[n]] for n in outsizes}

In [None]:
for n in outsizes:
    fig, axes = plt.subplots(4, n, figsize=(4 * n, 7.5), sharex=True,
                             gridspec_kw={"height_ratios": [5, 5, 15, 1], "width_ratios": [3] * n})
    for i, axs in enumerate(axes.T):
        for j, ax in enumerate(axs[:2]):
            ax.bar(np.arange(nres), secm[n][i, :, j], yerr=[secm[n][i, :, j] - secs[n][0, i, :, j],
                                                            secs[n][1, i, :, j] - secm[n][i, :, j]],
                   capsize=2, color=colors[i])
            sns.despine(ax=ax)
            ax.set_ylim(0, 1)
            if i == 0:
                ax.set_ylabel([r"$P(\alpha)$", r"$P(\beta)$"][j], fontsize=24)
            else:
                ax.set_yticklabels([])
            ax.set_yticks([0, 0.5, 1.0])
            ax.set_xticks(np.arange(0, 50, 10))
            ax.set_xticks(np.arange(0, 42, 1), minor=True)
            ax.set_xticklabels([])
            ax.tick_params(labelsize=24)
            ax.tick_params(axis="x", length=0, labelsize=12)

        ax = axs[2]
        dat = contacts[n].mean(axis=0)[global_sorter[n]][i]
        cmap = sns.light_palette(colors[i], as_cmap=True)
        im = ax.matshow(dat, vmin=0.0, vmax=1.0, interpolation="nearest", cmap=cmap)
        ax.xaxis.tick_bottom()
        ax.set_xticks(np.arange(0, 50, 10))
        ax.set_xticks(np.arange(0, 42, 1), minor=True)
        ax.set_yticks(np.arange(0, 50, 10))
        ax.set_yticks(np.arange(0, 42, 1), minor=True)
        ax.set_xticklabels(np.arange(0, 50, 10))
        if i != 0:
            ax.set_yticklabels([])
        ax.set_xlim(0, 41)
        ax.set_ylim(41, 0)
        ax.tick_params(labelsize=24)

        cax = axs[3]
        shared = cax.get_shared_x_axes()
        for a in shared.get_siblings(cax):
            shared.remove(a)
        fig.colorbar(im, cax=cax, fraction=0.046, pad=0.04, orientation="horizontal")
        cax.xaxis.set_ticks(np.arange(0, 1.5, 0.5))
        cax.set_xticklabels(np.arange(0, 1.5, 0.5))
        cax.tick_params(labelsize=16)
        ax.set_xticks(np.arange(0, 50, 10))
        ax.set_xticks(np.arange(0, 42, 1), minor=True)
        ax.set_yticks(np.arange(0, 50, 10))
        ax.set_yticks(np.arange(0, 42, 1), minor=True)
        ax.set_xticklabels(np.arange(0, 50, 10))

        cax.xaxis.set_ticks(np.arange(0, 1.5, 0.5))
        cax.set_xticklabels(np.arange(0, 1.5, 0.5))
        cax.tick_params(labelsize=16)
    plt.savefig("figs/structure-ox-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/structure-ox-{0}.svg".format(n), bbox_inches="tight", transparent=True)

## Kinetics
### Koopman operators
We can look at the Koopman operators and their errors directly:

In [None]:
n = 4
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, mat, title in zip(axes, (koops[n].mean(axis=0)[global_sorter[n]][:, global_sorter[n]],
                                 koops[n].std(axis=0)[global_sorter[n]][:, global_sorter[n]]),
                          ("$P$", r"$\sigma(P)$")):
    ax.matshow(mat, vmin=0.0, vmax=0.02, interpolation="none", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.4f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)

### Mean first passage times
We calculate mean first-passage times from our Koopman matrix $\mathbf{K}(\tau)$:

In [None]:
mfpts, rates = {}, {}
for n in outsizes:
    mfpts[n] = np.zeros((attempts, n, n))
    rates[n] = np.zeros((attempts, n, n))
    for i in range(attempts):
        for u in range(n):
            for v in range(n):
                if u == v:
                    continue
                koop = renormalize(koops[n][i])
                f = tpt(koop, [u], [v])
                rates[n][i, u, v] = f.rate
                mfpts[n][i, u, v] = f.mfpt * 50 * dt * 0.001

In [None]:
n = 4
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, mat, title in zip(axes, (mfpts[n].mean(axis=0)[global_sorter[n]][:, global_sorter[n]],
                                 mfpts[n].std(axis=0)[global_sorter[n]][:, global_sorter[n]]),
                          (r"$\mathrm{MFPT}$ [µs]", r"$\sigma(\mathrm{MFPT})$")):
    ax.matshow(mat, vmin=0.0, vmax=60, interpolation="nearest", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.2f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)
    plt.savefig("figs/mfpt-ox-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/mfpt-ox-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Transition rates
The transition rates are the inverse of the mean first-passage times (MFPTs):

In [None]:
n = 4
perms = 1e6 * rates[n] / (50 * dt)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, mat, title in zip(axes, (perms.mean(axis=0)[global_sorter[n]][:, global_sorter[n]],
                                 perms.std(axis=0)[global_sorter[n]][:, global_sorter[n]]),
                          [r"$k_{ij}$ [1/ms]", "$\sigma(k_{ij})$ [1/ms]"]):
    ax.matshow(mat, vmin=0.0, vmax=1000, interpolation="none", cmap="GnBu")
    for i in range(n):
        for j in range(n):
            ax.text(j, i, "{0:2.2f}".format(mat[i, j]), ha="center", va="center", fontsize=12)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_title(title, fontsize=24)
    ax.tick_params(length=0)

### Lifetimes
The lifetimes are dependent on the probability of transition into the same state, they are given by:

$$ \overline{t}_i = \frac{-\tau}{\log K_{ii}} $$

where $K_{ii}$ are the diagonal elements of the transition matrix and $\tau$ is the model lagtime.

In [None]:
lifetimes = {}
for n in outsizes:
    lt = np.empty((attempts, n))
    for i in range(attempts):
        lt[i] = -lag * dt / np.log(np.diag(koops[n][i])) * 1e-3
    lifetimes[n] = lt.mean(axis=0)[global_sorter[n]], *(np.percentile(lt, q=(2.5, 97.5), axis=0)[:, global_sorter[n]])

In [None]:
ylimit = 15
for n in outsizes:
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n), lifetimes[n][0], yerr=[lifetimes[n][0] - lifetimes[n][1],
                                                lifetimes[n][2] - lifetimes[n][0]],
           color=colors, capsize=8)
    for i in range(n):
        if lifetimes[n][2][i] < ylimit:
            ax.text(i, lifetimes[n][2][i] + 1.0, "{:.2f}".format(lifetimes[n][0][i]),
                    fontsize=20, ha="center", va="center")
    ax.set_ylim(0, ylimit)
    ax.set_xticks(np.arange(n))
    ax.set_ylabel(r"$\overline{t}_i$ [µs]", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/lifetime-ox-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/lifetime-ox-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Timescales
We just plot the relaxation timescales here for clarity:

In [None]:
timescales = {}
with h5py.File("intermediate/data.hdf5") as read:
    store = read["ox"]
    for n in outsizes:
        its = np.stack(store["{0}/{1}/its".format(i, n)] for i in range(attempts))[:, ::-1] * 1e-3
        timescales[n] = (its[:, :, -2].mean(axis=0), *(np.percentile(its[:, :, -2], q=(2.5, 97.5), axis=0)))

In [None]:
for n in outsizes:
    fig = plt.figure(figsize=(n - 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n - 1), timescales[n][0], yerr=[timescales[n][0] - timescales[n][1],
                                                     timescales[n][2] - timescales[n][0]],
           color=colors[4:], capsize=8)
    for i in range(n - 1):
        ax.text(i, timescales[n][2][i] + 0.2, "{:.2f}".format(timescales[n][0][i]),
                fontsize=20, ha="center", va="center")
    ax.set_ylim(0, 3)
    ax.set_xticks(np.arange(n - 1))
    ax.set_ylabel(r"$t_i$ [µs]", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/timescales-ox-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
    plt.savefig("figs/timescales-ox-{0}.svg".format(n), bbox_inches="tight", transparent=True)

### Equilibrium distribution
We can look at the equilibrium distributions $\mathbf{\pi}$:

In [None]:
for n in outsizes:
    pm = pis[n].mean(axis=0)[global_sorter[n]]
    pv = np.percentile(pis[n], q=(2.5, 97.5), axis=0)[:, global_sorter[n]]
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n), pm, yerr=[pm - pv[0], pv[1] - pm], color=colors, capsize=8)
    for i in range(n):
        ax.text(i, pv[1, i] + 0.05, "{:.2f}".format(pm[i]), fontsize=20, ha="center", va="center")
    ax.set_ylim(0, 1)
    ax.set_xticks(np.arange(n))
    ax.set_ylabel("Probability", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)
    plt.savefig("figs/pops-ox-{0}.pdf".format(n), transparent=True, bbox_inches="tight")
    plt.savefig("figs/pops-ox-{0}.svg".format(n), transparent=True, bbox_inches="tight")

In [None]:
n = 4
pm = pis[n].mean(axis=0)[global_sorter[n]]
pv = np.percentile(pis[n], q=(2.5, 97.5), axis=0)[:, global_sorter[n]]
for i in range(n):
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    cols = [(0.8, 0.8, 0.8)] * n
    cols[i] = colors[i]
    ax.bar(np.arange(n), pm, yerr=[pm - pv[0], pv[1] - pm], color=cols, capsize=8)
    ax.set_ylim(0, 1)
    ax.set_yticks([0.0, 0.5, 1.0])
    ax.set_xticks(np.arange(n))
    ax.set_ylabel("P", fontsize=32, labelpad=10)
    ax.tick_params(labelsize=32)
    ax.tick_params(axis="x", length=0, pad=10)
    ax.text(0.7, 0.9, "{0}: {1:2.0f} %".format(i, pm[i] * 100), fontsize=32)
    sns.despine(ax=ax)
    plt.savefig("figs/pops-fine-ox-{0}-{1}.pdf".format(n, i), transparent=True, bbox_inches="tight")
    plt.savefig("figs/pops-fine-ox-{0}-{1}.svg".format(n, i), transparent=True, bbox_inches="tight")

### Entropy
We can get some idea of the "entropy" of each state by calculating the information entropy $S_i = -\sum_t^N \chi_i(\mathbf{x}_t) \log_2(\chi_i(\mathbf{x}_t)) $. In some sense, this encodes ambiguity in the state assignment, or how "wide" the state is:

In [None]:
ents = {}
for i, n in enumerate(outsizes):
    ent = -np.nansum(pfsn_boot[n] * np.log2(pfsn_boot[n]) / np.log2(pfsn_boot[n].shape[1]), axis=1)
    ents[n] = np.array([ent.mean(axis=0)[global_sorter[n]],
                        *(np.percentile(ent, (2.5, 97.5), axis=0))[:, global_sorter[n]]])

In [None]:
for n in outsizes:
    fig = plt.figure(figsize=(n * 1, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(n), ents[n][0], yerr=[ents[n][0] - ents[n][1], ents[n][2] - ents[n][0]],
           color=colors, capsize=8)
    for i in range(n):
        ax.text(i, ents[n][2, i] + 0.05, "{:.2f}".format(ents[n][0, i]), fontsize=20, ha="center", va="center")
    ax.set_ylim(0, 1)
    ax.set_xticks(np.arange(n))
    ax.set_ylabel("Normalized entropy", fontsize=24, labelpad=10)
    ax.tick_params(labelsize=24)
    ax.tick_params(axis="x", length=0, pad=10)
    sns.despine(ax=ax)

### Graph
We will now look at the model in the classic graph format. A good way of projecting the states is on the space of the two slowest time-lagged independent components (tICs), as they separate the states very well.

In [None]:
pos = {}
for n in outsizes:
    state_tic = np.einsum("ijk,jl->ikl", pfsn[n], ycon[:, :2])
    pos[n] = state_tic.mean(axis=0)[global_sorter[n]]

We mask out transition probabilities below a certain threshold, and define the crispness as $\mathscr{c}_i := S_i^{-1}$, i.e. a crisper state is less ambiguous in it's state assignments.

In [None]:
n = 4

minflux = 3e-4
ps = np.empty((attempts, n, n))
for i in range(attempts):
    p = koops[n][i][global_sorter[n]][:, global_sorter[n]].copy()
    u, v = np.where((np.diag(pis[n][i][global_sorter[n]]) @ p) < minflux)
#     p[u, v] = 0.0
    ps[i, :, :] = p

crisp = 1 / ents[n][0]
crisp /= crisp.max()
posi = pos[n]

In [None]:
# This is just to get the arrow thickness proportional to the flux,
# when using external software like Illustrator...
psm = ps.mean(axis=0)
pmin, pmax = psm[psm > 0].min(), psm[psm < 0.9].max()
psm[(psm == 0.0) | (psm > 0.9)] = np.nan
psm -= pmin
psm /= pmax
psm * 3 + 1

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
fig, posi = pe.plots.plot_network(
    ps.mean(axis=0), pos=posi, state_sizes=pis[n].mean(axis=0)[global_sorter[n]], show_frame=True, ax=ax, arrow_curvature=1.2)
plt.close();

figpadding = 0.2
xmin = posi[:, 0].min()
xmax = posi[:, 0].max()
Dx = xmax - xmin
xmin -= Dx * figpadding
xmax += Dx * figpadding
Dx *= 1 + figpadding
ymin = posi[:, 0].min()
ymax = posi[:, 0].max()
Dy = ymax - ymin
ymin -= Dy * figpadding
ymax += Dy * figpadding
Dy *= 1 + figpadding
sizes = min(Dx, Dy) ** 2 * 0.5 * pis[n].mean(axis=0)[global_sorter[n]] / (pis[n].mean(axis=0).max() * n)
crispness = min(Dx, Dy) ** 2 * 0.5 * crisp

# We need to redraw these to be able to show the crispness
ax.artists = []
for i in range(n):
    ax.add_artist(plt.Circle(posi[i], radius=0.5 * np.sqrt(0.5 * sizes[i]),
                             facecolor=colors[i], edgecolor=colors[i], alpha=0.5))
    ax.add_artist(plt.Circle(posi[i], radius=0.1 * np.sqrt(0.5 * crispness[i]), facecolor=colors[i]))
ax.tick_params(labelsize=24)
ax.set_xticks(np.arange(-2, 3, 1))
ax.set_yticks(np.arange(-2, 3, 1))
ax.set_xticks(np.arange(-2, 2.1, 0.1), minor=True)
ax.set_yticks(np.arange(-2, 2.1, 0.1), minor=True)
ax.set_xlim(-2.1, 2.1)
ax.set_ylim(-2.1, 2.1)
ax.set_xlabel(r"tIC 0", fontsize=24, labelpad=10)
ax.set_ylabel(r"tIC 1", fontsize=24, labelpad=10)
sns.despine(ax=ax)
fig.savefig("figs/graph-ox-{0}.pdf".format(n), bbox_inches="tight", transparent=True)
fig.savefig("figs/graph-ox-{0}.svg".format(n), bbox_inches="tight", transparent=True)
fig

### Example structures
We extract representative structures for each state, for example just the ones with highest weight:

In [None]:
flatlengths = np.array(list(itertools.chain(*lengths_ox)))
for n in outsizes:
    spfs = pfs[n].mean(axis=0)[:, global_sorter[n]]
    sids = spfs.argsort(axis=0)[-50:]
    for s in range(n):
        trjind = []
        for ind in sids[:, s]:
            pdiff = flatlengths.cumsum() - ind
            trjidx = np.where(pdiff > 0)[0][0]
            trjnr = pdiff[trjidx]
            trjind.append((trjidx, trjnr - 1))

        frames = md.join(md.load_frame(trajs_ox[trjidx], trjnr, top=top_ox)
                         for i, (trjidx, trjnr) in enumerate(trjind))
        frames.save_pdb("structures/state-metso-{0}-{1}-top.pdb".format(n, s))
        np.savetxt("structures/state-metso-{0}-{1}-top-p.dat".format(n, s), spfs[:, s][sids[:, s]])

or by sampling randomly based on these weights:

In [None]:
nsamples = 50
flatlengths = np.array(list(itertools.chain(*lengths_ox)))
allinds = np.arange(nframes_ox, dtype=np.int64)
for n in outsizes:
    spfs = pfs[n].mean(axis=0)[:, global_sorter[n]]
    for s in range(n):
        p = spfs[:, s] / spfs[:, s].sum()
        sids = np.random.choice(allinds, size=nsamples, replace=False, p=p)
        trjind = []
        for ind in sids:
            pdiff = flatlengths.cumsum() - ind
            trjidx = np.where(pdiff > 0)[0][0]
            trjnr = pdiff[trjidx]
            trjind.append((trjidx, trjnr - 1))

        frames = md.join(md.load_frame(trajs_ox[trjidx], trjnr, top=top_ox)
                         for i, (trjidx, trjnr) in enumerate(trjind))
        frames.save_pdb("structures-alt/state-metso-{0}-{1}-top.pdb".format(n, s))
        np.savetxt("structures-alt/state-metso-{0}-{1}-top-p.dat".format(n, s), spfs[:, s][sids])