# Parameter Efficient Finetuning (PEFT) using Low Rank Adapters (LoRA) for cell

## Goals
1. Write a PyTorch training loop implementing LoRA via PEFT
2. Tokenize multiple sclerosis data first
3. Parallelize LoRA on all GPUs if possible

## Steps
1. Integrate HuggingFace's PEFT into scGPT to perform finetuning
2. Implementation will use HuggingFace's scGPT implementation from Therapeutic Commons - https://huggingface.co/tdc/scGPT
3. Test dataset - M.S. dataset (since there is a benchmark)

Requirements from HuggingFace
- transformers 
- accelerate 
- evaluate
- datasets 
- peft
- loralib
- PyTDC

In [None]:
# HF imports
import transformers
import accelerate
import datasets
import torch
import numpy as np
import scanpy as sc
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight

from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os

# TDC imports
from tdc import tdc_hf_interface
from tdc.model_server.tokenizers.scgpt import scGPTTokenizer


# Version check (optional)
print(f"Transformers version: {transformers.__version__}")
print(f"Accelerate version: {accelerate.__version__}")
print(f"Datasets version: {datasets.__version__}")

# Step 1: Load data

1. Load raw counts from training and test dataset
2. Follow steps for normalization, tokenization, and embedding



In [None]:
# Load pretrained scGPT model from Hugging Face
scgpt = tdc_hf_interface("scGPT")
base_model = scgpt.load()
tokenizer = scGPTTokenizer()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    base_model = nn.DataParallel(base_model)

base_model = base_model.to(device)

In [None]:
# Load data
data_path = "../data/sample_ms/"
adata = sc.read_h5ad(data_path + "c_data.h5ad")
adata_test = sc.read_h5ad(data_path + "filtered_ms_adata.h5ad")

In [None]:
gene_names = adata.var["gene_name"].to_numpy()
tokenized_data = tokenizer.tokenize_cell_vectors(
    adata.X.toarray(), gene_names)

In [None]:

# Encode labels
le = LabelEncoder()
adata.obs["cell_type_encoded"] = le.fit_transform(adata.obs["celltype"])

# Tokenize train data

## This is a list(# tuple(torch.tensor for cell emb, torch.tensor for value emb))

train_tokens = tokenizer.tokenize_cell_vectors(adata.X.toarray(), gene_names)
train_labels = adata.obs["cell_type_encoded"].to_numpy()


In [None]:
from torch.nn.utils.rnn import pad_sequence

gene_tokens = [tensor_tuple[0] for tensor_tuple in train_tokens]
value_tokens = [tensor_tuple[1] for tensor_tuple in train_tokens]

In [None]:
# feature length: 1421

padded_gene_tokens = pad_sequence(gene_tokens, batch_first=True, padding_value=60694)
padded_value_tokens = pad_sequence(value_tokens, batch_first=True, padding_value=0.0) 

In [None]:

class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

In [None]:
class scRNADataset(Dataset):
    def __init__(self, tokenized, labels=None):
        self.tokenized = tokenized
        self.labels = labels

    def __len__(self):
        return len(self.tokenized)

    def __getitem__(self, idx):
        tokens, counts = self.tokenized[idx]
        sample = {
            "input_ids": torch.tensor(tokens, dtype=torch.long),
            "attention_mask": (torch.tensor(counts) > 0).long(),
            "values": torch.tensor(counts, dtype=torch.float),
        }
        if self.labels is not None:
            sample["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return sample
    
from torch.nn.utils.rnn import pad_sequence


def collate_fn(batch):
    input_ids = [b["input_ids"] for b in batch]
    attention_masks = [b["attention_mask"] for b in batch]
    values = [b["values"] for b in batch]
    labels = [b["labels"] for b in batch] if "labels" in batch[0] else None

    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=60694)  # <pad> token id
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0).bool()
    values_padded = pad_sequence(values, batch_first=True, padding_value=0.0)

    result = {
        "input_ids": input_ids_padded,
        "attention_mask": attention_masks_padded,
        "values": values_padded,
    }

    if labels is not None:
        result["labels"] = torch.tensor(labels, dtype=torch.long)

    return result

# Wrap base_model with classifier head
class scGPTClassifier(nn.Module):
    def __init__(self, base_model, hidden_dim, num_classes):
        super().__init__()
        self.base = base_model
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, input_ids, attention_mask, values, labels=None):
        outputs = self.base(input_ids=input_ids, attention_mask=attention_mask, values=values)

        cls_token = outputs["cell_emb"]  # use pooled cell embedding
        logits = self.classifier(cls_token)

        loss = F.cross_entropy(logits, labels, weight=class_weights)
        return {"loss": loss, "logits": logits}

In [None]:
# Build model
num_classes = len(le.classes_)
hidden_dim = 512  # common default for scGPT
model = scGPTClassifier(base_model, hidden_dim, num_classes).to(device)

# Prepare DataLoader
train_dataset = scRNADataset(train_tokens, train_labels)
train_loader = DataLoader(train_dataset, batch_size=96, shuffle=True, collate_fn=collate_fn)

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"✅ TRAINABLE: {name}")
    else:
        print(f"⛔️ FROZEN: {name}")

In [None]:
model.train()
epochs = 10

for epoch in range(epochs):
    total_loss = 0.0
    correct = 0
    total = 0

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        values = batch["values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        out = model(input_ids=input_ids, attention_mask=attention_mask, values=values, labels=labels)
        loss = out["loss"]
        logits = out["logits"]

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Accuracy calculation
        preds = torch.argmax(logits, dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    accuracy = correct / total
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

In [None]:
# Predict on test data
test_tokens = tokenizer.tokenize_cell_vectors(adata_test.X.toarray(), gene_names)
test_dataset = scRNADataset(test_tokens)
test_loader = DataLoader(test_dataset, batch_size=96, collate_fn=collate_fn)

model.eval()
preds = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        values = batch["values"].to(device)  # ADD THIS LINE
        logits = model(input_ids=input_ids, attention_mask=attention_mask, values=values)["logits"]
        batch_preds = torch.argmax(logits, dim=-1).cpu().numpy()
        preds.extend(batch_preds)

# Map predictions to labels and save
adata_test.obs["predicted_cell_type"] = le.inverse_transform(preds)
output_path = os.path.join(data_path, "filtered_ms_adata_with_predictions.h5ad")
adata_test.write(output_path)
print(f"✅ Predictions saved to: {output_path}")

In [None]:
predicted_data = sc.read_h5ad(os.path.join(data_path, "filtered_ms_adata_with_predictions.h5ad"))

In [None]:
predicted_data.obs.predicted_cell_type.unique()