# Pretraining on the merged dataset

In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import gc
import copy
import json
import os
from pathlib import Path
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings

import torch
import scanpy as sc
import scvi
import numpy as np
import wandb
import matplotlib.pyplot as plt
from anndata import AnnData
from scipy.sparse import issparse
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchtext.vocab import Vocab
from torchtext._torchtext import Vocab as VocabPybind

  warn(f"Failed to load image Python extension: {e}")
Global seed set to 0


In [3]:
sys.path.insert(0, "../")
from scgpt.utils import set_seed
from scgpt.utils import category_str2int, eval_scib_metrics
from scgpt import SubsetsBatchSampler
from scgpt.preprocess import Preprocessor
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.model import TransformerModel, AdversarialDiscriminator
import scgpt as scg
from scgpt.tokenizer.gene_tokenizer import GeneVocab
sc.set_figure_params(figsize=(4, 4))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings("ignore")

  IPython.display.set_matplotlib_formats(*ipython_format)


In [4]:
sys.path.append(".")
sys.path.append("../code")

from config import hyperparameter_defaults
from pretrain_all_in_one_0801 import load_tokenized, prepare_data

In [8]:
%%time
# dataset_name="cre329_tokenized_merged_numds10"
# file_path="/home/qiliu02/GHDDI/DS-group/ghddixcre_singlecell_gpt/dev_pretrain_on_cre_all_in_one/script/save/cre329_tokenized_merged/20230801_175734/cre329_tokenized_merged_numds10.pt"

dataset_name="cre329_tokenized_merged_numds2"
file_path="/home/qiliu02/GHDDI/DS-group/ghddixcre_singlecell_gpt/dev_pretrain_on_cre_all_in_one/script/save/cre329_tokenized_merged/20230803_104338/cre329_tokenized_merged_numds2.pt"

tokenized_data = torch.load(file_path)

CPU times: user 0 ns, sys: 790 ms, total: 790 ms
Wall time: 3.84 s


In [9]:
tokenized_data

{'genes': tensor([[60695,  3213,  3832,  ..., 12229, 19237, 35831],
         [60695,  2520, 35271,  ..., 19445, 19851, 33698],
         [60695,  1955,  3291,  ...,  3434, 18244, 20874],
         ...,
         [60695, 17577, 36641,  ..., 60694, 60694, 60694],
         [60695, 32624,  3330,  ..., 16400, 20098,  7593],
         [60695, 11107,  1495,  ..., 33851, 20059,  4986]]),
 'values': tensor([[ 0, 15,  1,  ..., 28, 50, 49],
         [ 0, 10, 11,  ..., 15, 11, 45],
         [ 0,  3,  8,  ..., 46, 44,  9],
         ...,
         [ 0, 37, 45,  ..., -2, -2, -2],
         [ 0, 18,  9,  ...,  7, 17, 33],
         [ 0,  3, 13,  ..., 49, 41, 17]])}

In [10]:
tokenized_data['genes']

tensor([[60695,  3213,  3832,  ..., 12229, 19237, 35831],
        [60695,  2520, 35271,  ..., 19445, 19851, 33698],
        [60695,  1955,  3291,  ...,  3434, 18244, 20874],
        ...,
        [60695, 17577, 36641,  ..., 60694, 60694, 60694],
        [60695, 32624,  3330,  ..., 16400, 20098,  7593],
        [60695, 11107,  1495,  ..., 33851, 20059,  4986]])

In [27]:
tokenized_data['genes'].shape

torch.Size([20358, 1201])

In [11]:
tokenized_data['values']

tensor([[ 0, 15,  1,  ..., 28, 50, 49],
        [ 0, 10, 11,  ..., 15, 11, 45],
        [ 0,  3,  8,  ..., 46, 44,  9],
        ...,
        [ 0, 37, 45,  ..., -2, -2, -2],
        [ 0, 18,  9,  ...,  7, 17, 33],
        [ 0,  3, 13,  ..., 49, 41, 17]])

In [12]:
from sklearn.model_selection import train_test_split

