# Parameter Efficient Finetuning (PEFT) using Low Rank Adapters (LoRA) techniques
1. Integrate HuggingFace's PEFT into scGPT to perform finetuning
2. Implementation will use HuggingFace's scGPT implementation from Therapeutic Commons - https://huggingface.co/tdc/scGPT
3. Test dataset - M.S. dataset (since there is a benchmark)

Requirements from HuggingFace
- transformers 
- accelerate 
- evaluate
- datasets 
- peft
- loralib
- PyTDC



In [1]:
### Multiple Sclerosis Data

# filtered_ms_adata.h5ad
# # !gdown 1casFhq4InuBNhJLMnGebzkRXM2UTTeQG 

# c_data.h5ad
# !gdown 1bV1SHKVZgkcL-RmmuN51_IIUJTSJbXOi 

In [4]:
!pip install --upgrade transformers accelerate peft datasets PyTDC

Collecting transformers
  Downloading transformers-4.51.2-py3-none-any.whl.metadata (38 kB)
Collecting accelerate
  Using cached accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Collecting datasets
  Using cached datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting transformers
  Using cached transformers-4.50.3-py3-none-any.whl.metadata (39 kB)
Using cached transformers-4.50.3-py3-none-any.whl (10.2 MB)
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.51.0.dev0
    Uninstalling transformers-4.51.0.dev0:
      Successfully uninstalled transformers-4.51.0.dev0
Successfully installed transformers-4.50.3


In [13]:
# HF imports 
import transformers
import accelerate
import peft
import datasets

import scanpy as sc

# TDC Imports
from tdc.multi_pred.anndata_dataset import DataLoader
from tdc import tdc_hf_interface
from tdc.model_server.tokenizers.scgpt import scGPTTokenizer
from tdc.model_server.models import scgpt

import torch
import numpy as np


print(f"Transformers version: {transformers.__version__}")
print(f"Accelerate version: {accelerate.__version__}")
print(f"PEFT version: {peft.__version__}")
print(f"Datasets version: {datasets.__version__}")
# print(f"TDC version: {tdc.__version__")

Transformers version: 4.50.3
Accelerate version: 0.33.0
PEFT version: 0.15.1
Datasets version: 2.19.2


# Step 1: Load data

In [75]:
data_path = "../data/peft_test/"

adata = sc.read_h5ad(data_path+"filtered_ms_adata.h5ad")
# del adata.uns["log1p"]

# Normalize and log1p — don't do HVG or batch correction
## ALREADY DONE FOR MULTIPLE SCLEROSIS
# sc.pp.normalize_total(adata, target_sum=1e4)
# sc.pp.log1p(adata)

# # # Ensure dense matrix
adata.X = adata.X.toarray()

# # # Get gene names and expression matrix
gene_names = adata.var["gene_name"].to_numpy() 
X = adata.X

In [76]:
# Load model
from tdc.model_server.tokenizers.scgpt import scGPTTokenizer

tokenizer = scGPTTokenizer()
tokenized_data = tokenizer.tokenize_cell_vectors(X, gene_names)  # Returns list of (gene_ids, values)

Found local copy...


In [80]:
from tdc import tdc_hf_interface
from peft import get_peft_model, LoraConfig, TaskType

# Load pretrained scGPT model
scgpt = tdc_hf_interface("scGPT")
base_model = scgpt.load()

# Wrap with LoRA (only Q and V projections usually)
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["self_attn"],  # You may need to adapt based on scGPT architecture
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

trainable params: 294,912 || all params: 51,099,137 || trainable%: 0.5771


In [85]:
print(base_model.config)

ScGPTConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "ScGPTModel"
  ],
  "cell_emb_style": "cls",
  "d_hid": 512,
  "dropout": 0.0,
  "embsize": 512,
  "explicit_zero_prob": false,
  "input_emb_style": "continuous",
  "max_seq_len": 1536,
  "model_type": "scgpt",
  "nhead": 8,
  "nlayers": 12,
  "norm_scheme": "post",
  "pad_token_id": 0,
  "torch_dtype": "float32",
  "transformers_version": "4.50.3",
  "use_fast_transformer": true,
  "use_flash_attention": false,
  "vocab_size": 60697
}



In [86]:
import torch.nn as nn

class CellTypeClassifier(nn.Module):
    def __init__(self, base_model, hidden_dim, num_classes):
        super().__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, gene_ids, expr_vals, attention_mask=None):
        output = self.base_model(
            input_ids=gene_ids,
            values=expr_vals,
            attention_mask=attention_mask
        )
        pooled = output.mean(dim=1)  # Mean pooling over genes
        return self.classifier(pooled)

from sklearn.preprocessing import LabelEncoder
import torch

# Encode your cell types
label_encoder = LabelEncoder()
labels = label_encoder.fit_transform(adata.obs["celltype"])
labels = torch.tensor(labels)

clf_model = CellTypeClassifier(model, hidden_dim=base_model.config.d_hid, num_classes=len(label_encoder.classes_))

In [87]:
import torch.optim as optim
import torch.nn as nn

optimizer = optim.Adam(clf_model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

clf_model.train()
for epoch in range(5):
    total_loss = 0
    for i, (gene_ids, expr_vals) in enumerate(tokenized_data):
        mask = expr_vals != 0
        label = labels[i].unsqueeze(0)

        optimizer.zero_grad()
        pred = clf_model(
            gene_ids.unsqueeze(0),
            expr_vals.unsqueeze(0),
            attention_mask=mask.unsqueeze(0)
        )
        loss = loss_fn(pred, label)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1} loss: {total_loss:.4f}")

TypeError: forward() got an unexpected keyword argument 'inputs_embeds'

In [88]:
print(base_model.forward.__doc__)


        Args:
            input_ids: Tensor of gene indices, shape [batch_size, seq_len]
            values: Tensor of expression values, shape [batch_size, seq_len]
            attention_mask: Optional mask tensor, shape [batch_size, seq_len]
            output_cell_emb: Whether to output cell embeddings

        Returns:
            Dictionary containing:
                - 'pred': Predicted expression values
                - 'cell_emb': Cell embeddings (if output_cell_emb=True)
                - 'zero_probs': Zero probabilities (if config.explicit_zero_prob=True)
        


In [79]:
print(base_model)

ScGPTModel(
  (gene_encoder): ModuleDict(
    (embedding): Embedding(60697, 512, padding_idx=0)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): ModuleDict(
    (linear1): Linear(in_features=1, out_features=512, bias=True)
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((