In [1]:
import copy
import gc
import json
import os
from pathlib import Path
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings

import torch
from anndata import AnnData
import scanpy as sc
#import scvi
import numpy as np
import pandas as pd
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)


sys.path.insert(0, "../")
#import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, eval_scib_metrics, load_pretrained

sc.set_figure_params(figsize=(4, 4))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

  IPython.display.set_matplotlib_formats(*ipython_format)


In [2]:
hyperparameter_defaults = dict(
    seed=42,
    dataset_name="fibro", # Dataset name
    do_train=True, # Flag to indicate whether to do update model parameters during training
    load_model="/scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model",
    model_name="best_model.pt",
    GEPC=True,  # Gene expression modelling for cell objective
    ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=1.0, # DAR objective weight for batch correction
    mask_ratio=0.4, # Default mask ratio
    epochs=15, # Default number of epochs for fine-tuning
    n_bins=51, # Default number of bins for value binning in data pre-processing
    lr=1e-4, # Default learning rate for fine-tuning
    batch_size=64, # Default batch size for fine-tuning
    layer_size=128,
    nlayers=4,
    nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
    dropout=0.2, # Default dropout rate during model fine-tuning
    schedule_ratio=0.9,  # Default rate for learning rate decay
    save_eval_interval=5, # Default model evaluation interval
    log_interval=100, # Default log interval
    fast_transformer=True, # Default setting
    pre_norm=False, # Default setting
    amp=True,  # # Default setting: Automatic Mixed Precision
)
run = wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

