# Fit a dynamical model to sc-flow data

In this notebook, we will

1. Import and visualize simulated single cell flow cytometry data
2. Develop a conditional variational autoencoder that finds a latent space representation, and estimates differentiation rates
3. Visualize the results and compare the inferred values with the ground truth
4. Restrict the model space and improve our estimates

We start by importing some python modules. We will use `torch` and `pyro` for deep learing and variational inference. This is faster on a GPU, but can also be done on a CPU if required. We therefore check if a GPU is available. 

In [None]:
import pandas as pd # for importing the dataset
import matplotlib.pyplot as plt # for plotting
from sklearn.decomposition import PCA # for dimensionality reduction
import numpy as np # for numerical computations
import tqdm # for progress bars
import torch # PyTorch for tensor computations
import pyro # probabilistic programming library
# import objects for stochastic variational inference
from pyro.infer import SVI, Trace_ELBO, JitTrace_ELBO 
from pyro.optim import Adam


# allow imports from parent directory (for scdynsys package)

import sys
sys.path.append("..")


# import the dynamical VAE model
from scdynsys.dynamic_vae import VAEgmmdyn
from scdynsys.utilities import cell_type_colors

# check if CUDA (GPU) is available and set device accordingly
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

print("CUDA available:", USE_CUDA)


## 1. Import the data and visualize

The dataset contains values for 12 flow cytometry markers, and time and celltype information for each cell. The celltype is typically not known, but as this data is simulated, we actually know the celltype. Of course, we should not use this information in our model fitting...

In [None]:
# import the (simulated) dataset
dataset = pd.read_csv('../data/simulated_time_series_data.csv')

# look at the first few rows
dataset.head()

### Inspection of the data

The dataset contains: 

- 12 flow cytometry markers 
- sampling time information

Let's plot 1-d flow plots (marginal densities) 
and see how these change with time

In [None]:
# plot marginal densities of markers at different time points
time_points = sorted(dataset['Timepoint'].unique())
num_timepoints = len(time_points)

# marker names:
markers = [col for col in dataset.columns if col not in ['Timepoint', 'CellType']]
num_markers = len(markers)

fig, axs = plt.subplots(
    num_timepoints, num_markers,
    figsize=(2*num_markers, 2*num_timepoints), 
    sharex='col', sharey=True
)

# remove space between subplots
fig.subplots_adjust(hspace=0.0, wspace=0.0)


# spit data by marker and time point and plot histograms (1D FACS plots)
for j, marker in enumerate(markers):
    for i, tp in enumerate(time_points):
        data_tp = dataset[dataset['Timepoint'] == tp][marker]
        ax = axs[i,j]
        ax.hist(data_tp, bins=30, density=True, color='k', alpha=0.7)
        if j == 0:
            ax.set_ylabel(f'Time {tp}')
        if i == 0:
            ax.set_title(marker)
        ax.set(xticks=[], yticks=[])


The dataset is quite large, and so we'll subsample to speed up
computations. 
Next, we'll define some arrays and tensors that are required below.

These are some of the key elements:

* `xs`: these are the (scaled) marker expression measured by flow cytometry.
* `ts`: time point for each cell.
* `celltypes`: the ground-truth cell type for each cell (only known because we simulate the data)
* `utime`: these are the unique time points in the data set. The reason for using these instead of `ts` is that we only have to evaluate our dynamical model at these time points. If we would evaluate at `ts` we would be doing A LOT of redundant work.

Torch and Pyro work with so-called "tensors" instead of e.g. numpy arrays. So we will conver the "raw" arrays to "tensor". We also move the data to the specified devide (the CPU or the GPU).

In [None]:
# subsample the dataset for faster experimentation

dataset = dataset.sample(20000, random_state=42)

# prepare tensors for model fitting: xs and ts (marker expression and time)

xs_raw = dataset.iloc[:, 0:-1].values
ts_raw = dataset["Timepoint"].values
celltypes = dataset["CellType"].values
celltypes_unique = np.unique(celltypes)
print(f"Unique cell types: {celltypes_unique}")

# torch and Pyro work with tensors: these are arrays with additional functionality
# to support GPU computations and automatic differentiation

xs_tensor = torch.tensor(xs_raw, dtype=torch.float32, device=DEVICE)
ts_tensor = torch.tensor(ts_raw, dtype=torch.float32, device=DEVICE)

# we need to know the data dimension and number of samples

