In [1]:
import os 
import pytorch_lightning as pl
import seml
import numpy as np
import torch
from sacred import SETTINGS, Experiment
from functools import partial

import pandas as pd
import anndata

import scanpy as sc
import scvelo as scv
import cellrank as cr

from PerturbSeq_CMV.datamodules.distribution_datamodule import TrajectoryDataModule
from PerturbSeq_CMV.models.cfm_module import CFMLitModule
from PerturbSeq_CMV.models.components.augmentation import AugmentationModule
from PerturbSeq_CMV.models.components.simple_mlp import VelocityNet

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from torch.optim import AdamW

import yaml

import sys 
sys.path.insert(0, "../../../" )
from paths import EXPERIMENT_FOLDER

## Util functions 

In [2]:
def add_velocity_to_adata(adata, model):
    # Put model in evaluation mode
    model.eval()
    velocities = []
    with torch.no_grad():
        for i, x in enumerate(adata.X.A):
            t = torch.tensor(adata.obs.experimental_time[i]).view(1, -1).float()
            x = torch.from_numpy(x).to(model.device).view(1, -1).float()
            dx_dt = model(t,x)
            velocities.append(dx_dt.cpu().numpy())
    velocities = np.concatenate(velocities, axis=0)

    adata.layers["velocity"] = velocities

## Run model

In [3]:
with open("/nfs/homedirs/pala/PerturbSeq_CMV/configs/datasets_standard_run/unperturbed_time_course_low.yaml", "r") as stream:
    hparams = yaml.safe_load(stream)["fixed"]

Initialization

In [4]:
task_name = hparams["training.training"]["task_name"]
        
# Fix seed for reproducibility
torch.manual_seed(hparams["training.training"]["seed"])      
if hparams["training.training"]["seed"]: 
    pl.seed_everything(hparams["training.training"]["seed"], workers=True)

# Initialize folder 
current_experiment_dir = EXPERIMENT_FOLDER / task_name
current_experiment_dir.mkdir(parents=True, exist_ok=True) 
    

# Initialize datamodule
datamodule = TrajectoryDataModule(**hparams["datamodule.datamodule"])
    

# Initialize augmentations
augmentations = AugmentationModule(**hparams["augmentations.augmentations"])
         

# Neural network 
net = partial(VelocityNet, **hparams["net.net"])   


# Initialize the model 
model = CFMLitModule(
                    net=net,
                    datamodule=datamodule,
                    augmentations= augmentations, 
                    **hparams["model.model"]
                    ) 

[rank: 0] Global seed set to 42


Training

In [5]:
# Initialize callbacks 
model_ckpt_callbacks = ModelCheckpoint(dirpath=current_experiment_dir / "checkpoints", 
                                        **hparams["model_checkpoint.model_checkpoint"])


# Initialize callbacks 
early_stopping_callbacks = EarlyStopping(**hparams["early_stopping.early_stopping"])
        

# Initialize logger 
logger = WandbLogger(save_dir=current_experiment_dir, 
                     **hparams["logger.logger"]) 

# Initialize the lightning trainer 
trainer = Trainer(default_root_dir=current_experiment_dir,
                  callbacks=[model_ckpt_callbacks, early_stopping_callbacks], 
                  logger=logger, 
                  **hparams["trainer.trainer"])
        

# Fit the model 
trainer.fit(model=model, 
                  train_dataloaders=datamodule.train_dataloader(),
                  val_dataloaders=datamodule.val_dataloader())
train_metrics = trainer.callback_metrics

# Test model 
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
    ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
test_metrics = trainer.callback_metrics

# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mallepalma[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                 | Params
-----------------------------------------------------------
0 | net               | VelocityNet          | 471 K 
1 | augmentations     | AugmentationModule   | 0     
2 | val_augmentations | AugmentationModule   | 0     
3 | aug_net           | AugmentedVectorField | 471 K 
4 | val_aug_net       | AugmentedVectorField | 471 K 
5 | node              | NeuralODE            | 471 K 
6 | aug_node          | Sequential           | 471 K 
7 | val_aug_node      | Sequential           | 471 K 
8 | criterion         | MSELoss              | 0     
-----------------------------------------------------------
471 K     Trainable params
0    

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

torch.Size([3072, 3588])
tensor([0.0000, 0.0100, 0.0200, 0.0300, 0.0400, 0.0500, 0.0600, 0.0700, 0.0800,
        0.0900, 0.1000, 0.1100, 0.1200, 0.1300, 0.1400, 0.1500, 0.1600, 0.1700,
        0.1800, 0.1900, 0.2000, 0.2100, 0.2200, 0.2300, 0.2400, 0.2500, 0.2600,
        0.2700, 0.2800, 0.2900, 0.3000, 0.3100, 0.3200, 0.3300, 0.3400, 0.3500,
        0.3600, 0.3700, 0.3800, 0.3900, 0.4000, 0.4100, 0.4200, 0.4300, 0.4400,
        0.4500, 0.4600, 0.4700, 0.4800, 0.4900, 0.5000, 0.5100, 0.5200, 0.5300,
        0.5400, 0.5500, 0.5600, 0.5700, 0.5800, 0.5900, 0.6000, 0.6100, 0.6200,
        0.6300, 0.6400, 0.6500, 0.6600, 0.6700, 0.6800, 0.6900, 0.7000, 0.7100,
        0.7200, 0.7300, 0.7400, 0.7500, 0.7600, 0.7700, 0.7800, 0.7900, 0.8000,
        0.8100, 0.8200, 0.8300, 0.8400, 0.8500, 0.8600, 0.8700, 0.8800, 0.8900,
        0.9000, 0.9100, 0.9200, 0.9300, 0.9400, 0.9500, 0.9600, 0.9700, 0.9800,
        0.9900, 1.0000])
