# scGPT Saliency vs Gene Interactions - Simple Analysis

This notebook performs a straightforward analysis to answer:
**Do genes with known interactions (from STRING database) have higher saliency scores in scGPT predictions?**

## Approach
1. Load scGPT model and compute saliency maps
2. Load STRING gene interaction network
3. Divide gene pairs into two groups:
   - Group 1: Gene pairs WITH interaction (from STRING)
   - Group 2: Gene pairs WITHOUT interaction
4. Compare saliency distributions using boxplots


## Setup for Google Colab


In [None]:
# Mount Google Drive
from google.colab import drive
import os

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/notebooks/scGPT_finetune')


In [None]:
# Install required packages
%pip install -q scgpt
%pip install -q gears
%pip install -q scanpy
%pip install -q scikit-learn


In [None]:
# Import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from pathlib import Path
import json
import warnings
warnings.filterwarnings('ignore')

# scGPT imports
from scgpt.model import TransformerGenerator
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed

# Data loading
from gears import PertData

# Set device and seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)
print(f"Using device: {device}")


## 1. Load scGPT Model


In [None]:
# Model configuration - UPDATE PATHS AS NEEDED
model_dir = Path("./save/scGPT_human")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

# Load vocabulary
vocab = GeneVocab.from_file(vocab_file)
for s in ["<pad>", "<cls>", "<eoc>"]:
    if s not in vocab:
        vocab.append_token(s)

print(f"Vocabulary size: {len(vocab)}")

# Load model configuration
with open(model_config_file, "r") as f:
    model_configs = json.load(f)

# Model parameters
ntokens = len(vocab)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

print(f"Model: {embsize}D, {nlayers} layers, {nhead} heads")


In [None]:
# Load perturbation data
pert_data = PertData("../../data")
pert_data.load(data_name="adamson")
pert_data.prepare_split(split="simulation", seed=1)

# Get gene list
genes = pert_data.adata.var["gene_name"].tolist()
gene_ids = np.array([vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int)

print(f"Total genes: {len(genes)}")
print(f"Genes in vocab: {sum([1 for g in genes if g in vocab])}")


In [None]:
# Initialize and load model
model = TransformerGenerator(
    ntokens, embsize, nhead, d_hid, nlayers,
    nlayers_cls=n_layers_cls, n_cls=1, vocab=vocab,
    dropout=0, pad_token="<pad>", pad_value=0, pert_pad_id=0,
    use_fast_transformer=True
)

# Load pretrained weights
pretrained_dict = torch.load(model_file, map_location=device)
load_prefixes = ["encoder", "value_encoder", "transformer_encoder"]
pretrained_dict = {k: v for k, v in pretrained_dict.items() 
                   if any(k.startswith(p) for p in load_prefixes)}

model_dict = model.state_dict()
for k, v in pretrained_dict.items():
    if k in model_dict and v.shape == model_dict[k].shape:
        model_dict[k] = v

model.load_state_dict(model_dict)
model.to(device)
model.eval()

print("✓ Model loaded successfully")


## 2. Compute Saliency Map


In [None]:
def compute_saliency(model, input_ids, input_values, target_idx):
    """Compute gradient-based saliency for a target gene."""
    model.eval()
    input_ids = input_ids.unsqueeze(0).to(device)
    input_values = input_values.unsqueeze(0).to(device)
    input_values.requires_grad_(True)
    
    pert_flags = torch.zeros_like(input_ids).long()
    src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool)
    
    # Forward pass
    output = model(input_ids, input_values, pert_flags,
                   src_key_padding_mask=src_key_padding_mask,
                   CLS=False, CCE=False, MVC=False, ECS=False)
    
    # Backprop to get gradients
    target_output = output["mlm_output"][0, target_idx]
    gradients = torch.autograd.grad(target_output, input_values,
                                     create_graph=False, retain_graph=False)[0]
    
    return torch.abs(gradients).squeeze(0).detach().cpu().numpy()


