In [1]:
import math
from tqdm import tqdm
import scvi
import torch
import numpy as np
import pytorch_lightning as pl
import os

import scanpy as sc

from pytorch_lightning.callbacks import TQDMProgressBar
from datamodules import FMDataset, fm_collate, CFMDataset, SCFMDataset, cfm_collate, StratifiedBatchSampler, ot_collate
from torch.utils.data import RandomSampler
from sc_etl_utils import *
from arch import *
from flow_utils import compute_conditional_flow
import json

import scvi
import torch
import numpy as np
import pytorch_lightning as pl
import os

import scanpy as sc

from torchcfm.conditional_flow_matching import *
import scanpy as sc
import hashlib
from llm import MAE

In [2]:
# load some data
adata = sc.read_h5ad('/orcd/archive/abugoot/001/Projects/dlesman/datasets/satija_IFNB_HVG_and_perturbed_genes_raw.h5ad')

In [3]:
gene_map = {'NT': 0} | {k: i for i, k in enumerate(adata.var.index)}
gene_unmap = {gene_map[k]: k for k in gene_map}
perts = adata.obs.gene.unique().map(gene_map)
adata.obs['pert_type'] = adata.obs.gene.map(gene_map)
pert_ids = np.array(adata.obs['pert_type'])
pert_mat = np.arange(pert_ids.max() + 1)[:, None]

In [4]:
cell_col = 'cell_type'
pert_col = 'pert_type'
control_pert, holdout_cells, holdout_perts = 0, ['HT29'], [gene_map['USP18']]

In [5]:
from sc_etl_utils import *
control_idx, pert_idx, eval_idx, eval_cell_idx, eval_pert_idx = get_train_eval_idxs(
    adata, control_pert, holdout_cells, holdout_perts, cell_col=cell_col, pert_col=pert_col
)

_, _, cell_types = get_identity_features(
    adata, cell_col=cell_col, pert_col=pert_col, cell_type_features=False
)

adata.obsm["standard"] = adata.X
X = adata.obsm["standard"]
X = X.toarray()
X = X / X.sum(axis=1)[:, None]

control_train, pert_train, pert_ids_train, control_cell_types, pert_cell_types, control_eval, pert_eval, pert_ids_eval = get_train_eval(
    X, pert_ids, cell_types, control_idx, pert_idx, eval_idx, eval_cell_idx, eval_pert_idx
)


Controls: 14582, Perturbations: 313080,  Eval: 880


In [6]:
batch_size = 128
base_learning_rate = 5e-5
weight_decay=0.0
total_epoch = 5000
warmup_epoch = 1

In [7]:
dset = SCFMDataset(
    control_train, pert_train, 
    pert_ids_train, pert_mat, 
    control_cell_types, pert_cell_types,
    batch_size=batch_size, size=X.shape[0]
)
dl = torch.utils.data.DataLoader(
    dset, collate_fn=ot_collate, 
    batch_sampler=StratifiedBatchSampler(
        RandomSampler(dset), batch_size=batch_size, drop_last=True, 
        probs=dset.probs, num_strata=dset.num_strata
    )
)

[0.09837271 0.2565357  0.2074058  0.17077354 0.08714468 0.17976757]


In [8]:
model = MAE(X.shape[1])
device = 'cuda'
# device = 'cpu'
model = model.to(device)

