In [None]:
import os 
import pytorch_lightning as pl
import seml
import numpy as np
from tqdm import tqdm
import torch

from sacred import SETTINGS, Experiment
from functools import partial

import scanpy as sc
import scvelo as scv
import cellrank as cr
import pandas as pd

from torchdyn.core import NeuralODE

from scCFM.datamodules.sc_datamodule import scDataModule
from scCFM.datamodules.time_sc_datamodule import TrajectoryDataModule
from scCFM.models.base.vae import VAE, AE
from scCFM.models.cfm.cfm_module import CFMLitModule
from scCFM.models.cfm.components.simple_mlp import VelocityNet
 
from conditional_flow_matching import *

from torch.optim import AdamW

import yaml

import sys 
sys.path.insert(0, "../../../" )
from paths import EXPERIMENT_FOLDER

from conditional_flow_matching import *

## Import configurations

In [184]:
with open("/nfs/homedirs/pala/scCFM/configs/ae/eb/config.yaml", "r") as stream:
    hparams_ae = yaml.safe_load(stream)["fixed"]

In [185]:
class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))


## Initialize and train/load autoencoder 

In [186]:
pretrained_ae = True
pretrained_ckpt_ae = "/nfs/students/pala/scCFM/experiments/ae/eb_vae/checkpoints/epoch_0033.ckpt"

In [187]:
task_name = hparams_ae["training.training"]["task_name"]
        
# Fix seed for reproducibility
torch.manual_seed(hparams_ae["training.training"]["seed"])      
if hparams_ae["training.training"]["seed"]: 
    pl.seed_everything(hparams_ae["training.training"]["seed"], workers=True)

# Initialize folder 
current_experiment_dir = EXPERIMENT_FOLDER / "ae" / task_name
current_experiment_dir.mkdir(parents=True, exist_ok=True) 
    

# Initialize datamodule
datamodule = scDataModule(**hparams_ae["datamodule.datamodule"])


# Initialize the model 
ae_model = VAE(in_dim = datamodule.dim,
            **hparams_ae["model.model"]
            ) 
        
if not pretrained_ae:
    # Initialize callbacks 
    model_ckpt_callbacks = ModelCheckpoint(dirpath=current_experiment_dir / "checkpoints", 
                                            **hparams_ae["model_checkpoint.model_checkpoint"])


    # Initialize callbacks 
    early_stopping_callbacks = EarlyStopping(**hparams_ae["early_stopping.early_stopping"])


    # Initialize logger 
    logger = WandbLogger(save_dir=current_experiment_dir / "logs", 
                         **hparams_ae["logger.logger"]) 


    # Initialize the lightning trainer 
    trainer = Trainer(default_root_dir=current_experiment_dir,
                      callbacks=[model_ckpt_callbacks, early_stopping_callbacks], 
                      logger=logger, 
                      **hparams_ae["trainer.trainer"])


    # # Fit the model 
    trainer.fit(model=ae_model, 
                      train_dataloaders=datamodule.train_dataloader(),
                      val_dataloaders=datamodule.val_dataloader())
    train_metrics = trainer.callback_metrics

else:
    ae_model.load_state_dict(torch.load(pretrained_ckpt_ae)["state_dict"])

**Analyze autoencoder latent space**

In [188]:
z_cells = []
annot = []
with torch.no_grad():
    for batch in datamodule.train_dataloader():
        annot.append(batch["cond"])
        mu = ae_model.encode(batch["X"])["z"]
        z_cells.append(mu)

z_cells= torch.cat(z_cells, dim=0)
annot = pd.DataFrame(torch.cat(annot).numpy())
annot.columns = ["experimental_time"]

adata_latent = sc.AnnData(X=z_cells.cpu().numpy(), 
                  obs=annot)

In [189]:
sc.tl.pca(adata_latent)
sc.pp.neighbors(adata_latent)
sc.tl.umap(adata_latent)

In [190]:
sc.pl.umap(adata_latent, color="experimental_time")

## Perform CFM training like in the notebook

In [191]:
class MLP(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x)

First we standardize the coordinates

In [192]:
times = sorted(adata_latent.obs["experimental_time"].unique())
n_times = len(times)
# Standardize coordinates
coords = adata_latent.X
coords = (coords - coords.mean(axis=0)) / coords.std(axis=0)
adata_latent.layers["X_standardized"] = coords
X = [
    adata_latent.layers["X_standardized"][adata_latent.obs["experimental_time"] == t]
    for t in times
]

In [193]:
# Shape of the time batches
[i.shape for i in X]

In [194]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 256
sigma = 0.1
dim = 64
model = MLP(dim=dim, time_varying=True, w=64).to(device)
# score_model = MLP(dim=dim, time_varying=True, w=64)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
# FM = ConditionalFlowMatcher(sigma=sigma)
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)