data_dim = xs_tensor.shape[1]  # number of features
n_samples = xs_tensor.shape[0]

print(f"Number of feature5: {data_dim}, Number of samples: {n_samples}")

# create unique time tensor

utime_raw = np.sort(dataset["Timepoint"].unique())
utime_tensor = torch.tensor(utime_raw, dtype=torch.float32, device=DEVICE)

num_timepoints = len(utime_raw)
print(f"Number of unique time points: {num_timepoints}")

print(f"Unique time points: {utime_raw}")

# create time index tensor

xtime_raw = np.sum([i * (ts_raw == t) for i, t in enumerate(utime_raw)], axis=0)
xtime_tensor = torch.tensor(xtime_raw, dtype=torch.long, device=DEVICE)

# check that the time index tensor is correct

assert torch.all(utime_tensor[xtime_tensor] == ts_tensor)


## 2. Develop a conditional VAE

Using Pyro for building VAE models is relatively easy, as we don't have to manually calculate the ELBO, or worry about ["reparameterization tricks"](https://en.wikipedia.org/wiki/Reparameterization_trick).

However, Pyro has a couple of quirks, leading to a slightly steep learning curve. We therefore provided a working VAE model to make this workshop feasible.

Here, we give an overview of the main components of the VAE model, and we do encourage you to have a look at the model files. They are given in the `scdynsys` folder:

* `dynamic_vae.py`: this contains the main Python class defining our VAE. It has a `model` and a `guide` method, representing the model (including the decoder) and the variational distribution (which Pyro refers to as a "guide").
* `dynamic_model.py`: this is the (solved) ODE model, which also has a `model` and `guide` method such that it can be integrated in the same variational inference loop.
* `mixture_model.py`: As the name suggests: this defines the GMM.
* `nets.py`: These are the encoder and decoder networks.

### Methods of VAEgmmdyn

```python
# the init method...
def __init__(data_dim, z_dim, hidden_dim, num_clus, ...):       
    # initialize the PyroModule
    super().__init__()

    # define the encoder and decoder networks
    self.decoder_x = DecoderX(
        z_dim, 
        hidden_dim, 
        data_dim, 
    )
    self.encoder_z = CondEncoderZ(
        z_dim, 
        hidden_dim, 
        data_dim,
    )

    # define the mixture model
    self.mix = GaussMix(z_dim, num_clus, weighted=False)

    # define the dynamic model
    self.dynamic_model = DynamicModel(num_clus)


# the model method...
def model(x, xtime, utime, ...):
    # use GaussMix object to sample parameters for the mixture model
    clus_locs, clus_chol_fact = self.mix()
    # use the dynamic model to get time-dependent weights
    weights = self.dynamic_model(utime)[..., xtime, :]
    logweights = torch.log(weights + 1e-10)
            
    # plates indicate independence of samples
    with pyro.plate("unobserved", x.shape[-2]):
        # setup a mixture model...
        mix = dist.Categorical(logits=logweights)
        comp = dist.MultivariateNormal(clus_locs, scale_tril=clus_chol_fact)
        # and sample latent vectors.
        z = pyro.sample("latent", dist.MixtureSameFamily(mix, comp))

    # decode the latent code z
    x_loc, x_scale = self.decoder_x(z)
            
    with pyro.plate("xdata", x.shape[-2]):
        ## score reconstructed x against actual expression data
        pyro.sample("xobs", dist.Normal(x_loc, x_scale).to_event(1), obs=x)


# the guide method
def guide(x, xtime, utime, ...):
    # use the guide method of GaussMix
    clus_locs, clus_chol_fact = self.mix.guide()

    # and the guide method of dynamic_model 
    weights = self.dynamic_model.guide(utime)
        
    # use the encoder to get the parameters used to define q(z|x)
    z_loc, z_scale = self.encoder_z(x, utime[xtime])
    
    # sample latent vectors
    with pyro.plate("unobserved", x.shape[-2]):
        # sample the latent code z
        pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

```


### Method of the `DynamicModel` class

This class samples initial values and rate parameters, and then 
solves the ODE model, and returns the time-dependent weights.
For more complex models, you'd have to use a numerical ODE integrator...