In [None]:
# Prepare control cells data
ctrl_cells = pert_data.adata[pert_data.adata.obs["condition"] == "ctrl"]
n_samples = min(50, ctrl_cells.shape[0])
sample_idx = np.random.choice(ctrl_cells.shape[0], n_samples, replace=False)

# Average expression across samples
expr_data = ctrl_cells[sample_idx, :].X.toarray()
mean_expr = np.mean(expr_data, axis=0)

input_ids = torch.tensor(gene_ids, dtype=torch.long)
input_values = torch.tensor(mean_expr, dtype=torch.float32)

print(f"Input shape: {input_values.shape}")
print(f"Using {n_samples} control cells for mean expression")


In [None]:
# Compute saliency matrix
# For efficiency, compute for a subset of genes
n_genes = len(genes)
n_compute = min(300, n_genes)  # Adjust based on computational resources

print(f"Computing saliency for {n_compute} target genes...")
print("This may take a few minutes...")

saliency_matrix = np.zeros((n_genes, n_genes))  # [source, target]
target_indices = np.random.choice(n_genes, n_compute, replace=False)

for i, target_idx in enumerate(target_indices):
    if (i + 1) % 50 == 0:
        print(f"  Progress: {i+1}/{n_compute}")
    
    saliency = compute_saliency(model, input_ids, input_values, target_idx)
    saliency_matrix[:, target_idx] = saliency

print(f"\n✓ Saliency matrix computed: {saliency_matrix.shape}")
print(f"  Range: [{saliency_matrix.min():.6f}, {saliency_matrix.max():.6f}]")


## 3. Load STRING Gene Interactions


In [None]:
# UPDATE THESE PATHS TO YOUR STRING DATABASE FILES
STRING_LINKS = "/content/drive/MyDrive/data/STRING/9606.protein.links.v12.0.txt"
STRING_ALIAS = "/content/drive/MyDrive/data/STRING/9606.protein.aliases.v12.0.txt"

# Load STRING data
print("Loading STRING database...")
string_links = pd.read_csv(STRING_LINKS, sep=" ")
string_links["protein1"] = string_links["protein1"].str.replace("9606.", "", regex=False)
string_links["protein2"] = string_links["protein2"].str.replace("9606.", "", regex=False)

# Load protein-to-gene mapping
string_alias = pd.read_csv(STRING_ALIAS, sep="\t", header=None)
string_alias = string_alias[string_alias[1].str.startswith("ENSG", na=False)][[0,1]].drop_duplicates()
string_alias[0] = string_alias[0].str.replace("9606.", "", regex=False)
protein_to_gene = dict(zip(string_alias[0], string_alias[1]))

# Map to gene IDs
string_links["gene1"] = string_links["protein1"].map(protein_to_gene)
string_links["gene2"] = string_links["protein2"].map(protein_to_gene)
string_links = string_links.dropna(subset=["gene1", "gene2"])

# Filter for high-confidence interactions
high_conf = string_links[string_links["combined_score"] > 700]

print(f"✓ Loaded {len(high_conf)} high-confidence interactions")
print(f"  Unique genes: {len(set(high_conf['gene1']) | set(high_conf['gene2']))}")


In [None]:
# Build interaction matrix
def build_interaction_matrix(string_data, gene_list):
    """Build binary interaction matrix from STRING data."""
    n = len(gene_list)
    matrix = np.zeros((n, n))
    gene_to_idx = {gene: i for i, gene in enumerate(gene_list)}
    
    n_matched = 0
    for _, row in string_data.iterrows():
        g1, g2 = row["gene1"], row["gene2"]
        if g1 in gene_to_idx and g2 in gene_to_idx:
            i1, i2 = gene_to_idx[g1], gene_to_idx[g2]
            matrix[i1, i2] = matrix[i2, i1] = 1
            n_matched += 1
    
    print(f"  Matched {n_matched} interactions to gene list")
    print(f"  Density: {matrix.sum() / (n * n):.4f}")
    return matrix

