In [1]:
from google.colab import drive
import os

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks/scGPT_finetune')

Mounted at /content/drive


In [None]:
!git reset --hard origin/main

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks/scGPT_finetune')

!pip install -r ./requirements.txt
!pip install scgpt "flash-attn<1.0.5"

Mounted at /content/drive
Collecting torch==2.3.1 (from -r ./requirements.txt (line 6))
  Downloading torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl.metadata (26 kB)
Collecting torchvision==0.18.1 (from -r ./requirements.txt (line 7))
  Downloading torchvision-0.18.1-cp312-cp312-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting torchaudio==2.3.1 (from -r ./requirements.txt (line 8))
  Downloading torchaudio-2.3.1-cp312-cp312-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting torchtext==0.18.0 (from -r ./requirements.txt (line 9))
  Downloading torchtext-0.18.0-cp312-cp312-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting scanpy<2.0,>=1.9.1 (from -r ./requirements.txt (line 12))
  Downloading scanpy-1.11.4-py3-none-any.whl.metadata (9.2 kB)
Collecting scvi-tools<1.0,>=0.16.0 (from -r ./requirements.txt (line 13))
  Downloading scvi_tools-0.20.3-py3-none-any.whl.metadata (9.8 kB)
Collecting torch_geometric (from -r ./requirements.txt (line 16))
  Downloading torch_geometric-2.6.1-py3-no

## Load pretrained/finetuned models

In [None]:
import json
import os
import sys
import time
import copy
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Union, Optional
import warnings

import torch
import numpy as np
import matplotlib
from torch import nn
from torch.nn import functional as F
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from torch_geometric.loader import DataLoader
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction

sys.path.insert(0, "../")

import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.loss import (
    masked_mse_loss,
    criterion_neg_log_bernoulli,
    masked_relative_error,
)
from scgpt.tokenizer import tokenize_batch, pad_batch, tokenize_and_pad_batch
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id, compute_perturbation_metrics

matplotlib.rcParams["savefig.transparent"] = False
warnings.filterwarnings("ignore")

set_seed(42)




In [None]:
# settings for data prcocessing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0  # for padding values
pert_pad_id = 0
include_zero_gene = "all"
max_seq_len = 1536

# settings for training
MLM = True  # whether to use masked language modeling, currently it is always on.
CLS = False  # celltype classification objective
CCE = False  # Contrastive cell embedding objective
MVC = False  # Masked value prediction for cell embedding
ECS = False  # Elastic cell similarity objective
amp = True
load_model = "./save/scGPT_human"
load_param_prefixs = [
    "encoder",
    "value_encoder",
    "transformer_encoder",
]

# settings for optimizer
lr = 1e-4  # or 1e-4
batch_size = 64
eval_batch_size = 64
epochs = 15
schedule_interval = 1
early_stop = 10

# settings for the model
embsize = 512  # embedding dimension
d_hid = 512  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 12  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8  # number of heads in nn.MultiheadAttention
n_layers_cls = 3
dropout = 0  # dropout probability
use_fast_transformer = True  # whether to use fast transformer

# logging
log_interval = 100

# dataset and evaluation choices
data_name = "adamson"
split = "simulation"
if data_name == "norman":
    perts_to_plot = ["SAMD1+ZBTB1"]