```python
def __init__(self, num_clus, ...):
    super().__init__()
        
    self.num_clus = num_clus
    
    # setup an auto-guide
    self.auto_guide = AutoMultivariateNormal(
        self, init_loc_fn=init_fn, init_scale=init_scale
    )
                
    def forward(self, time: torch.Tensor) -> torch.Tensor:
        # use a typical symmetric Dirichlet prior with concentration 0.5
        halves = torch.full((self.num_clus,), 0.5, device=self.device)
        X0 = pyro.sample("X0", dist.Dirichlet(halves))

        # and sample the off-diagonal elements of Q
        shp_Q = (*X0.shape[:-1], self.num_clus-1, self.num_clus)
        rate_Q = 10*torch.ones(shp_Q, device=self.device)
        Qoffdiag = pyro.sample("Qoffdiag", dist.Exponential(rate_Q).to_event(2))
        
        # solve the ODE explicitly using the matrix exponential:
        #    Xt = exp(t*Q) * X0
        Xt = dynamic_model(time, X0, Qoffdiag)
        return Xt
    
    def guide(self, time: torch.Tensor) -> torch.Tensor:
        # calls the auto-guide and returns the logweights using a HACK
        guide_trace = poutine.trace(self.auto_guide).get_trace(time)
        logweights = poutine.block(poutine.replay(self, guide_trace))(time)
                
        return logweights
```

In this notebook (interactive python session), you might want to run cells multiple times. However, Pyro keeps track of your parameters in a so-called "parameter store". To reset all estimates, you can call the next cell:

In [None]:
# clear Pyro parameter store: erase any previously learned parameters

pyro.clear_param_store()

Next, we initialize the VAE model imported from `scdynsys.dynamic_vae.py`.

Printing the vae gives an overview of all the components. Can you figure out what they mean?

In [None]:
vae = VAEgmmdyn(
    data_dim = data_dim, 
    z_dim = 3, 
    hidden_dim = 12,
    num_clus = 4,
    time_scaling = 0.5, # scale time axis for cond. encoder
    use_cuda = USE_CUDA,
)

# print the model architecture:

print(vae)

### Fit the model

Next, we'll define a function that fits the model to the sc data. The inputs are an `SVI` object, which we will define below, and contains the model and guide, and some details about what optimization methods and how to compute the ELBO (Pyro provides multiple options). 

Of course, we must also provide the data (including the timepoints) and the number of training iterations to use. The following function trains the model in the complete dataset at once, and does not use mini-batching. This works if you have a powerful GPU, but may be slow on the CPU or less capable computers.

In [None]:
def train_test_loop(
    svi: SVI,
    xs_raw: np.array,
    xtime_raw: np.array,
    utime_raw: np.array,
    batch_size: int,
    num_epochs: int,
) -> tuple[list[tuple[int, float]], ...]:  
    from torch.utils.data import DataLoader

    # set up the data loader for mini-batch training

    kwargs = {'num_workers': 1, 'pin_memory': USE_CUDA}
    xs_tens = torch.tensor(xs_raw, dtype=torch.float32)
    xtime_tens = torch.tensor(xtime_raw, dtype=torch.long)
    raw_data_combined = list(zip(xs_tens, xtime_tens))

    data_loader = DataLoader(
        dataset=raw_data_combined,
        batch_size=batch_size, shuffle=True, **kwargs
    )

    # the model requires the total number of samples N 
    # and unique time points utime

    N = torch.tensor(len(data_loader.dataset), device=DEVICE)
    utime = torch.tensor(utime_raw, dtype=torch.float32, device=DEVICE)

    # set up progress bar
    trange = tqdm.notebook.trange
    
    # return loss list for convergence tracking
    train_elbo = []
    for epoch in (pbar := trange(num_epochs)):
        epoch_loss = 0.0
        for xs, ts in data_loader:
            # do ELBO gradient and accumulate loss
            epoch_loss += svi.step(xs.to(DEVICE), ts.to(DEVICE), utime, N)
        total_epoch_loss = epoch_loss / len(data_loader)
        train_elbo.append((epoch, total_epoch_loss))
        pbar.set_description(f"average train loss: {total_epoch_loss:0.2f}")

    return train_elbo

### Now we can finally fit the model

We have to choose some hyperparameters like batch size, number of epochs, 
and the learning rate ("lr"). 

The class JitTrace_ELBO is used to automatically compute the ELBO (and gradient).
The "Jit" in the name stands for just-in-time and refers to "just-in-time compilation".
This is a method to speed up runtime without using a more low-level programming language.
HOWEVER: JIT requires very precise coding, and it is easy to break code that uses JIT.

