# scGPT Saliency Map Analysis

This notebook analyzes the attention patterns and saliency maps of scGPT to understand how the model uses gene interactions for perturbation prediction.

## Research Question
Given a perturbation of gene A, if gene B has strong interaction with gene A, does the saliency map show that gene A has large weight in predicting the expression of gene B?

## Approach
1. Load pre-trained scGPT model
2. Compute saliency maps using gradient-based methods
3. Analyze attention weights for gene-gene interactions
4. Compare saliency patterns with known gene interaction networks


In [None]:
# Setup and imports
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.nn import functional as F
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# scGPT imports
import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id

# Data loading
from gears import PertData
from torch_geometric.loader import DataLoader

# Set device and seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)
print(f"Using device: {device}")


## 1. Load Model and Data


In [None]:
# Model configuration
model_dir = Path("./save/scGPT_human")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

# Load vocabulary
vocab = GeneVocab.from_file(vocab_file)
special_tokens = ["<pad>", "<cls>", "<eoc>"]
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

print(f"Vocabulary size: {len(vocab)}")

# Load model configuration
with open(model_config_file, "r") as f:
    model_configs = json.load(f)

# Model parameters
ntokens = len(vocab)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]
dropout = 0
pad_token = "<pad>"
pad_value = 0
pert_pad_id = 0
use_fast_transformer = True

print(f"Model config: {embsize} dim, {nlayers} layers, {nhead} heads")


In [None]:
# Load perturbation data
pert_data = PertData("./data")
pert_data.load(data_name="adamson")
pert_data.prepare_split(split="simulation", seed=1)

# Get gene information
genes = pert_data.adata.var["gene_name"].tolist()
pert_data.adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
print(f"Matched {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes in vocabulary")

# Create gene ID mapping
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)

# Get control data for baseline
ctrl_data = pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
print(f"Control cells: {ctrl_data.shape[0]}")

# Get perturbation conditions
pert_conditions = pert_data.adata.obs["condition"].unique()
pert_conditions = [c for c in pert_conditions if c != "ctrl"]
print(f"Available perturbations: {len(pert_conditions)}")
print(f"Sample perturbations: {pert_conditions[:5]}")


In [None]:
# Initialize model
model = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=use_fast_transformer,
)

# Load pre-trained weights
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file, map_location=device)
load_param_prefixs = ["encoder", "value_encoder", "transformer_encoder"]

pretrained_dict = {
    k: v for k, v in pretrained_dict.items()
    if any([k.startswith(prefix) for prefix in load_param_prefixs])
}

for k, v in pretrained_dict.items():
    if k in model_dict and v.shape == model_dict[k].shape:
        model_dict[k] = v

model.load_state_dict(model_dict)
model.to(device)
model.eval()

print("Model loaded successfully")


## 2. Saliency Map Computation Functions


In [None]:
def compute_gradient_saliency(model, input_ids, input_values, target_gene_idx, 
                             include_zero_gene="all", max_seq_len=1536):
    """
    Compute gradient-based saliency map for a specific target gene.
    
    Args:
        model: scGPT model
        input_ids: gene token IDs [seq_len]
        input_values: gene expression values [seq_len]
        target_gene_idx: index of target gene in the sequence
        include_zero_gene: whether to include zero-expressed genes
        max_seq_len: maximum sequence length
    
    Returns:
        saliency_map: gradient magnitudes for each input gene
    """
    model.eval()
    input_ids = input_ids.unsqueeze(0).to(device)  # [1, seq_len]
    input_values = input_values.unsqueeze(0).to(device)  # [1, seq_len]
    
    # Create perturbation flags (all zeros for now)
    pert_flags = torch.zeros_like(input_ids).long()
    src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool)
    
    # Enable gradients for input values
    input_values.requires_grad_(True)
    
    # Forward pass
    output_dict = model(
        input_ids,
        input_values,
        pert_flags,
        src_key_padding_mask=src_key_padding_mask,
        CLS=False, CCE=False, MVC=False, ECS=False,
    )
    
    # Get output for target gene
    target_output = output_dict["mlm_output"][0, target_gene_idx]
    
    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=target_output,
        inputs=input_values,
        create_graph=False,
        retain_graph=False
    )[0]
    
    # Return saliency (gradient magnitude)
    saliency = torch.abs(gradients).squeeze(0).detach().cpu().numpy()
    
    return saliency


