## Minimal scGPT Pipeline

In [1]:
#############################################################
import sys
import copy
import os
import time
import logging
from pathlib import Path
import pickle
import shutil

#############################################################
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchtext.vocab import Vocab
from torchtext._torchtext import Vocab as VocabPybind

#############################################################
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

#############################################################
import numpy as np
import scanpy as sc
from scipy.sparse import issparse

#############################################################
import scgpt as scg
from scgpt.model import TransformerModel
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.preprocess import Preprocessor

sys.path.insert(0, "../")

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)

logger = logging.getLogger(__name__)

# Set device and seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

  backends.update(_get_backends("networkx.backends"))


In [3]:
# Hyperparameters for cell type classification
seed = 0
dataset_name = "ms"
mask_ratio = 0.0
epochs = 10
n_bins = 51
lr = 1e-4
batch_size = 32
eval_batch_size = 32
layer_size = 128
nlayers = 4
nhead = 4
dropout = 0.2
schedule_ratio = 0.9
fast_transformer = True
pre_norm = False
amp = True  # Automatic Mixed Precision
include_zero_gene = False
freeze = False  # Whether to freeze pre-trained layers

# Settings for input and preprocessing
pad_token = ""
special_tokens = [pad_token, "", ""]
max_seq_len = 3001
input_style = "binned"
input_emb_style = "continuous"
cell_emb_style = "cls"

# Set values based on input style
if input_emb_style == "category":
    mask_value = n_bins + 1
    pad_value = n_bins
    n_input_bins = n_bins + 2
else:
    mask_value = -1
    pad_value = -2
    n_input_bins = n_bins

# Create save directory
save_dir = Path(f"./save/dev_{dataset_name}-{time.strftime('%b%d-%H-%M')}/")
save_dir.mkdir(parents=True, exist_ok=True)

In [None]:
def load_data():
    """Load training and test data using scanpy"""
    data_dir = Path("../data/sample_ms")
    adata = sc.read(data_dir / "c_data.h5ad")
    adata_test = sc.read(data_dir / "filtered_ms_adata.h5ad")
    
    # Setup celltype information
    adata.obs["celltype"] = adata.obs["Factor Value[inferred cell type - authors labels]"].astype("category")
    adata_test.obs["celltype"] = adata_test.obs["Factor Value[inferred cell type - authors labels]"].astype("category")
    
    # Setup batch information
    adata.obs["batch_id"] = adata.obs["str_batch"] = "0"
    adata_test.obs["batch_id"] = adata_test.obs["str_batch"] = "1"
    
    # Setup gene names
    adata.var.set_index(adata.var["gene_name"], inplace=True)
    adata_test.var.set_index(adata.var["gene_name"], inplace=True)
    
    # Concatenate data
    adata_test_raw = adata_test.copy()
    adata = adata.concatenate(adata_test, batch_key="str_batch")
    
    # Setup celltype and batch categories
    batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
    adata.obs["batch_id"] = batch_id_labels
    
    celltype_id_labels = adata.obs["celltype"].astype("category").cat.codes.values
    adata.obs["celltype_id"] = celltype_id_labels
    
    celltypes = adata.obs["celltype"].unique()
    num_types = len(np.unique(celltype_id_labels))
    id2type = dict(enumerate(adata.obs["celltype"].astype("category").cat.categories))
    
    adata.var["gene_name"] = adata.var.index.tolist()
    
    return adata, adata_test_raw, num_types, id2type

In [5]:
def preprocess_data(adata, data_is_raw=False, filter_gene_by_counts=False):
    """Preprocess the data using scGPT preprocessor"""
    preprocessor = Preprocessor(
        use_key="X",
        filter_gene_by_counts=filter_gene_by_counts,
        filter_cell_by_counts=False,
        normalize_total=1e4,
        result_normed_key="X_normed",
        log1p=data_is_raw,
        result_log1p_key="X_log1p",
        subset_hvg=False,
        hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
        binning=n_bins,
        result_binned_key="X_binned",
    )
    
    # Separate test data
    adata_test = adata[adata.obs["str_batch"] == "1"].copy()
    adata_train = adata[adata.obs["str_batch"] == "0"].copy()
    
    # Apply preprocessing
    preprocessor(adata_train, batch_key=None)
    preprocessor(adata_test, batch_key=None)
    
    return adata_train, adata_test

