In [1]:
# %%
import copy
import gc
import json
import os
from pathlib import Path
import shutil
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings
import pandas as pd
# from . import asyn
import pickle
import torch
from anndata import AnnData
import scanpy as sc
# import scvi
import seaborn as sns
import numpy as np
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from sklearn.metrics import confusion_matrix

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

sc.set_figure_params(figsize=(6, 6))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')



In [2]:
hyperparameter_defaults = dict(
    seed=0,
    dataset_name="covid",
    do_train=True,
    load_model="../save/scGPT_human",
    mask_ratio=0.0,
    epochs=10,
    n_bins=51,
    MVC=False, # Masked value prediction for cell embedding
    ecs_thres=0.0, # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=0.0,
    lr=1e-4,
    batch_size=16   ,
    layer_size=128,
    nlayers=4,  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    nhead=4,  # number of heads in nn.MultiheadAttention
    dropout=0.2,  # dropout probability
    schedule_ratio=0.9,  # ratio of epochs for learning rate schedule
    save_eval_interval=5,
    fast_transformer=True,
    pre_norm=False,
    amp=True,  # Automatic Mixed Precision
    include_zero_gene = False,
    freeze = False, #freeze
    DSBN = False,  # Domain-spec batchnorm
)

In [3]:
run = wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
print(config)

