In [1]:
import copy
import gc
import json
import os
from pathlib import Path
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings

import torch
from anndata import AnnData
import scanpy as sc
#import scvi
import numpy as np
import pandas as pd
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)


sys.path.insert(0, "../")
#import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, eval_scib_metrics, load_pretrained

sc.set_figure_params(figsize=(4, 4))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

  IPython.display.set_matplotlib_formats(*ipython_format)


In [2]:
hyperparameter_defaults = dict(
    seed=42,
    dataset_name="fibro", # Dataset name
    do_train=True, # Flag to indicate whether to do update model parameters during training
    load_model="/scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model",
    model_name="best_model.pt",
    #"/scratch/ssd004/scratch/chloexq/fibro/dev_fibro-Jun19-18-29",
    #"/scratch/ssd004/scratch/chloexq/fibro/dev_fibro-Jun23-14-13", # Path to pre-trained model
    GEPC=True,  # Gene expression modelling for cell objective
    ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=1.0, # DAR objective weight for batch correction
    mask_ratio=0.4, # Default mask ratio
    epochs=15, # Default number of epochs for fine-tuning
    n_bins=51, # Default number of bins for value binning in data pre-processing
    lr=1e-4, # Default learning rate for fine-tuning
    batch_size=64, # Default batch size for fine-tuning
    layer_size=128,
    nlayers=4,
    nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
    dropout=0.2, # Default dropout rate during model fine-tuning
    schedule_ratio=0.9,  # Default rate for learning rate decay
    save_eval_interval=5, # Default model evaluation interval
    log_interval=100, # Default log interval
    fast_transformer=True, # Default setting
    pre_norm=False, # Default setting
    amp=True,  # # Default setting: Automatic Mixed Precision
)
run = wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

