# Model Performance Comparison: Pre-trained vs Fine-tuned scGPT

This notebook provides a comprehensive comparison of the pre-trained and fine-tuned scGPT models' performance on both training and test data to assess the effectiveness of fine-tuning.

## Overview
- **Goal**: Compare performance between pre-trained and fine-tuned models on training and test data
- **Dataset**: Adamson perturbation data with simulation split
- **Analysis**: Multiple evaluation metrics including perturbation prediction accuracy, gene expression reconstruction, and downstream task performance
- **Context**: Previous analysis showed OOD issues - this notebook quantifies how well fine-tuning addresses them


In [1]:
from google.colab import drive
import os

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks/scGPT_finetune')

!pip install -r ./requirements.txt
!pip install scgpt "flash-attn<1.0.5"

Mounted at /content/drive
Collecting torch==2.3.1 (from -r ./requirements.txt (line 6))
  Downloading torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl.metadata (26 kB)
Collecting torchvision==0.18.1 (from -r ./requirements.txt (line 7))
  Downloading torchvision-0.18.1-cp312-cp312-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting torchaudio==2.3.1 (from -r ./requirements.txt (line 8))
  Downloading torchaudio-2.3.1-cp312-cp312-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting torchtext==0.18.0 (from -r ./requirements.txt (line 9))
  Downloading torchtext-0.18.0-cp312-cp312-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting scanpy<2.0,>=1.9.1 (from -r ./requirements.txt (line 12))
  Downloading scanpy-1.11.4-py3-none-any.whl.metadata (9.2 kB)
Collecting scvi-tools<1.0,>=0.16.0 (from -r ./requirements.txt (line 13))
  Downloading scvi_tools-0.20.3-py3-none-any.whl.metadata (9.8 kB)
Collecting torch_geometric (from -r ./requirements.txt (line 16))
  Downloading torch_geometric-2.6.1-py3-no

In [2]:
# Import libraries
import json
import os
import sys
import time
import copy
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Union, Optional
import warnings

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.decomposition import PCA
from torch import nn
from torch.nn import functional as F
from torchtext.vocab import Vocab
from torchtext._torchtext import Vocab as VocabPybind
from torch_geometric.loader import DataLoader
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction

sys.path.insert(0, "../")

import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.loss import (
    masked_mse_loss,
    criterion_neg_log_bernoulli,
    masked_relative_error,
)
from scgpt.tokenizer import tokenize_batch, pad_batch, tokenize_and_pad_batch
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id, compute_perturbation_metrics

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
warnings.filterwarnings("ignore")

set_seed(42)
print("Libraries imported successfully!")




Libraries imported successfully!


In [3]:
# Load and prepare data
print("Loading perturbation data...")

# Settings for data processing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0
pert_pad_id = 0
include_zero_gene = "all"
max_seq_len = 1536

# Dataset settings
data_name = "adamson"
split = "simulation"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load perturbation data
pert_data = PertData("./data")
pert_data.load(data_name=data_name)
pert_data.prepare_split(split=split, seed=1)
pert_data.get_dataloader(batch_size=64, test_batch_size=64)

# print(f"Data loaded successfully!")
# print(f"Dataset: {data_name}")
# print(f"Split: {split}")
# print(f"Device: {device}")

# # Get basic info about the dataset
# adata = pert_data.adata
# print(f"\nDataset info:")
# print(f"Total cells: {adata.n_obs}")
# print(f"Total genes: {adata.n_vars}")
# print(f"Conditions: {len(adata.obs['condition'].unique())} unique conditions")

# # Extract train/test splits
# def extract_split_data_by_conditions(adata, set2conditions, split_name):
#     """Extract data for a specific split based on conditions"""
#     if split_name not in set2conditions:
#         raise ValueError(f"Unknown split: {split_name}")

#     # Get conditions for this split
#     split_conditions = set2conditions[split_name]

#     # Create boolean mask for cells in this split
#     split_mask = adata.obs['condition'].isin(split_conditions)

