# Inference

In [108]:
import scanpy as sc
import pandas as pd
import numpy as np
import scgpt 
from scgpt.preprocess import Preprocessor
from scgpt.model import TransformerModel
from scgpt.tokenizer.gene_tokenizer import GeneVocab
import torch
import json

from scgpt.tokenizer import tokenize_and_pad_batch
import numpy as np

import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import scvi

## Step 1: Data loading and preprocessing

In [109]:
czi_reference = sc.read_h5ad("../data/datasets/czi_covid_pbmc_5pct.h5ad") # Obtain reference CZI data for classification labels
id2type = dict(enumerate(czi_reference.obs["cell_type"].astype("category").cat.categories)) # create dict to match labels

In [110]:
dengue_datapath = "../data/inference/dengue/natural-2/dneg5/"
dengue_data =  sc.read_10x_mtx(dengue_datapath, prefix="GSM4670219_Natural2_Dneg5_")

In [111]:
dengue_data

AnnData object with n_obs × n_vars = 3651 × 33694
    var: 'gene_ids'

In [112]:
## Remove unused values in columns

obs_cols = list(dengue_data.obs.columns)
print(obs_cols)

for colname in obs_cols:
    if pd.api.types.is_categorical_dtype(dengue_data.obs[colname]): 
        dengue_data.obs[colname] = dengue_data.obs[colname].cat.remove_unused_categories()

[]


In [113]:
# Run scGPT preprocessor
preprocessor = Preprocessor(
    use_key="X",
    normalize_total=1e4, # 1. Normalization protocol - sum to 10000 
    result_normed_key="X_normed",
    log1p=False,  # 2. or True, depending on original training - original training is false
    binning=51,
    subset_hvg=False,
    hvg_flavor="seurat_v3",
    hvg_use_key="X_normed",
    result_binned_key="X_binned" # 3. Layer for scGPT to work with
)
preprocessor(dengue_data)

scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Binning data ...


In [114]:
## Compute PCA -> kNN -> UMAP

# Compute PCA
sc.pp.pca(dengue_data)
# Compute neighbors and UMAP using PCA embedding
sc.pp.neighbors(dengue_data, use_rep="X_pca")
sc.tl.umap(dengue_data)
# Store UMAP coordinates from PCA in a new layer
dengue_data.obsm["X_umap_pca"] = dengue_data.obsm["X_umap"]

In [115]:
# Load the model config
with open("../models/scgpt_human/args.json") as f:
    args = json.load(f)

# Load vocab
vocab = GeneVocab.from_file("save/finetuning_czi_5pct_2pct/vocab.json")
vocab.set_default_index(vocab["<pad>"])

# Number of cell types in NEW dataset (not original training set)
num_cell_types = 97  # update this

# Initialize model
model = TransformerModel(
    ntoken=len(vocab),
    d_model=args["embsize"],
    nhead=args["nheads"],
    d_hid=args["d_hid"],
    nlayers=args["nlayers"],
    nlayers_cls=args.get("n_layers_cls", 3),
    n_cls=num_cell_types,
    vocab=vocab,
    dropout=args.get("dropout", 0.2),
    pad_token="<pad>",
    pad_value=-2,
    input_emb_style="continuous",
    n_input_bins=args.get("n_input_bins", 51),
    cell_emb_style="cls",
    use_fast_transformer=args.get("fast_transformer", False),
    fast_transformer_backend=args.get("fast_transformer_backend", "flash"),
    pre_norm=args.get("pre_norm", False),
)



In [116]:
state_dict = torch.load("save/finetuning_czi_5pct_2pct/best_model.pt", map_location="cpu")

if all(k.startswith("module.") for k in state_dict.keys()):
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

In [117]:
model.load_state_dict(state_dict)
model.eval()

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

In [118]:
gene_ids = np.array([vocab[gene] for gene in dengue_data.var_names], dtype=int)

counts = dengue_data.layers["X_binned"]  # must match input_layer_key from training

tokenized = tokenize_and_pad_batch(
    counts,
    gene_ids,
    max_len=3001,
    vocab=vocab,
    pad_token="<pad>",
    pad_value=-2,
    append_cls=True,
    include_zero_gene=False
)

In [119]:
# Define inference dataset
class InferenceDataset(Dataset):
    def __init__(self, tokenized):
        self.data = tokenized

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}

    def __len__(self):
        return self.data["genes"].shape[0]