set_seed(config.seed)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchloewxq[0m ([33mscformer[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'seed': 42, 'dataset_name': 'fibro', 'do_train': True, 'load_model': '/scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model', 'model_name': 'best_model.pt', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}


In [3]:
from pathlib import Path
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = -1
pad_value = -2
n_input_bins = config.n_bins

n_hvg = 1200  # number of highly variable genes
max_seq_len = n_hvg + 1
per_seq_batch_sample = True
DSBN = False  # Domain-spec batchnorm
explicit_zero_prob = True  # whether explicit bernoulli for zeros

dataset_name = config.dataset_name
save_dir = Path(f"/scratch/ssd004/scratch/chloexq/fibro/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")

save to /scratch/ssd004/scratch/chloexq/fibro/dev_fibro-Feb17-11-19


## Load and preprocess dataset

####  ✅ Note
Perturbation datasets can be found in this path: `/scratch/ssd004/scratch/chloexq/perturb_analysis/{dataset_name}` 

In [4]:
data_dir = Path("/scratch/ssd004/scratch/chloexq/perturb_analysis")
adata = sc.read(data_dir / "adamson/perturb_processed.h5ad")
ori_batch_col = "control"
adata.obs["celltype"] = adata.obs["condition"].astype("category")
adata.obs["str_batch"] = adata.obs["control"].astype(str)
data_is_raw = False

In [5]:
if config.load_model is not None:
    model_dir = Path(config.load_model)
    model_config_file = model_dir / "args.json"
    model_file = model_dir / config.model_name
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    print(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    adata = adata[:, adata.var["id_in_vocab"] >= 0]
    
    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    print(
        f"Resume model from {model_file}, the model args will be overriden by the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]
else:
    embsize = config.layer_size 
    nhead = config.nhead
    nlayers = config.nlayers  
    d_hid = config.layer_size

match 4399/5060 genes in vocabulary of size 60697.
Resume model from /scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model/best_model.pt, the model args will be overriden by the config /scratch/ssd004/scratch/chloexq/scGPT_models/scGPT_human_model/args.json.


In [6]:
# Align perturbation condition naming
gene_names_set = [i + '+ctrl' for i in adata.var.gene_name.values]
gene_names_set = gene_names_set + ['ctrl']

####  ✅ Note
This experiment is computationally expensive, so we select 1000 cells per perturbation condition.

In [7]:
# Cap all conditions to 1000 cells
sampled_df = (
    adata.obs[adata.obs['condition'].isin(gene_names_set)]
    .groupby('condition', group_keys=False)
    .apply(lambda x: x.sample(min(len(x), 1000), random_state=42))
)
adata = adata[sampled_df.index].copy()
adata.obs.groupby('condition').count()

Unnamed: 0_level_0,cell_type,dose_val,control,condition_name,celltype,str_batch
condition,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
AMIGO3+ctrl,616,616,616,616,616,616
ARHGAP22+ctrl,406,406,406,406,406,406
ASCC3+ctrl,524,524,524,524,524,524
BHLHE40+ctrl,504,504,504,504,504,504
CAD+ctrl,242,242,242,242,242,242
...,...,...,...,...,...,...
UFM1+ctrl,591,591,591,591,591,591
XRN1+ctrl,621,621,621,621,621,621
YIPF5+ctrl,1000,1000,1000,1000,1000,1000
ZNF326+ctrl,517,517,517,517,517,517


In [8]:
# 5 conditions are capped, including ctrl
condition_counts = adata.obs.groupby('condition').count()

In [9]:
condition_names = set(adata.obs.condition.tolist())

In [10]:
condition_names.remove('ctrl')

In [11]:
condition_names_gene = [i.split('+')[0] for i in list(condition_names)]

In [12]:
condition_names_gene.sort()

####  ✅ Note
HVGs selection will filter out some perturbed genes. We manually add them back in the experiment.

In [13]:
# Do filtering
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=None,  # step 2
    normalize_total=None,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=False,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=None,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    #binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    #result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key=None)

scGPT - INFO - Filtering genes by counts ...


In [14]:
sc.pp.highly_variable_genes(
    adata,
    layer=None,
    n_top_genes=1200,
    flavor="seurat_v3" if data_is_raw else "cell_ranger",
    subset=False,
)

In [15]:
add_counter = 0
for g in condition_names_gene:
    if not adata.var.loc[adata.var[adata.var.gene_name==g].index, 'highly_variable'].values[0]:
        adata.var.loc[adata.var[adata.var.gene_name==g].index, 'highly_variable'] = True
        add_counter += 1

In [16]:
print('Manually add conditions: {}, {}'.format(add_counter, add_counter/len(condition_names_gene)))

Manually add conditions: 67, 0.8933333333333333


In [17]:
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=0,  # step 1
    filter_cell_by_counts=None,  # step 2
    normalize_total=None,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=False,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=None,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)
preprocessor(adata, batch_key=None)

scGPT - INFO - Binning data ...


In [18]:
adata = adata[:, adata.var["highly_variable"]].copy()
print(adata)

AnnData object with n_obs × n_vars = 39847 × 1267
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch'
    var: 'gene_name', 'id_in_vocab', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'hvg'
    obsm: 'bin_edges'
    layers: 'X_binned'


## Prepare model input

In [19]:
max_len = adata.shape[1] + 1
max_len

1268

In [20]:
adata.obs['batch_id'] = adata.obs['condition'].copy()

In [21]:
input_layer_key = "X_binned"
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()

celltypes_labels = adata.obs["celltype"].tolist()  # make sure count from 0
num_types = len(set(celltypes_labels))
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata.obs["batch_id"].tolist()
num_batch_types = len(set(batch_ids))
batch_ids = np.array(batch_ids)

(
    train_data,
    valid_data,
    train_celltype_labels,
    valid_celltype_labels,
    train_batch_labels,
    valid_batch_labels,
) = train_test_split(
    all_counts, celltypes_labels, batch_ids, test_size=0.1, shuffle=True
)


In [22]:
if config.load_model is None:
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(vocab(genes), dtype=int)

In [23]:
len(gene_ids)

1267

In [24]:
adata

AnnData object with n_obs × n_vars = 39847 × 1267
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch', 'batch_id'
    var: 'gene_name', 'id_in_vocab', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'hvg'
    obsm: 'bin_edges'
    layers: 'X_binned'

In [25]:
def prepare_data(sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    )

    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
    tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

    if sort_seq_batch:
        train_sort_ids = np.argsort(train_batch_labels)
        input_gene_ids_train = input_gene_ids_train[train_sort_ids]
        input_values_train = input_values_train[train_sort_ids]
        target_values_train = target_values_train[train_sort_ids]
        tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]

        valid_sort_ids = np.argsort(valid_batch_labels)
        input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
        input_values_valid = input_values_valid[valid_sort_ids]
        target_values_valid = target_values_valid[valid_sort_ids]
        tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
        "batch_labels": tensor_batch_labels_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
        "batch_labels": tensor_batch_labels_valid,
    }

    return train_data_pt, valid_data_pt


# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader

## Load the pre-trained scGPT model
####  ✅ Note
Make sure to import from .model instead of .model_pcpt in __ init __.py

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    vocab=vocab,
    dropout=config.dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=config.GEPC,
    do_dab=False,
    use_batch_labels=False,
    num_batch_labels=num_batch_types,
    domain_spec_batchnorm=DSBN,
    n_input_bins=n_input_bins,
    ecs_threshold=config.ecs_thres,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=config.fast_transformer,
    use_generative_training=True,
    pre_norm=config.pre_norm,
)
if config.load_model is not None:
    load_pretrained(model, torch.load(model_file), verbose=False)

model.to(device)
wandb.watch(model)

[]

In [27]:
model.eval()
adata_t = adata.copy()

## Expand cells
📗 Slide deck pages 66-73

In [36]:
def expand_cell(tokenized_all, key, k, select_gene_id):
    # tokenzed_all: gene expression profiles e.g., AMIGO3 and control
    # key: key for genes list in tokenized_all
    # k: current cell index
    # select_gene_id: perturbed gene index
    cell_k = tokenized_all[key][k]
    # Repeat cell_k n_genes times
    # [n_genes, n_genes]
    cell_k_expand = cell_k.repeat(n_genes).view(n_genes, n_genes)
    # Append a new column of PAD tokens to the end (page 68 in slides)
    new_column = torch.full((n_genes, 1), vocab([pad_token])[0])
    cell_k_expand = torch.cat((cell_k_expand, new_column), dim=1)
    # Create mask (page 69 in slides)
    # Set diagonal to be True
    mask = torch.eye(n_genes).bool()
    new_column_mask = torch.full((n_genes, 1), False)
    # [n_genes, n_genes+1], last column is PAD, hence set to be False
    mask = torch.cat((mask, new_column_mask), dim=1)
    # Set column of perturbed gene to be True
    mask[:, select_gene_id] = True
    # Set one PAD to be True
    mask[select_gene_id, n_genes] = True
    mask_select_expand = cell_k_expand[mask]
    select_ids_gen = mask_select_expand.view(n_genes, 2)
    select_ids_pcpt = cell_k_expand[~mask].view(n_genes, n_genes-1)
    return select_ids_gen, select_ids_pcpt

from tqdm import tqdm
def collate_cell_by_key(tokenized_all, key, select_gene_id):
    # tokenzed_all: gene expression profiles e.g., AMIGO3 and control
    # key: key for genes list in tokenized_all
    # select_gene_id: perturbed gene index
    print(key)
    select_ids_gen_list = []
    select_ids_pcpt_list = []
    for k in tqdm(range(n_cells)):
        # Iterate through each cell
        # Expand each cell to 1268 fakle cells
        select_ids_gen, select_ids_pcpt = expand_cell(tokenized_all, key, k, select_gene_id)
        select_ids_gen_list.append(select_ids_gen)
        select_ids_pcpt_list.append(select_ids_pcpt)
    select_ids_gen = torch.cat(select_ids_gen_list, dim=0)
    select_ids_pcpt = torch.cat(select_ids_pcpt_list, dim=0)
    print(select_ids_gen.shape, select_ids_pcpt.shape)
    return select_ids_gen, select_ids_pcpt

In [37]:
from torch.utils.data import DataLoader, TensorDataset

In [38]:
from sklearn.metrics.pairwise import cosine_distances
from tqdm import tqdm
import pandas as pd

In [39]:
select_gene_list = condition_names_gene

for select_gene in select_gene_list:
    # Filter on perturbed gene and control gene
    adata_t = adata[adata.obs['condition'].isin([select_gene+'+ctrl', 'ctrl'])].copy()
    print(adata_t.obs['condition'])
    # Bug fix: select_gene_id should have +1
    select_gene_id = genes.index(select_gene)+1
    print(select_gene_id)
    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )
    celltypes_labels = adata_t.obs["celltype"].tolist()
    celltypes_labels = np.array(celltypes_labels)

    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)

    tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )

    # tokenized_all output genes and values in format [CLS, G0, G1, G2 ...]
    all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
    src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
    # [N_cells, 1268] => 1267+1
    print(tokenized_all['genes'].shape, tokenized_all['values'].shape)
    n_cells = tokenized_all['genes'].shape[0]
    n_genes = tokenized_all['genes'].shape[1]
    
    # Expand to N_cells*1268 fake cells
    # select_gene_id = perturbed gene
    collate_genes_gen, collate_genes_pcpt = collate_cell_by_key(tokenized_all, 'genes', select_gene_id)
    _, collate_values_pcpt = collate_cell_by_key(tokenized_all, 'values', select_gene_id)
    
    tokenized_all_expand = {'genes_pcpt': collate_genes_pcpt, 'genes_gen': collate_genes_gen, 'values_pcpt': collate_values_pcpt}
    print(tokenized_all_expand)
    query_id = tokenized_all['genes'][0].repeat(n_cells)
    
    cell_counter = torch.arange(0, n_cells)
    cell_counter = cell_counter.repeat(n_genes).view(n_genes, n_cells).t().flatten()
    gene_counter = torch.arange(0, n_genes).repeat(n_cells)

    dataloader = DataLoader(
        TensorDataset(tokenized_all_expand['genes_pcpt'], 
                      tokenized_all_expand['genes_gen'], 
                      tokenized_all_expand['values_pcpt'],
                      query_id,
                      cell_counter,
                      gene_counter,
                     ), 
        batch_size=512, 
        shuffle=False)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    gene_embeddings = np.zeros((n_cells, n_genes, 512))
    
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
        for batch_idx, batch_data in enumerate(tqdm(dataloader)):
            pcpt_genes = batch_data[0].to(device)
            gen_genes = batch_data[1].to(device)
            pcpt_values = batch_data[2].to(device)
            query_id_select = batch_data[3].to(device)
            cell_counter_batch = batch_data[4].to(device)
            gene_counter_batch = batch_data[5].to(device)
            pcpt_key_padding_mask = pcpt_genes.eq(vocab[pad_token]).to(device)
            gen_key_padding_mask = gen_genes.eq(vocab[pad_token]).to(device)
            _, gen_output = model.transformer_generate(
                pcpt_genes=pcpt_genes,
                pcpt_values=pcpt_values,
                pcpt_key_padding_mask=pcpt_key_padding_mask,
                gen_genes=gen_genes,
                gen_key_padding_mask=gen_key_padding_mask,
            )
            select_mask = (gen_genes == query_id_select.unsqueeze(1)).long()
            selected_output = gen_output[torch.arange(gen_output.shape[0]), select_mask.argmax(dim=1), :]
            selected_output_np = selected_output.detach().cpu().numpy()
            gene_embeddings[cell_counter_batch.detach().cpu().numpy(), gene_counter_batch.detach().cpu().numpy(), :] = selected_output_np
    
    conditions = adata_t.obs['condition'].values
    
    dict_sum_condition_mean = {}
    for c in np.unique(conditions):
        dict_sum_condition_mean[c] = gene_embeddings[np.where(conditions == c)[0]].mean(0)
    
    print(dict_sum_condition_mean)
        
    celltype_0 = select_gene + '+ctrl'
    celltype_1 = 'ctrl'
    gene_emb_celltype_0 = np.expand_dims(dict_sum_condition_mean[celltype_0][1:, 1:], 0)
    gene_emb_celltype_1 = np.expand_dims(dict_sum_condition_mean[celltype_1][1:, 1:], 0)
    gene_dist_dict = {}
    for i, g in tqdm(enumerate(genes)):
        gene_dist_dict[g] = cosine_distances(gene_emb_celltype_0[:, i, :], gene_emb_celltype_1[:, i, :]).mean()
    df_gene_emb_dist = pd.DataFrame.from_dict(gene_dist_dict, orient='index', columns=['cos_dist'])
    df_deg = df_gene_emb_dist.sort_values(by='cos_dist', ascending=False)
    rank_celltype_0 = np.where(df_deg.index==celltype_0.split('+')[0])[0][0]
    print(celltype_0, rank_celltype_0) 
    np.savez('/scratch/hdd001/home/haotian/perturb_data/vevo_adamson_mean_gene_emb/mean_gene_emb_{}_{}.npz'.format(select_gene, rank_celltype_0), **dict_sum_condition_mean)
    assert 0