elif data_name == "adamson":
    perts_to_plot = ["KCTD16+ctrl"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
save_dir = Path(f"./save/dev_perturb_{data_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)
print(f"saving to {save_dir}")

logger = scg.logger
scg.utils.add_file_handler(logger, save_dir / "run.log")
# log running date and current git commit
logger.info(f"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}")



saving to save/dev_perturb_adamson-Sep22-00-13
scGPT - INFO - Running on 2025-09-22 00:13:34


In [None]:
pert_data = PertData("./data")
pert_data.load(data_name=data_name)
pert_data.prepare_split(split=split, seed=1)
pert_data.get_dataloader(batch_size=batch_size, test_batch_size=eval_batch_size)


Found local copy...
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:22
Done!
Creating dataloaders....
Done!


In [None]:
if load_model is not None:
    model_dir = Path("./save/scGPT_human")
    model_config_file = model_dir / "args.json"
    model_file = model_dir / "best_model.pt"
    vocab_file = model_dir / "vocab.json"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    pert_data.adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    genes = pert_data.adata.var["gene_name"].tolist()

    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    logger.info(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]
else:
    genes = pert_data.adata.var["gene_name"].tolist()
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)



scGPT - INFO - match 4399/5060 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from save/scGPT_human/best_model.pt, the model args will override the config save/scGPT_human/args.json.


In [None]:
ntokens = len(vocab)  # size of vocabulary
model_pretrain = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=use_fast_transformer,
)
if load_param_prefixs is not None and load_model is not None:
    # only load params that start with the prefix
    model_dict = model_pretrain.state_dict()
    pretrained_dict = torch.load(model_file)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if any([k.startswith(prefix) for prefix in load_param_prefixs])
    }
    for k, v in pretrained_dict.items():
        logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model_pretrain.load_state_dict(model_dict)
elif load_model is not None:
    try:
        model_pretrain.load_state_dict(torch.load(model_file))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model_pretrain.state_dict()
        pretrained_dict = torch.load(model_file)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model_pretrain.load_state_dict(model_dict)
model_pretrain.to(device)

scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.bias with shape torch.Size([153

TransformerGenerator(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (pert_encoder): Embedding(3, 512, padding_idx=0)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x FlashTransformerEncoderLayer(
        (self_attn): FlashMHA(
          (Wqkv): Linear(in_features=512, out_features=1536, bias=True)
          (inner_attn): FlashAttention()
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0, i

In [None]:
if load_model is not None:
    model_dir = Path("./save/scGPT_human")
    model_config_file = model_dir / "args.json"
    vocab_file = model_dir / "vocab.json"
    model_dir = Path("./save/scGPT_human_finetuned_adamson")
    model_file = model_dir / "best_model.pt"

    vocab = GeneVocab.from_file(vocab_file)
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)

    pert_data.adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
    ]
    gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
    logger.info(
        f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    genes = pert_data.adata.var["gene_name"].tolist()

    # model
    with open(model_config_file, "r") as f:
        model_configs = json.load(f)
    logger.info(
        f"Resume model from {model_file}, the model args will override the "
        f"config {model_config_file}."
    )
    embsize = model_configs["embsize"]
    nhead = model_configs["nheads"]
    d_hid = model_configs["d_hid"]
    nlayers = model_configs["nlayers"]
    n_layers_cls = model_configs["n_layers_cls"]
else:
    genes = pert_data.adata.var["gene_name"].tolist()
    vocab = Vocab(
        VocabPybind(genes + special_tokens, None)
    )  # bidirectional lookup [gene <-> int]
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)



scGPT - INFO - match 4399/5060 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from save/scGPT_human_finetuned_adamson/best_model.pt, the model args will override the config save/scGPT_human/args.json.


In [None]:
ntokens = len(vocab)  # size of vocabulary
model_finetune = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=use_fast_transformer,
)
if load_param_prefixs is not None and load_model is not None:
    # only load params that start with the prefix
    model_dict = model_finetune.state_dict()
    pretrained_dict = torch.load(model_file)
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if any([k.startswith(prefix) for prefix in load_param_prefixs])
    }
    for k, v in pretrained_dict.items():
        logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model_finetune.load_state_dict(model_dict)
elif load_model is not None:
    try:
        model_finetune.load_state_dict(torch.load(model_file))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model_finetune.state_dict()
        pretrained_dict = torch.load(model_file)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model_finetune.load_state_dict(model_dict)
model_finetune.to(device)

scGPT - INFO - Loading params encoder.embedding.weight with shape torch.Size([60697, 512])
scGPT - INFO - Loading params encoder.enc_norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params encoder.enc_norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear1.weight with shape torch.Size([512, 1])
scGPT - INFO - Loading params value_encoder.linear1.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.linear2.weight with shape torch.Size([512, 512])
scGPT - INFO - Loading params value_encoder.linear2.bias with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.weight with shape torch.Size([512])
scGPT - INFO - Loading params value_encoder.norm.bias with shape torch.Size([512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
scGPT - INFO - Loading params transformer_encoder.layers.0.self_attn.Wqkv.bias with shape torch.Size([153

TransformerGenerator(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (pert_encoder): Embedding(3, 512, padding_idx=0)
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x FlashTransformerEncoderLayer(
        (self_attn): FlashMHA(
          (Wqkv): Linear(in_features=512, out_features=1536, bias=True)
          (inner_attn): FlashAttention()
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0, i

## Extract perturbation embedding

In [None]:
import numpy as np
import torch
from typing import List, Dict, Optional

def generate_perturbation_embeddings(
    model: "TransformerGenerator",
    pert_genes: List[str],
    ctrl_adata,
    gene_list: List[str],
    pool_size: Optional[int] = None,
    gene_ids: Dict[str, int] = None,  # mapping used by map_raw_id_to_vocab_id
    amp: bool = True,
) -> Dict[str, np.ndarray]:
    """
    Generate embeddings that capture the "essence" of each perturbation type.
    """
    model.eval()
    device = next(model.parameters()).device
    use_amp = bool(amp and device.type == "cuda")

    if pool_size is None:
        pool_size = len(ctrl_adata.obs)

    # gene name -> index in gene_list
    gene_name_to_idx = {g: i for i, g in enumerate(gene_list)}

    perturbation_embeddings: Dict[str, np.ndarray] = {}

    # helper: get a 1D float32 numpy array for a control cell row
    def _row_to_np(idx: int) -> np.ndarray:
        row = ctrl_adata.X[idx]
        if hasattr(row, "toarray"):  # sparse
            return row.toarray().ravel().astype(np.float32)
        return np.asarray(row, dtype=np.float32).ravel()

    # helper: ensure mapped ids are a LongTensor on device
    def _to_long_dev(x):
        if isinstance(x, torch.Tensor):
            return x.to(device=device, dtype=torch.long)
        return torch.as_tensor(x, dtype=torch.long, device=device)

    with torch.no_grad():
        n_genes = len(gene_list)
        # raw gene indices [0..n_genes-1] on device
        input_gene_ids_raw = torch.arange(n_genes, device=device, dtype=torch.long)

        # map to vocab ids (whatever your function returns)
        mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids_raw, gene_ids)
        mapped_input_gene_ids = _to_long_dev(mapped_input_gene_ids).unsqueeze(0)  # [1, n_genes]

        for pert_gene in pert_genes:
            print(f"Generating embedding for perturbation: {pert_gene}")

            if pert_gene not in gene_name_to_idx:
                print(f"Warning: {pert_gene} not in gene list, skipping…")
                continue

            pert_gene_idx = gene_name_to_idx[pert_gene]
            all_embeddings = []

            # sample control indices
            num_ctrl = len(ctrl_adata.obs)
            sampled = np.random.choice(num_ctrl, size=min(pool_size, num_ctrl), replace=False)

            for idx in sampled:
                # values: [n_genes] -> [1, n_genes] float32 on device
                vals_np = _row_to_np(idx)
                input_values = torch.from_numpy(vals_np).to(device=device, dtype=torch.float32).unsqueeze(0)

                # pert flags: [1, n_genes] long on device (0 everywhere, 1 at target gene)
                input_pert_flags = torch.zeros(n_genes, dtype=torch.long, device=device).unsqueeze(0)
                input_pert_flags[0, pert_gene_idx] = 1

                # padding mask: no padding -> all False, shape [1, n_genes], bool on device
                src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device)

                # forward
                if use_amp:
                    with torch.cuda.amp.autocast():
                        tr_out = model._encode(
                            mapped_input_gene_ids,    # [1, n_genes] long
                            input_values,             # [1, n_genes] float
                            input_pert_flags,         # [1, n_genes] long
                            src_key_padding_mask,     # [1, n_genes] bool
                        )
                        cell_emb = model._get_cell_emb_from_layer(tr_out, input_values)
                else:
                    tr_out = model._encode(
                        mapped_input_gene_ids,
                        input_values,
                        input_pert_flags,
                        src_key_padding_mask,
                    )
                    cell_emb = model._get_cell_emb_from_layer(tr_out, input_values)

                all_embeddings.append(cell_emb.detach().to(torch.float32).cpu().numpy())

            if all_embeddings:
                avg = np.mean(all_embeddings, axis=0)
                perturbation_embeddings[pert_gene] = avg
                print(f"Generated embedding for {pert_gene}: shape {avg.shape}")
            else:
                print(f"No embeddings generated for {pert_gene}")
    return perturbation_embeddings


In [None]:
with open("pert_gene_list.json") as f:
    pert_gene_list = json.load(f)

In [None]:
adata = pert_data.adata
ctrl_adata = adata[adata.obs["condition"] == "ctrl"]
gene_list = pert_data.gene_names.values.tolist()

# Define perturbation genes
perturbation_genes = pert_gene_list  # Replace with actual gene names

# Generate perturbation embeddings
pert_embeddings_pretrain = generate_perturbation_embeddings(
    model=model_pretrain,
    pert_genes=perturbation_genes,
    ctrl_adata=ctrl_adata,
    gene_list=gene_list,
    pool_size=100,  # Use 100 control cells for averaging
    gene_ids=gene_ids,
    amp=amp
)

# Print results
print("\nPerturbation Embeddings Generated:")
for gene, embedding in pert_embeddings_pretrain.items():
    print(f"{gene}: {embedding.shape}")

In [None]:
pert_embeddings_finetune = generate_perturbation_embeddings(
    model=model_finetune,
    pert_genes=perturbation_genes,
    ctrl_adata=ctrl_adata,
    gene_list=gene_list,
    pool_size=100,  # Use 100 control cells for averaging
    gene_ids=gene_ids,
    amp=amp
)

# Print results
print("\nPerturbation Embeddings Generated:")
for gene, embedding in pert_embeddings_finetune.items():
    print(f"{gene}: {embedding.shape}")

Generating embedding for perturbation: UBL5
Generating embedding for perturbation: TIMM9
Generating embedding for perturbation: SMG5
Generating embedding for perturbation: MED9
Generating embedding for perturbation: MED1
Generating embedding for perturbation: HNRNPH1
Generated embedding for HNRNPH1: shape (1, 512)
Generating embedding for perturbation: RPL10A
Generating embedding for perturbation: LIN54
Generating embedding for perturbation: EIF1AX
Generating embedding for perturbation: NUP98
Generating embedding for perturbation: N6AMT1
Generating embedding for perturbation: RPS20
Generating embedding for perturbation: RIOK2
Generated embedding for RIOK2: shape (1, 512)
Generating embedding for perturbation: POP7
Generating embedding for perturbation: EIF3E
Generated embedding for EIF3E: shape (1, 512)
Generating embedding for perturbation: MRPS14
Generating embedding for perturbation: ctrl
Generating embedding for perturbation: MRPL2
Generating embedding for perturbation: MED10
Gener

In [None]:
import pandas as pd
pert_embeddings_pretrain_1d = {
    k: np.ravel(v) for k, v in pert_embeddings_pretrain.items()
}
pert_embeddings_pretrain_1d = pd.DataFrame(pert_embeddings_pretrain_1d)
pert_embeddings_pretrain_1d.to_csv("./pert_embeddings_pretrain.csv")

pert_embeddings_finetune_1d = {
    k: np.ravel(v) for k, v in pert_embeddings_finetune.items()
}
pert_embeddings_finetune_1d = pd.DataFrame(pert_embeddings_finetune_1d)
pert_embeddings_finetune_1d.to_csv("./pert_embeddings_finetune.csv")