In [None]:
def load_vocab(adata, model_dir):
    """Load vocabulary from a pre-trained model"""
    model_dir = Path(model_dir)
    vocab_file = model_dir / "vocab.json"
    vocab = GeneVocab.from_file(vocab_file)
    
    # Copy vocabulary to save directory
    shutil.copy(vocab_file, save_dir / "vocab.json")
    
    # Add special tokens if not already in vocab
    for s in special_tokens:
        if s not in vocab:
            vocab.append_token(s)
    
    # Check how many genes in the data match the vocabulary
    adata.var["id_in_vocab"] = [
        1 if gene in vocab else -1 for gene in adata.var["gene_name"]
    ]
    
    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    logger.info(
        f"Match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
        f"in vocabulary of size {len(vocab)}."
    )
    
    # Filter genes by vocabulary
    adata = adata[:, adata.var["id_in_vocab"] >= 0]
    
    return vocab, adata

In [7]:
class GeneExpressionDataset(Dataset):
    """Simple dataset class for gene expression data"""
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return self.data["gene_ids"].shape[0]
    
    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}

def prepare_dataloader(data_pt, batch_size, shuffle=False):
    """Prepare dataloader for the model"""
    dataset = GeneExpressionDataset(data_pt)
    
    # Use multiple workers if available
    num_workers = min(os.cpu_count() or 1, batch_size // 2)
    
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    return data_loader

In [8]:
def setup_model(vocab, num_types, model_dir):
    """Load and set up the model for cell type classification"""
    ntokens = len(vocab)
    
    model = TransformerModel(
        ntokens,
        embsize=layer_size,
        nhead=nhead,
        d_hid=layer_size,
        nlayers=nlayers,
        nlayers_cls=3,
        n_cls=num_types,
        vocab=vocab,
        dropout=dropout,
        pad_token=pad_token,
        pad_value=pad_value,
        # Disable unnecessary features
        do_mvc=False,
        do_dab=False,
        use_batch_labels=False,
        num_batch_labels=1,
        domain_spec_batchnorm=False,
        # Keep necessary settings
        input_emb_style=input_emb_style,
        n_input_bins=n_input_bins,
        cell_emb_style=cell_emb_style,
        mvc_decoder_style="inner product",
        ecs_threshold=0.0,
        explicit_zero_prob=explicit_zero_prob,
        use_fast_transformer=fast_transformer,
        fast_transformer_backend="flash",
        pre_norm=pre_norm,
    )
    
    # Load pre-trained weights
    model_file = model_dir / "best_model.pt"
    try:
        model.load_state_dict(torch.load(model_file))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # Only load matching parameters
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file)
        pretrained_dict = {
            k: v for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    
    # Freeze parameters if requested
    if freeze:
        for name, para in model.named_parameters():
            if "encoder" in name and "transformer_encoder" not in name:
                logger.info(f"Freezing weights for: {name}")
                para.requires_grad = False
    
    model.to(device)
    return model

In [9]:
def train_epoch(model, loader, optimizer, criterion, scaler, epoch):
    """Train the model for one epoch"""
    model.train()
    
    total_loss = 0.0
    total_error = 0.0
    num_batches = len(loader)
    start_time = time.time()
    
    for batch, batch_data in enumerate(loader):
        input_gene_ids = batch_data["gene_ids"].to(device)
        input_values = batch_data["values"].to(device)
        celltype_labels = batch_data["celltype_labels"].to(device)
        
        # Create padding mask
        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        
        # Forward pass with mixed precision
        with torch.cuda.amp.autocast(enabled=amp):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=None,
                CLS=True,
                CCE=False,
                MVC=False,
                ECS=False,
                do_sample=False,
            )
            
            # Only compute classification loss
            loss = criterion(output_dict["cls_output"], celltype_labels)
        
        # Error rate calculation
        error_rate = 1 - ((output_dict["cls_output"].argmax(1) == celltype_labels).sum().item()) / celltype_labels.size(0)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        # Track metrics
        total_loss += loss.item()
        total_error += error_rate
        
        # Log progress
        if batch % 100 == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / 100
            cur_loss = total_loss / 100
            cur_error = total_error / 100
            
            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | err {cur_error:5.2f}"
            )
            
            total_loss = 0
            total_error = 0
            start_time = time.time()
    
    return model