cell_barcode
CAGGCCGAGATGAA-2     AMIGO3+ctrl
CACACCTGCATCAG-4     AMIGO3+ctrl
CCAAGTGAGTATGC-10    AMIGO3+ctrl
TGCACGCTACCTTT-3     AMIGO3+ctrl
GGAGCCACTGCCCT-2     AMIGO3+ctrl
                        ...     
CTTGTATGGTATGC-3            ctrl
GTGGTAACCTACTT-8            ctrl
GCAAACTGATTCCT-1            ctrl
TTCTCAGATTCATC-5            ctrl
GACCTAGAGGAGTG-1            ctrl
Name: condition, Length: 1616, dtype: category
Categories (2, object): ['AMIGO3+ctrl', 'ctrl']
265
torch.Size([1616, 1268]) torch.Size([1616, 1268])
genes


  0%|          | 0/1616 [00:00<?, ?it/s]

265
1268
torch.Size([1268, 1269])





AssertionError: 

## Example gene - AMIGO3

In [36]:
select_gene = 'AMIGO3'
select_gene_id = [genes.index(select_gene)]
select_gene_id

[264]

In [39]:
all_counts = (
    adata_t.layers[input_layer_key].A
    if issparse(adata_t.layers[input_layer_key])
    else adata_t.layers[input_layer_key]
)