genes_train, genes_valid, values_train, values_valid = train_test_split(
    tokenized_data['genes'], 
    tokenized_data['values'],
    test_size=0.8,
    shuffle=False
)
print(f"genes_train: {len(genes_train)}")
print(f"genes_valid: {len(genes_valid)}")
print(f"values_train: {len(values_train)}")
print(f"values_valid: {len(values_valid)}")


tokenized_train = {
    'genes': genes_train,
    'values': values_train
}

tokenized_valid = {
    'genes': genes_valid,
    'values': values_valid
}

genes_train: 19747
genes_valid: 611
values_train: 19747
values_valid: 611


In [13]:
genes_train

tensor([[60695,  3213,  3832,  ..., 12229, 19237, 35831],
        [60695,  2520, 35271,  ..., 19445, 19851, 33698],
        [60695,  1955,  3291,  ...,  3434, 18244, 20874],
        ...,
        [60695, 12118, 17632,  ..., 19488,  8714, 18912],
        [60695,  9242,  8714,  ..., 19090, 16134, 16244],
        [60695,  4533, 33092,  ...,  8591,  2106, 12087]])

In [14]:
genes_valid

tensor([[60695,  1418, 19185,  ..., 21054, 20678, 10964],
        [60695, 17714,  5291,  ..., 33470, 35192, 32859],
        [60695, 32162, 17324,  ..., 36550, 19433,  8251],
        ...,
        [60695, 17577, 36641,  ..., 60694, 60694, 60694],
        [60695, 32624,  3330,  ..., 16400, 20098,  7593],
        [60695, 11107,  1495,  ..., 33851, 20059,  4986]])

In [15]:
# tokenized_train, tokenized_valid = load_tokenized(
#     dataset_name="cre329_tokenized_merged_numds10",
#     file_path="/home/qiliu02/GHDDI/DS-group/ghddixcre_singlecell_gpt/dev_pretrain_on_cre_all_in_one/script/save/cre329_tokenized_merged/20230801_175734/cre329_tokenized_merged_numds10.pt",
#     valid_size=0.03
# )

## Configurations

In [16]:
do_train = True
if do_train:
    config = {
        'seed': 42,
        'dataset_name': "cre329_tokenized_merged_numds10",
        'dataset_filepath': "../data/cre329_tokenized_merged/cre329_tokenized_merged_numds10.pt",
        
        'do_train': True,
        'load_model': None,

        'layer_size': 128,
        'nhead': 4,
        'nlayers': 4,

        'n_hvg': None,
        'max_seq_len': 1200,

        'fast_transformer': False,
        'pre_norm': False,
        'ecs_thres': 0.8,

        'GEPC': False,
        'DSBN': False,

        'batch_size': 32,
        'epochs': 6,
        'lr': 1e-4,
        "dropout": 0.2,

        'n_bins': 51,
        'mask_ratio': 0.4,

        'amp': False,
        'schedule_ratio': 0.9,
        'save_eval_interval': 5,
        'log_interval': 100,
    }
    embsize = config['layer_size']
    d_hid = config['layer_size']
    nhead = config['nhead']
    nlayers = config['nlayers']

    batch_size = config['batch_size']
    dropout = config['dropout']
    pre_norm = config['pre_norm']
    ecs_thres = config['ecs_thres']

    n_hvg = None
    max_seq_len = 1201

    DSBN = False
    GEPC = False
    do_dab = False
    use_batch_labels = False
    num_batch_types = None

    per_seq_batch_sample = False
    explicit_zero_prob = False  # whether explicit bernoulli for zeros
else:
    config = {
        'seed': 42,
        'dataset_name': "PBMC_10K",
        'do_train': True,
        'load_model': "./save/scGPT_human",
        'GEPC': True,
        'ecs_thres': 0.8,
        'dab_weight': 1.0,
        'mask_ratio': 0.4,
        'epochs': 15,
        'n_bins': 51,
        'lr': 1e-4,
        'batch_size': 16,

        'layer_size': 128,
        'nlayers': 4,
        'nhead': 4,

        'dropout': 0.2,
        'schedule_ratio': 0.9,
        'save_eval_interval': 5,
        'log_interval': 100,

        'fast_transformer': False,
        'pre_norm': False,
        'amp': True,
    }
    n_hvg = 1200  # number of highly variable genes
    max_seq_len = n_hvg + 1        