interaction_matrix = build_interaction_matrix(high_conf, genes)
print(f"✓ Interaction matrix: {interaction_matrix.shape}")


## 4. Divide Saliency into Two Groups


In [None]:
# Remove diagonal and filter valid entries
n = saliency_matrix.shape[0]
mask = ~np.eye(n, dtype=bool)  # Exclude self-interactions

saliency_flat = saliency_matrix[mask]
interaction_flat = interaction_matrix[mask]

# Only keep pairs where we computed saliency
valid = saliency_flat != 0
saliency_flat = saliency_flat[valid]
interaction_flat = interaction_flat[valid]

# Divide into two groups
group_with_interaction = saliency_flat[interaction_flat == 1]
group_without_interaction = saliency_flat[interaction_flat == 0]

print("\n" + "="*60)
print("SALIENCY GROUPS")
print("="*60)
print(f"Total gene pairs evaluated: {len(saliency_flat):,}")
print(f"\nGroup 1 - WITH interaction (STRING):")
print(f"  Count: {len(group_with_interaction):,}")
print(f"  Mean saliency: {group_with_interaction.mean():.6f}")
print(f"  Median saliency: {np.median(group_with_interaction):.6f}")
print(f"  Std dev: {group_with_interaction.std():.6f}")
print(f"\nGroup 2 - WITHOUT interaction:")
print(f"  Count: {len(group_without_interaction):,}")
print(f"  Mean saliency: {group_without_interaction.mean():.6f}")
print(f"  Median saliency: {np.median(group_without_interaction):.6f}")
print(f"  Std dev: {group_without_interaction.std():.6f}")
print(f"\nRatio (With / Without):")
print(f"  Mean ratio: {group_with_interaction.mean() / group_without_interaction.mean():.3f}x")
print(f"  Median ratio: {np.median(group_with_interaction) / np.median(group_without_interaction):.3f}x")
print("="*60)


## 5. Visualization: Boxplots


In [None]:
# Create boxplot comparison
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Plot 1: Boxplot
ax = axes[0]
data_to_plot = [group_without_interaction, group_with_interaction]
bp = ax.boxplot(data_to_plot, labels=['Without\nInteraction', 'With\nInteraction'],
                patch_artist=True, showfliers=False)
bp['boxes'][0].set_facecolor('lightblue')
bp['boxes'][1].set_facecolor('lightcoral')
ax.set_ylabel('Saliency Score', fontsize=12)
ax.set_title('Saliency Comparison: Interacting vs Non-Interacting Gene Pairs', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Add sample sizes
ax.text(1, ax.get_ylim()[1]*0.95, f'n={len(group_without_interaction):,}', 
        ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
ax.text(2, ax.get_ylim()[1]*0.95, f'n={len(group_with_interaction):,}', 
        ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Plot 2: Violin plot
ax = axes[1]
parts = ax.violinplot([group_without_interaction, group_with_interaction], 
                       positions=[1, 2], showmeans=True, showmedians=True)
ax.set_xticks([1, 2])
ax.set_xticklabels(['Without\nInteraction', 'With\nInteraction'])
ax.set_ylabel('Saliency Score', fontsize=12)
ax.set_title('Saliency Distribution (Violin Plot)', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Plot 3: Histogram overlay
ax = axes[2]
ax.hist(group_without_interaction, bins=50, alpha=0.6, label='Without Interaction', 
        density=True, color='blue', edgecolor='black')
ax.hist(group_with_interaction, bins=50, alpha=0.6, label='With Interaction', 
        density=True, color='red', edgecolor='black')
ax.set_xlabel('Saliency Score', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.set_title('Saliency Distribution (Histogram)', fontsize=13, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.grid(True, alpha=0.3, axis='y')
ax.set_yscale('log')

plt.tight_layout()
plt.savefig('saliency_vs_interaction_boxplots.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Plots saved as 'saliency_vs_interaction_boxplots.png'")