torch.Size([3072, 3588])
tensor([0.0000, 0.0100, 0.0200, 0.0300, 0.040

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.15 GiB (GPU 0; 10.92 GiB total capacity; 9.01 GiB already allocated; 1.07 GiB free; 9.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Anlyze vector field with CellRank

In [7]:
adata = sc.read_h5ad("/nfs/homedirs/pala/PerturbSeq_CMV/project_folder/data/processed/unperturbed_time_course_low.h5ad")
adata = adata[:, adata.var.highly_variable]



In [None]:
add_velocity_to_adata(adata, model)

In [None]:
vk = cr.kernels.VelocityKernel(adata,
                          xkey="X", 
                        vkey="velocity").compute_transition_matrix()

ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix()


combined_kernel = 0.8 * vk + 0.2 * ck

combined_kernel.compute_transition_matrix().compute_projection(basis="umap")

In [None]:
scv.pl.velocity_embedding_stream(adata, vkey="T_fwd", basis="umap", color="cluster")

In [None]:
gpcc = cr.estimators.GPCCA(combined_kernel)

In [None]:
gpcc.compute_schur(n_components=20)
gpcc.plot_spectrum(real_only=True)

In [None]:
gpcc.compute_macrostates(n_states=4, cluster_key="cluster")

In [None]:
gpcc.plot_macrostates(
    discrete=True, size=100, basis="umap", title="Macrostates - labeled"
)

In [None]:
gpcc.set_terminal_states_from_macrostates(names=["infected_6", "infected_abortive", "bystander", 
                                                        "naive"])
gpcc.compute_absorption_probabilities()

In [None]:
gpcc.plot_absorption_probabilities(same_plot=False, size=50, basis="umap")

## Plot trajectories 

In [None]:
idx2time = {0.0: 0.0,
            1.0: 6.0, 
            2.0: 20.0, 
            3.0: 28.0, 
            4.0: 48.0, 
            5.0: 72.0, 
            6.0: 96.0, 
            7.0: 120.0}

In [None]:
# Initialize the model 
ckpt_path = "/nfs/homedirs/pala/PerturbSeq_CMV/project_folder/experiments/unperturbed_time_course_low/checkpoints/last-v7.ckpt"


model.load_state_dict(torch.load(ckpt_path)["state_dict"])

In [None]:
trajectories = [] 

In [None]:
cell_loader = datamodule.train_dataloader()

In [None]:
# Collect trajectories 
t_interp = torch.linspace(0, 1, 1).unsqueeze(0)
t_ext = torch.arange(8).unsqueeze(1)
t_int = (t_interp+t_ext).ravel()
obs = {"experimental_time": []}

with torch.no_grad():
    for batch in cell_loader:
        batch = model.unpack_batch(batch)
        x_start = batch[:,0,:]
        times = [idx2time[idx] for idx in range(8)]
        obs["experimental_time"].append(torch.tensor(times).unsqueeze(1).repeat(1, batch.shape[0]))
        _, traj = model.val_aug_node(x_start, t_int)
        trajectories.append(traj[:,:,3:].clamp(min=0).cpu())

In [None]:
traj_cat = torch.cat(trajectories, dim=1)
obs["experimental_time"] = torch.cat(obs["experimental_time"], dim=1)

In [None]:
traj_ravel = traj_cat.view(-1, traj_cat.shape[2])
obs["experimental_time"] = obs["experimental_time"].view(-1)

In [None]:
adata_extrap = anndata.AnnData(X=traj_ravel.numpy(), 
                              obs=pd.DataFrame(obs))

In [None]:
sc.tl.pca(adata_extrap, svd_solver="arpack")
sc.pp.neighbors(adata_extrap, n_pcs=30)
sc.tl.umap(adata_extrap)

In [None]:
sc.pl.umap(adata_extrap, color="experimental_time")

## Put together real and generated

In [None]:
X_concat = np.concatenate([adata.X.A, adata_extrap.X], axis=0)
obs_concat = pd.DataFrame((pd.concat([adata.obs.experimental_time, adata_extrap.obs.experimental_time])))
obs_concat["true_generated"] = np.array([1 for _ in range(adata.n_obs)] + 
                                        [0 for _ in range(adata_extrap.n_obs)])

In [None]:
adata_joint = anndata.AnnData(X=X_concat,
                             obs=obs_concat)

In [None]:
adata_joint.obs

In [None]:
sc.tl.pca(adata_joint, svd_solver="arpack")
sc.pp.neighbors(adata_joint, n_pcs=30)
sc.tl.umap(adata_joint)

In [None]:
sc.pl.umap(adata_joint, color=["experimental_time", "true_generated"])