In [None]:
def compute_attention_weights(model, input_ids, input_values, layer_idx=None):
    """
    Extract attention weights from the transformer model.
    
    Args:
        model: scGPT model
        input_ids: gene token IDs [1, seq_len]
        input_values: gene expression values [1, seq_len]
        layer_idx: which layer to extract attention from (None for all layers)
    
    Returns:
        attention_weights: attention weights [n_heads, seq_len, seq_len] or list of such
    """
    model.eval()
    
    with torch.no_grad():
        # Create perturbation flags and padding mask
        pert_flags = torch.zeros_like(input_ids).long()
        src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool)
        
        # Forward pass with attention extraction
        # We need to modify the model to return attention weights
        # For now, we'll use a hook-based approach
        
        attention_weights = []
        
        def attention_hook(module, input, output):
            # Extract attention weights from multi-head attention
            if hasattr(module, 'attention_weights'):
                attention_weights.append(module.attention_weights.detach().cpu())
        
        # Register hooks on attention layers
        hooks = []
        for name, module in model.named_modules():
            if 'self_attn' in name and 'attention' in name:
                hook = module.register_forward_hook(attention_hook)
                hooks.append(hook)
        
        # Forward pass
        output_dict = model(
            input_ids,
            input_values,
            pert_flags,
            src_key_padding_mask=src_key_padding_mask,
            CLS=False, CCE=False, MVC=False, ECS=False,
        )
        
        # Remove hooks
        for hook in hooks:
            hook.remove()
        
        return attention_weights


In [None]:
def prepare_sample_data(pert_data, condition, n_samples=10):
    """
    Prepare sample data for saliency analysis.
    
    Args:
        pert_data: PertData object
        condition: perturbation condition
        n_samples: number of samples to take
    
    Returns:
        sample_data: dictionary with input_ids, input_values, and gene_names
    """
    # Get data for the condition
    if condition == "ctrl":
        condition_data = pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
    else:
        condition_data = pert_data.adata[pert_data.adata.obs["condition"] == condition]
    
    # Sample cells
    n_cells = min(n_samples, condition_data.shape[0])
    sample_indices = np.random.choice(condition_data.shape[0], n_cells, replace=False)
    sample_data = condition_data[sample_indices, :]
    
    # Get expression values
    expr_values = sample_data.X.toarray()  # [n_cells, n_genes]
    
    # Get gene IDs
    input_ids = torch.tensor(gene_ids, dtype=torch.long)  # [n_genes]
    
    # Take mean across samples
    mean_expr = np.mean(expr_values, axis=0)  # [n_genes]
    input_values = torch.tensor(mean_expr, dtype=torch.float32)
    
    return {
        'input_ids': input_ids,
        'input_values': input_values,
        'gene_names': genes,
        'condition': condition
    }


## 3. Saliency Analysis for Gene Interactions


In [None]:
# Select a perturbation condition to analyze
pert_condition = pert_conditions[0]  # Use first available perturbation
print(f"Analyzing perturbation: {pert_condition}")

# Parse the perturbation to identify the perturbed gene(s)
if "+" in pert_condition:
    pert_genes = pert_condition.split("+")
    pert_genes = [g for g in pert_genes if g != "ctrl"]
else:
    pert_genes = [pert_condition]

print(f"Perturbed genes: {pert_genes}")

# Find indices of perturbed genes
pert_gene_indices = []
for gene in pert_genes:
    if gene in genes:
        idx = genes.index(gene)
        pert_gene_indices.append(idx)
        print(f"Gene {gene} at index {idx}")
    else:
        print(f"Warning: Gene {gene} not found in gene list")

if not pert_gene_indices:
    print("No valid perturbed genes found. Using first gene as example.")
    pert_gene_indices = [0]
    pert_genes = [genes[0]]