In [17]:
# logging
dataset_name = "cre329_tokenized_merged_numds10"  # config['dataset_name']
save_dir = Path(f"./save/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")

run = wandb.init(
    config=config,
    project="Pretraining scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)
config = wandb.config
set_seed(config.seed)
print(config)
print("="*100)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mqiliu-ghddi[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'seed': 42, 'dataset_name': 'cre329_tokenized_merged_numds10', 'dataset_filepath': '../data/cre329_tokenized_merged/cre329_tokenized_merged_numds10.pt', 'do_train': True, 'load_model': None, 'layer_size': 128, 'nhead': 4, 'nlayers': 4, 'n_hvg': None, 'max_seq_len': 1200, 'fast_transformer': False, 'pre_norm': False, 'ecs_thres': 0.8, 'GEPC': False, 'DSBN': False, 'batch_size': 32, 'epochs': 6, 'lr': 0.0001, 'dropout': 0.2, 'n_bins': 51, 'mask_ratio': 0.4, 'amp': False, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100}


In [18]:
# preprocessing
mask_ratio = config['mask_ratio']
mask_value = -1
pad_value = -2
n_input_bins = config['n_bins']
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
vocab_file = "default_census_vocab.json"
vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s) 
ntokens = len(vocab)  # size of vocabulary, 60694

logger.info(f"ntokens {ntokens}")
print(f"ntokens {ntokens}")

scGPT - INFO - ntokens 60697
ntokens 60697


In [19]:
# training    
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=False,
    do_dab=False,
    use_batch_labels=False,
    num_batch_labels=None,
    domain_spec_batchnorm=False,
    n_input_bins=n_input_bins,
    ecs_threshold=ecs_thres,
    explicit_zero_prob=True,
    use_fast_transformer=False,
    pre_norm=pre_norm,
)

Using simple batchnorm instead of domain specific batchnorm


In [20]:
model.to(device)
wandb.watch(model)

[]

In [21]:
from pretrain_all_in_one_0801 import define_wandb_metrcis

lr = config['lr']
eps = 1e-4 if config['amp'] else 1e-8
criterion = masked_mse_loss
# criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=lr,
    eps=eps
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, 
    1, 
    gamma=config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)


## Step 4: Finetune scGPT with task-specific objectives
best_val_loss = float("inf")
# best_avg_bio = 0.0

best_model = None
define_wandb_metrcis()

per_seq_batch_sample = False
sort_seq_batch = False

In [22]:
print(config)

{'seed': 42, 'dataset_name': 'cre329_tokenized_merged_numds10', 'dataset_filepath': '../data/cre329_tokenized_merged/cre329_tokenized_merged_numds10.pt', 'do_train': True, 'load_model': None, 'layer_size': 128, 'nhead': 4, 'nlayers': 4, 'n_hvg': None, 'max_seq_len': 1200, 'fast_transformer': False, 'pre_norm': False, 'ecs_thres': 0.8, 'GEPC': False, 'DSBN': False, 'batch_size': 32, 'epochs': 6, 'lr': 0.0001, 'dropout': 0.2, 'n_bins': 51, 'mask_ratio': 0.4, 'amp': False, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100}


In [25]:
from pretrain_all_in_one_0801 import train, evaluate, prepare_dataloader

In [26]:
    
