In [1]:
from datamodules import FMDataset, fm_collate, CFMDataset, cfm_collate, torch_wrapper
from arch import *
from flow_utils import compute_conditional_flow

import scvi
import torch
import numpy as np
import pytorch_lightning as pl
import os

import scanpy as sc

from torchcfm.conditional_flow_matching import *
import scanpy as sc
from sklearn.decomposition import PCA

In [2]:
# load some data
adata = sc.read_h5ad('../../../dlesman/datasets/cellot/kaggle_HVG_cellot.h5ad')

# here we set up the train/eval and control/pert sets
# set the idx of the controls
control_idx = adata.obs['sm_name'] == 'Dimethyl Sulfoxide'
# set the idx of the perts (currently just "all not control")
pert_idx = adata.obs['sm_name'] != 'Dimethyl Sulfoxide'
# set the hold out cell-type/pert
eval_cell_idx = adata.obs.cell_type == 'B cells'
eval_pert_idx = adata.obs['sm_name'] == 'Belinostat'
eval_idx = eval_cell_idx & eval_pert_idx

In [3]:
# here we set up our embeddings for cfm
# this is just so everything lives in obsm for the for loop below
adata.obsm["X"] = adata.X

# this is an example of how we can embed something using just the train idxs
# and then run fm on that embedding
embedder = PCA(n_components=30).fit(adata.X[(control_idx | pert_idx) & ~eval_idx])
adata.obsm["X_pca"] = embedder.transform(adata.X)

In [4]:
# here we set up the perts
import pandas as pd
perts = pd.get_dummies(adata.obs['sm_name']).values.astype(float)
pert_ids = perts.argmax(axis=1)
# this is the "identity featurization"; we can swap this matrix for
# any latent representation of perturbations we want but this is 
# a non-parametric featurization for right now
pert_mat = np.eye(pert_ids.max() + 1).astype('float32')

In [None]:
# if we add another embedding we can just add it to the list here
for embedding in ["X", "X_pca"]:
    print(embedding)
    # set X to the (latent) counts
    X = adata.obsm[embedding]
    
    # set train and eval split
    control_train = X[control_idx & ~eval_idx]
    pert_train = X[pert_idx & ~eval_idx]
    pert_ids_train =  pert_ids[pert_idx & ~eval_idx]
    
    control_eval = X[control_idx & eval_cell_idx]
    pert_eval = X[eval_idx]
    pert_ids_eval = pert_ids[eval_idx]

    # set up data processing for cfm
    dset = CFMDataset(control_train, pert_train, pert_ids_train, pert_mat)
    dl = torch.utils.data.DataLoader(dset, batch_size=32, collate_fn=cfm_collate)
    
    # Train the model
    trainer = pl.Trainer(
        gpus=1,  # Specify the number of GPUs to use
        max_epochs=10,  # Specify the maximum number of training epochs
    )
    model = CMHA(feat_dim=X.shape[1], cond_dim=pert_mat.shape[1], time_varying=True)
    trainer.fit(model, dl)

    # Save the predicted perturbation
    torch.cuda.empty_cache()
    traj = compute_conditional_flow(model, control_eval, pert_ids[pert_idx & eval_idx], pert_mat)

    np.savez(
        f"kaggle.{embedding}.identity.npz", 
        pred_pert=traj[-1, :, :], true_pert=pert_eval, control=control_eval
    )

X


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | feat_net  | Sequential         | 148 K 
1 | cond_net  | Sequential         | 9.5 K 
2 | mha1      | MultiheadAttention | 148 K 
3 | mha2      | MultiheadAttention | 148 K 
4 | combo_net | Sequential         | 146 K 
-------------------------------------------------
601 K     Trainable params
0         Non-trainable params
601 K     Total params
2.4

Training: 0it [00:00, ?it/s]