In [1]:
import math
from tqdm import tqdm
import scvi
import torch
import numpy as np
import pytorch_lightning as pl
import os

import scanpy as sc

from pytorch_lightning.callbacks import TQDMProgressBar
from datamodules import  SCFMDataset, cfm_collate, StratifiedBatchSampler, ot_collate
from torch.utils.data import RandomSampler
from sc_etl_utils import *
from arch import *
from flow_utils import compute_conditional_flow
import json

import scvi
import torch
import numpy as np
import pytorch_lightning as pl
import os

import scanpy as sc

from torchcfm.conditional_flow_matching import *
import scanpy as sc
import hashlib
from llm import MAE

In [19]:
# load some data
adata = sc.read_h5ad('/orcd/archive/abugoot/001/Projects/dlesman/datasets/satija_IFNB_HVG_and_perturbed_genes_raw.h5ad')

In [20]:
gene_map = {k: i for i, k in enumerate(adata.var.index)}
gene_map = gene_map | {'NT': max(gene_map.values()) + 1}
gene_unmap = {gene_map[k]: k for k in gene_map}
perts = adata.obs.gene.unique().map(gene_map)
adata.obs['pert_type'] = adata.obs.gene.map(gene_map)
pert_ids = np.array(adata.obs['pert_type'])
pert_mat = np.arange(pert_ids.max() + 1)[:, None]

In [23]:
cell_col = 'cell_type'
pert_col = 'pert_type'
control_pert, holdout_cells, holdout_perts = gene_map['NT'], ['HT29'], [gene_map['USP18']]

In [24]:
from sc_etl_utils import *
control_idx, pert_idx, eval_idx, eval_cell_idx, eval_pert_idx = get_train_eval_idxs(
    adata, control_pert, holdout_cells, holdout_perts, cell_col=cell_col, pert_col=pert_col
)

_, _, cell_types = get_identity_features(
    adata, cell_col=cell_col, pert_col=pert_col, cell_type_features=False
)

adata.obsm["standard"] = adata.X
X = adata.obsm["standard"]
X = X.toarray()
X = X / X.sum(axis=1)[:, None]

control_train, pert_train, pert_ids_train, control_cell_types, pert_cell_types, control_eval, pert_eval, pert_ids_eval = get_train_eval(
    X, pert_ids, cell_types, control_idx, pert_idx, eval_idx, eval_cell_idx, eval_pert_idx
)


Controls: 14582, Perturbations: 313080,  Eval: 880


In [26]:
batch_size = 512
dset = SCFMDataset(
    control_train, pert_train, 
    pert_ids_train, pert_mat, 
    control_cell_types, pert_cell_types,
    batch_size=batch_size, size=X.shape[0]
)
ns = np.array([[t.shape[0] for t in ts] for ts in dset.target])
dl = torch.utils.data.DataLoader(
    dset, collate_fn=ot_collate, 
    batch_sampler=StratifiedBatchSampler(
        ns=ns, batch_size=512
    )
)