In [9]:
optim = torch.optim.AdamW(model.parameters(), lr=base_learning_rate, betas=(0.9, 0.95), weight_decay=weight_decay)
lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-5), 0.5 * (math.cos(epoch / total_epoch * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

Adjusting learning rate of group 0 to 5.0000e-05.


In [None]:
step_count = 0
optim.zero_grad()
pert_task = 0
for e in range(1):
    model.train()
    losses = {(0, 0): [], (0, 1): [], (1, 0): [], (1, 1): []}
    for (control, pert, pert_index) in (pbar := tqdm(iter(dl))):
        control_sparsity = (control > 0).float().to(device)
        active_weights = 10 * control_sparsity + 1
        control = control.float().to(device)
        pert = pert.float().to(device)
        step_count += 1
        mask_task = step_count % 2
        control_recon, sparsity_probs, _ = model(control, mask=mask_task)
        loss = torch.sum(active_weights * torch.abs(control_recon - control)) / batch_size
        loss += torch.sum(active_weights * torch.abs(control_sparsity - sparsity_probs)) / batch_size
        loss.backward()
        optim.step()
        optim.zero_grad()
        if step_count % 100 == 0:
            lr_scheduler.step()
        losses[(pert_task, mask_task)].append(loss.item())
        if pert_task == 0 and mask_task == 0:
            pbar.set_description(
                f"tv: {np.array(losses[(0, 0)])[-100:].mean():.3f}, mtv: {np.array(losses[(0, 1)])[-100:].mean():.3f}")
    
    avg_loss = sum(losses[(0, 0)]) / len(losses[(0, 0)])
    torch.save(model, f"transformer_v3.{e}")
    # writer.add_scalar('mae_loss', avg_loss, global_step=e)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

In [None]:
step_count = 0
optim.zero_grad()
pert_task = 0
for e in range(10):
    model.train()
    losses = {(0, 0): [], (0, 1): [], (1, 0): [], (1, 1): []}
    for (control, pert, pert_index) in (pbar := tqdm(iter(dl))):
        active_weights = 10 * (control > 0).float().to(device) + 1
        control = control.float().to(device)
        pert = pert.float().to(device)
        step_count += 1
        mask_task = step_count % 2
        pert_task = (pert_task + mask_task) % 2
        if pert_task:
            pert_expr = pert[torch.arange(pert.size(0)), pert_index[:, 0], None]
            pert_recon, _, _ = model(control, pert_expr=pert_expr, pert_index=pert_index[:, 0], mask=mask_task)
            loss = torch.sum(active_weights * torch.abs(pert_recon - pert)) / batch_size
        else:
            control_recon, _, _ = model(control, mask=mask_task)
            loss = torch.sum(active_weights * torch.abs(control_recon - control)) / batch_size
        loss.backward()
        optim.step()
        optim.zero_grad()
        if step_count % 100 == 0:
            lr_scheduler.step()
        losses[(pert_task, mask_task)].append(loss.item())
        if pert_task == 0 and mask_task == 0:
            pbar.set_description(
                f"tv: {np.array(losses[(0, 0)])[-100:].mean():.3f}, mtv: {np.array(losses[(0, 1)])[-100:].mean():.3f}, ptv: {np.array(losses[(1, 0)])[-100:].mean():.3f}, pmtv: {np.array(losses[(1, 1)])[-100:].mean():.3f}")
    
    avg_loss = sum(losses[(0, 0)]) / len(losses[(0, 0)])
    torch.save(model, f"transformer_v4.{e}")
    # writer.add_scalar('mae_loss', avg_loss, global_step=e)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

In [None]:
step_count = 0
optim.zero_grad()
pert_task = 0
for e in range(10):
    model.train()
    losses = {(0, 0): [], (0, 1): [], (1, 0): [], (1, 1): []}
    for (control, pert, pert_index) in (pbar := tqdm(iter(dl))):
        active_weights = 1 # 10 * (control > 0).float().to(device) + 1
        control = control.float().to(device)
        pert = pert.float().to(device)
        step_count += 1
        mask_task = step_count % 2
        pert_task = (pert_task + mask_task) % 2
        if pert_task:
            pert_expr = pert[torch.arange(pert.size(0)), pert_index[:, 0], None]
            pert_recon, _, _ = model(control, pert_expr=pert_expr, pert_index=pert_index[:, 0], mask=mask_task)
            loss = torch.sum(active_weights * torch.abs(pert_recon - pert)) / batch_size
        else:
            control_recon, _, _ = model(control, mask=mask_task)
            loss = torch.sum(active_weights * torch.abs(control_recon - control)) / batch_size
        loss.backward()
        optim.step()
        optim.zero_grad()
        if step_count % 100 == 0:
            lr_scheduler.step()
        losses[(pert_task, mask_task)].append(loss.item())
        if pert_task == 0 and mask_task == 0:
            pbar.set_description(
                f"tv: {np.array(losses[(0, 0)])[-100:].mean():.3f}, mtv: {np.array(losses[(0, 1)])[-100:].mean():.3f}, ptv: {np.array(losses[(1, 0)])[-100:].mean():.3f}, pmtv: {np.array(losses[(1, 1)])[-100:].mean():.3f}")
    
    avg_loss = sum(losses[(0, 0)]) / len(losses[(0, 0)])
    torch.save(model, f"transformer_v5.{e}")
    # writer.add_scalar('mae_loss', avg_loss, global_step=e)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

In [None]:
cell_type_names = adata.obs[cell_col]
pert_type_names = adata.obs[pert_col]
# Save the predicted perturbation
for cell_type, pert_type in zip(holdout_cells, holdout_perts):
    control_eval = X[cell_type_names == cell_type]
    true_pert= X[(pert_type_names == pert_type) & (cell_type_names == cell_type)]
    # cheating right now, just trying to get a mvp together
    pert_expr = true_pert[torch.arange(true_pert.size(0)), pert_type, None]
    pred_pert, _, _ = model(
        control, pert_expr=pert_expr, pert_index=pert_type, mask=False
    )
    print(f"Saving {pert_type} predictions")
    np.savez(
        f"Satija_IFNB_HVG/llm/pred_{gene_unmap[pert_type]}_{cell_type}.npz", 
        pred_pert=pred_pert[:, hvg_idx].cpu().detach().numpy(), 
        true_pert=true_pert[:, hvg_idx], 
        control=control_eval[:, hvg_idx],
    )