#     return adata[split_mask].copy()

# train_adata = extract_split_data_by_conditions(adata, pert_data.set2conditions, "train")
# test_adata = extract_split_data_by_conditions(adata, pert_data.set2conditions, "test")
# val_adata = extract_split_data_by_conditions(adata, pert_data.set2conditions, "val")

# print(f"\nSplit sizes:")
# print(f"Train: {train_adata.n_obs} cells")
# print(f"Test: {test_adata.n_obs} cells")
# print(f"Val: {val_adata.n_obs} cells")


Found local copy...


Loading perturbation data...


Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:0
combo_seen1:0
combo_seen2:0
unseen_single:22
Done!
Creating dataloaders....
Done!


In [4]:
# Load pretrained and finetuned models
print("Loading models...")

# Model settings
load_model = "./save/scGPT_human"
load_param_prefixs = [
    "encoder",
    "value_encoder",
    "transformer_encoder",
]

# Load model configuration
model_dir = Path("./save/scGPT_human")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

pert_data.adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
genes = pert_data.adata.var["gene_name"].tolist()

# Load model configuration
with open(model_config_file, "r") as f:
    model_configs = json.load(f)

embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)
ntokens = len(vocab)

print(f"Model configuration loaded:")
print(f"  Vocabulary size: {ntokens}")
print(f"  Embedding size: {embsize}")
print(f"  Number of layers: {nlayers}")
print(f"  Genes in vocab: {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)}")


Loading models...
Model configuration loaded:
  Vocabulary size: 60697
  Embedding size: 512
  Number of layers: 12
  Genes in vocab: 4399/5060


In [5]:
# Create and load pretrained model
print("Loading pretrained model...")
model_pretrain = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Load pretrained weights
model_dict = model_pretrain.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
    k: v for k, v in pretrained_dict.items()
    if any([k.startswith(prefix) for prefix in load_param_prefixs])
}
for k, v in pretrained_dict.items():
    print(f"Loading pretrained param {k} with shape {v.shape}")
model_dict.update(pretrained_dict)
model_pretrain.load_state_dict(model_dict)
model_pretrain.to(device)
model_pretrain.eval()

print("Pretrained model loaded successfully!")

# Load finetuned model
print("Loading finetuned model...")
model_finetune = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Try to load finetuned weights
finetuned_model_dir = Path("./save/scGPT_human_finetuned_adamson")
finetuned_model_file = finetuned_model_dir / "best_model.pt"

if finetuned_model_file.exists():
    try:
        model_finetune.load_state_dict(torch.load(finetuned_model_file))
        print("Finetuned model loaded successfully!")
    except Exception as e:
        print(f"Error loading finetuned model: {e}")
        print("Using pretrained model for both comparisons...")
        model_finetune = copy.deepcopy(model_pretrain)
else:
    print("Finetuned model not found. Using pretrained model for both comparisons...")
    model_finetune = copy.deepcopy(model_pretrain)

model_finetune.to(device)
model_finetune.eval()

print("Loading finetuned model...")
model_finetune_1 = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Try to load finetuned weights
finetuned_model_dir = Path("./save/scGPT_human_finetuned_adamson")
finetuned_model_file = finetuned_model_dir / "model_1.pt"

if finetuned_model_file.exists():
    model_finetune_1.load_state_dict(torch.load(finetuned_model_file))
    print("Finetuned model loaded successfully!")

model_finetune.to(device)
model_finetune.eval()

print("Loading finetuned model...")
model_finetune_2 = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Try to load finetuned weights
finetuned_model_dir = Path("./save/scGPT_human_finetuned_adamson")
finetuned_model_file = finetuned_model_dir / "model_2.pt"

if finetuned_model_file.exists():
    model_finetune_2.load_state_dict(torch.load(finetuned_model_file))
    print("Finetuned model loaded successfully!")

model_finetune_2.to(device)
model_finetune_2.eval()