You'll see some warnings (that you can ignore) that are a testament to this tedious 
JIT-compatible programming.

If `JitTrace_ELBO` does not work, try using `Trace_ELBO` instead. This is slower but more forgiving.

**Skip this step if you don't want to wait! We also included a pre-trained model**

In [None]:
# fit the model using mini-batch training

# setup data loaders
BATCH_SIZE = 2**10
NUM_EPOCHS = 500

# setup the optimizer
adam_args = {"lr": 2e-3}
optimizer = Adam(adam_args)
# setup the inference algorithm
loss_method = JitTrace_ELBO(num_particles=10, vectorize_particles=True)
svi = SVI(vae.model, vae.guide, optimizer, loss=loss_method)

# START TRAINING!!!
elbo_loss = train_test_loop(
    svi,
    xs_raw,
    xtime_raw,
    utime_raw,
    BATCH_SIZE,
    NUM_EPOCHS
)

# plot the training and test loss curves
train_elbo_vals = [x[1] for x in elbo_loss]

fig, ax = plt.subplots(1, 1, figsize=(3, 3))
ax.plot(train_elbo_vals, linewidth=0.5, color='k', label='Train ELBO')
ax.legend()


In [None]:
# to save the trained model parameters, uncomment the following line:
# pyro.get_param_store().save("../data/vae_dyn_model_params.pt")

### Load a pre-trained model

As training might take 20 minutes, depending on your hardware,
we've included a pre-trained model. We do advise to also run the training loop,
even for a smaller number of epochs. You'll see that the model does not always converge
to the correct result.

In [None]:
## load pre-fitted model

# HACK:
# register safe globals for loading, account for recent torch/pyro changes
torch.serialization.add_safe_globals([
    torch.distributions.constraints._Real,
    pyro.distributions.constraints._SoftplusPositive,
    pyro.distributions.constraints._UnitLowerCholesky,
])

# load the trained model parameters
store = pyro.get_param_store()
store.clear()
param_file = "../data/vae_dyn_model_params.pt"
store.load(param_file, map_location=DEVICE)

# load parameters into model
state_dict = vae.state_dict()
state_dict.update({k : store[k] for k in state_dict if k in store})
vae.load_state_dict(state_dict)

## 3. Visualize the results

We'll first have a look at the latent space embedding the model came up with. As our latent space is 3D, we can easily project this to 2D graphics, using PCA to get the most informative projection first (i.e. the first 2 PCs).

We can also use the VAE model to classify cells: the GMM components correspond to clusters, and for each cell we can probabilistically sample a cluster that the cell belongs to. This is probabilistic, as the GMM components have overlap, and so for intermediate cell states, we can't be quite sure which cluster the cell belongs to.

One important issue with GMMs is that due to random initialization of the NN weights and the centers of the GMM components, we can't know a priori which cluster (0, 1, 2, 3) is going to belong to which cell type (T1, T2, MZ, FM). 
At this point, we might look at the mean marker expression of clusters, and link clusters to meaningful cell types. In our case with simulated data, we can just compare the inferred cluster with the ground truth cell type, and come up with a matching between them. For example: 0 -> T2, 1 -> MZ, 2 -> T1, 3 -> FM.

For visualization in cells below, we define a `permutation` such that we can re-order model output to match with the ground truth.


In [None]:
## infer cell types 

# get the latent space representation of all cells
zs_tensor = vae.dimension_reduction(xs_tensor, ts_tensor).cpu()
zs_raw = zs_tensor.numpy()

# do PCA on the latent space
pca = PCA(n_components=3)
pca.fit(zs_raw)
zs_pca = pca.transform(zs_raw)

# the VAEgmmdyn class has a classifier method to assign clusters to cells
with torch.no_grad():
    clus = vae.classifier(xs_tensor, ts_tensor)
clus = clus.cpu().numpy()

# find the ground truth cell types that are most represented in each inferred cluster
# make a dictionary to translate cluster indices to cell type names
translation_dict = {}

for clus_idx in range(vae.num_clus):
    # find overlap between inferred clusters and ground truth cell types
    celltype_counts = {
        ct : np.sum((clus == clus_idx) & (celltypes == ct))
        for ct in celltypes_unique
    }
    # find the most common cell type in this cluster
    most_common_ct = max(celltype_counts, key=celltype_counts.get)
    count = celltype_counts[most_common_ct]
    print(f"Cluster {clus_idx}: most common cell type: {most_common_ct} ({count} cells)")
    translation_dict[clus_idx] = most_common_ct