Strata probs [0.         0.0001086  0.00030344 0.00039606 0.00040565 0.00041523
 0.00042162 0.00043439 0.00045356 0.00045995 0.00049189 0.00052063
 0.00054938 0.00056535 0.00061007 0.00062284 0.00062604 0.00062604
 0.00063562 0.00068673 0.00075061 0.00076977 0.00077297 0.00077935
 0.00079852 0.0008081  0.00084323 0.0008624  0.00087837 0.00088156
 0.00088795 0.00088795 0.00089753 0.00089753 0.0009167  0.00092628
 0.00093906 0.00094864 0.00099016 0.00099016 0.00100294 0.00104127
 0.00104446 0.00105085 0.00109237 0.00110834 0.00110834 0.00111154
 0.00111473 0.00111473 0.0011307  0.0011339  0.00113709 0.00116584
 0.00117222 0.00117222 0.00121375 0.00121694 0.00122333 0.00122652
 0.00124569 0.00124888 0.00124888 0.00126485 0.00128082 0.00129999
 0.00130318 0.00130318 0.00131596 0.00131596 0.00132554 0.00133193
 0.00133512 0.0013447  0.00135109 0.00137026 0.00137984 0.00137984
 0.00139262 0.0014022  0.00140539 0.00140539 0.00141178 0.00141497
 0.00141817 0.00143733 0.00146288 0.00146608 0.00

In [43]:
model = MAE(
    X.shape[1], emb_dim=24, decoder_layer=4, encoder_layer=4, encoder_head=3, decoder_head=3, ff_dim=128,
    true_sparsity=False, expr_activation="sigmoid"
)
# model = torch.load(f"...")
device = 'cuda'
# device = 'cpu'
model = model.to(device)

In [44]:
base_learning_rate = 5e-4
weight_decay=0.0
total_epoch = 500
warmup_epoch = 1
minibatch_size = 128
save_dir = "llm/v"

Below are two training loops, first is just the autoencoder and second is the autoencoder and pert task. I usually run the autoencoder for a couple iterations then the joint task.

In [45]:
base_learning_rate = 5e-5
optim = torch.optim.AdamW(model.parameters(), lr=base_learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay)
lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-5), 0.5 * (math.cos(epoch / total_epoch * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

Adjusting learning rate of group 0 to 5.0000e-05.


In [None]:
use_sparsity_loss = False
use_mask_task = False
use_active_weights = True
lr_step = 32

step_count = 0
optim.zero_grad()
pert_task = 0
for e in range(2):
    model.train()
    losses = {'control': [], 'pert': []}
    for (bcontrol, bpert, bpert_index) in (pbar := tqdm(iter(dl))):
        curr_batch_size = bcontrol.shape[0]
        for i in range(curr_batch_size // minibatch_size):
            control = bcontrol[(i * minibatch_size):((i + 1) * minibatch_size)]
            pert = bpert[(i * minibatch_size):((i + 1) * minibatch_size)]
            pert_index = bpert_index[(i * minibatch_size):((i + 1) * minibatch_size)]
            pert_index = pert_index.squeeze()
            active_weights = 10 * (control > 0).float().to(device) + 1 if use_active_weights else 1
            control = control.float().to(device)
            step_count += 1

            control_results = model(control, mask=use_mask_task)
            control_recon = control_results[0]
            
            control_loss = torch.sum(active_weights * torch.abs(control_recon - control)) / minibatch_size
            if use_sparsity_loss and len(control_results) == 3:
                control_sparsity = control_results[1]
                control_loss += torch.sum(active_weights * torch.abs(control_sparsity - (control > 0).float())) / minibatch_size

            loss = control_loss
            loss.backward()
            optim.step()
            optim.zero_grad()
            losses['control'].append(control_loss.item())
            if step_count % lr_step == 0:
                lr_scheduler.step()

            pbar.set_description(
                f"tv: {np.array(losses['control'])[-lr_step:].mean():.3f}"
            )
    
    avg_loss = sum(losses['control']) / len(losses['control'])
    torch.save(model, f"{save_dir}{e}")
    # writer.add_scalar('mae_loss', avg_loss, global_step=e)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

tv: 2123.005:   0%|                                                                                                                                                                              | 0/611 [00:03<?, ?it/s]

In [33]:
use_sparsity_loss = False
use_mask_task = False
use_active_weights = True
lr_step = 32

step_count = 0
optim.zero_grad()
pert_task = 0
for e in range(2):
    model.train()
    losses = {'control': [], 'pert': []}
    for (bcontrol, bpert, bpert_index) in (pbar := tqdm(iter(dl))):
        curr_batch_size = bcontrol.shape[0]
        for i in range(curr_batch_size // minibatch_size):
            control = bcontrol[(i * minibatch_size):((i + 1) * minibatch_size)]
            pert = bpert[(i * minibatch_size):((i + 1) * minibatch_size)]
            pert_index = bpert_index[(i * minibatch_size):((i + 1) * minibatch_size)]
            pert_index = pert_index.squeeze()
            active_weights = 10 * (control > 0).float().to(device) + 1 if use_active_weights else 1
            pert_active_weights = 10 * (pert > 0).float().to(device) + 1 if use_active_weights else 1
            control = control.float().to(device)
            pert = pert.float().to(device)
            step_count += 1

            pert_expr = pert[torch.arange(pert.size(0)), pert_index, None]
            control_results, pert_results = model(
                control, pert_expr=pert_expr, pert_index=pert_index, mask=use_mask_task, recon_and_pert=True
            )
            
            control_recon, pert_recon = control_results[0], pert_results[0]
            control_loss = torch.sum(active_weights * torch.abs(control_recon - control)) / minibatch_size
            pert_loss = torch.sum(pert_active_weights * torch.abs(pert_recon - pert)) / minibatch_size
            
            if use_sparsity_loss and len(control_results) == 3:
                control_sparsity, pert_sparsity = control_results[1], pert_results[1]
                control_loss += torch.sum(active_weights * torch.abs(control_sparsity - (control > 0).float())) / minibatch_size
                pert_loss += torch.sum(pert_active_weights * torch.abs(pert_sparsity - (pert > 0).float())) / minibatch_size

            loss = (pert_loss + control_loss)
            loss.backward()
            optim.step()
            optim.zero_grad()
            losses['control'].append(control_loss.item())
            losses['pert'].append(pert_loss.item())
            if step_count % lr_step == 0:
                lr_scheduler.step()
            pbar.set_description(
                f"tv: {np.array(losses['control'])[-lr_step:].mean():.3f}, ptv: {np.array(losses['pert'])[-lr_step:].mean():.3f}"
            )
    
    avg_loss = sum(losses['control']) / len(losses['control'])
    torch.save(model, f"{save_dir}{e}")
    # writer.add_scalar('mae_loss', avg_loss, global_step=e)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

tv: 1986.251, ptv: 2031.380:   1%|█▉                                                                                                                                                   | 8/611 [01:14<1:32:10,  9.17s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 1774.949, ptv: 1799.242:   3%|███▉                                                                                                                                                | 16/611 [02:29<1:32:45,  9.35s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 1655.834, ptv: 1673.001:   4%|██████                                                                                                                                              | 25/611 [03:45<1:31:38,  9.38s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 1296.967, ptv: 1319.913:   5%|███████▉                                                                                                                                            | 33/611 [05:01<1:23:12,  8.64s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 1112.240, ptv: 1138.553:   7%|█████████▉                                                                                                                                          | 41/611 [06:17<1:29:47,  9.45s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 944.278, ptv: 962.950:   8%|████████████▎                                                                                                                                         | 50/611 [07:33<1:23:31,  8.93s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 823.816, ptv: 844.687:   9%|██████████████▏                                                                                                                                       | 58/611 [08:49<1:20:52,  8.78s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 719.038, ptv: 726.769:  11%|████████████████▏                                                                                                                                     | 66/611 [10:05<1:24:25,  9.29s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 626.998, ptv: 631.751:  12%|██████████████████▍                                                                                                                                   | 75/611 [11:21<1:23:43,  9.37s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 590.837, ptv: 604.652:  14%|████████████████████▍                                                                                                                                 | 83/611 [12:37<1:23:21,  9.47s/it]

Adjusting learning rate of group 0 to 5.0000e-05.


tv: 546.040, ptv: 554.647:  15%|██████████████████████▎                                                                                                                               | 91/611 [13:52<1:22:05,  9.47s/it]

Adjusting learning rate of group 0 to 4.9999e-05.


tv: 532.089, ptv: 542.118:  16%|████████████████████████▎                                                                                                                             | 99/611 [15:08<1:19:48,  9.35s/it]

Adjusting learning rate of group 0 to 4.9999e-05.


tv: 501.481, ptv: 511.651:  18%|██████████████████████████                                                                                                                           | 107/611 [16:24<1:14:43,  8.90s/it]

Adjusting learning rate of group 0 to 4.9999e-05.


tv: 508.120, ptv: 515.469:  19%|████████████████████████████▎                                                                                                                        | 116/611 [17:40<1:14:28,  9.03s/it]

Adjusting learning rate of group 0 to 4.9999e-05.


tv: 519.402, ptv: 521.996:  20%|██████████████████████████████▏                                                                                                                      | 124/611 [18:56<1:16:44,  9.46s/it]

Adjusting learning rate of group 0 to 4.9999e-05.


tv: 474.387, ptv: 482.025:  22%|████████████████████████████████▏                                                                                                                    | 132/611 [20:12<1:15:42,  9.48s/it]

Adjusting learning rate of group 0 to 4.9999e-05.


tv: 478.445, ptv: 487.510:  22%|█████████████████████████████████▏                                                                                                                   | 136/611 [20:57<1:13:11,  9.24s/it]


KeyboardInterrupt: 

In [47]:
cell_type_names = adata.obs[cell_col]
pert_type_names = adata.obs[pert_col]
# Save the predicted perturbation
for cell_type, pert_type in zip(holdout_cells, holdout_perts):
    control_eval =  torch.tensor(X[cell_type_names == cell_type]).float()# .to(device)
    true_pert= torch.tensor(X[(pert_type_names == pert_type) & (cell_type_names == cell_type)]).float()# .to(device)
    curr_batch_size = min(control_eval.shape[0], true_pert.shape[0])
    pred_perts = []
    for i in range(curr_batch_size // minibatch_size + 1):
        print(i)
        control = control_eval[(i * minibatch_size):min(curr_batch_size, ((i + 1) * minibatch_size))].to(device)
        pert = true_pert[(i * minibatch_size):min(curr_batch_size, ((i + 1) * minibatch_size))].to(device)
        pert_expr = pert[torch.arange(pert.shape[0]), pert_type, None]
        pred_pert, _, _, = model(
            control, pert_expr=pert_expr, pert_index=pert_type, mask=False
        )
        control.cpu().numpy()
        pert.cpu().numpy()
        pred_perts.append(pred_pert.cpu().detach().numpy())
        torch.cuda.empty_cache()
    pred_pert = np.vstack(pred_perts)
    print(f"Saving {gene_unmap[pert_type]} predictions")
    np.savez(
        f"Satija_IFNB_HVG/llm/pred_{gene_unmap[pert_type]}_{cell_type}.npz", 
        pred_pert=pred_pert, 
        true_pert=true_pert, 
        control=control_eval,
    )

0
1
2
3
4
5
6
Saving USP18 predictions
