## HF scGPT - inference

In [None]:
from tdc import tdc_hf_interface
from tdc.model_server.tokenizers.scgpt import scGPTTokenizer
import torch
import numpy as np

scgpt = tdc_hf_interface("scGPT")
model = scgpt.load().eval().cuda()  # or .cpu() if needed
tokenizer = scGPTTokenizer()

In [None]:
import scanpy as sc

adata = sc.read_h5ad("../data/sample_ms/c_data.h5ad")

# Gene names to match tokenizer vocab
gene_names = adata.var["gene_name"].to_numpy()  # or var_names depending on the file
expr = adata.X.toarray()  # ensure it's dense
tokenized_data = tokenizer.tokenize_cell_vectors(expr, gene_names)

In [None]:
embeddings = []

with torch.no_grad():
    for token_ids, counts in tokenized_data:
        input_ids = torch.tensor(token_ids).unsqueeze(0).to(model.device)
        values = torch.tensor(counts).unsqueeze(0).to(model.device)
        attention_mask = (values != 0).bool()

        out = model(input_ids=input_ids, values=values, attention_mask=attention_mask)
        cell_emb = out["cell_emb"].squeeze().cpu().numpy()
        embeddings.append(cell_emb)

adata.obsm["X_scgpt"] = np.stack(embeddings)

In [None]:
sc.pp.neighbors(adata, use_rep="X_scgpt")
sc.tl.umap(adata)
sc.pl.umap(adata, color="celltype", title="UMAP from scGPT Embeddings")

In [None]:
import scanpy as sc
import matplotlib.pyplot as plt


sc.pp.pca(adata)
sc.pp.neighbors(adata)

sc.tl.umap(adata)

In [None]:
sc.pl.umap(adata, color="celltype")

## Native scGPT - inference

In [None]:
from scgpt.model import TransformerModel
from scgpt.tokenizer.gene_tokenizer import GeneVocab
# import torch

# Load vocab
vocab = GeneVocab.from_file("../models/hf-scgpt/scgpt_vocab.json")

# Args for model
ntokens = len(vocab)
embsize = 128 # layer_size
nhead = 4 # four attention heads
d_hid = 128 # layer_size
nlayers = 4 # number of layers
nlayers_cls=3
n_cls=num_types
vocab=vocab
dropout=0.2, # dropout
pad_token="<pad>"
pad_value=-2
do_mvc=False
do_dab=False
use_batch_labels=False
num_batch_labels=num_batch_types
domain_spec_batchnorm=False
input_emb_style = "continuous"  # "category" or "continuous" or "scaling"
n_input_bins=51
cell_emb_style="cls"
mvc_decoder_style="inner product"
ecs_threshold=0.0
# explicit_zero_prob=explicit_zero_prob
use_fast_transformer=True
fast_transformer_backend="flash"
pre_norm=False

x

model.eval().cuda()  # or .cpu()

In [None]:
import scgpt as scg
import scanpy as sc
import numpy as np 
import pandas as pd
import sklearn

model_dir = "../models/scgpt_human"
adata = sc.read_h5ad("../data/sample_ms/c_data.h5ad")

cell_type_key = "celltype"
gene_col = "index"

In [None]:
sc.pp.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color=cell_type_key, frameon=False)

In [None]:
adata = scg.tasks.embed_data(
    adata,
    model_dir,
    gene_col=gene_col,
    obs_to_save=cell_type_key,  # optional arg, only for saving metainfo
    batch_size=64,
    return_new_adata=False,
)

In [None]:
adata

In [None]:
sc.pp.neighbors(adata, use_rep="X_scGPT")
sc.tl.umap(adata)
sc.pl.umap(adata, color=cell_type_key, frameon=False)

In [None]:
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=3,
    n_cls=num_types if CLS else 1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=MVC,
    do_dab=DAB,
    use_batch_labels=INPUT_BATCH_LABELS,
    num_batch_labels=num_batch_types,
    domain_spec_batchnorm=config.DSBN,
    input_emb_style=input_emb_style,
    n_input_bins=n_input_bins,
    cell_emb_style=cell_emb_style,
    mvc_decoder_style=mvc_decoder_style,
    ecs_threshold=ecs_threshold,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=use_fast_transformer,
    fast_transformer_backend=fast_transformer_backend,
    pre_norm=config.pre_norm,
)

In [None]:
pip show torch torchtext

# Inference example