In [None]:
# Prepare data for control and perturbed conditions
ctrl_sample = prepare_sample_data(pert_data, "ctrl", n_samples=20)
pert_sample = prepare_sample_data(pert_data, pert_condition, n_samples=20)

print(f"Control sample shape: {ctrl_sample['input_values'].shape}")
print(f"Perturbed sample shape: {pert_sample['input_values'].shape}")

# Compute saliency maps for perturbed genes
saliency_results = {}

for i, pert_gene_idx in enumerate(pert_gene_indices):
    pert_gene_name = pert_genes[i]
    print(f"\nComputing saliency for gene {pert_gene_name} (index {pert_gene_idx})")
    
    # Compute saliency on control data
    ctrl_saliency = compute_gradient_saliency(
        model, 
        ctrl_sample['input_ids'], 
        ctrl_sample['input_values'],
        pert_gene_idx
    )
    
    # Compute saliency on perturbed data
    pert_saliency = compute_gradient_saliency(
        model, 
        pert_sample['input_ids'], 
        pert_sample['input_values'],
        pert_gene_idx
    )
    
    saliency_results[pert_gene_name] = {
        'ctrl_saliency': ctrl_saliency,
        'pert_saliency': pert_saliency,
        'gene_idx': pert_gene_idx
    }
    
    print(f"Control saliency range: [{ctrl_saliency.min():.6f}, {ctrl_saliency.max():.6f}]")
    print(f"Perturbed saliency range: [{pert_saliency.min():.6f}, {pert_saliency.max():.6f}]")


## 4. Visualization and Analysis


In [None]:
def plot_saliency_comparison(saliency_results, gene_names, top_n=20):
    """
    Plot comparison of saliency maps between control and perturbed conditions.
    """
    fig, axes = plt.subplots(len(saliency_results), 2, figsize=(15, 4*len(saliency_results)))
    if len(saliency_results) == 1:
        axes = axes.reshape(1, -1)
    
    for i, (pert_gene, results) in enumerate(saliency_results.items()):
        ctrl_saliency = results['ctrl_saliency']
        pert_saliency = results['pert_saliency']
        
        # Get top contributing genes
        ctrl_top_indices = np.argsort(ctrl_saliency)[-top_n:][::-1]
        pert_top_indices = np.argsort(pert_saliency)[-top_n:][::-1]
        
        # Plot control saliency
        ax1 = axes[i, 0]
        ctrl_top_genes = [gene_names[idx] for idx in ctrl_top_indices]
        ctrl_top_values = ctrl_saliency[ctrl_top_indices]
        
        bars1 = ax1.barh(range(top_n), ctrl_top_values)
        ax1.set_yticks(range(top_n))
        ax1.set_yticklabels(ctrl_top_genes)
        ax1.set_xlabel('Saliency Score')
        ax1.set_title(f'Control: Top {top_n} genes for {pert_gene}')
        ax1.invert_yaxis()
        
        # Highlight the perturbed gene if it's in the top genes
        if pert_gene in ctrl_top_genes:
            idx = ctrl_top_genes.index(pert_gene)
            bars1[idx].set_color('red')
        
        # Plot perturbed saliency
        ax2 = axes[i, 1]
        pert_top_genes = [gene_names[idx] for idx in pert_top_indices]
        pert_top_values = pert_saliency[pert_top_indices]
        
        bars2 = ax2.barh(range(top_n), pert_top_values)
        ax2.set_yticks(range(top_n))
        ax2.set_yticklabels(pert_top_genes)
        ax2.set_xlabel('Saliency Score')
        ax2.set_title(f'Perturbed: Top {top_n} genes for {pert_gene}')
        ax2.invert_yaxis()
        
        # Highlight the perturbed gene if it's in the top genes
        if pert_gene in pert_top_genes:
            idx = pert_top_genes.index(pert_gene)
            bars2[idx].set_color('red')
    
    plt.tight_layout()
    plt.show()