# map inferred cluster indices to cell type names
clus_names = [translation_dict[clus_idx] for clus_idx in clus]

# define a permutation of cluster indices based on alphabetical 
# order of cell type names
permutation = sorted(translation_dict.keys(), key=lambda k: translation_dict[k])

# re-index clusters to have consistent colors
new_clus = np.zeros_like(clus)
for new_idx, clus_idx in enumerate(permutation):
    new_clus[clus == clus_idx] = new_idx
clus = new_clus

# define a color for each cell based on inferred cell type
cols = [cell_type_colors[nm] for nm in clus_names]

# visualize inferred cell types in PCA space
fig, axs = plt.subplots(2, 2, figsize=(6, 6))

ax = axs[0, 0]
ax.scatter(zs_pca[:, 0], zs_pca[:, 1], s=3, linewidths=0, c=cols)
ax.set(xlabel='PCA 1', ylabel='PCA 2')
ax = axs[1, 0]
ax.scatter(zs_pca[:, 0], zs_pca[:, 2], s=3, linewidths=0, c=cols)
ax.set(xlabel='PCA 1', ylabel='PCA 3')
ax = axs[0, 1]
ax.scatter(zs_pca[:, 2], zs_pca[:, 1], s=3, linewidths=0, c=cols)
ax.set(xlabel='PCA 3', ylabel='PCA 2')

axs[1, 1].set_visible(False)

fig.tight_layout()

### Inspect parameter estimates

Next, we'll check what parameter values we actually estimated.
The dynamical model is quite simple and containst only initial values ($X_0$)
and differentiation rates $Q_{ij}$ with $i \neq j $ (called `Qoffdiag` in the model).

We will use the `Predictive` class from Pyro to simulate our model, and sample from the posterior.

In [None]:
## sample Q and X0 values from the fitted model

from pyro.infer import Predictive

# setup the Predictive object to sample from the posterior
predictive = Predictive(
    vae.model, guide=vae.guide, num_samples=100, 
    return_sites=["X0", "Qoffdiag"],
    parallel=True
)

# sample from the posterior distribution
samples = predictive(
    xs_tensor, 
    xtime_tensor, 
    utime_tensor, 
    N=torch.tensor(n_samples, device=DEVICE)
)

# extract Q and X0 samples
Qoffdiag_samples = samples["Qoffdiag"].cpu().numpy().squeeze()
weights_samples = samples["X0"].cpu().numpy().squeeze()

# apply the permutation to the Q matrices and weights
weights_samples_permuted = weights_samples[:, permutation]

# (we'll do this permutation for Q matrices below)


Plot the values of $X_0$ (the initial consitions)
together with the ground truth.

In [None]:
fig, ax = plt.subplots(figsize=(2, 2))

# plot violin plots of inferred initial weights
violins = ax.violinplot(
    weights_samples_permuted, 
    showmeans=False, showextrema=False
)
for pc in violins['bodies']:
    pc.set(facecolor='k', alpha=1)

ax.set(xlabel="cell type", ylabel="Fraction", 
       title="Initial conditions $X_0$")

# compare with true weights at time 0
cts_t0 = dataset[dataset['Timepoint'] == utime_raw[0]]['CellType'].values

true_weights = [
    np.sum(cts_t0 == ct) / len(cts_t0) 
    for ct in sorted(celltypes_unique)
]

pos = np.arange(1, vae.num_clus + 1)
ax.scatter(pos, true_weights, s=10, color='red', 
           label='True Weights', zorder=1)

ax.set_xticks(pos)
ax.set_xticklabels(celltypes_unique)

pass

### Plot the Q matrix

We've saved the ground truth differentiation rates in a file in the data folder.
Let's see how well we were able to reconstruct these values..

In [None]:
from scdynsys.dynamic_model import build_Q_mat

# import ground truth Q matrix

file_Q = '../data/simulated_time_series_Q_matrix.csv'
Q_gt = pd.read_csv(file_Q, index_col=0).values

# set diagonal entries to zero for better visualization
np.fill_diagonal(Q_gt, 0.0)

# plot the inferred Q
Qoffdiag_samples = samples["Qoffdiag"].cpu().squeeze()
Q_samples = build_Q_mat(Qoffdiag_samples).numpy()