Batch collecting function

In [195]:
def get_batch(FM, X, batch_size, n_times, return_noise=False):
    """Construct a batch with point sfrom each timepoint pair"""
    ts = []
    xts = []
    uts = []
    noises = []
    for t_start in range(n_times - 1):
        x0 = (
            torch.from_numpy(X[t_start][np.random.randint(X[t_start].shape[0], size=batch_size)])
            .float()
            .to(device)
        )
        x1 = (
            torch.from_numpy(
                X[t_start + 1][np.random.randint(X[t_start + 1].shape[0], size=batch_size)]
            )
            .float()
            .to(device)
        )
        if return_noise:
            t, xt, ut, eps = FM.sample_location_and_conditional_flow(
                x0, x1, return_noise=return_noise
            )
            noises.append(eps)
        else:
            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1, return_noise=return_noise)
        ts.append(t + t_start)
        xts.append(xt)
        uts.append(ut)
    t = torch.cat(ts)
    xt = torch.cat(xts)
    ut = torch.cat(uts)
    if return_noise:
        noises = torch.cat(noises)
        return t, xt, ut, noises
    return t, xt, ut

Train OT cfm

In [None]:
for i in tqdm(range(10000)):
    optimizer.zero_grad()
    t, xt, ut = get_batch(FM, X, batch_size, n_times)
    vt = model(torch.cat([xt, t[:, None]], dim=-1))
    loss = torch.mean((vt - ut) ** 2)
    loss.backward()
    if i % 20 == 0:
        print(loss)
    optimizer.step()

**Check the streamplots**

In [None]:
def add_velocity_to_adata(adata, model):
    # Put model in evaluation mode
    model.eval()
    velocities = []
    with torch.no_grad():
        for i, x in enumerate(adata.X):
            t = torch.tensor(adata.obs.experimental_time[i]).view(1, -1).float().cuda()
            x = torch.from_numpy(x).view(1, -1).float().cuda()
            dx_dt = model(torch.cat([x, t], dim=1))
            velocities.append(dx_dt.cpu().numpy())
    velocities = np.concatenate(velocities, axis=0)

    adata.layers["velocity"] = velocities

In [None]:
add_velocity_to_adata(adata_latent, model)

In [None]:
adata_latent.layers["X_latent"] = adata_latent.X.copy()

In [None]:
vk = cr.kernels.VelocityKernel(adata_latent,
                          xkey="X_latent", 
                        vkey="velocity").compute_transition_matrix()

vk.compute_projection(basis="umap")

In [None]:
scv.pl.velocity_embedding_stream(adata_latent, vkey="T_fwd", basis="umap", color="experimental_time")

**Propagate with neural ODE**

In [None]:
node = NeuralODE(
    torch_wrapper(model), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
)

# Append first time point
trajs = []
with torch.no_grad():
    X_pf = torch.from_numpy(X[0])
    trajs.append(X_pf.unsqueeze(0))
    for t in range(n_times-1):
        traj = node.trajectory(X_pf.float().to(device),
            t_span=torch.linspace(t, t+1, 400),
        ).cpu()
        X_pf = traj[-1]
        trajs.append(X_pf.unsqueeze(0))

trajs = torch.cat(trajs, dim=0)

In [None]:
X_pf.shape

In [None]:
X_pf = trajs.view(trajs.shape[0]*trajs.shape[1], -1)
times = torch.arange(5).unsqueeze(1).expand(trajs.shape[0],trajs.shape[1]).ravel()
times = pd.DataFrame(times)
times.columns = ["experimental_time"]

In [None]:
times.shape

In [None]:
adata_pf = sc.AnnData(X=X_pf.cpu().numpy(), 
                     obs=times)

In [None]:
sc.tl.pca(adata_pf)
sc.pp.neighbors(adata_pf)
sc.tl.umap(adata_pf)

In [None]:
sc.pl.umap(adata_pf, color="experimental_time")

Co-embed

In [None]:
X_total = np.concatenate([adata_latent.X, X_pf], axis=0)
dataset_type = ["True" for _ in range(adata_latent.X.shape[0])] + \
                ["False" for _ in range(X_pf.shape[0])]
dataset_type = pd.DataFrame(dataset_type)
dataset_type.columns = ["Dataset_type"]

In [None]:
adata_pf = sc.AnnData(X=X_total, 
                     obs=dataset_type)

In [None]:
sc.tl.pca(adata_pf)
sc.pp.neighbors(adata_pf)
sc.tl.umap(adata_pf)

In [None]:
sc.pl.umap(adata_pf, color="Dataset_type")

In [None]:
~ZC ZZ                                