In [1]:
import copy
import gc
import json
import os
from pathlib import Path
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings

import torch
from anndata import AnnData
import scanpy as sc
#import scvi
import numpy as np
import pandas as pd
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)


sys.path.insert(0, "../")
#import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, eval_scib_metrics, load_pretrained

sc.set_figure_params(figsize=(4, 4))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

  IPython.display.set_matplotlib_formats(*ipython_format)


In [2]:
hyperparameter_defaults = dict(
    seed=42,
    dataset_name="norman", # Dataset name
    do_train=True, # Flag to indicate whether to do update model parameters during training
    load_model="/scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model",
    model_name="best_model.pt",
    GEPC=True,  # Gene expression modelling for cell objective
    ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=1.0, # DAR objective weight for batch correction
    mask_ratio=0.4, # Default mask ratio
    epochs=15, # Default number of epochs for fine-tuning
    n_bins=51, # Default number of bins for value binning in data pre-processing
    lr=1e-4, # Default learning rate for fine-tuning
    batch_size=64, # Default batch size for fine-tuning
    layer_size=128,
    nlayers=4,
    nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
    dropout=0.2, # Default dropout rate during model fine-tuning
    schedule_ratio=0.9,  # Default rate for learning rate decay
    save_eval_interval=5, # Default model evaluation interval
    log_interval=100, # Default log interval
    fast_transformer=True, # Default setting
    pre_norm=False, # Default setting
    amp=True,  # # Default setting: Automatic Mixed Precision
)
run = wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

set_seed(config.seed)