In [11]:
def evaluate(model, loader):
    """Evaluate model performance on provided dataloader"""
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_num = 0
    criterion = nn.CrossEntropyLoss()
    predictions = []
    
    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            celltype_labels = batch_data["celltype_labels"].to(device)
            
            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            
            with torch.cuda.amp.autocast(enabled=amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=None,
                    CLS=True,
                    CCE=False,
                    MVC=False,
                    ECS=False,
                    do_sample=False,
                )
                
                output_values = output_dict["cls_output"]
                loss = criterion(output_values, celltype_labels)
            
            total_loss += loss.item() * len(input_gene_ids)
            accuracy = (output_values.argmax(1) == celltype_labels).sum().item()
            total_error += (1 - accuracy / len(input_gene_ids)) * len(input_gene_ids)
            total_num += len(input_gene_ids)
            
            preds = output_values.argmax(1).cpu().numpy()
            predictions.append(preds)
    
    return total_loss / total_num, total_error / total_num, np.concatenate(predictions, axis=0)

In [12]:
def main():
    # Set seed for reproducibility
    set_seed(seed)
    
    # Load data
    logger.info("Loading data...")
    adata, adata_test_raw, num_types, id2type = load_data()
    
    # Load vocabulary
    logger.info("Loading vocabulary...")
    vocab, adata = load_vocab(adata, model_dir="../save/scGPT_human")
    
    # Preprocess data
    logger.info("Preprocessing data...")
    adata_train, adata_test = preprocess_data(adata)
    
    # Determine input layer key based on input style
    input_layer_key = {
        "normed_raw": "X_normed",
        "log1p": "X_normed",
        "binned": "X_binned",
    }[input_style]
    
    # Prepare tokenized training data
    train_data_pt, valid_data_pt, gene_ids = prepare_training_data(adata_train, vocab, input_layer_key)
    
    # Create dataloaders
    train_loader = prepare_dataloader(train_data_pt, batch_size=batch_size, shuffle=True)
    valid_loader = prepare_dataloader(valid_data_pt, batch_size=eval_batch_size)
    
    # Setup model
    model = setup_model(vocab, num_types, model_dir="../save/scGPT_human")
    
    # Setup optimizer and criterion
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=schedule_ratio)
    criterion = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=amp)
    
    # Training loop
    best_val_loss = float("inf")
    best_model = None
    
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        
        # Train one epoch
        model = train_epoch(model, train_loader, optimizer, criterion, scaler, epoch)
        
        # Evaluate
        val_loss, val_err, _ = evaluate(model, valid_loader)
        
        # Log results
        elapsed = time.time() - epoch_start_time
        logger.info("-" * 89)
        logger.info(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss {val_loss:5.4f} | err {val_err:5.4f}"
        )
        logger.info("-" * 89)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_model_epoch = epoch
            logger.info(f"Best model with score {best_val_loss:5.4f}")
        
        scheduler.step()
    
    # Inference on test data
    predictions, labels, results = inference(best_model, adata_test, vocab, gene_ids, input_layer_key)
    
    # Save predictions to test data h5ad object
    adata_test_raw.obs["predictions"] = [id2type[p] for p in predictions]
    adata_test_raw.write(save_dir / "test_with_predictions.h5ad")
    
    # Save model and results
    torch.save(best_model.state_dict(), save_dir / "model.pt")
    
    with open(save_dir / "results.pkl", "wb") as f:
        pickle.dump({
            "predictions": predictions,
            "labels": labels,
            "results": results,
            "id_maps": id2type
        }, f)
    
    logger.info(f"Final metrics: {results}")
    logger.info("Pipeline completed successfully!")

In [13]:
main()

2025-04-17 14:10:43,091 - __main__ - INFO - Loading data...


FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = '../data/ms/c_data.h5ad', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)