# apply permutation to Q matrices
# (to match inferred clusters to ground truth cell types)
Q_samples_permuted = Q_samples[:, permutation, :][:, :, permutation]

meanQ = np.mean(Q_samples_permuted, axis=0)
# set diagonal entries to zero 
np.fill_diagonal(meanQ, 0)

fig, axs = plt.subplots(1, 2, figsize=(6, 3))
ax = axs[0]
im = ax.pcolor(meanQ, cmap='hot', vmin=0, vmax=np.max(Q_gt))
fig.colorbar(im, ax=ax)

# plot the ground true Q matrix for comparison
ax = axs[1]
im = ax.pcolor(Q_gt, cmap='hot', vmin=0, vmax=np.max(Q_gt))
fig.colorbar(im, ax=ax)

axs[0].set_title("Inferred Q matrix")
axs[1].set_title("Ground Truth Q matrix")

# set axis labels and ticks: cell type names
for ax in axs:
    ax.set_xlabel("From Cluster")
    ax.set_ylabel("To Cluster")
    ticks = np.arange(0.5, vae.num_clus + 0.5)
    ax.set(xticks=ticks, yticks=ticks)
    ax.set(xticklabels=(ctu:=celltypes_unique), yticklabels=ctu)

fig.tight_layout()


### Timecourse of latent space

Next, we will plot the latent cell distributions as a function of time.

In [None]:
## plot latents space embeddings for each time points

zs_tensor = vae.dimension_reduction(xs_tensor, ts_tensor).cpu()
zs_raw = zs_tensor.numpy()

# do PCA on the latent space
pca = PCA(n_components=3)
pca.fit(zs_raw)
zs_pca = pca.transform(zs_raw)

# plot PCA 1 and 2 in different panels as a function of time
fig, axs = plt.subplots(
    3, num_timepoints, figsize=(3*num_timepoints, 9), 
    sharey=True, sharex=True
)

# prepare colors for splitting by time point
cols_array = np.array(cols)

for t_idx in range(num_timepoints):
    ax1, ax2, ax3 = axs[:, t_idx]
    time_mask = (xtime_raw == t_idx)
    ax1.scatter(zs_pca[time_mask, 0], zs_pca[time_mask, 1], s=2, alpha=1,
                   linewidths=0, c=cols_array[time_mask])
    ax2.scatter(zs_pca[time_mask, 0], zs_pca[time_mask, 2], s=2, alpha=1,
                   linewidths=0, c=cols_array[time_mask])
    ax3.scatter(zs_pca[time_mask, 2], zs_pca[time_mask, 1], s=2, alpha=1,
                   linewidths=0, c=cols_array[time_mask])
    ax1.set_title(f'Time {utime_raw[t_idx]}')
    ax1.set(xlabel='PCA 1', ylabel='PCA 2')
    ax2.set(xlabel='PCA 1', ylabel='PCA 3')
    ax3.set(xlabel='PCA 3', ylabel='PCA 2')


### Plot the inferred trajectories

In addition to the time-dependent latent space plot above,
we can also look how well the inferred trajectories correspond with
the ground truth. We saved the simulated trajectories in a file
in the data folder.

The `VAEgmmdyn` class provides a method to simulate trajecties.

In [None]:
# import ground truth weights for comparison

file = '../data/simulated_time_series_ground_truth_weights.csv'
ground_truth_weights = pd.read_csv(file)

# simulate trajectories from the fitted model
ws_samples = vae.sample_trajectories(
    n=100, ts=utime_tensor, 
)

# convert to numpy array
ws_samples_raw = ws_samples.cpu().numpy()

# apply permutation to the weights
ws_samples_raw = ws_samples_raw[:, :, permutation]

# compute mean and credible intervals
mean_ws = np.mean(ws_samples_raw, axis=0)
lower_ws = np.percentile(ws_samples_raw, 5, axis=0)
upper_ws = np.percentile(ws_samples_raw, 95, axis=0)

fig, ax = plt.subplots(figsize=(6, 4))
for i in range(vae.num_clus):
    ct = translation_dict[permutation[i]]
    ct_color = cell_type_colors[ct]
    ax.plot(utime_raw, mean_ws[:, i], label=ct, color=ct_color)
    ax.fill_between(utime_raw, lower_ws[:, i], upper_ws[:, i], 
                    color=ct_color, alpha=0.3)

    ax.scatter(ground_truth_weights['Timepoint'], 
               ground_truth_weights[f'Cluster_{i}_Weight'],
               color=ct_color, marker='x')
    