celltypes_labels = adata_t.obs["celltype"].tolist()
celltypes_labels = np.array(celltypes_labels)

batch_ids = adata_t.obs["batch_id"].tolist()
batch_ids = np.array(batch_ids)

tokenized_all = tokenize_and_pad_batch(
    all_counts,
    gene_ids,
    max_len=max_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=True,
)
all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])

In [40]:
tokenized_all

{'genes': tensor([[60695, 10954, 33817,  ..., 11394, 20695, 12288],
         [60695, 10954, 33817,  ..., 11394, 20695, 12288],
         [60695, 10954, 33817,  ..., 11394, 20695, 12288],
         ...,
         [60695, 10954, 33817,  ..., 11394, 20695, 12288],
         [60695, 10954, 33817,  ..., 11394, 20695, 12288],
         [60695, 10954, 33817,  ..., 11394, 20695, 12288]]),
 'values': tensor([[ 0.,  0.,  0.,  ...,  0., 39.,  0.],
         [ 0., 33.,  0.,  ...,  0., 41.,  0.],
         [ 0.,  0.,  0.,  ...,  0., 20.,  0.],
         ...,
         [ 0.,  0.,  0.,  ...,  0., 26., 26.],
         [ 0.,  0.,  0.,  ...,  0., 25.,  0.],
         [ 0.,  0.,  0.,  ...,  0.,  0.,  0.]])}