print("Loading finetuned model...")
model_finetune_3 = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Try to load finetuned weights
finetuned_model_dir = Path("./save/scGPT_human_finetuned_adamson")
finetuned_model_file = finetuned_model_dir / "model_3.pt"

if finetuned_model_file.exists():
    model_finetune_3.load_state_dict(torch.load(finetuned_model_file))
    print("Finetuned model loaded successfully!")

model_finetune_3.to(device)
model_finetune_3.eval()

print("Models ready for evaluation!")


Loading pretrained model...
Loading pretrained param encoder.embedding.weight with shape torch.Size([60697, 512])
Loading pretrained param encoder.enc_norm.weight with shape torch.Size([512])
Loading pretrained param encoder.enc_norm.bias with shape torch.Size([512])
Loading pretrained param value_encoder.linear1.weight with shape torch.Size([512, 1])
Loading pretrained param value_encoder.linear1.bias with shape torch.Size([512])
Loading pretrained param value_encoder.linear2.weight with shape torch.Size([512, 512])
Loading pretrained param value_encoder.linear2.bias with shape torch.Size([512])
Loading pretrained param value_encoder.norm.weight with shape torch.Size([512])
Loading pretrained param value_encoder.norm.bias with shape torch.Size([512])
Loading pretrained param transformer_encoder.layers.0.self_attn.Wqkv.weight with shape torch.Size([1536, 512])
Loading pretrained param transformer_encoder.layers.0.self_attn.Wqkv.bias with shape torch.Size([1536])
Loading pretrained para

In [6]:
def eval_perturb(
    loader: DataLoader, model: TransformerGenerator, device: torch.device
) -> Dict:
    """
    Run model in inference mode using a given data loader
    """

    model.eval()
    model.to(device)
    pert_cat = []
    pred = []
    truth = []
    pred_de = []
    truth_de = []
    results = {}
    logvar = []

    for itr, batch in enumerate(loader):
        batch.to(device)
        pert_cat.extend(batch.pert)

        with torch.no_grad():
            p = model.pred_perturb(
                batch,
                include_zero_gene=include_zero_gene,
                gene_ids=gene_ids,
            )
            t = batch.y
            pred.extend(p.cpu())
            truth.extend(t.cpu())

            # Differentially expressed genes
            for itr, de_idx in enumerate(batch.de_idx):
                pred_de.append(p[itr, de_idx])
                truth_de.append(t[itr, de_idx])

    # all genes
    results["pert_cat"] = np.array(pert_cat)
    pred = torch.stack(pred)
    truth = torch.stack(truth)
    results["pred"] = pred.detach().cpu().numpy().astype(np.float64)
    results["truth"] = truth.detach().cpu().numpy().astype(np.float64)

    pred_de = torch.stack(pred_de)
    truth_de = torch.stack(truth_de)
    results["pred_de"] = pred_de.detach().cpu().numpy().astype(np.float64)
    results["truth_de"] = truth_de.detach().cpu().numpy().astype(np.float64)

    return results

train_loader = pert_data.dataloader["train_loader"]
valid_loader = pert_data.dataloader["val_loader"]
test_loader = pert_data.dataloader["test_loader"]

# train_res_pt = eval_perturb(train_loader, model_pretrain, device)
# train_metrics_pt = compute_perturbation_metrics(
#     train_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
# )
# val_res_pt = eval_perturb(valid_loader, model_pretrain, device)
# val_metrics_pt = compute_perturbation_metrics(
#     val_res_pt, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
# )
# test_res_pt = eval_perturb(test_loader, model_pretrain, device)
# test_metrics_pt = compute_perturbation_metrics(
#     test_res_pt, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
# )

# print("pretrain_done")

# # train_res_ft = eval_perturb(train_loader, model_finetune, device)
# # train_metrics_ft = compute_perturbation_metrics(
# #     train_res, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
# # )
# val_res_ft = eval_perturb(valid_loader, model_finetune, device)
# val_metrics_ft = compute_perturbation_metrics(
#     val_res_ft, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
# )
# test_res_ft = eval_perturb(test_loader, model_finetune, device)
# test_metrics_ft = compute_perturbation_metrics(
#     test_res_ft, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
# )