In [None]:
def analyze_gene_interactions(saliency_results, gene_names, pert_gene_indices):
    """
    Analyze whether perturbed genes show high saliency for their own prediction.
    """
    results_summary = []
    
    for pert_gene, results in saliency_results.items():
        pert_idx = results['gene_idx']
        ctrl_saliency = results['ctrl_saliency']
        pert_saliency = results['pert_saliency']
        
        # Get saliency score for the perturbed gene itself
        ctrl_self_saliency = ctrl_saliency[pert_idx]
        pert_self_saliency = pert_saliency[pert_idx]
        
        # Rank of the perturbed gene in saliency scores
        ctrl_rank = np.sum(ctrl_saliency > ctrl_self_saliency) + 1
        pert_rank = np.sum(pert_saliency > pert_self_saliency) + 1
        
        # Top percentile
        ctrl_percentile = (1 - (ctrl_rank - 1) / len(ctrl_saliency)) * 100
        pert_percentile = (1 - (pert_rank - 1) / len(pert_saliency)) * 100
        
        results_summary.append({
            'pert_gene': pert_gene,
            'ctrl_self_saliency': ctrl_self_saliency,
            'pert_self_saliency': pert_self_saliency,
            'ctrl_rank': ctrl_rank,
            'pert_rank': pert_rank,
            'ctrl_percentile': ctrl_percentile,
            'pert_percentile': pert_percentile,
            'saliency_change': pert_self_saliency - ctrl_self_saliency
        })
        
        print(f"\nGene: {pert_gene}")
        print(f"  Control self-saliency: {ctrl_self_saliency:.6f} (rank {ctrl_rank}/{len(ctrl_saliency)}, {ctrl_percentile:.1f}th percentile)")
        print(f"  Perturbed self-saliency: {pert_self_saliency:.6f} (rank {pert_rank}/{len(pert_saliency)}, {pert_percentile:.1f}th percentile)")
        print(f"  Change in self-saliency: {pert_self_saliency - ctrl_self_saliency:.6f}")
    
    return pd.DataFrame(results_summary)


In [None]:
# Generate visualizations
print("Generating saliency comparison plots...")
plot_saliency_comparison(saliency_results, genes, top_n=15)


In [None]:
# Analyze gene interactions
print("Analyzing gene interactions...")
summary_df = analyze_gene_interactions(saliency_results, genes, pert_gene_indices)

print("\nSummary DataFrame:")
print(summary_df)


In [None]:
# Create summary visualization
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Self-saliency comparison
axes[0].bar(summary_df['pert_gene'], summary_df['ctrl_self_saliency'], 
           alpha=0.7, label='Control', color='blue')
axes[0].bar(summary_df['pert_gene'], summary_df['pert_self_saliency'], 
           alpha=0.7, label='Perturbed', color='red')
axes[0].set_xlabel('Perturbed Gene')
axes[0].set_ylabel('Self-Saliency Score')
axes[0].set_title('Self-Saliency: Control vs Perturbed')
axes[0].legend()
axes[0].tick_params(axis='x', rotation=45)

# Plot 2: Rank comparison
axes[1].bar(summary_df['pert_gene'], summary_df['ctrl_rank'], 
           alpha=0.7, label='Control', color='blue')
axes[1].bar(summary_df['pert_gene'], summary_df['pert_rank'], 
           alpha=0.7, label='Perturbed', color='red')
axes[1].set_xlabel('Perturbed Gene')
axes[1].set_ylabel('Rank (lower is better)')
axes[1].set_title('Rank: Control vs Perturbed')
axes[1].legend()
axes[1].tick_params(axis='x', rotation=45)

# Plot 3: Saliency change
axes[2].bar(summary_df['pert_gene'], summary_df['saliency_change'], 
           alpha=0.7, color='green')
axes[2].set_xlabel('Perturbed Gene')
axes[2].set_ylabel('Change in Self-Saliency')
axes[2].set_title('Change in Self-Saliency (Perturbed - Control)')
axes[2].axhline(y=0, color='black', linestyle='--', alpha=0.5)
axes[2].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()