set_seed(config.seed)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchloewxq[0m ([33mscformer[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'seed': 42, 'dataset_name': 'fibro', 'do_train': True, 'load_model': '/scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model', 'model_name': 'best_model.pt', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}


In [3]:
from pathlib import Path
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = -1
pad_value = -2
n_input_bins = config.n_bins

n_hvg = 1200  # number of highly variable genes
max_seq_len = n_hvg + 1
per_seq_batch_sample = True
DSBN = True  # Domain-spec batchnorm
explicit_zero_prob = True  # whether explicit bernoulli for zeros

dataset_name = config.dataset_name
save_dir = Path(f"/scratch/ssd004/scratch/chloexq/fibro/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")

save to /scratch/ssd004/scratch/chloexq/fibro/dev_fibro-Feb06-15-12


In [4]:
from pathlib import Path

## Load and preprocess dataset

####  ✅ Note
Perturbation datasets can be found in this path: `/scratch/ssd004/scratch/chloexq/perturb_analysis/{dataset_name}` 

In [5]:
# Perturbation datasets can be found in this path
data_dir = Path("/scratch/ssd004/scratch/chloexq/perturb_analysis")
adata = sc.read(data_dir / "adamson/perturb_processed.h5ad")
ori_batch_col = "control"
adata.obs["celltype"] = adata.obs["condition"].astype("category")
adata.obs["str_batch"] = adata.obs["control"].astype(str)
data_is_raw = False

In [6]:
if config.load_model is not None:
    model_dir = Path(config.load_model)
    model_config_file = model_dir / "args.json"
    model_file = model_dir / config.model_name
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    print(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    adata = adata[:, adata.var["id_in_vocab"] >= 0]
    
    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    print(
        f"Resume model from {model_file}, the model args will be overriden by the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]
else:
    embsize = config.layer_size 
    nhead = config.nhead
    nlayers = config.nlayers  
    d_hid = config.layer_size

match 4399/5060 genes in vocabulary of size 60697.
Resume model from /scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model/best_model.pt, the model args will be overriden by the config /scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model/args.json.


In [7]:
# Align perturbation condition naming
gene_names_set = [i + '+ctrl' for i in adata.var.gene_name.values]
gene_names_set = gene_names_set + ['ctrl']

####  ✅ Note
This experiment is computationally expensive, so we select 1000 cells per perturbation condition.

In [8]:
# Cap all conditions to 1000 cells
sampled_df = (
    adata.obs[adata.obs['condition'].isin(gene_names_set)]
    .groupby('condition', group_keys=False)
    .apply(lambda x: x.sample(min(len(x), 1000), random_state=42))
)
adata = adata[sampled_df.index].copy()
adata.obs.groupby('condition').count()

Unnamed: 0_level_0,cell_type,dose_val,control,condition_name,celltype,str_batch
condition,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
AMIGO3+ctrl,616,616,616,616,616,616
ARHGAP22+ctrl,406,406,406,406,406,406
ASCC3+ctrl,524,524,524,524,524,524
BHLHE40+ctrl,504,504,504,504,504,504
CAD+ctrl,242,242,242,242,242,242
...,...,...,...,...,...,...
UFM1+ctrl,591,591,591,591,591,591
XRN1+ctrl,621,621,621,621,621,621
YIPF5+ctrl,1000,1000,1000,1000,1000,1000
ZNF326+ctrl,517,517,517,517,517,517


In [9]:
# 5 conditions are capped, including ctrl
condition_counts = adata.obs.groupby('condition').count()
condition_counts[condition_counts == 1000].dropna()

Unnamed: 0_level_0,cell_type,dose_val,control,condition_name,celltype,str_batch
condition,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
HSPA5+ctrl,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0
IER3IP1+ctrl,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0
SCYL1+ctrl,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0
YIPF5+ctrl,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0
ctrl,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0


In [10]:
condition_names = set(adata.obs.condition.tolist())

In [11]:
condition_names.remove('ctrl')

In [12]:
condition_names_gene = [i.split('+')[0] for i in list(condition_names)]

In [13]:
condition_names_gene.sort()

####  ✅ Note
HVGs selection will filter out some perturbed genes. We manually add them back in the experiment.

In [14]:
# Do filtering
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=None,  # step 2
    normalize_total=None,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=False,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=None,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    #binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    #result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key=None)

scGPT - INFO - Filtering genes by counts ...


In [15]:
sc.pp.highly_variable_genes(
    adata,
    layer=None,
    n_top_genes=1200,
    flavor="seurat_v3" if data_is_raw else "cell_ranger",
    subset=False,
)

In [16]:
add_counter = 0
for g in condition_names_gene:
    if not adata.var.loc[adata.var[adata.var.gene_name==g].index, 'highly_variable'].values[0]:
        adata.var.loc[adata.var[adata.var.gene_name==g].index, 'highly_variable'] = True
        add_counter += 1

In [17]:
print('Manually add conditions: {}, {}'.format(add_counter, add_counter/len(condition_names_gene)))

Manually add conditions: 67, 0.8933333333333333


In [18]:
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=0,  # step 1
    filter_cell_by_counts=None,  # step 2
    normalize_total=None,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=False,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=None,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key=None)

scGPT - INFO - Binning data ...


In [19]:
adata = adata[:, adata.var["highly_variable"]].copy()
print(adata)

AnnData object with n_obs × n_vars = 39847 × 1267
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch'
    var: 'gene_name', 'id_in_vocab', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'hvg'
    obsm: 'bin_edges'
    layers: 'X_binned'


#### 🔵 Optional
Create another randomly shuffled list of `condition_names_gene_match` as control, if running the control experiment. 
Note that there are many ways to construct the control list, either from perturbation targets or random from all genes.

In [20]:
# Here is an example of randomly shuffle perturbation targets
import random
random.seed(42)
condition_names_gene_match = condition_names_gene.copy()
random.shuffle(condition_names_gene_match)
condition_names_gene, condition_names_gene_match

(['AMIGO3',
  'ARHGAP22',
  'ASCC3',
  'BHLHE40',
  'CAD',
  'CCND3',
  'CHERP',
  'COPB1',
  'COPZ1',
  'CREB1',
  'DAD1',
  'DDIT3',
  'DDOST',
  'DDRGK1',
  'DERL2',
  'DHDDS',
  'DNAJC19',
  'EIF2B2',
  'EIF2B3',
  'EIF2B4',
  'EIF2S1',
  'FARSB',
  'FECH',
  'GBF1',
  'GMPPB',
  'GNPNAT1',
  'HSD17B12',
  'HSPA5',
  'HSPA9',
  'HYOU1',
  'IARS2',
  'IDH3A',
  'IER3IP1',
  'KCTD16',
  'MANF',
  'MRGBP',
  'MRPL39',
  'MTHFD1',
  'NEDD8',
  'OST4',
  'P4HB',
  'PDIA6',
  'PPWD1',
  'PSMD4',
  'PTDSS1',
  'SAMM50',
  'SCYL1',
  'SEC61A1',
  'SEC61B',
  'SEC61G',
  'SEC63',
  'SEL1L',
  'SLC35B1',
  'SLC39A7',
  'SOCS1',
  'SPCS2',
  'SPCS3',
  'SRP68',
  'SRP72',
  'SRPRB',
  'STT3A',
  'SYVN1',
  'TELO2',
  'TIMM23',
  'TIMM44',
  'TMED10',
  'TMED2',
  'TMEM167A',
  'TTI1',
  'TTI2',
  'UFL1',
  'UFM1',
  'XRN1',
  'YIPF5',
  'ZNF326'],
 ['SCYL1',
  'EIF2B3',
  'MRPL39',
  'COPZ1',
  'COPB1',
  'SRPRB',
  'UFM1',
  'STT3A',
  'PSMD4',
  'OST4',
  'DHDDS',
  'TMED10',
  'SRP68',
  '

In [22]:
# Here is an example of using non-targets
# This is the most recent version
genes = adata.var["gene_name"].tolist()
non_targets = list(set(genes).difference(set(condition_names_gene)))
non_targets.sort()
random.seed(42)
random.shuffle(non_targets)
non_targets
condition_names_gene_match = non_targets[:len(condition_names_gene)]

## Prepare model input

In [23]:
max_len = adata.shape[1] + 1
max_len

1268

In [24]:
if config.load_model is None:
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
genes = adata.var["gene_name"].tolist()
gene_ids = np.array(vocab(genes), dtype=int)
adata.obs['batch_id'] = adata.obs['condition'].copy()
batch_ids = adata.obs["batch_id"].tolist()
num_batch_types = len(set(batch_ids))
input_layer_key = "X_binned"

In [25]:
def prepare_data(sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    )

    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
    tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

    if sort_seq_batch:
        train_sort_ids = np.argsort(train_batch_labels)
        input_gene_ids_train = input_gene_ids_train[train_sort_ids]
        input_values_train = input_values_train[train_sort_ids]
        target_values_train = target_values_train[train_sort_ids]
        tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]

        valid_sort_ids = np.argsort(valid_batch_labels)
        input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
        input_values_valid = input_values_valid[valid_sort_ids]
        target_values_valid = target_values_valid[valid_sort_ids]
        tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
        "batch_labels": tensor_batch_labels_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
        "batch_labels": tensor_batch_labels_valid,
    }

    return train_data_pt, valid_data_pt


# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader

## Load the pre-trained scGPT model

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    vocab=vocab,
    dropout=config.dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=config.GEPC,
    do_dab=True,
    use_batch_labels=True,
    num_batch_labels=num_batch_types,
    domain_spec_batchnorm=DSBN,
    n_input_bins=n_input_bins,
    ecs_threshold=config.ecs_thres,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=config.fast_transformer,
    pre_norm=config.pre_norm,
)
if config.load_model is not None:
    load_pretrained(model, torch.load(model_file), verbose=False)

model.to(device)
wandb.watch(model)

Use domain specific batchnorm with affine=False


[]

In [27]:
model.eval()
adata_t = adata.copy()

In [28]:
all_counts = (
    adata_t.layers[input_layer_key].A
    if issparse(adata_t.layers[input_layer_key])
    else adata_t.layers[input_layer_key]
)
celltypes_labels = adata_t.obs["celltype"].tolist()
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata_t.obs["batch_id"].tolist()
batch_ids = np.array(batch_ids)

tokenized_all = tokenize_and_pad_batch(
    all_counts,
    gene_ids,
    max_len=max_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=True,
)
all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])

##  Get attention vectors

In [29]:
from einops import rearrange
from tqdm import tqdm

In [30]:
condition_ids = np.array(adata_t.obs["condition"].tolist())

In [31]:
torch.cuda.empty_cache()
dict_sum_condition = {}
num_attn_layers = 11 
batch_size = 16

In [32]:
all_gene_ids

tensor([[60695, 10954, 33817,  ..., 11394, 20695, 12288],
        [60695, 10954, 33817,  ..., 11394, 20695, 12288],
        [60695, 10954, 33817,  ..., 11394, 20695, 12288],
        ...,
        [60695, 10954, 33817,  ..., 11394, 20695, 12288],
        [60695, 10954, 33817,  ..., 11394, 20695, 12288],
        [60695, 10954, 33817,  ..., 11394, 20695, 12288]])

In [33]:
model.eval()
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
    M = all_gene_ids.size(1)
    N = all_gene_ids.size(0)
    device = next(model.parameters()).device
    for i in tqdm(range(0, N, batch_size)):
        batch_size = all_gene_ids[i : i + batch_size].size(0)
        outputs = np.zeros((batch_size, M, M), dtype=np.float32)
        # Replicate the operations in model forward pass
        src_embs = model.encoder(torch.tensor(all_gene_ids[i : i + batch_size], dtype=torch.long).to(device))
        val_embs = model.value_encoder(torch.tensor(all_values[i : i + batch_size], dtype=torch.float).to(device))
        total_embs = src_embs + val_embs
        #total_embs = model.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1)
        # Send total_embs to attention layers for attention operations
        # Retrieve the output from second to last layer
        for layer in model.transformer_encoder.layers[:num_attn_layers]:
            total_embs = layer(total_embs, src_key_padding_mask=src_key_padding_mask[i : i + batch_size].to(device))
        # Send total_embs to the last layer in flash-attn
        # https://github.com/HazyResearch/flash-attention/blob/1b18f1b7a133c20904c096b8b222a0916e1b3d37/flash_attn/flash_attention.py#L90
        qkv = model.transformer_encoder.layers[num_attn_layers].self_attn.Wqkv(total_embs)
        # Retrieve q, k, and v from flast-attn wrapper
        qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=8)
        q = qkv[:, :, 0, :, :]
        k = qkv[:, :, 1, :, :]
        v = qkv[:, :, 2, :, :]
        # https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
        # q = [batch, gene, n_heads, n_hid]
        # k = [batch, gene, n_heads, n_hid]
        # attn_scores = [batch, n_heads, gene, gene]
        attn_scores = q.permute(0, 2, 1, 3) @ k.permute(0, 2, 3, 1)
        # Rank normalization by row
        attn_scores = attn_scores.reshape((-1, M))
        order = torch.argsort(attn_scores, dim=1)
        rank = torch.argsort(order, dim=1)
        attn_scores = rank.reshape((-1, 8, M, M))/M
        # Rank normalization by column
        attn_scores = attn_scores.permute(0, 1, 3, 2).reshape((-1, M))
        order = torch.argsort(attn_scores, dim=1)
        rank = torch.argsort(order, dim=1)
        attn_scores = (rank.reshape((-1, 8, M, M))/M).permute(0, 1, 3, 2)
        # Average 8 attention heads
        attn_scores = attn_scores.mean(1)
        #attn_scores = attn_scores[:, head_num, :, :].squeeze(1)
        
        outputs = attn_scores.detach().cpu().numpy()
        
        for index in range(batch_size):
            # Keep track of sum per condition
            c = condition_ids[i : i + batch_size][index]
            if c not in dict_sum_condition:
                dict_sum_condition[c] = np.zeros((M, M), dtype=np.float32)
                #dict_sum_condition[c] = [np.expand_dims(outputs[index, :, :], 0)]
            else:
                #dict_sum_condition[c].append(np.expand_dims(outputs[index, :, :], 0))
                dict_sum_condition[c] += outputs[index, :, :]