val_res_ft_1 = eval_perturb(valid_loader, model_finetune_1, device)
val_metrics_ft_1 = compute_perturbation_metrics(
    val_res_ft_1, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
)
test_res_ft_1 = eval_perturb(test_loader, model_finetune_1, device)
test_metrics_ft_1 = compute_perturbation_metrics(
    test_res_ft_1, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
)
val_res_ft_2 = eval_perturb(valid_loader, model_finetune_2, device)
val_metrics_ft_2 = compute_perturbation_metrics(
    val_res_ft_2, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
)
test_res_ft_2 = eval_perturb(test_loader, model_finetune_2, device)
test_metrics_ft_2 = compute_perturbation_metrics(
    test_res_ft_2, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
)
val_res_ft_3 = eval_perturb(valid_loader, model_finetune_3, device)
val_metrics_ft_3 = compute_perturbation_metrics(
    val_res_ft_3, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
)
test_res_ft_3 = eval_perturb(test_loader, model_finetune_3, device)
test_metrics_ft_3 = compute_perturbation_metrics(
    test_res_ft_3, pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
)

In [7]:
# print("Pretrained model evaluation:")
# # print(train_metrics_pt)
# print(val_metrics_pt)
# print(test_metrics_pt)
print("Finetuned model evaluation:")
# print(train_metrics_ft)
print(val_metrics_ft_1)
print(test_metrics_ft_1)
print(val_metrics_ft_2)
print(test_metrics_ft_2)
print(val_metrics_ft_3)
print(test_metrics_ft_3)