for epoch in range(1, config.epochs + 1):
    epoch_start_time = time.time()

    train_data_pt, valid_data_pt = prepare_data(
        tokenized_train,
        tokenized_valid,
        mask_ratio,
        mask_value,
        pad_value
    )

    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=batch_size,
        shuffle=False,
        intra_domain_shuffle=True,
        drop_last=False,
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=batch_size,
        shuffle=False,
        intra_domain_shuffle=False,
        drop_last=False,
    )

    logger.info("===")
    if do_train:
        logger.info("config.do_train")
        # training one epoch
        train(
            config,
            epoch,
            model,
            device,
            loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            scaler=scaler,
            scheduler=scheduler,
            logger=logger
        )

    # evaluation of one epoch
    val_loss, val_mre = evaluate(
        config,
        epoch,
        model,
        device,
        loader=valid_loader,
        criterion=criterion
    )
    elapsed = time.time() - epoch_start_time
    logger.info("-" * 89)
    logger.info(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
    )
    logger.info("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        logger.info(f"Best model with score {best_val_loss:5.4f}")

    if epoch % config.save_eval_interval == 0 or epoch == config.epochs:
        logger.info(f"Saving model to {save_dir}")
        torch.save(best_model.state_dict(), save_dir / f"model_e{best_model_epoch}.pt")

#         # eval on testdata
#         results = eval_testdata(
#             best_model,
#             adata_t=adata_sorted if per_seq_batch_sample else adata,
#             include_types=["cls"],
#         )
#         results["batch_umap"].savefig(
#             save_dir / f"embeddings_batch_umap[cls]_e{best_model_epoch}.png", dpi=300
#         )

#         results["celltype_umap"].savefig(
#             save_dir / f"embeddings_celltype_umap[cls]_e{best_model_epoch}.png", dpi=300
#         )
#         metrics_to_log = {"test/" + k: v for k, v in results.items()}
#         metrics_to_log["test/batch_umap"] = wandb.Image(
#             str(save_dir / f"embeddings_batch_umap[cls]_e{best_model_epoch}.png"),
#             caption=f"celltype avg_bio epoch {best_model_epoch}",
#         )

#         metrics_to_log["test/celltype_umap"] = wandb.Image(
#             str(save_dir / f"embeddings_celltype_umap[cls]_e{best_model_epoch}.png"),
#             caption=f"celltype avg_bio epoch {best_model_epoch}",
#         )
#         metrics_to_log["test/best_model_epoch"] = best_model_epoch
#         wandb.log(metrics_to_log)
#         wandb.log({"avg_bio": results.get("avg_bio", 0.0)})

    scheduler.step()

# save the best model
torch.save(best_model.state_dict(), save_dir / "best_model.pt")

artifact = wandb.Artifact(f"best_model", type="model")

glob_str = os.path.join(save_dir, "best_model.pt")
artifact.add_file(glob_str)
run.log_artifact(artifact)


random masking at epoch , ratio of masked values in train:  0.3997
scGPT - INFO - ===
scGPT - INFO - config.do_train
scGPT - INFO - | epoch   1 | 100/618 batches | lr 0.0001 | ms/batch 350.97 | loss 733.94 | mse 733.94 | mre 2151.12 |
scGPT - INFO - | epoch   1 | 200/618 batches | lr 0.0001 | ms/batch 252.12 | loss 363.45 | mse 363.45 | mre 11535.80 |
scGPT - INFO - | epoch   1 | 300/618 batches | lr 0.0001 | ms/batch 255.45 | loss 203.13 | mse 203.13 | mre 18604.52 |
scGPT - INFO - | epoch   1 | 400/618 batches | lr 0.0001 | ms/batch 264.96 | loss 188.15 | mse 188.15 | mre 13388.81 |
scGPT - INFO - | epoch   1 | 500/618 batches | lr 0.0001 | ms/batch 254.44 | loss 182.55 | mse 182.55 | mre 6651.91 |
scGPT - INFO - | epoch   1 | 600/618 batches | lr 0.0001 | ms/batch 253.04 | loss 188.15 | mse 188.15 | mre 4009.97 |
scGPT - INFO - -----------------------------------------------------------------------------------------
scGPT - INFO - | end of epoch   1 | time: 174.60s | valid loss/mse 

<wandb.sdk.wandb_artifacts.Artifact at 0x2b41f32de400>