In [None]:
from tdc import tdc_hf_interface
from tdc.model_server.tokenizers.scgpt import scGPTTokenizer
import torch
import scanpy as sc
import numpy as np
import umap
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
scgpt = tdc_hf_interface("scGPT")
model = scgpt.load()
# model = scgpt.load().eval().cuda()

In [None]:
from transformers import PreTrainedModel
import inspect
print(inspect.signature(model.forward))


In [None]:
tokenizer = scGPTTokenizer()

In [None]:
data_dir = "../data/sample_ms/"
adata = sc.read_h5ad(data_dir+"c_data.h5ad")

In [None]:
expr = adata.X.toarray()
gene_names = adata.var["gene_name"].to_numpy()
tokenized_data = tokenizer.tokenize_cell_vectors(expr, gene_names)

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

class scGPTDataset(Dataset):
    def __init__(self, tokenized_data):
        self.data = tokenized_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, values = self.data[idx]
        return {
            "input_ids": torch.tensor(tokens, dtype=torch.long),
            "values": torch.tensor(values, dtype=torch.float)
        }

def collate_fn(batch):
    input_ids = [b["input_ids"] for b in batch]
    values = [b["values"] for b in batch]

    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=60694)
    values_padded = pad_sequence(values, batch_first=True, padding_value=0.0)
    attention_mask = (values_padded != 0).bool()

    return {
        "input_ids": input_ids_padded,
        "values": values_padded,
        "attention_mask": attention_mask
    }

def collate_fn(batch):
    input_ids = [b["input_ids"] for b in batch]
    values = [b["values"] for b in batch]

    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=60694)
    values_padded = pad_sequence(values, batch_first=True, padding_value=0.0)
    attention_mask = (values_padded != 0).bool()

    return {
        "input_ids": input_ids_padded,
        "values": values_padded,
        "attention_mask": attention_mask
    }


In [None]:
dataset = scGPTDataset(tokenized_data)
loader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)

embeddings = []

model.eval()
with torch.no_grad():
    for batch in loader:
        input_ids = batch["input_ids"].to(model.device)
        values = batch["values"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)

        out = model(input_ids=input_ids, values=values, attention_mask=attention_mask)
        cell_emb = out["cell_emb"].cpu().numpy()
        embeddings.append(cell_emb)

adata.obsm["X_scgpt"] = np.concatenate(embeddings)

In [None]:
# ✅ Step 6: Run UMAP on scGPT Embeddings
sc.pp.pca(adata.obsm["X_scgpt"])

In [None]:
adata

In [None]:
from transformers import PretrainedConfig

class scGPTConfig(PretrainedConfig):
    model_type = "scgpt"

    def __init__(
        self,
        vocab_size=60697,
        embsize=512,
        nhead=8,
        nlayers=12,
        d_hid=512,
        dropout=0.0,
        pad_token_id=0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = embsize
        self.nhead = nhead
        self.nlayers = nlayers
        self.d_hid = d_hid
        self.dropout = dropout
        self.pad_token_id = pad_token_id


In [None]:
from transformers import PreTrainedModel
from scgpt.model import TransformerModel  # from the native repo
import torch.nn as nn
import torch.nn.functional as F
import torch

class scGPTWrapped(PreTrainedModel):
    config_class = scGPTConfig

    def __init__(self, config: scGPTConfig):
        super().__init__(config)
        self.model = TransformerModel(
            ntoken=config.vocab_size,
            d_model=config.hidden_size,
            nhead=config.nhead,
            d_hid=config.d_hid,
            nlayers=config.nlayers,
            dropout=config.dropout,
            pad_token=config.pad_token_id
        )
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, input_ids, values, attention_mask=None, labels=None):
        out = self.model(input_ids=input_ids, values=values, attention_mask=attention_mask, output_cell_emb=True)
        emb = out["cell_emb"]
        logits = self.classifier(emb)
        loss = F.cross_entropy(logits, labels) if labels is not None else None
        return {"loss": loss, "logits": logits}


config = scGPTConfig(
    vocab_size=60697,      # or load from vocab
    embsize=512,
    nhead=8,
    nlayers=12,
    d_hid=512,
    dropout=0.0,
    pad_token_id=0
)

In [None]:
state_dict = torch.load("../models/scgpt_human/best_model.pt", map_location="cpu")
model = scGPTWrapped(config)
model.model.load_state_dict(state_dict)