100%|██████████| 2491/2491 [14:44<00:00,  2.82it/s]


#### ✅ Note
Each perturbation group has one n_gene x n_gene matrix associated with it

In [34]:
groups = adata_t.obs.groupby('condition').groups
dict_sum_condition_mean = dict_sum_condition.copy()
for i in groups.keys():
    dict_sum_condition_mean[i] = dict_sum_condition_mean[i]/len(groups[i])
gene_vocab_idx = all_gene_ids[0].clone().detach().cpu().numpy()
dict_sum_condition_mean

{'AMIGO3+ctrl': array([[0.5310546 , 0.3011772 , 0.61151224, ..., 0.66566384, 0.2138397 ,
         0.19172029],
        [0.60797244, 0.55480415, 0.6012709 , ..., 0.59466916, 0.48264793,
         0.5755632 ],
        [0.61890405, 0.54929096, 0.43206474, ..., 0.5508291 , 0.59692013,
         0.53999317],
        ...,
        [0.57601035, 0.40707183, 0.46932662, ..., 0.6160896 , 0.53505355,
         0.38854548],
        [0.5352424 , 0.32389387, 0.4493968 , ..., 0.6112954 , 0.35057646,
         0.49225843],
        [0.48438302, 0.39183864, 0.5697442 , ..., 0.55788374, 0.36243856,
         0.5271803 ]], dtype=float32),
 'ARHGAP22+ctrl': array([[0.5289558 , 0.31310806, 0.610498  , ..., 0.65732014, 0.22263235,
         0.20019177],
        [0.60951245, 0.5595481 , 0.5937757 , ..., 0.5969046 , 0.480081  ,
         0.57757056],
        [0.6138089 , 0.54193395, 0.43774462, ..., 0.55455554, 0.5955508 ,
         0.53952247],
        ...,
        [0.5747428 , 0.4027575 , 0.46881834, ..., 0.62020457,

In [35]:
perturb_conditions = list(dict_sum_condition_mean.keys())
len(perturb_conditions)

76

In [36]:
perturb_conditions.remove('ctrl')

In [37]:
assert 'ctrl' not in perturb_conditions

#### ✅ Note
Calculate consine distance between control and each perturbation conditions

In [38]:
from sklearn.metrics.pairwise import cosine_distances

In [39]:
rank_list = []

In [40]:
for c in perturb_conditions:
    celltype_0 = c
    celltype_1 = 'ctrl'
    gene_emb_celltype_0 = np.expand_dims(dict_sum_condition_mean[celltype_0][1:, 1:], 0)
    gene_emb_celltype_1 = np.expand_dims(dict_sum_condition_mean[celltype_1][1:, 1:], 0)
    gene_dist_dict = {}
    for i, g in tqdm(enumerate(genes)):
        gene_dist_dict[g] = cosine_distances(gene_emb_celltype_0[:, i, :], gene_emb_celltype_1[:, i, :]).mean()
    df_gene_emb_dist = pd.DataFrame.from_dict(gene_dist_dict, orient='index', columns=['cos_dist'])
    df_deg = df_gene_emb_dist.sort_values(by='cos_dist', ascending=False)
    print(c, np.where(df_deg.index==c.split('+')[0])[0][0])
    rank_list.append(np.where(df_deg.index==c.split('+')[0])[0][0])

1267it [00:00, 3083.87it/s]


AMIGO3+ctrl 569


1267it [00:00, 3054.00it/s]


ARHGAP22+ctrl 175


1267it [00:00, 3114.26it/s]


ASCC3+ctrl 15


1267it [00:00, 3008.22it/s]


BHLHE40+ctrl 584


1267it [00:00, 3071.66it/s]


CAD+ctrl 257


1267it [00:00, 2748.08it/s]


CCND3+ctrl 14


1267it [00:00, 2996.84it/s]


CHERP+ctrl 3


1267it [00:00, 3051.56it/s]


COPB1+ctrl 312


1267it [00:00, 2928.79it/s]


COPZ1+ctrl 8


1267it [00:00, 3036.30it/s]


CREB1+ctrl 85


1267it [00:00, 3067.71it/s]


DAD1+ctrl 0


1267it [00:00, 2983.96it/s]


DDIT3+ctrl 231


1267it [00:00, 2906.21it/s]


DDOST+ctrl 18


1267it [00:00, 2940.82it/s]


DDRGK1+ctrl 6


1267it [00:00, 2977.28it/s]


DERL2+ctrl 23


1267it [00:00, 3002.14it/s]


DHDDS+ctrl 385


1267it [00:00, 3206.87it/s]


DNAJC19+ctrl 3


1267it [00:00, 3399.98it/s]


EIF2B2+ctrl 345


1267it [00:00, 3337.59it/s]


EIF2B3+ctrl 12


1267it [00:00, 3501.81it/s]


EIF2B4+ctrl 21


1267it [00:00, 3396.39it/s]


EIF2S1+ctrl 30


1267it [00:00, 3282.66it/s]


FARSB+ctrl 6


1267it [00:00, 3418.82it/s]


FECH+ctrl 65


1267it [00:00, 3380.86it/s]


GBF1+ctrl 15


1267it [00:00, 3269.42it/s]


GMPPB+ctrl 416


1267it [00:00, 3318.29it/s]


GNPNAT1+ctrl 0


1267it [00:00, 3313.07it/s]


HSD17B12+ctrl 12


1267it [00:00, 3465.52it/s]


HSPA5+ctrl 5


1267it [00:00, 3282.56it/s]


HSPA9+ctrl 7


1267it [00:00, 3435.40it/s]


HYOU1+ctrl 26


1267it [00:00, 3096.23it/s]


IARS2+ctrl 138


1267it [00:00, 2974.29it/s]


IDH3A+ctrl 11


1267it [00:00, 3078.08it/s]


IER3IP1+ctrl 33


1267it [00:00, 3038.53it/s]


KCTD16+ctrl 267


1267it [00:00, 2862.06it/s]


MANF+ctrl 6


1267it [00:00, 2914.48it/s]


MRGBP+ctrl 0


1267it [00:00, 3187.70it/s]


MRPL39+ctrl 3


1267it [00:00, 3066.49it/s]


MTHFD1+ctrl 53


1267it [00:00, 2818.07it/s]


NEDD8+ctrl 5


1267it [00:00, 2992.69it/s]


OST4+ctrl 3


1267it [00:00, 2950.90it/s]


P4HB+ctrl 2


1267it [00:00, 2946.31it/s]


PDIA6+ctrl 0


1267it [00:00, 3033.95it/s]


PPWD1+ctrl 3


1267it [00:00, 3216.80it/s]


PSMD4+ctrl 35


1267it [00:00, 2954.79it/s]


PTDSS1+ctrl 10


1267it [00:00, 2974.08it/s]


SAMM50+ctrl 21


1267it [00:00, 3116.42it/s]


SCYL1+ctrl 1


1267it [00:00, 3064.53it/s]


SEC61A1+ctrl 82


1267it [00:00, 3084.89it/s]


SEC61B+ctrl 0


1267it [00:00, 2993.49it/s]


SEC61G+ctrl 1


1267it [00:00, 3040.64it/s]


SEC63+ctrl 7


1267it [00:00, 3063.29it/s]


SEL1L+ctrl 18


1267it [00:00, 2992.03it/s]


SLC35B1+ctrl 11


1267it [00:00, 2970.52it/s]


SLC39A7+ctrl 77


1267it [00:00, 2950.89it/s]


SOCS1+ctrl 754


1267it [00:00, 2967.80it/s]


SPCS2+ctrl 2


1267it [00:00, 2971.31it/s]


SPCS3+ctrl 12


1267it [00:00, 2991.84it/s]


SRP68+ctrl 63


1267it [00:00, 3045.77it/s]


SRP72+ctrl 1


1267it [00:00, 2841.09it/s]


SRPRB+ctrl 81


1267it [00:00, 2817.32it/s]


STT3A+ctrl 20


1267it [00:00, 2907.96it/s]


SYVN1+ctrl 4


1267it [00:00, 3071.55it/s]


TELO2+ctrl 11


1267it [00:00, 3217.80it/s]


TIMM23+ctrl 4


1267it [00:00, 2910.08it/s]


TIMM44+ctrl 53


1267it [00:00, 2979.27it/s]


TMED2+ctrl 0


1267it [00:00, 2965.33it/s]


TMED10+ctrl 15


1267it [00:00, 3087.15it/s]


TMEM167A+ctrl 6


1267it [00:00, 2938.55it/s]


TTI1+ctrl 77


1267it [00:00, 3100.57it/s]


TTI2+ctrl 295


1267it [00:00, 2953.48it/s]


UFL1+ctrl 3


1267it [00:00, 3136.05it/s]


UFM1+ctrl 4


1267it [00:00, 3119.04it/s]


XRN1+ctrl 163


1267it [00:00, 3078.76it/s]


YIPF5+ctrl 16


1267it [00:00, 3183.95it/s]


ZNF326+ctrl 218


In [41]:
print(perturb_conditions)

['AMIGO3+ctrl', 'ARHGAP22+ctrl', 'ASCC3+ctrl', 'BHLHE40+ctrl', 'CAD+ctrl', 'CCND3+ctrl', 'CHERP+ctrl', 'COPB1+ctrl', 'COPZ1+ctrl', 'CREB1+ctrl', 'DAD1+ctrl', 'DDIT3+ctrl', 'DDOST+ctrl', 'DDRGK1+ctrl', 'DERL2+ctrl', 'DHDDS+ctrl', 'DNAJC19+ctrl', 'EIF2B2+ctrl', 'EIF2B3+ctrl', 'EIF2B4+ctrl', 'EIF2S1+ctrl', 'FARSB+ctrl', 'FECH+ctrl', 'GBF1+ctrl', 'GMPPB+ctrl', 'GNPNAT1+ctrl', 'HSD17B12+ctrl', 'HSPA5+ctrl', 'HSPA9+ctrl', 'HYOU1+ctrl', 'IARS2+ctrl', 'IDH3A+ctrl', 'IER3IP1+ctrl', 'KCTD16+ctrl', 'MANF+ctrl', 'MRGBP+ctrl', 'MRPL39+ctrl', 'MTHFD1+ctrl', 'NEDD8+ctrl', 'OST4+ctrl', 'P4HB+ctrl', 'PDIA6+ctrl', 'PPWD1+ctrl', 'PSMD4+ctrl', 'PTDSS1+ctrl', 'SAMM50+ctrl', 'SCYL1+ctrl', 'SEC61A1+ctrl', 'SEC61B+ctrl', 'SEC61G+ctrl', 'SEC63+ctrl', 'SEL1L+ctrl', 'SLC35B1+ctrl', 'SLC39A7+ctrl', 'SOCS1+ctrl', 'SPCS2+ctrl', 'SPCS3+ctrl', 'SRP68+ctrl', 'SRP72+ctrl', 'SRPRB+ctrl', 'STT3A+ctrl', 'SYVN1+ctrl', 'TELO2+ctrl', 'TIMM23+ctrl', 'TIMM44+ctrl', 'TMED2+ctrl', 'TMED10+ctrl', 'TMEM167A+ctrl', 'TTI1+ctrl', 'TT

#### 🔵 Optional
For control experiment, run the following block

In [42]:
rank_list = []
for i, c in enumerate(perturb_conditions):
    celltype_0 = c
    celltype_1 = 'ctrl'
    gene_emb_celltype_0 = np.expand_dims(dict_sum_condition_mean[celltype_0][1:, 1:], 0)
    gene_emb_celltype_1 = np.expand_dims(dict_sum_condition_mean[celltype_1][1:, 1:], 0)
    gene_dist_dict = {}
    for i, g in tqdm(enumerate(genes)):
        gene_dist_dict[g] = cosine_distances(gene_emb_celltype_0[:, i, :], gene_emb_celltype_1[:, i, :]).mean()
    df_gene_emb_dist = pd.DataFrame.from_dict(gene_dist_dict, orient='index', columns=['cos_dist'])
    df_deg = df_gene_emb_dist.sort_values(by='cos_dist', ascending=False)
    c = condition_names_gene_match.pop(0)
    print(c, np.where(df_deg.index==c.split('+')[0])[0][0])
    rank_list.append(np.where(df_deg.index==c.split('+')[0])[0][0])

1267it [00:00, 3041.27it/s]


RP5-1159O4.2 497


1267it [00:00, 3384.85it/s]


RP11-584P21.2 69


1267it [00:00, 2880.84it/s]


FBXO32 651


1267it [00:00, 2861.64it/s]


NUSAP1 867


1267it [00:00, 3119.23it/s]


PCNA 1194


1267it [00:00, 3254.93it/s]


SKIL 420


1267it [00:00, 3412.56it/s]


MYOZ1 924


1267it [00:00, 3369.22it/s]


GADD45G 1237


1267it [00:00, 3355.00it/s]


EEF1A1 724


1267it [00:00, 3175.43it/s]


MCF2L 607


1267it [00:00, 3307.21it/s]


DLX4 991


1267it [00:00, 2935.36it/s]


CCNE2 134


1267it [00:00, 3132.50it/s]


COCH 1202


1267it [00:00, 3358.02it/s]


NAV2 602


1267it [00:00, 3138.28it/s]


KCNQ1 592


1267it [00:00, 3166.87it/s]


SLC39A10 596


1267it [00:00, 2980.64it/s]


AC023590.1 574


1267it [00:00, 3022.95it/s]


CAPN2 765


1267it [00:00, 2842.75it/s]


PTPN7 93


1267it [00:00, 3051.15it/s]


TUBB1 1173


1267it [00:00, 2948.59it/s]


SLC6A20 317


1267it [00:00, 3076.71it/s]


RP11-96K19.2 675


1267it [00:00, 2973.82it/s]


STAR 308


1267it [00:00, 2950.75it/s]


OR8D1 620


1267it [00:00, 2825.25it/s]


TRIM24 69


1267it [00:00, 2968.36it/s]


KDM7A 213


1267it [00:00, 2787.28it/s]


HSPA8 1009


1267it [00:00, 2940.79it/s]


SLC7A8 865


1267it [00:00, 2996.65it/s]


PSRC1 858


1267it [00:00, 3011.77it/s]


TGIF1 416


1267it [00:00, 2959.52it/s]


NR4A1 1223


1267it [00:00, 2986.28it/s]


PTH2 993


1267it [00:00, 3044.42it/s]


DLGAP5 1093


1267it [00:00, 3066.02it/s]


NMU 478


1267it [00:00, 3046.79it/s]


RBPMS 956


1267it [00:00, 3046.47it/s]


LGALS3 1193


1267it [00:00, 3080.03it/s]


PSAT1 298


1267it [00:00, 3211.95it/s]


AQP1 76


1267it [00:00, 2985.50it/s]


RP11-867G2.8 228


1267it [00:00, 3022.93it/s]


MYLIP 1237


1267it [00:00, 3050.76it/s]


F8 394


1267it [00:00, 2903.10it/s]


PTTG1 1136


1267it [00:00, 2945.93it/s]


HNMT 616


1267it [00:00, 3142.30it/s]


FNDC4 1229


1267it [00:00, 3059.73it/s]


RP11-320G24.1 212


1267it [00:00, 2968.97it/s]


BAIAP3 695


1267it [00:00, 3129.99it/s]


CTC-203F4.2 866


1267it [00:00, 3208.04it/s]


HOTAIRM1 1227


1267it [00:00, 3357.40it/s]


SMC4 528


1267it [00:00, 3432.51it/s]


MRAP2 816


1267it [00:00, 3404.84it/s]


FEV 41


1267it [00:00, 3201.23it/s]


RP11-1277A3.1 755


1267it [00:00, 3423.46it/s]


HMGCS1 1107


1267it [00:00, 2985.54it/s]


RP11-422P24.10 1150


1267it [00:00, 3193.05it/s]


RIMS3 373


1267it [00:00, 3260.06it/s]


IFITM1 786


1267it [00:00, 3241.10it/s]


VIM 5


1267it [00:00, 3003.26it/s]


PPP3CA 947


1267it [00:00, 3418.20it/s]


HBZ 139


1267it [00:00, 3263.39it/s]


IL4R 815


1267it [00:00, 3407.74it/s]


TNFAIP2 1047


1267it [00:00, 3332.14it/s]


SQSTM1 700


1267it [00:00, 3389.85it/s]


PCK2 490


1267it [00:00, 3187.13it/s]


LGALS1 253


1267it [00:00, 2964.30it/s]


MUC4 46


1267it [00:00, 3090.29it/s]


AQP3 852


1267it [00:00, 2972.41it/s]


AP000640.10 1212


1267it [00:00, 2959.95it/s]


AC034243.1 822


1267it [00:00, 3066.94it/s]


KALRN 184


1267it [00:00, 2835.31it/s]


SEC24D 21


1267it [00:00, 2944.18it/s]


RP11-727F15.9 287


1267it [00:00, 2867.33it/s]


ABTB2 405


1267it [00:00, 3137.11it/s]


LIMS1 59


1267it [00:00, 3021.16it/s]


PRSS57 477


1267it [00:00, 2802.01it/s]

ARL6IP1 989





In [46]:
df_results = pd.read_csv('/fs01/home/chloexq/scGPT-release/tutorials/vevo_adamson_ranks_Nov21.csv', index_col=0)
df_results

Unnamed: 0,conditions,wilcoxon,scGPT_rank (attn),scGPT_rank (gene emb),"scGPT_rank (attn, mean)","scGPT_rank (gene emb, mean)",wilcoxon (random),"scGPT_rank (attn, random)","scGPT_rank (gene emb, random)","scGPT_rank (attn, random, null)","scGPT_rank (gene emb, random, null)"
0,AMIGO3+ctrl,1098,569,409,569,409,1098,569,409,1,1
1,ARHGAP22+ctrl,318,175,126,0,0,837,0,0,5,3
2,ASCC3+ctrl,36,15,48,0,0,834,5,7,14,2
3,BHLHE40+ctrl,3,584,558,0,0,1075,25,1,91,211
4,CAD+ctrl,93,257,557,0,0,1257,59,28,58,58
...,...,...,...,...,...,...,...,...,...,...,...
70,UFL1+ctrl,19,3,39,0,0,962,6,5,369,53
71,UFM1+ctrl,0,4,2,0,0,612,21,1,4,1
72,XRN1+ctrl,216,163,206,0,0,789,71,127,6,2
73,YIPF5+ctrl,11,16,55,0,0,1061,12,3,0,0


In [61]:
df_results['scGPT_rank (attn, null, non-targets)'] = rank_list

In [62]:
df_results

Unnamed: 0,conditions,wilcoxon,scGPT_rank (attn),scGPT_rank (gene emb),"scGPT_rank (attn, mean)","scGPT_rank (gene emb, mean)",wilcoxon (random),"scGPT_rank (attn, random)","scGPT_rank (gene emb, random)","scGPT_rank (attn, random, null)","scGPT_rank (gene emb, random, null)","scGPT_rank (attn, null, targets)","scGPT_rank (attn, null, non-targets)"
0,AMIGO3+ctrl,1098,569,409,569,409,1098,569,409,1,1,586,497
1,ARHGAP22+ctrl,318,175,126,0,0,837,0,0,5,3,565,69
2,ASCC3+ctrl,36,15,48,0,0,834,5,7,14,2,248,651
3,BHLHE40+ctrl,3,584,558,0,0,1075,25,1,91,211,56,867
4,CAD+ctrl,93,257,557,0,0,1257,59,28,58,58,819,1194
...,...,...,...,...,...,...,...,...,...,...,...,...,...
70,UFL1+ctrl,19,3,39,0,0,962,6,5,369,53,1012,287
71,UFM1+ctrl,0,4,2,0,0,612,21,1,4,1,743,405
72,XRN1+ctrl,216,163,206,0,0,789,71,127,6,2,832,59
73,YIPF5+ctrl,11,16,55,0,0,1061,12,3,0,0,319,477


In [63]:
df_results.mean()

wilcoxon                                 69.920000
scGPT_rank (attn)                        83.160000
scGPT_rank (gene emb)                   110.320000
scGPT_rank (attn, mean)                  70.506667
scGPT_rank (gene emb, mean)              68.320000
wilcoxon (random)                       848.866667
scGPT_rank (attn, random)               122.093333
scGPT_rank (gene emb, random)            83.560000
scGPT_rank (attn, random, null)         151.720000
scGPT_rank (gene emb, random, null)     102.866667
scGPT_rank (attn, null, targets)        594.253333
scGPT_rank (attn, null, non-targets)    649.573333
dtype: float64

## Misc/Baseline Wilcoxon

In [53]:
df_results = pd.DataFrame()

In [49]:
df_results['conditions'] = perturb_conditions

In [55]:
adata_t_sub_copy = adata_t.copy()
adata_t_sub_copy.X = adata_t_sub_copy.X
adata_t_sub_copy.var.index = genes

In [56]:
adata_t_sub_copy

AnnData object with n_obs × n_vars = 39847 × 1267
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch', 'n_counts', 'batch_id'
    var: 'gene_name', 'id_in_vocab', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'hvg'
    obsm: 'bin_edges'
    layers: 'X_binned'

In [57]:
sc.tl.rank_genes_groups(adata_t_sub_copy, 'condition', method='wilcoxon', key_added = "wilcoxon", n_genes=max_len-1, reference
='ctrl')

In [58]:
adata_t_sub_copy

AnnData object with n_obs × n_vars = 39847 × 1267
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch', 'n_counts', 'batch_id'
    var: 'gene_name', 'id_in_vocab', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'hvg', 'wilcoxon'
    obsm: 'bin_edges'
    layers: 'X_binned'

In [59]:
adata_t_sub_copy.uns['wilcoxon']['names']

rec.array([('CITED2', 'SH3BGRL3', 'SH3BGRL3', 'EEF1A1', 'BLVRB', 'ALAS2', 'GYPB', 'SH3BGRL3', 'SH3BGRL3', 'EEF1A1', 'PDIA3', 'EEF1A1', 'PDIA6', 'CITED2', 'SH3BGRL3', 'PDIA6', 'SH3BGRL3', 'PRSS57', 'PHGDH', 'PRSS57', 'S100A11', 'PHGDH', 'CITED2', 'CITED2', 'CITED2', 'CITED2', 'SH3BGRL3', 'PDIA6', 'S100A11', 'CITED2', 'PHGDH', 'KLF1', 'PDIA6', 'CITED2', 'CITED2', 'HES4', 'CITED2', 'BLVRB', 'SH3BGRL3', 'CITED2', 'CITED2', 'CITED2', 'CLCA1', 'BLVRB', 'CITED2', 'SERPINF1', 'CALR', 'SLC3A2', 'SLC3A2', 'SLC3A2', 'CITED2', 'HSPA5', 'CITED2', 'PDIA6', 'PDIA6', 'CITED2', 'PDIA6', 'SLC3A2', 'CITED2', 'SLC3A2', 'CITED2', 'PDIA3', 'SH3BGRL3', 'SH3BGRL3', 'PHGDH', 'SH3BGRL3', 'CITED2', 'PDIA6', 'RHCE', 'CITED2', 'PDIA6', 'CITED2', 'CLCA1', 'PDIA6', 'EEF1A1'),
           ('PDIA3', 'LIMS1', 'CITED2', 'AIF1', 'ALAS2', 'GYPB', 'CITED2', 'NFKBIA', 'NFKBIA', 'RPSAP58', 'PDIA6', 'RPSAP58', 'PDIA3', 'PDIA6', 'PDIA3', 'HSP90B1', 'PRSS57', 'PHGDH', 'S100A11', 'PHGDH', 'UBE2L6', 'SERPINF1', 'SH3BGRL3', 'FCER1G

In [60]:
baseline_rank = []

In [61]:
for c in perturb_conditions:
    hvg_list = adata_t_sub_copy.uns['wilcoxon']['names'][c]
    p_val = adata_t_sub_copy.uns['wilcoxon']['pvals_adj'][c]
    df_gene_emb_dist = pd.DataFrame()
    df_gene_emb_dist['gene'] = hvg_list
    df_gene_emb_dist['p_val'] = p_val
    df_gene_emb_dist = df_gene_emb_dist.sort_values(by='p_val')
    print(c, np.where(df_gene_emb_dist.gene.values==c.split('+')[0])[0][0])
    baseline_rank.append(np.where(df_gene_emb_dist.gene.values==c.split('+')[0])[0][0])

AMIGO3+ctrl 1098
ARHGAP22+ctrl 837
ASCC3+ctrl 834
BHLHE40+ctrl 1075
CAD+ctrl 1257
CCND3+ctrl 417
CHERP+ctrl 921
COPB1+ctrl 671
COPZ1+ctrl 746
CREB1+ctrl 1242
DAD1+ctrl 515
DDIT3+ctrl 1063
DDOST+ctrl 1263
DDRGK1+ctrl 1094
DERL2+ctrl 537
DHDDS+ctrl 932
DNAJC19+ctrl 959
EIF2B2+ctrl 803
EIF2B3+ctrl 714
EIF2B4+ctrl 1239
EIF2S1+ctrl 556
FARSB+ctrl 1001
FECH+ctrl 967
GBF1+ctrl 794
GMPPB+ctrl 1211
GNPNAT1+ctrl 1265
HSD17B12+ctrl 822
HSPA5+ctrl 1142
HSPA9+ctrl 1075
HYOU1+ctrl 908
IARS2+ctrl 1037
IDH3A+ctrl 716
IER3IP1+ctrl 654
KCTD16+ctrl 1028
MANF+ctrl 415
MRGBP+ctrl 784
MRPL39+ctrl 827
MTHFD1+ctrl 1155
NEDD8+ctrl 658
OST4+ctrl 643
P4HB+ctrl 634
PDIA6+ctrl 984
PPWD1+ctrl 744
PSMD4+ctrl 814
PTDSS1+ctrl 916
SAMM50+ctrl 782
SCYL1+ctrl 710
SEC61A1+ctrl 760
SEC61B+ctrl 822
SEC61G+ctrl 923
SEC63+ctrl 678
SEL1L+ctrl 748
SLC35B1+ctrl 881
SLC39A7+ctrl 673
SOCS1+ctrl 838
SPCS2+ctrl 731
SPCS3+ctrl 760
SRP68+ctrl 659
SRP72+ctrl 1012
SRPRB+ctrl 1009
STT3A+ctrl 692
SYVN1+ctrl 469
TELO2+ctrl 951
TIMM23+ctrl 

In [62]:
df_results

Unnamed: 0,conditions,wilcoxon,scGPT_rank (attn),scGPT_rank (gene emb),"scGPT_rank (attn, mean)","scGPT_rank (gene emb, mean)"
0,AMIGO3+ctrl,1098,569,409,569,409
1,ARHGAP22+ctrl,318,175,126,0,0
2,ASCC3+ctrl,36,15,48,0,0
3,BHLHE40+ctrl,3,584,558,0,0
4,CAD+ctrl,93,257,557,0,0
...,...,...,...,...,...,...
70,UFL1+ctrl,19,3,39,0,0
71,UFM1+ctrl,0,4,2,0,0
72,XRN1+ctrl,216,163,206,0,0
73,YIPF5+ctrl,11,16,55,0,0


In [63]:
df_results['wilcoxon (random)'] = baseline_rank

In [64]:
df_results

Unnamed: 0,conditions,wilcoxon,scGPT_rank (attn),scGPT_rank (gene emb),"scGPT_rank (attn, mean)","scGPT_rank (gene emb, mean)",wilcoxon (random)
0,AMIGO3+ctrl,1098,569,409,569,409,1098
1,ARHGAP22+ctrl,318,175,126,0,0,837
2,ASCC3+ctrl,36,15,48,0,0,834
3,BHLHE40+ctrl,3,584,558,0,0,1075
4,CAD+ctrl,93,257,557,0,0,1257
...,...,...,...,...,...,...,...
70,UFL1+ctrl,19,3,39,0,0,962
71,UFM1+ctrl,0,4,2,0,0,612
72,XRN1+ctrl,216,163,206,0,0,789
73,YIPF5+ctrl,11,16,55,0,0,1061


In [95]:
df_results.mean()

scGPT_rank    83.16
wilcoxon      69.92
dtype: float64

In [64]:
#df_results.to_csv('/fs01/home/chloexq/scGPT-release/tutorials/vevo_adamson_ranks_Nov28.csv')