In [41]:
tokenized_all['genes'].shape, tokenized_all['values'].shape

(torch.Size([3292, 1268]), torch.Size([3292, 1268]))

In [42]:
select_gene_id = genes.index(select_gene)
select_gene_id

ValueError: ['AMIGO3', 'ARHGAP22', 'ASCC3', 'BHLHE40', 'CAD'] is not in list

In [39]:
n_cells = tokenized_all['genes'].shape[0]
n_genes = tokenized_all['genes'].shape[1]

In [40]:
vocab(['<pad>'])

[60694]

In [43]:
collate_genes_gen, collate_genes_pcpt = collate_cell_by_key(tokenized_all, 'genes')
_, collate_values_pcpt = collate_cell_by_key(tokenized_all, 'values')

genes


100%|██████████| 1616/1616 [00:20<00:00, 80.69it/s]


torch.Size([2049088, 2]) torch.Size([2049088, 1267])
values


100%|██████████| 1616/1616 [00:15<00:00, 101.07it/s]


torch.Size([2049088, 2]) torch.Size([2049088, 1267])


In [44]:
tokenized_all_expand = {'genes_pcpt': collate_genes_pcpt, 'genes_gen': collate_genes_gen, 'values_pcpt': collate_values_pcpt}

In [45]:
tokenized_all_expand