set_seed(config.seed)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33msrks[0m ([33msrks-uc-san-diego[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111274573356948, max=1.0)…

{'seed': 0, 'dataset_name': 'covid', 'do_train': True, 'load_model': '../save/scGPT_human', 'mask_ratio': 0.0, 'epochs': 10, 'n_bins': 51, 'MVC': False, 'ecs_thres': 0.0, 'dab_weight': 0.0, 'lr': 0.0001, 'batch_size': 16, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'fast_transformer': True, 'pre_norm': False, 'amp': True, 'include_zero_gene': False, 'freeze': False, 'DSBN': False}


In [4]:
# settings for input and preprocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_ratio = config.mask_ratio
mask_value = "auto"  # for masked values, now it should always be auto

include_zero_gene = config.include_zero_gene  # if True, include zero genes among hvgs in the training
max_seq_len = 3001
n_bins = config.n_bins

# input/output representation
input_style = "binned"  # "normed_raw", "log1p", or "binned"
output_style = "binned"  # "normed_raw", "log1p", or "binned"

# settings for training
MLM = False  # whether to use masked language modeling, currently it is always on.
CLS = True  # celltype classification objective
ADV = False  # Adversarial training for batch correction
CCE = False  # Contrastive cell embedding objective
MVC = config.MVC  # Masked value prediction for cell embedding
ECS = config.ecs_thres > 0  # Elastic cell similarity objective
DAB = False  # Domain adaptation by reverse backpropagation, set to 2 for separate optimizer
INPUT_BATCH_LABELS = False  # TODO: have these help MLM and MVC, while not to classifier
input_emb_style = "continuous"  # "category" or "continuous" or "scaling"
cell_emb_style = "cls"  # "avg-pool" or "w-pool" or "cls"
adv_E_delay_epochs = 0  # delay adversarial training on encoder for a few epochs
adv_D_delay_epochs = 0
mvc_decoder_style = "inner product"
ecs_threshold = config.ecs_thres
dab_weight = config.dab_weight

explicit_zero_prob = MLM and include_zero_gene  # whether explicit bernoulli for zeros
do_sample_in_train = False and explicit_zero_prob  # sample the bernoulli in training

per_seq_batch_sample = False

# settings for optimizer
lr = config.lr  # TODO: test learning rate ratio between two tasks
lr_ADV = 1e-3  # learning rate for discriminator, used when ADV is True
batch_size = config.batch_size
eval_batch_size = config.batch_size
epochs = config.epochs
schedule_interval = 1

# settings for the model
fast_transformer = config.fast_transformer
fast_transformer_backend = "flash"  # "linear" or "flash"
embsize = config.layer_size  # embedding dimension
d_hid = config.layer_size  # dimension of the feedforward network in TransformerEncoder
nlayers = config.nlayers  # number of TransformerEncoderLayer in TransformerEncoder
nhead = config.nhead  # number of heads in nn.MultiheadAttention
dropout = config.dropout  # dropout probability

# logging
log_interval = 100  # iterations
save_eval_interval = config.save_eval_interval  # epochs
do_eval_scib_metrics = True

In [5]:
# %% validate settings
assert input_style in ["normed_raw", "log1p", "binned"]
assert output_style in ["normed_raw", "log1p", "binned"]
assert input_emb_style in ["category", "continuous", "scaling"]
if input_style == "binned":
    if input_emb_style == "scaling":
        raise ValueError("input_emb_style `scaling` is not supported for binned input.")
elif input_style == "log1p" or input_style == "normed_raw":
    if input_emb_style == "category":
        raise ValueError(
            "input_emb_style `category` is not supported for log1p or normed_raw input."
        )

if input_emb_style == "category":
    mask_value = n_bins + 1
    pad_value = n_bins  # for padding gene expr values
    n_input_bins = n_bins + 2
else:
    mask_value = -1
    pad_value = -2
    n_input_bins = n_bins

if ADV and DAB:
    raise ValueError("ADV and DAB cannot be both True.")
DAB_separate_optim = True if DAB > 1 else False

In [6]:
dataset_name = config.dataset_name
save_dir = Path(f"../save/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"save to {save_dir}")
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")

save to ../save/dev_covid-Nov07-16-41


## Understanding differences between MS and COVID datasets

In [8]:
msDataPath = Path("../data/ms") # multiple sclerosis data's path for reference
covidDataPath = Path("../data/covid") # covid data's path

msDataTest = sc.read(msDataPath / "filtered_ms_adata.h5ad") # loading MS test data
msData = sc.read(msDataPath / "c_data.h5ad") # loading MS data
covidData = sc.read(covidDataPath / "covidObj.h5ad") # loading Covid data

In [9]:
# Observe shapes for both datasets
print(f'COVID dataset \n Cells: {covidData.X.shape[0]} \n Genes: {covidData.X.shape[1]}\n')
print(f'MS dataset \n Cells: {msData.X.shape[0]} \n Genes: {msData.X.shape[1]}\n')
print(f'MS Test dataset \n Cells: {msDataTest.X.shape[0]} \n Genes: {msDataTest.X.shape[1]}\n')

COVID dataset 
 Cells: 375438 
 Genes: 14063

MS dataset 
 Cells: 7844 
 Genes: 3000

MS Test dataset 
 Cells: 13468 
 Genes: 3000



In [10]:
covidData.obs

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,Infection_status,Severity,log10GenesPerUMI,percent.mt,percent.Heme,Study,seurat_clusters,new.annot2
560_cell.588_AGTTGGAGAACG_1,560_cell,285.0,236,Covid,Severe,0.800229,11.561265,0.000000,Wilk,26,4
560_cell.637_TTGACACATACC_1,560_cell,555.0,284,Covid,Severe,0.821141,8.294501,17.334576,Wilk,34,4
560_cell.639_CACCATAAGAAT_1,560_cell,523.0,284,Covid,Severe,0.828072,2.131783,15.116279,Wilk,55,14
560_cell.647_TTGTGAAACAGT_1,560_cell,675.0,295,Covid,Severe,0.813503,7.190083,23.636364,Wilk,10,4
560_cell.652_CTAAGTTGCTTT_1,560_cell,461.0,301,Covid,Severe,0.836481,9.232264,7.580175,Wilk,50,1
...,...,...,...,...,...,...,...,...,...,...,...
TGTACCTTACACTCGAGAAAGGGTCAG_22_5,BN-28,1629.0,659,Healthy,Healthy,0.882532,21.009867,0.000000,SS_C2,15,8
AACTGTATTACTATCCTCCTAGATAGA_22_5,BN-05,709.0,390,Covid,Non-severe,0.907000,20.129870,0.779221,SS_C2,23,4
CACATTGCAGCTAACTCACAGGCATTT_22_5,BN-28,446.0,310,Healthy,Healthy,0.931754,17.892644,0.397614,SS_C2,23,4
TCTCTTCAACAATTGATCAGCCGCAAG_22_5,BN-12,893.0,586,Covid,Severe,0.940330,13.244353,0.410678,SS_C2,31,2


In [11]:
msGenes = msData.var['gene_name'] # 3000 unique genes, same for the others
covidGenes = covidData.var['features'] # 14063 unique genes

commonGenes = msGenes[msGenes.isin(covidGenes)].unique() # check all common genes
uniqueGenes = pd.concat([msGenes, covidGenes]).unique() # check all unique genes - a union

In [39]:
# From Mukund et al. supplementary information 
# URL: https://www.frontiersin.org/journals/immunology/articles/10.3389/fimmu.2021.738073/full#supplementary-material

column_names = [
'cluster',
'number of cells',
'SingleR Annotation',	
'Human_gene_atlas enrichment(of ClusterMarkers)',
'Final Cluster Identities'
]

clusterIdentities = pd.read_csv(covidDataPath / 'clusterIdentities.tsv', sep="\t", header=None, names=column_names)
cluster_identity_map = dict(zip(clusterIdentities["cluster"], clusterIdentities["Final Cluster Identities"]))

covidData.obs['celltype'] = covidData.obs["seurat_clusters"].map(cluster_identity_map).astype("category")
                
# make the celltype category column
celltype_id_labels = covidData.obs["celltype"].astype("category").cat.codes.values
celltypes = covidData.obs["celltype"].unique()
num_types = len(np.unique(celltype_id_labels))
id2type = dict(enumerate(covidData.obs["celltype"].astype("category").cat.categories))
covidData.obs["celltype_id"] = celltype_id_labels

print(f"Number of unique final cluster annotations: {num_types}")

Number of unique final cluster annotations: 20


In [None]:
# Rename features -> gene_name
covidData.var.rename(columns={'features': 'gene_name'}, inplace=True)

# Set gene_name as variable index
covidData.var.set_index(covidData.var["gene_name"], inplace=True)
covidData.var["gene_name"] = covidData.var.index.tolist()


covidData.var

# Make a copy of COVID data, use standard naming conventions
adata = covidData.copy()

In [None]:
if config.load_model is not None:
    model_dir = Path(config.load_model)
    model_config_file = model_dir / "args.json"
    model_file = model_dir / "best_model.pt"
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    shutil.copy(vocab_file, save_dir / "vocab.json")
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    adata = adata[:, adata.var["id_in_vocab"] >= 0]

    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    logger.info(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]