ax.set(xlabel="Time", ylabel="Cluster Weights")
ax.set_title("Inferred GMM Cluster Weights over Time")

# add ground truth legend
ax.scatter([], [], color='k', marker='x', label='Ground Truth')

ax.legend(fontsize='small', loc='upper left')

### Show differentiation pathways between clusters

Next, we will give a visual representation of how cells develop. For this, we will plot the cluster centers (locations) in the latent space (after applying PCA),
together with arrows indicating differentiation between the cell types.

This is essentially the same information of the Q matrix plots above,
but of you'd have more clusters, you could investigate how cell populations
differentiate to neighboring (or distant) other cell types. 
Here distance is defined in the latent cell state space.

In [None]:
# make a diagram with differentation rates and
# the cluster locations in latent space

loc_pred = Predictive(vae.mix, guide=vae.mix.guide,
                      num_samples=100, parallel=True, 
                      return_sites=["clus_locs"])
loc_samples = loc_pred()["clus_locs"].cpu().numpy()

# apply the PCA transformation to the loc samples
locs_pca = pca.transform(loc_samples.reshape(-1, vae.z_dim))
locs_pca = locs_pca.reshape(-1, vae.num_clus, vae.z_dim)

# create the plot
fig, axs = plt.subplots(1, 2, figsize=(8, 4))

# plot the estimated arrows between the cluster centers in PCA space
ax = axs[0]

# aux function to plot locs in PCA space
def plot_locs(ax, locs_pca):
    for i in range(vae.num_clus):
        ct = translation_dict[permutation[i]]
        ct_color = cell_type_colors[ct]
        ax.scatter(locs_pca[:, i, 0], locs_pca[:, i, 1], s=5, alpha=0.5,
                color=ct_color, label=ct)
        
    ax.set(xlabel="PCA 1", ylabel="PCA 2")
    ax.legend(fontsize='small')

plot_locs(axs[0], locs_pca)

# plot arrows between the locs according to the Q matrix
Q_mean = Q_samples_permuted.mean(axis=0)

# aux function to plot arrows
def plot_arrows(ax, Q_mean, locs_pca, threshold=0.05):
    for i in range(vae.num_clus):
        for j in range(vae.num_clus):
            if i == j or Q_mean[i, j] <= threshold:
                continue
            start = np.mean(locs_pca[:, i, 0]), np.mean(locs_pca[:, i, 1])
            end = np.mean(locs_pca[:, j, 0]), np.mean(locs_pca[:, j, 1])
            # make sure that the width of the arrows 
            # is proportional to the Q values
            ap =dict(
                arrowstyle="->", color='k', lw=Q_mean[i, j]*5,
                shrinkA=10, shrinkB=10, connectionstyle="arc3,rad=0.1"
            )
            # plot the arrow
            ax.annotate(
                "",
                xy=start, xycoords='data',
                xytext=end, textcoords='data',
                arrowprops=ap
            )

plot_arrows(axs[0], Q_mean, locs_pca)

# now plot the ground truth arrows

plot_locs(axs[1], locs_pca)
plot_arrows(axs[1], Q_gt, locs_pca)

axs[0].set_title("inferred differentiation rates")
axs[1].set_title("ground truth differentiation rates")

fig.tight_layout()

## 4. Restrict possible differentiation pathways. 

As you can see in the above diagram, The model predicts differentiation pathways that we did not put into the simulated data. Our fitted model allows for $4\times 3 = 12$ possible differentiation pairs (even loops). 

This freedom might not be biologically plausible, and maybe you have some information to restrict some of these pathways. The goal is to put such restrictions in the model. 

The simplest way to do this is to multiply the matrix $Q$ with a binary "mask" matrix $M$, with a 1 indicating that a path is allowed and a 0 that it is not feasible. Note that the multiplication of $Q$ and $M$ is elementwise.

1. Modify the dynamical model to accept a mask matrix $M$.
2. Add the mask $M$ to the initialization function of the VAE, and make sure it finds it's way to the right functions.
3. Figure out which index pair of $M$ correspond with a cluster index in the model. What happens if you reset parameters and fit the restricted model directly?
4. Experiment with first fitting the unrestricted model ($M_{ij} = 1$), and then re-fitting with a different $M$. 
5. Compare different resticted models in terms of ELBO and goodness of fit (visually). Can you find/select the best fitting model?

In [None]:
# your code here