{'genes_pcpt': tensor([[10954, 33817, 33823,  ..., 20695, 12288, 60694],
         [60695, 33817, 33823,  ..., 20695, 12288, 60694],
         [60695, 10954, 33823,  ..., 20695, 12288, 60694],
         ...,
         [60695, 10954, 33817,  ..., 20695, 12288, 60694],
         [60695, 10954, 33817,  ..., 11394, 12288, 60694],
         [60695, 10954, 33817,  ..., 11394, 20695, 60694]]),
 'genes_gen': tensor([[60695, 17061],
         [10954, 17061],
         [33817, 17061],
         ...,
         [17061, 11394],
         [17061, 20695],
         [17061, 12288]]),
 'values_pcpt': tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.9000e+01, 0.0000e+00,
          6.0694e+04],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.9000e+01, 0.0000e+00,
          6.0694e+04],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.9000e+01, 0.0000e+00,
          6.0694e+04],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
          6.0694e+04],
         [0.

In [46]:
query_id = tokenized_all['genes'][0].repeat(n_cells)
query_id.shape

torch.Size([2049088])

In [47]:
cell_counter = torch.arange(0, n_cells)
cell_counter = cell_counter.repeat(n_genes).view(n_genes, n_cells).t().flatten()
gene_counter = torch.arange(0, n_genes).repeat(n_cells)

In [53]:
from torch.utils.data import DataLoader, TensorDataset

dataloader = DataLoader(
    TensorDataset(tokenized_all_expand['genes_pcpt'], 
                  tokenized_all_expand['genes_gen'], 
                  tokenized_all_expand['values_pcpt'],
                  query_id,
                  cell_counter,
                  gene_counter,
                 ), 
    batch_size=256, 
    shuffle=False)

In [40]:
gene_embeddings = np.load('/scratch/hdd001/home/haotian/perturb_data/gene_emb_AMIGO3.npy')

In [41]:
gene_embeddings

array([[[ 0.14471959, -0.31810772, -0.15156181, ...,  0.38157901,
          0.18877536,  0.00387936],
        [ 0.61960208, -0.11534177, -0.06519999, ...,  0.77028453,
          0.33847851, -0.34249339],
        [ 0.69855273, -0.27113497, -0.60498726, ...,  0.87949604,
          0.19560406,  0.19362614],
        ...,
        [ 0.57298452, -0.40124443,  0.15548812, ...,  1.01221168,
          0.26631504, -0.27922907],
        [ 0.53930181, -0.35991946, -0.17062013, ...,  0.83171099,
         -0.26467252, -0.49455971],
        [ 0.05785051, -0.16414733, -0.59651417, ...,  0.6242736 ,
          0.00651832,  0.33623916]],

       [[-0.02670202, -0.32364917, -0.17526059, ...,  0.2777673 ,
          0.19717763,  0.05834019],
        [ 0.50269228, -0.11272967, -0.0869033 , ...,  0.63911295,
          0.25286543, -0.34779724],
        [ 0.57340115, -0.21534836, -0.64286149, ...,  0.76529098,
          0.13185714,  0.21523918],
        ...,
        [ 0.55179942, -0.46997151,  0.14870609, ...,  

In [48]:
gene_embeddings.shape

(1616, 1268, 512)

In [54]:
adata_t

AnnData object with n_obs × n_vars = 1616 × 1267
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch', 'n_counts', 'batch_id'
    var: 'gene_name', 'id_in_vocab', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'hvg'
    obsm: 'bin_edges'
    layers: 'X_binned'

In [43]:
conditions = adata_t.obs['condition'].values
conditions

['AMIGO3+ctrl', 'AMIGO3+ctrl', 'AMIGO3+ctrl', 'AMIGO3+ctrl', 'AMIGO3+ctrl', ..., 'ctrl', 'ctrl', 'ctrl', 'ctrl', 'ctrl']
Length: 1616
Categories (2, object): ['AMIGO3+ctrl', 'ctrl']

In [44]:
dict_sum_condition_mean = {}
for c in np.unique(conditions):
    dict_sum_condition_mean[c] = gene_embeddings[np.where(conditions == c)[0]].mean(0)

In [45]:
dict_sum_condition_mean

{'AMIGO3+ctrl': array([[ 0.02896031, -0.30605496, -0.18769101, ...,  0.31369529,
          0.18078589,  0.01247744],
        [ 0.53777679, -0.08083235, -0.09220441, ...,  0.7052932 ,
          0.30027999, -0.3661572 ],
        [ 0.61212808, -0.18003001, -0.64855161, ...,  0.81548358,
          0.16254684,  0.16346903],
        ...,
        [ 0.55028727, -0.41949252,  0.12546903, ...,  0.95711756,
          0.22197948, -0.30449767],
        [ 0.47133369, -0.32766302, -0.2268176 , ...,  0.76627743,
         -0.30820553, -0.49919339],
        [-0.02416596, -0.08885051, -0.64301324, ...,  0.56321243,
          0.01604781,  0.35337052]]),
 'ctrl': array([[ 0.03563262, -0.31076524, -0.1945659 , ...,  0.3190965 ,
          0.19469241,  0.03929696],
        [ 0.55548337, -0.08925866, -0.10811319, ...,  0.70242681,
          0.31922594, -0.3347161 ],
        [ 0.60411533, -0.15919195, -0.6666798 , ...,  0.81715135,
          0.17643361,  0.18269647],
        ...,
        [ 0.55992617, -0.431781

In [54]:
celltype_0 = select_gene + '+ctrl'
celltype_1 = 'ctrl'
gene_emb_celltype_0 = np.expand_dims(dict_sum_condition_mean[celltype_0][1:, 1:], 0)
gene_emb_celltype_1 = np.expand_dims(dict_sum_condition_mean[celltype_1][1:, 1:], 0)
gene_dist_dict = {}
for i, g in tqdm(enumerate(genes)):
    gene_dist_dict[g] = cosine_distances(gene_emb_celltype_0[:, i, :], gene_emb_celltype_1[:, i, :]).mean()
df_gene_emb_dist = pd.DataFrame.from_dict(gene_dist_dict, orient='index', columns=['cos_dist'])
df_deg = df_gene_emb_dist.sort_values(by='cos_dist', ascending=False)
rank_celltype_0 = np.where(df_deg.index==celltype_0.split('+')[0])[0][0]
print(celltype_0, rank_celltype_0)

1267it [00:00, 4388.30it/s]

AMIGO3+ctrl 406





In [59]:
select_gene

'AMIGO3'

In [65]:
np.savez('/scratch/hdd001/home/haotian/perturb_data/vevo_adamson_mean_gene_emb/mean_gene_emb_{}.npz'.format(select_gene), **dict_sum_condition_mean)

In [61]:
loaded_data = np.load('/scratch/hdd001/home/haotian/perturb_data/mean_gene_emb_{}.npz'.format(select_gene))

In [64]:
loaded_data['AMIGO3+ctrl']

array([[ 0.02896031, -0.30605496, -0.18769101, ...,  0.31369529,
         0.18078589,  0.01247744],
       [ 0.53777679, -0.08083235, -0.09220441, ...,  0.7052932 ,
         0.30027999, -0.3661572 ],
       [ 0.61212808, -0.18003001, -0.64855161, ...,  0.81548358,
         0.16254684,  0.16346903],
       ...,
       [ 0.55028727, -0.41949252,  0.12546903, ...,  0.95711756,
         0.22197948, -0.30449767],
       [ 0.47133369, -0.32766302, -0.2268176 , ...,  0.76627743,
        -0.30820553, -0.49919339],
       [-0.02416596, -0.08885051, -0.64301324, ...,  0.56321243,
         0.01604781,  0.35337052]])

In [None]:
/scratch/hdd001/home/haotian/perturb_data/gene_emb_AMIGO3.npy'

In [58]:
from sklearn.metrics.pairwise import cosine_distances
from tqdm import tqdm
import pandas as pd

In [70]:
perturb_conditions = ['AMIGO3+ctrl']

rank_list = []
for i, c in enumerate(perturb_conditions):
    celltype_0 = c
    celltype_1 = 'ctrl'
    gene_emb_celltype_0 = np.expand_dims(dict_sum_condition_mean[celltype_0][1:, 1:], 0)
    gene_emb_celltype_1 = np.expand_dims(dict_sum_condition_mean[celltype_1][1:, 1:], 0)
    gene_dist_dict = {}
    for i, g in tqdm(enumerate(genes)):
        gene_dist_dict[g] = cosine_distances(gene_emb_celltype_0[:, i, :], gene_emb_celltype_1[:, i, :]).mean()
    df_gene_emb_dist = pd.DataFrame.from_dict(gene_dist_dict, orient='index', columns=['cos_dist'])
    df_deg = df_gene_emb_dist.sort_values(by='cos_dist', ascending=False)
    print(c, np.where(df_deg.index==c.split('+')[0])[0][0])
    rank_list.append(np.where(df_deg.index==c.split('+')[0])[0][0])

1267it [00:00, 3860.98it/s]

AMIGO3+ctrl 406