[34m[1mwandb[0m: Currently logged in as: [33mandrewhz-zhang[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'seed': 42, 'dataset_name': 'norman', 'do_train': True, 'load_model': '/scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model', 'model_name': 'best_model.pt', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}


In [3]:
from pathlib import Path
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = -1
pad_value = -2
n_input_bins = config.n_bins

n_hvg = 1200  # number of highly variable genes
max_seq_len = n_hvg + 1
per_seq_batch_sample = True
DSBN = False  # Domain-spec batchnorm
explicit_zero_prob = True  # whether explicit bernoulli for zeros

dataset_name = config.dataset_name
save_dir = Path(f"/scratch/ssd004/scratch/ahz/perturb/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")

save to /scratch/ssd004/scratch/ahz/perturb/dev_norman-Feb26-00-24


In [4]:
from pathlib import Path

## Load and preprocess dataset

####  ✅ Note
Perturbation datasets can be found in this path: `/scratch/ssd004/scratch/chloexq/perturb_analysis/{dataset_name}` 

In [5]:
data_dir = Path("/scratch/ssd004/scratch/chloexq/perturb_analysis")
adata = sc.read(data_dir / "norman/perturb_processed.h5ad")

In [6]:
adata

AnnData object with n_obs × n_vars = 91205 × 5045
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name'
    var: 'gene_name'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20'
    layers: 'counts'

In [7]:
adata.var.index = pd.Index(adata.var["gene_name"])

In [8]:
np.unique(adata.obs.condition.values)

array(['AHR+FEV', 'AHR+KLF1', 'AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl',
       'ATL1+ctrl', 'BAK1+ctrl', 'BCL2L11+BAK1', 'BCL2L11+TGFBR2',
       'BCL2L11+ctrl', 'BCORL1+ctrl', 'BPGM+SAMD1', 'BPGM+ZBTB1',
       'BPGM+ctrl', 'C19orf26+ctrl', 'C3orf72+FOXL2', 'C3orf72+ctrl',
       'CBFA2T3+ctrl', 'CBL+CNN1', 'CBL+PTPN12', 'CBL+PTPN9',
       'CBL+TGFBR2', 'CBL+UBASH3A', 'CBL+UBASH3B', 'CBL+ctrl',
       'CDKN1A+ctrl', 'CDKN1B+CDKN1A', 'CDKN1B+ctrl', 'CDKN1C+CDKN1A',
       'CDKN1C+CDKN1B', 'CDKN1C+ctrl', 'CEBPA+ctrl', 'CEBPB+CEBPA',
       'CEBPB+MAPK1', 'CEBPB+OSR2', 'CEBPB+PTPN12', 'CEBPB+ctrl',
       'CEBPE+CEBPA', 'CEBPE+CEBPB', 'CEBPE+CNN1', 'CEBPE+KLF1',
       'CEBPE+PTPN12', 'CEBPE+RUNX1T1', 'CEBPE+SPI1', 'CEBPE+ctrl',
       'CELF2+ctrl', 'CITED1+ctrl', 'CKS1B+ctrl', 'CLDN6+ctrl',
       'CNN1+MAPK1', 'CNN1+UBASH3A', 'CNN1+ctrl', 'CNNM4+ctrl',
       'COL1A1+ctrl', 'COL2A1+ctrl', 'CSRNP1+ctrl', 'DLX2+ctrl',
       'DUSP9+ETS2', 'DUSP9+IGDCC3', 'DUSP9+KLF1', 'DUSP9+MAPK1',
   

In [9]:
len(np.unique(adata.obs.condition.values))

284

In [10]:
single_gene_filter = [i for i in np.unique(adata.obs.condition.values) if not ('+' in i and 'ctrl' not in i)]
print(single_gene_filter, len(single_gene_filter))

['AHR+ctrl', 'ARID1A+ctrl', 'ARRDC3+ctrl', 'ATL1+ctrl', 'BAK1+ctrl', 'BCL2L11+ctrl', 'BCORL1+ctrl', 'BPGM+ctrl', 'C19orf26+ctrl', 'C3orf72+ctrl', 'CBFA2T3+ctrl', 'CBL+ctrl', 'CDKN1A+ctrl', 'CDKN1B+ctrl', 'CDKN1C+ctrl', 'CEBPA+ctrl', 'CEBPB+ctrl', 'CEBPE+ctrl', 'CELF2+ctrl', 'CITED1+ctrl', 'CKS1B+ctrl', 'CLDN6+ctrl', 'CNN1+ctrl', 'CNNM4+ctrl', 'COL1A1+ctrl', 'COL2A1+ctrl', 'CSRNP1+ctrl', 'DLX2+ctrl', 'DUSP9+ctrl', 'EGR1+ctrl', 'ELMSAN1+ctrl', 'ETS2+ctrl', 'FEV+ctrl', 'FOSB+ctrl', 'FOXA1+ctrl', 'FOXA3+ctrl', 'FOXF1+ctrl', 'FOXL2+ctrl', 'FOXO4+ctrl', 'GLB1L2+ctrl', 'HES7+ctrl', 'HK2+ctrl', 'HNF4A+ctrl', 'HOXA13+ctrl', 'HOXB9+ctrl', 'HOXC13+ctrl', 'IER5L+ctrl', 'IGDCC3+ctrl', 'IKZF3+ctrl', 'IRF1+ctrl', 'ISL2+ctrl', 'JUN+ctrl', 'KIAA1804+ctrl', 'KIF18B+ctrl', 'KIF2C+ctrl', 'KLF1+ctrl', 'KMT2A+ctrl', 'LHX1+ctrl', 'LYL1+ctrl', 'MAML2+ctrl', 'MAP2K3+ctrl', 'MAP2K6+ctrl', 'MAP4K3+ctrl', 'MAP4K5+ctrl', 'MAP7D1+ctrl', 'MAPK1+ctrl', 'MEIS1+ctrl', 'MIDN+ctrl', 'NCL+ctrl', 'NIT1+ctrl', 'OSR2+ctrl', 

In [11]:
adata = adata[adata.obs.condition.isin(single_gene_filter)].copy()

In [12]:
adata

AnnData object with n_obs × n_vars = 55760 × 5045
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name'
    var: 'gene_name'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20'
    layers: 'counts'

In [13]:
# TODO: Update condition names 

In [14]:
ori_batch_col = "control"
adata.obs["celltype"] = adata.obs["condition"].astype("category")
adata.obs["str_batch"] = adata.obs["control"].astype(str)
data_is_raw = False

In [15]:
if config.load_model is not None:
    model_dir = Path(config.load_model)
    model_config_file = model_dir / "args.json"
    model_file = model_dir / config.model_name
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    print(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    adata = adata[:, adata.var["id_in_vocab"] >= 0]
    
    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    print(
        f"Resume model from {model_file}, the model args will be overriden by the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]
else:
    embsize = config.layer_size 
    nhead = config.nhead
    nlayers = config.nlayers  
    d_hid = config.layer_size

match 4547/5045 genes in vocabulary of size 60697.
Resume model from /scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model/best_model.pt, the model args will be overriden by the config /scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model/args.json.


In [16]:
gene_names_set = [i + '+ctrl' for i in adata.var.gene_name.values]
gene_names_set = gene_names_set + ['ctrl']

####  ✅ Note
This experiment is computationally expensive, so we select 1000 cells per perturbation condition.

In [17]:
# Cap all conditions to 1000 cells
sampled_df = (
    adata.obs[adata.obs['condition'].isin(gene_names_set)]
    .groupby('condition', group_keys=False)
    .apply(lambda x: x.sample(min(len(x), 1000), random_state=42))
)
adata = adata[sampled_df.index].copy()
adata.obs.groupby('condition').count()

Unnamed: 0_level_0,cell_type,dose_val,control,condition_name,celltype,str_batch
condition,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
AHR+ctrl,479,479,479,479,479,479
ARID1A+ctrl,182,182,182,182,182,182
ARRDC3+ctrl,405,405,405,405,405,405
ATL1+ctrl,305,305,305,305,305,305
BAK1+ctrl,534,534,534,534,534,534
...,...,...,...,...,...,...
ZBTB1+ctrl,315,315,315,315,315,315
ZBTB10+ctrl,145,145,145,145,145,145
ZBTB25+ctrl,343,343,343,343,343,343
ZC3HAV1+ctrl,436,436,436,436,436,436


In [18]:
# 5 conditions are capped, including ctrl
condition_counts = adata.obs.groupby('condition').count()
condition_counts[condition_counts == 1000].dropna()

Unnamed: 0_level_0,cell_type,dose_val,control,condition_name,celltype,str_batch
condition,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ctrl,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0


In [19]:
condition_names = set(adata.obs.condition.tolist())

In [20]:
condition_names.remove('ctrl')

In [21]:
condition_names_gene = [i.split('+')[0] for i in list(condition_names)]

In [22]:
condition_names_gene.sort()

####  ✅ Note
HVGs selection will filter out some perturbed genes. We manually add them back in the experiment.

In [23]:
# Do filtering
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=None,  # step 2
    normalize_total=None,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=False,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=None,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    #binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    #result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key=None)

scGPT - INFO - Filtering genes by counts ...


In [24]:
sc.pp.highly_variable_genes(
    adata,
    layer=None,
    n_top_genes=1200,
    flavor="seurat_v3" if data_is_raw else "cell_ranger",
    subset=False,
)

In [25]:
add_counter = 0
for g in condition_names_gene:
    if not adata.var.loc[adata.var[adata.var.gene_name==g].index, 'highly_variable'].values[0]:
        adata.var.loc[adata.var[adata.var.gene_name==g].index, 'highly_variable'] = True
        add_counter += 1

In [26]:
print('Manually add conditions: {}, {}'.format(add_counter, add_counter/len(condition_names_gene)))

Manually add conditions: 67, 0.6767676767676768


In [27]:
# This step for binning
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=0,  # step 1
    filter_cell_by_counts=None,  # step 2
    normalize_total=None,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=False,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=None,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key=None)

scGPT - INFO - Binning data ...


In [28]:
adata = adata[:, adata.var["highly_variable"]].copy()
print(adata)

AnnData object with n_obs × n_vars = 33059 × 1267
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch'
    var: 'gene_name', 'id_in_vocab', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'hvg'
    obsm: 'bin_edges'
    layers: 'counts', 'X_binned'


#### 🔵 Optional
Create another randomly shuffled list of `condition_names_gene_match` as control, if running the control experiment. 
Note that there are many ways to construct the control list, either from perturbation targets or random from all genes.

In [29]:
# Here is an example of randomly shuffle perturbation targets
import random
random.seed(42)
condition_names_gene_match = condition_names_gene.copy()
random.shuffle(condition_names_gene_match)

In [30]:
# Here is an example of using non-targets
# This is the most recent version
genes = adata.var["gene_name"].tolist()
non_targets = list(set(genes).difference(set(condition_names_gene)))
non_targets.sort()
random.seed(42)
random.shuffle(non_targets)
non_targets
condition_names_gene_match = non_targets[:len(condition_names_gene)]

In [31]:
print(condition_names_gene)

['AHR', 'ARID1A', 'ARRDC3', 'ATL1', 'BAK1', 'BCL2L11', 'BCORL1', 'BPGM', 'CBFA2T3', 'CBL', 'CDKN1A', 'CDKN1B', 'CDKN1C', 'CEBPA', 'CEBPB', 'CEBPE', 'CELF2', 'CITED1', 'CKS1B', 'CLDN6', 'CNN1', 'CNNM4', 'COL1A1', 'COL2A1', 'CSRNP1', 'DLX2', 'DUSP9', 'EGR1', 'ETS2', 'FEV', 'FOSB', 'FOXA1', 'FOXA3', 'FOXF1', 'FOXL2', 'FOXO4', 'GLB1L2', 'HES7', 'HK2', 'HNF4A', 'HOXA13', 'HOXB9', 'HOXC13', 'IER5L', 'IGDCC3', 'IKZF3', 'IRF1', 'ISL2', 'JUN', 'KIF18B', 'KIF2C', 'KLF1', 'KMT2A', 'LHX1', 'LYL1', 'MAML2', 'MAP2K3', 'MAP2K6', 'MAP4K3', 'MAP4K5', 'MAP7D1', 'MAPK1', 'MEIS1', 'MIDN', 'NCL', 'NIT1', 'OSR2', 'PLK4', 'POU3F2', 'PRDM1', 'PRTG', 'PTPN1', 'PTPN12', 'PTPN13', 'PTPN9', 'RREB1', 'RUNX1T1', 'S1PR2', 'SAMD1', 'SET', 'SGK1', 'SLC4A1', 'SLC6A9', 'SNAI1', 'SPI1', 'STIL', 'TBX2', 'TBX3', 'TGFBR2', 'TMSB4X', 'TP73', 'TSC22D1', 'UBASH3A', 'UBASH3B', 'ZBTB1', 'ZBTB10', 'ZBTB25', 'ZC3HAV1', 'ZNF318']


In [32]:
print(condition_names_gene_match)

['RASSF4', 'FUT7', 'IL22RA2', 'IP6K3', 'MANF', 'TIMP1', 'MLC1', 'ATF4', 'OR51E1', 'CTD-2623N2.5', 'OS9', 'CYSLTR2', 'ST3GAL6', 'LINC00895', 'HEATR9', 'ANXA2R', 'PMEPA1', 'RP11-46D6.1', 'PDE4DIP', 'RP11-404F10.2', 'APOBEC3D', 'MAOB', 'RP11-90K6.1', 'CD244', 'SVEP1', 'MEIS3', 'GHRL', 'RP11-212I21.4', 'IL6ST', 'ABCA1', 'SLC2A1-AS1', 'CTSO', 'RP11-306G20.1', 'TMEM154', 'PCAT5', 'RP11-443B7.2', 'PCDH9', 'TUBB3', 'SMYD3', 'TRAC', 'PARVG', 'NUTM2G', 'ERP27', 'GDF15', 'RP11-727F15.9', 'RP11-887P2.5', 'HBG2', 'RP5-1086K13.1', 'IRF2BP2', 'PDZK1IP1', 'TEX13D', 'RP11-498P14.5', 'PLAC8', 'C20orf202', 'BTG1', 'GPC1', 'REN', 'HBA2', 'ALDH3B1', 'AC002463.3', 'IL3RA', 'CABP4', 'ICOSLG', 'TXNIP', 'TNFRSF14', 'RP1-286D6.5', 'RP11-1152H14.1', 'EVI2B', 'PPP3CA', 'HBG1', 'PRSS57', 'CD48', 'RNF213', 'EPX', 'CD2', 'TMEM150C', 'FAM166B', 'PNOC', 'IL20', 'TCP11L2', 'CLEC4D', 'AC005616.2', 'SVOPL', 'RAP2C-AS1', 'OPRL1', 'ADIRF', 'PHLDA1', 'ARMCX3', 'FAM234A', 'OSBPL10', 'VWA7', 'ID2-AS1', 'ADRB2', 'ACE', 'ATP10D

## Prepare model input

In [33]:
max_len = adata.shape[1] + 1
max_len

1268

In [34]:
if config.load_model is None:
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(vocab(genes), dtype=int)
adata.obs['batch_id'] = adata.obs['condition'].copy()
batch_ids = adata.obs["batch_id"].tolist()
num_batch_types = len(set(batch_ids))
input_layer_key = "X_binned"

## Load the pre-trained scGPT model

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    vocab=vocab,
    dropout=config.dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=config.GEPC,
    do_dab=False,
    use_batch_labels=False,
    num_batch_labels=num_batch_types,
    domain_spec_batchnorm=DSBN,
    n_input_bins=n_input_bins,
    ecs_threshold=config.ecs_thres,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=config.fast_transformer,
    use_generative_training=True,
    pre_norm=config.pre_norm,
)
if config.load_model is not None:
    load_pretrained(model, torch.load(model_file), verbose=False)

model.to(device)
wandb.watch(model)

[]

In [36]:
model.eval()
adata_t = adata.copy()

In [37]:
all_counts = (
    adata_t.layers[input_layer_key].A
    if issparse(adata_t.layers[input_layer_key])
    else adata_t.layers[input_layer_key]
)
celltypes_labels = adata_t.obs["celltype"].tolist()
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata_t.obs["batch_id"].tolist()
batch_ids = np.array(batch_ids)

tokenized_all = tokenize_and_pad_batch(
    all_counts,
    gene_ids,
    max_len=max_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=True,
)
all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])

##  Get gene embeddings (with Value Masking), Calculate Cosine Distance & Rank, and Save Results

In [38]:
def expand_cell(tokenized_all, key, k, select_gene_id):
    cell_k = tokenized_all[key][k]
    # Repeat 
    cell_k_expand = cell_k.repeat(n_genes).view(n_genes, n_genes)
    new_column = torch.full((n_genes, 1), vocab([pad_token])[0])
    cell_k_expand = torch.cat((cell_k_expand, new_column), dim=1)
    mask = torch.eye(n_genes).bool()
    new_column_mask = torch.full((n_genes, 1), False)
    mask = torch.cat((mask, new_column_mask), dim=1)
    mask[:, select_gene_id] = True
    mask[select_gene_id, n_genes] = True
    mask_select_expand = cell_k_expand[mask]
    select_ids_gen = mask_select_expand.view(n_genes, 2)
    select_ids_pcpt = cell_k_expand[~mask].view(n_genes, n_genes-1)
    return select_ids_gen, select_ids_pcpt

from tqdm import tqdm
def collate_cell_by_key(tokenized_all, key, select_gene_id):
    print(key)
    select_ids_gen_list = []
    select_ids_pcpt_list = []
    for k in tqdm(range(n_cells)):
        select_ids_gen, select_ids_pcpt = expand_cell(tokenized_all, key, k, select_gene_id)
        select_ids_gen_list.append(select_ids_gen)
        select_ids_pcpt_list.append(select_ids_pcpt)
    select_ids_gen = torch.cat(select_ids_gen_list, dim=0)
    select_ids_pcpt = torch.cat(select_ids_pcpt_list, dim=0)
    print(select_ids_gen.shape, select_ids_pcpt.shape)
    return select_ids_gen, select_ids_pcpt

In [39]:
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics.pairwise import cosine_distances
from tqdm import tqdm

In [40]:
# %%
select_gene_list = condition_names_gene

for select_gene in select_gene_list:
    adata_t = adata[adata.obs['condition'].isin([select_gene+'+ctrl', 'ctrl'])].copy()
    print(adata_t.obs['condition'])
    select_gene_id = genes.index(select_gene)+1
    print(select_gene_id)
    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )
    celltypes_labels = adata_t.obs["celltype"].tolist()
    celltypes_labels = np.array(celltypes_labels)

    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)

    tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
    src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
    print(tokenized_all['genes'].shape, tokenized_all['values'].shape)
    n_cells = tokenized_all['genes'].shape[0]
    n_genes = tokenized_all['genes'].shape[1]
    
    collate_genes_gen, collate_genes_pcpt = collate_cell_by_key(tokenized_all, 'genes', select_gene_id)
    _, collate_values_pcpt = collate_cell_by_key(tokenized_all, 'values', select_gene_id)
    
    tokenized_all_expand = {'genes_pcpt': collate_genes_pcpt, 'genes_gen': collate_genes_gen, 'values_pcpt': collate_values_pcpt}
    print(tokenized_all_expand)
    query_id = tokenized_all['genes'][0].repeat(n_cells)
    
    cell_counter = torch.arange(0, n_cells)
    cell_counter = cell_counter.repeat(n_genes).view(n_genes, n_cells).t().flatten()
    gene_counter = torch.arange(0, n_genes).repeat(n_cells)

    dataloader = DataLoader(
        TensorDataset(tokenized_all_expand['genes_pcpt'], 
                      tokenized_all_expand['genes_gen'], 
                      tokenized_all_expand['values_pcpt'],
                      query_id,
                      cell_counter,
                      gene_counter,
                     ), 
        batch_size=512, 
        shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    gene_embeddings = np.zeros((n_cells, n_genes, 512))
    
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
        for batch_idx, batch_data in enumerate(tqdm(dataloader)):
            pcpt_genes = batch_data[0].to(device)
            gen_genes = batch_data[1].to(device)
            pcpt_values = batch_data[2].to(device)
            query_id_select = batch_data[3].to(device)
            cell_counter_batch = batch_data[4].to(device)
            gene_counter_batch = batch_data[5].to(device)
            pcpt_key_padding_mask = pcpt_genes.eq(vocab[pad_token]).to(device)
            gen_key_padding_mask = gen_genes.eq(vocab[pad_token]).to(device)
            _, gen_output = model.transformer_generate(
                pcpt_genes=pcpt_genes,
                pcpt_values=pcpt_values,
                pcpt_key_padding_mask=pcpt_key_padding_mask,
                gen_genes=gen_genes,
                gen_key_padding_mask=gen_key_padding_mask,
            )
            select_mask = (gen_genes == query_id_select.unsqueeze(1)).long()
            selected_output = gen_output[torch.arange(gen_output.shape[0]), select_mask.argmax(dim=1), :]
            selected_output_np = selected_output.detach().cpu().numpy()
            gene_embeddings[cell_counter_batch.detach().cpu().numpy(), gene_counter_batch.detach().cpu().numpy(), :] = selected_output_np
    
    conditions = adata_t.obs['condition'].values
    
    dict_sum_condition_mean = {}
    for c in np.unique(conditions):
        dict_sum_condition_mean[c] = gene_embeddings[np.where(conditions == c)[0]].mean(0)
    
    print(dict_sum_condition_mean)
        
    celltype_0 = select_gene + '+ctrl'
    celltype_1 = 'ctrl'
    gene_emb_celltype_0 = np.expand_dims(dict_sum_condition_mean[celltype_0][1:, 1:], 0)
    gene_emb_celltype_1 = np.expand_dims(dict_sum_condition_mean[celltype_1][1:, 1:], 0)
    gene_dist_dict = {}
    for i, g in tqdm(enumerate(genes)):
        gene_dist_dict[g] = cosine_distances(gene_emb_celltype_0[:, i, :], gene_emb_celltype_1[:, i, :]).mean()
    df_gene_emb_dist = pd.DataFrame.from_dict(gene_dist_dict, orient='index', columns=['cos_dist'])
    df_deg = df_gene_emb_dist.sort_values(by='cos_dist', ascending=False)
    rank_celltype_0 = np.where(df_deg.index==celltype_0.split('+')[0])[0][0]
    print(celltype_0, rank_celltype_0) 
    np.savez(str(save_dir)+"mean_gene_emb_{}_{}.npz".format(select_gene, rank_celltype_0), **dict_sum_condition_mean)
    print(f'Saved:\n{str(save_dir)+"mean_gene_emb_{}_{}.npz".format(select_gene, rank_celltype_0)}')
    assert(0)

cell_barcode
GCTGCGACAAACTGTC-2    AHR+ctrl
GCGCAACTCAGGTAAA-6    AHR+ctrl
TGCGTGGTCTCGATGA-1    AHR+ctrl
CGATGTATCTGGCGAC-1    AHR+ctrl
GCGAGAACAGATGGCA-8    AHR+ctrl
                        ...   
TATCAGGGTAGCTGCC-6        ctrl
AGTCTTTTCTCTTATG-4        ctrl
ACGAGGAAGGCAGGTT-3        ctrl
GACTGCGTCCTCCTAG-8        ctrl
CAACTAGGTAGCGTCC-4        ctrl
Name: condition, Length: 1479, dtype: category
Categories (2, object): ['AHR+ctrl', 'ctrl']
439
torch.Size([1479, 1268]) torch.Size([1479, 1268])
genes


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1479/1479 [00:07<00:00, 200.26it/s]


torch.Size([1875372, 2]) torch.Size([1875372, 1267])
values


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1479/1479 [00:06<00:00, 218.93it/s]


torch.Size([1875372, 2]) torch.Size([1875372, 1267])
{'genes_pcpt': tensor([[24904, 30607, 21504,  ..., 12288, 12289, 60694],
        [60695, 30607, 21504,  ..., 12288, 12289, 60694],
        [60695, 24904, 21504,  ..., 12288, 12289, 60694],
        ...,
        [60695, 24904, 30607,  ..., 12288, 12289, 60694],
        [60695, 24904, 30607,  ..., 11394, 12289, 60694],
        [60695, 24904, 30607,  ..., 11394, 12288, 60694]]), 'genes_gen': tensor([[60695,  1743],
        [24904,  1743],
        [30607,  1743],
        ...,
        [ 1743, 11394],
        [ 1743, 12288],
        [ 1743, 12289]]), 'values_pcpt': tensor([[    0.,     0.,     0.,  ...,     0.,     0., 60694.],
        [    0.,     0.,     0.,  ...,     0.,     0., 60694.],
        [    0.,     0.,     0.,  ...,     0.,     0., 60694.],
        ...,
        [    0.,     0.,     0.,  ...,     0.,     0., 60694.],
        [    0.,     0.,     0.,  ...,     0.,     0., 60694.],
        [    0.,     0.,     0.,  ...,     0.,   

 28%|███████████████████████████████████████▍                                                                                                   | 1039/3663 [34:23<1:26:52,  1.99s/it]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