# Create dataset and loader
dataset = InferenceDataset(tokenized)
loader = DataLoader(
    dataset, 
    batch_size=12,     # You can experiment with 512, 1024, or 2048 on A30s
    num_workers=8,      # Reduce for GPU-bound tasks
    pin_memory=True,    # Enable for GPU
    shuffle=False
)

# Move model to GPU and set eval mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(60697, 512, padding_idx=60694)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, el

In [120]:
all_preds = []
all_embeddings = []

with torch.no_grad():
    for batch in loader:
        input_ids = batch["genes"].to(device, non_blocking=True) # add to device
        values = batch["values"].to(device, non_blocking=True) # add to device
        mask = input_ids.eq(vocab["<pad>"])

        # ---------- Step 1: Full prediction (from classification head) ----------
        out = model(input_ids, values, src_key_padding_mask=mask, CLS=True)
        logits = out["cls_output"]  # Already passed through cls_decoder
        preds = logits.argmax(1)
        all_preds.extend(preds.cpu().numpy())

        # ---------- Step 2: Encoder-only embedding ----------
        # 2.1 Embed genes
        gene_embed = model.encoder.embedding(input_ids)  # [B, L, D]
        gene_embed = model.encoder.enc_norm(gene_embed)

        # 2.2 Embed values
        value_embed = model.value_encoder(values)  # [B, L, D]

        # 2.3 Add them together
        x = gene_embed + value_embed

        # 2.4 Pass through TransformerEncoder
        x = model.transformer_encoder(x, src_key_padding_mask=mask)  # [B, L, D]

        # 2.5 Get CLS token (assumes CLS is first token)
        cls_embed = x[:, 0, :]  # [B, D]
        all_embeddings.append(cls_embed.cpu())

In [121]:
# Save predictions
dengue_data.obs["predictions"] = pd.Categorical(all_preds)

# Save encoder embeddings
dengue_data.obsm["X_scgpt"] = torch.cat(all_embeddings, dim=0).numpy()

In [122]:
dengue_data

AnnData object with n_obs × n_vars = 3651 × 33694
    obs: 'predictions'
    var: 'gene_ids'
    uns: 'pca', 'neighbors', 'umap'
    obsm: 'bin_edges', 'X_pca', 'X_umap', 'X_umap_pca', 'X_scgpt'
    varm: 'PCs'
    layers: 'X_normed', 'X_binned'
    obsp: 'distances', 'connectivities'

In [123]:
## UMAP from X_scgpt
sc.pp.neighbors(dengue_data, use_rep="X_scgpt")
sc.tl.umap(dengue_data)
dengue_data.obsm["X_umap_scgpt"] = dengue_data.obsm["X_umap"]  # copy before it gets overwritten

In [124]:
# Ensure predictions are integers (in case they're categorical)
pred_ids = dengue_data.obs["predictions"].astype(int)

# Map to cell type names using id2type
dengue_data.obs["predicted_celltype"] = pred_ids.map(id2type)

dengue_data.obs.predicted_celltype.unique()

array(['CD14-positive monocyte', 'mature NK T cell',
       'plasmacytoid dendritic cell', 'classical monocyte',
       'T follicular helper cell',
       'CD16-positive, CD56-dim natural killer cell, human',
       'non-classical monocyte', 'class switched memory B cell',
       'monocyte', 'B cell', 'gamma-delta T cell',
       'effector CD8-positive, alpha-beta T cell',
       'CD16-negative, CD56-bright natural killer cell, human',
       'CD8-positive, alpha-beta T cell',
       'CD4-positive, alpha-beta T cell', 'naive B cell',
       'naive thymus-derived CD4-positive, alpha-beta T cell',
       'naive thymus-derived CD8-positive, alpha-beta T cell', 'platelet',
       'central memory CD4-positive, alpha-beta T cell',
       'regulatory T cell', 'plasmablast', 'T-helper 22 cell',
       'CD14-low, CD16-positive monocyte', 'IgG plasma cell',
       'IgA plasma cell', 'natural killer cell',
       'effector memory CD8-positive, alpha-beta T cell',
       'mucosal invariant T cell'

In [125]:
dengue_data.write_h5ad(dengue_datapath+"n2dneg5_preds.h5ad")