Finetuned model evaluation:
{'pearson': np.float64(0.9896269943999882), 'pearson_de': np.float64(0.9521256939986291), 'pearson_delta': np.float64(0.6753658892400664), 'pearson_de_delta': np.float64(0.8529479473134424)}
{'pearson': np.float64(0.9896935866254931), 'pearson_de': np.float64(0.9761265028642859), 'pearson_delta': np.float64(0.6144568934466045), 'pearson_de_delta': np.float64(0.7843722436192774)}
{'pearson': np.float64(0.9885673434077875), 'pearson_de': np.float64(0.951038533174156), 'pearson_delta': np.float64(0.6779508353636224), 'pearson_de_delta': np.float64(0.8316920311951698)}
{'pearson': np.float64(0.9879044377188347), 'pearson_de': np.float64(0.9777859879099342), 'pearson_delta': np.float64(0.6086645358469214), 'pearson_de_delta': np.float64(0.788025265996234)}
{'pearson': np.float64(0.9890670983461928), 'pearson_de': np.float64(0.9441912344049294), 'pearson_delta': np.float64(0.6562050765776121), 'pearson_de_delta': np.float64(0.8109449165672137)}
{'pearson': np.floa

In [10]:
val_metrics

{'pearson': np.float64(0.9905728379185182),
 'pearson_de': np.float64(0.9530412024051579),
 'pearson_delta': np.float64(0.7030412970266849),
 'pearson_de_delta': np.float64(0.8525051671989553)}

# Saliency Map


In [11]:
pert_data.adata.obs["condition"].unique()

['CREB1+ctrl', 'ctrl', 'ZNF326+ctrl', 'BHLHE40+ctrl', 'DDIT3+ctrl', ..., 'CARS+ctrl', 'TMED2+ctrl', 'P4HB+ctrl', 'SPCS3+ctrl', 'SPCS2+ctrl']
Length: 87
Categories (88, object): ['AARS+ctrl', 'AMIGO3+ctrl', 'ARHGAP22+ctrl', 'ASCC3+ctrl', ...,
                          'XRN1+ctrl', 'YIPF5+ctrl', 'ZNF326+ctrl', 'ctrl']

In [20]:
def compute_saliency_map_for_perturbation_prediction(model, pert_data, pert_condition, target_gene_idx=None, n_cells=5):
    """
    Compute saliency map for perturbation expression prediction using model.pred_perturb().
    """
    model.eval()

    # Get cells from the specific perturbation condition
    pert_cells = pert_data.adata[pert_data.adata.obs["condition"] == pert_condition]

    if pert_cells.shape[0] == 0:
        print(f"No cells found for perturbation: {pert_condition}")
        return None

    # Sample cells from this perturbation
    n_available = min(n_cells, pert_cells.shape[0])
    sample_indices = np.random.choice(pert_cells.shape[0], n_available, replace=False)
    sampled_cells = pert_cells[sample_indices, :]

    print(f"Analyzing {n_available} cells from perturbation: {pert_condition}")

    saliency_results = {
        'pert_condition': pert_condition,
        'n_cells': n_available,
        'cell_saliencies': [],
        'average_saliency': None,
        'gene_names': genes
    }

    # Process each cell
    for cell_idx in range(n_available):
        print(f"Processing cell {cell_idx + 1}/{n_available}")

        try:
            # Get expression values for this cell
            cell_expr = sampled_cells.X[cell_idx].toarray().flatten()  # [n_genes]

            # Create input tensors - use float32 to avoid Flash Attention issues
            input_values = torch.tensor(cell_expr, dtype=torch.float32, device=device, requires_grad=True)

            # Create a batch-like structure for pred_perturb
            # We need to create the proper batch format that pred_perturb expects
            batch_size = 1
            n_genes = len(input_values)

            # Create the batch structure similar to what GEARS creates
            # pred_perturb expects a batch with specific attributes
            class PerturbBatch:
                def __init__(self, x, y, pert, de_idx):
                    self.x = x  # [batch_size * n_genes, 2] - expression values and perturbation flags
                    self.y = y  # [batch_size, n_genes] - target values
                    self.pert = pert  # perturbation names
                    self.de_idx = de_idx  # differentially expressed gene indices
                    self.to = self.to_device

                def to_device(self, device):
                    self.x = self.x.to(device)
                    self.y = self.y.to(device)
                    return self

            # Create perturbation flags (all zeros for now)
            pert_flags = torch.zeros_like(input_values)

            # Stack expression values and perturbation flags
            x_data = torch.stack([input_values, pert_flags], dim=1)  # [n_genes, 2]
            x_data = x_data.unsqueeze(0)  # [1, n_genes, 2]
            x_data = x_data.view(-1, 2)  # [n_genes, 2] - this is what pred_perturb expects

            # Create dummy target (we don't need it for saliency)
            y_data = torch.zeros(batch_size, n_genes, device=device)

            # Create batch
            batch = PerturbBatch(x_data, y_data, [pert_condition], [[]])
            batch = batch.to(device)

            # Forward pass using pred_perturb
            predicted_expression = model.pred_perturb(
                batch,
                include_zero_gene="all",  # or whatever setting you want
                gene_ids=gene_ids,
            )  # [batch_size, n_genes]

            if target_gene_idx is not None:
                # Compute saliency for specific target gene
                target_output = predicted_expression[0, target_gene_idx]
                target_output.backward(retain_graph=True)

                gradients = input_values.grad  # [n_genes]
                saliency = torch.abs(gradients).detach().cpu().numpy()

            else:
                # Compute saliency averaged across all predicted genes
                saliency_maps = []
                for gene_idx in range(n_genes):
                    # Zero gradients
                    if input_values.grad is not None:
                        input_values.grad.zero_()

                    gene_output = predicted_expression[0, gene_idx]
                    gene_output.backward(retain_graph=True)

                    gradients = input_values.grad
                    saliency = torch.abs(gradients).detach().cpu().numpy()
                    saliency_maps.append(saliency)

                # Average across all genes
                saliency = np.mean(saliency_maps, axis=0)

            saliency_results['cell_saliencies'].append({
                'cell_idx': cell_idx,
                'saliency': saliency,
                'expression': cell_expr,
                'predicted_expression': predicted_expression[0].detach().cpu().numpy()
            })

        except Exception as e:
            print(f"Error processing cell {cell_idx + 1}: {e}")
            import traceback
            traceback.print_exc()
            continue

    if not saliency_results['cell_saliencies']:
        print("No cells processed successfully")
        return None

    # Compute average saliency across cells
    all_saliencies = np.array([cell['saliency'] for cell in saliency_results['cell_saliencies']])
    average_saliency = np.mean(all_saliencies, axis=0)
    saliency_results['average_saliency'] = average_saliency

    return saliency_results

# Usage example:
pert_condition = 'CREB1+ctrl'
print(f"Computing saliency maps for perturbation prediction: {pert_condition}")

saliency_results = compute_saliency_map_for_perturbation_prediction(
    model_finetune,  # Use your model (not model_finetune)
    pert_data,
    pert_condition,
    target_gene_idx=None,
    n_cells=1
)

if saliency_results:
    # Analyze results
    average_saliency = saliency_results['average_saliency']
    gene_names = saliency_results['gene_names']

    # Get top contributing genes
    top_indices = np.argsort(average_saliency)[-15:][::-1]
    top_genes = [gene_names[idx] for idx in top_indices]
    top_saliency = average_saliency[top_indices]

    print(f"\nTop 15 genes by average saliency for {pert_condition}:")
    for gene, sal_val in zip(top_genes, top_saliency):
        print(f"  {gene}: {sal_val:.6f}")

    # Check if the perturbed gene itself has high saliency
    # Parse perturbation to get actual gene name
    if "+" in pert_condition:
        pert_genes = [g for g in pert_condition.split("+") if g != "ctrl"]
    else:
        pert_genes = [pert_condition]

    for pert_gene in pert_genes:
        if pert_gene in gene_names:
            pert_gene_idx = gene_names.index(pert_gene)
            pert_gene_saliency = average_saliency[pert_gene_idx]
            pert_rank = np.sum(average_saliency > pert_gene_saliency) + 1
            pert_percentile = (1 - (pert_rank - 1) / len(average_saliency)) * 100

            print(f"\nPerturbed gene {pert_gene}:")
            print(f"  Saliency score: {pert_gene_saliency:.6f}")
            print(f"  Rank: {pert_rank}/{len(average_saliency)} ({pert_percentile:.1f}th percentile)")

Computing saliency maps for perturbation prediction: CREB1+ctrl
Analyzing 1 cells from perturbation: CREB1+ctrl
Processing cell 1/1

Top 15 genes by average saliency for CREB1+ctrl:
  SRP68: 0.001279
  H2AFZ: 0.001221
  CTSC: 0.000703
  SNHG8: 0.000626
  H3F3B: 0.000591
  GMPR: 0.000578
  RPS29: 0.000560
  TCEAL8: 0.000505
  THAP7-AS1: 0.000460
  ATP5B: 0.000447
  CCDC85B: 0.000440
  ARGLU1: 0.000430
  HIST1H4C: 0.000423
  APOE: 0.000392
  ATP6V0B: 0.000379

Perturbed gene CREB1:
  Saliency score: 0.000107
  Rank: 5055/5060 (0.1th percentile)


In [22]:
def compute_saliency_for_specific_target_gene(model, pert_data, pert_condition, target_gene_name, n_cells=5):
    """
    Compute saliency map for predicting a specific target gene's expression.
    """
    model.eval()

    # Get cells from the specific perturbation condition
    pert_cells = pert_data.adata[pert_data.adata.obs["condition"] == pert_condition]

    if pert_cells.shape[0] == 0:
        print(f"No cells found for perturbation: {pert_condition}")
        return None

    # Find target gene index
    if target_gene_name not in genes:
        print(f"Target gene {target_gene_name} not found in gene list")
        return None

    target_gene_idx = genes.index(target_gene_name)
    print(f"Target gene {target_gene_name} at index {target_gene_idx}")

    # Sample cells from this perturbation
    n_available = min(n_cells, pert_cells.shape[0])
    sample_indices = np.random.choice(pert_cells.shape[0], n_available, replace=False)
    sampled_cells = pert_cells[sample_indices, :]

    print(f"Analyzing {n_available} cells for predicting {target_gene_name} expression")

    saliency_results = {
        'pert_condition': pert_condition,
        'target_gene': target_gene_name,
        'target_gene_idx': target_gene_idx,
        'n_cells': n_available,
        'cell_saliencies': [],
        'average_saliency': None,
        'gene_names': genes
    }

    # Process each cell
    for cell_idx in range(n_available):
        print(f"Processing cell {cell_idx + 1}/{n_available}")

        try:
            # Get expression values for this cell
            cell_expr = sampled_cells.X[cell_idx].toarray().flatten()  # [n_genes]

            # Create input tensors
            input_values = torch.tensor(cell_expr, dtype=torch.float32, device=device, requires_grad=True)

            # Create batch structure for pred_perturb
            class PerturbBatch:
                def __init__(self, x, y, pert, de_idx):
                    self.x = x
                    self.y = y
                    self.pert = pert
                    self.de_idx = de_idx
                    self.to = self.to_device

                def to_device(self, device):
                    self.x = self.x.to(device)
                    self.y = self.y.to(device)
                    return self

            # Create perturbation flags
            pert_flags = torch.zeros_like(input_values)

            # Stack expression values and perturbation flags
            x_data = torch.stack([input_values, pert_flags], dim=1)  # [n_genes, 2]
            x_data = x_data.unsqueeze(0)  # [1, n_genes, 2]
            x_data = x_data.view(-1, 2)  # [n_genes, 2]

            # Create dummy target
            y_data = torch.zeros(1, len(input_values), device=device)

            # Create batch
            batch = PerturbBatch(x_data, y_data, [pert_condition], [[]])
            batch = batch.to(device)

            # Forward pass using pred_perturb
            predicted_expression = model.pred_perturb(
                batch,
                include_zero_gene="all",
                gene_ids=gene_ids,
            )  # [1, n_genes]

            # Compute saliency for the specific target gene
            target_output = predicted_expression[0, target_gene_idx]
            target_output.backward(retain_graph=True)

            gradients = input_values.grad  # [n_genes]
            saliency = torch.abs(gradients).detach().cpu().numpy()

            saliency_results['cell_saliencies'].append({
                'cell_idx': cell_idx,
                'saliency': saliency,
                'expression': cell_expr,
                'target_prediction': target_output.detach().cpu().item()
            })

        except Exception as e:
            print(f"Error processing cell {cell_idx + 1}: {e}")
            continue

    if not saliency_results['cell_saliencies']:
        print("No cells processed successfully")
        return None

    # Compute average saliency across cells
    all_saliencies = np.array([cell['saliency'] for cell in saliency_results['cell_saliencies']])
    average_saliency = np.mean(all_saliencies, axis=0)
    saliency_results['average_saliency'] = average_saliency

    return saliency_results

def analyze_creb1_interactions(model, pert_data, pert_condition="CREB1+ctrl"):
    """
    Analyze CREB1's saliency when predicting genes that interact with CREB1.
    """
    # Genes known to interact with CREB1 (you can expand this list)
    # These are examples - you should use a proper protein-protein interaction database
    creb1_interacting_genes = [
        "FOS",      # Immediate early gene, activated by CREB1
        "JUN",      # Part of AP-1 complex with FOS
        "EGR1",     # Early growth response protein
        "BDNF",     # Brain-derived neurotrophic factor
        "NR4A1",    # Nuclear receptor
        "ATF3",     # Activating transcription factor 3
        "FOSB",     # FosB proto-oncogene
        "JUNB",     # JunB proto-oncogene
        "JUND",     # JunD proto-oncogene
        "ATF4",     # Activating transcription factor 4
    ]

    # Filter to genes that exist in our dataset
    available_interacting_genes = [gene for gene in creb1_interacting_genes if gene in genes]
    print(f"Found {len(available_interacting_genes)} CREB1-interacting genes in dataset: {available_interacting_genes}")

    results = {}

    for target_gene in available_interacting_genes:
        print(f"\n{'='*60}")
        print(f"Analyzing saliency for predicting {target_gene} expression")
        print(f"{'='*60}")

        saliency_results = compute_saliency_for_specific_target_gene(
            model, pert_data, pert_condition, target_gene, n_cells=5
        )

        if saliency_results:
            average_saliency = saliency_results['average_saliency']
            gene_names = saliency_results['gene_names']

            # Get CREB1's saliency score
            creb1_idx = genes.index("CREB1")
            creb1_saliency = average_saliency[creb1_idx]

            # Get top contributing genes
            top_indices = np.argsort(average_saliency)[-10:][::-1]
            top_genes = [gene_names[idx] for idx in top_indices]
            top_saliency = average_saliency[top_indices]

            # Check CREB1's rank
            creb1_rank = np.sum(average_saliency > creb1_saliency) + 1
            creb1_percentile = (1 - (creb1_rank - 1) / len(average_saliency)) * 100

            print(f"\nTop 10 genes by saliency for predicting {target_gene}:")
            for i, (gene, sal_val) in enumerate(zip(top_genes, top_saliency)):
                marker = "***" if gene == "CREB1" else ""
                print(f"  {i+1:2d}. {gene:<10} {sal_val:.6f} {marker}")

            print(f"\nCREB1 saliency for predicting {target_gene}:")
            print(f"  Score: {creb1_saliency:.6f}")
            print(f"  Rank: {creb1_rank}/{len(average_saliency)} ({creb1_percentile:.1f}th percentile)")

            if creb1_rank <= 10:
                print(f"  *** CREB1 is in top 10 most influential genes! ***")

            results[target_gene] = {
                'creb1_saliency': creb1_saliency,
                'creb1_rank': creb1_rank,
                'creb1_percentile': creb1_percentile,
                'top_genes': top_genes,
                'top_saliency': top_saliency
            }

    return results

# Run the analysis
print("Analyzing CREB1 interactions with target genes...")
interaction_results = analyze_creb1_interactions(model_finetune, pert_data, "CREB1+ctrl")

# Summary
print(f"\n{'='*80}")
print("SUMMARY: CREB1's influence on predicting interacting genes")
print(f"{'='*80}")

for target_gene, result in interaction_results.items():
    rank = result['creb1_rank']
    percentile = result['creb1_percentile']
    saliency = result['creb1_saliency']

    status = "HIGH INFLUENCE" if rank <= 10 else "LOW INFLUENCE"
    print(f"{target_gene:<10}: Rank {rank:4d} ({percentile:5.1f}th percentile) - {status}")
    print(f"            Saliency: {saliency:.6f}")

Analyzing CREB1 interactions with target genes...
Found 9 CREB1-interacting genes in dataset: ['FOS', 'JUN', 'EGR1', 'NR4A1', 'ATF3', 'FOSB', 'JUNB', 'JUND', 'ATF4']

Analyzing saliency for predicting FOS expression
Target gene FOS at index 3623
Analyzing 5 cells for predicting FOS expression
Processing cell 1/5
Processing cell 2/5
Processing cell 3/5
Processing cell 4/5
Processing cell 5/5

Top 10 genes by saliency for predicting FOS:
   1. FOS        0.555212 
   2. RP11-115C21.2 0.001509 
   3. STMN4      0.000702 
   4. S100A1     0.000612 
   5. H2AFZ      0.000433 
   6. DDIT3      0.000419 
   7. CDK1       0.000415 
   8. ATP5E      0.000353 
   9. RGS14      0.000322 
  10. HIST1H4C   0.000269 

CREB1 saliency for predicting FOS:
  Score: 0.000000
  Rank: 4269/5060 (15.7th percentile)

Analyzing saliency for predicting JUN expression
Target gene JUN at index 200
Analyzing 5 cells for predicting JUN expression
Processing cell 1/5
Processing cell 2/5
Processing cell 3/5
Processi