# scGPT Saliency Map vs Gene Network Analysis

This notebook analyzes whether scGPT's saliency maps represent actual gene-gene interactions from the STRING database.

## Research Question
Given a perturbation of gene A, if gene B has strong interaction with gene A in the STRING database, does the saliency map show that gene A has large weight in predicting the expression of gene B?

## Approach
1. Load pre-trained scGPT model
2. Load STRING gene interaction network
3. Compute saliency maps for gene-gene interactions
4. Compare saliency patterns with known gene interactions from STRING
5. Evaluate correlation and predictive power


## Setup for Google Colab


In [None]:
# Mount Google Drive
from google.colab import drive
import os

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/notebooks/scGPT_finetune')

# Install required packages
%pip install -q scgpt
%pip install -q gears
%pip install -q scanpy
%pip install -q torch-geometric
%pip install -q scikit-learn


In [None]:
# Import libraries
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.nn import functional as F
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# scGPT imports
import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id

# Data loading
from gears import PertData

# ML imports
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
from scipy.stats import spearmanr, pearsonr

# Set device and seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)
print(f"Using device: {device}")


## 1. Load scGPT Model and Data


In [None]:
# Model configuration - UPDATE THESE PATHS FOR YOUR SETUP
model_dir = Path("./save/scGPT_human")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

# Load vocabulary
vocab = GeneVocab.from_file(vocab_file)
special_tokens = ["<pad>", "<cls>", "<eoc>"]
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

print(f"Vocabulary size: {len(vocab)}")

# Load model configuration
with open(model_config_file, "r") as f:
    model_configs = json.load(f)

# Model parameters
ntokens = len(vocab)
embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]
dropout = 0
pad_token = "<pad>"
pad_value = 0
pert_pad_id = 0
use_fast_transformer = True

print(f"Model config: {embsize} dim, {nlayers} layers, {nhead} heads")


In [None]:
# Load perturbation data
pert_data = PertData("../../data")  # Adjust path as needed
pert_data.load(data_name="adamson")
pert_data.prepare_split(split="simulation", seed=1)

# Get gene information
genes = pert_data.adata.var["gene_name"].tolist()
pert_data.adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
print(f"Matched {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes in vocabulary")

# Create gene ID mapping
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)
print(f"Total genes: {n_genes}")


In [None]:
# Initialize model
model = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=use_fast_transformer,
)

# Load pre-trained weights
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file, map_location=device)
load_param_prefixs = ["encoder", "value_encoder", "transformer_encoder"]

pretrained_dict = {
    k: v for k, v in pretrained_dict.items()
    if any([k.startswith(prefix) for prefix in load_param_prefixs])
}

for k, v in pretrained_dict.items():
    if k in model_dict and v.shape == model_dict[k].shape:
        model_dict[k] = v

model.load_state_dict(model_dict)
model.to(device)
model.eval()

print("Model loaded successfully")


## 2. Load STRING Gene Interaction Network


In [None]:
# Download STRING database if needed
# You may need to download these files from https://string-db.org/
# For human: 9606.protein.links.v12.0.txt and 9606.protein.aliases.v12.0.txt

# UPDATE THESE PATHS FOR YOUR SETUP
STRING_LINKS_FILE = "/content/drive/MyDrive/data/STRING/9606.protein.links.v12.0.txt"
STRING_ALIAS_FILE = "/content/drive/MyDrive/data/STRING/9606.protein.aliases.v12.0.txt"

print("Loading STRING database...")

# Load protein-protein interactions
STRING_homosapien = pd.read_csv(STRING_LINKS_FILE, sep=" ")
STRING_homosapien["protein1"] = STRING_homosapien["protein1"].str.replace("9606.", "", regex=False)
STRING_homosapien["protein2"] = STRING_homosapien["protein2"].str.replace("9606.", "", regex=False)

print(f"Loaded {len(STRING_homosapien)} protein-protein interactions")

# Load protein-gene mappings
STRING_homosapien_alias = pd.read_csv(STRING_ALIAS_FILE, sep="\t", header=None)
STRING_homosapien_alias = STRING_homosapien_alias[STRING_homosapien_alias[1].str.startswith("ENSG", na=False)][[0,1]].drop_duplicates()
STRING_homosapien_alias[0] = STRING_homosapien_alias[0].str.replace("9606.", "", regex=False)
ENSP_to_ENSG = dict(zip(STRING_homosapien_alias[0], STRING_homosapien_alias[1]))

print(f"Loaded {len(ENSP_to_ENSG)} protein-gene mappings")

# Map protein IDs to gene IDs
STRING_homosapien["gene1"] = STRING_homosapien["protein1"].map(ENSP_to_ENSG)
STRING_homosapien["gene2"] = STRING_homosapien["protein2"].map(ENSP_to_ENSG)
STRING_gene_interaction = STRING_homosapien.dropna(subset=["gene1", "gene2", "combined_score"])

print(f"Total gene-gene interactions: {len(STRING_gene_interaction)}")

# Filter for high-confidence interactions (score > 700)
STRING_gene_interaction_high_conf = STRING_gene_interaction[STRING_gene_interaction["combined_score"] > 700]

print(f"High-confidence interactions (score > 700): {len(STRING_gene_interaction_high_conf)}")
print(f"Unique genes in network: {len(set(STRING_gene_interaction_high_conf['gene1']) | set(STRING_gene_interaction_high_conf['gene2']))}")


In [None]:
# Map gene names to Ensembl IDs if needed
# This depends on your data format - you may need to adjust this

# For now, we'll create a mapping based on gene symbols
# You might need to load a proper gene ID mapping file

print(f"Sample genes from perturbation data: {genes[:10]}")
print(f"Sample genes from STRING: {list(STRING_gene_interaction_high_conf['gene1'].unique()[:10])}")


## 3. Saliency Map Computation Functions


In [None]:
def compute_gradient_saliency(model, input_ids, input_values, target_gene_idx):
    """
    Compute gradient-based saliency map for a specific target gene.
    
    Args:
        model: scGPT model
        input_ids: gene token IDs [seq_len]
        input_values: gene expression values [seq_len]
        target_gene_idx: index of target gene in the sequence
    
    Returns:
        saliency_map: gradient magnitudes for each input gene
    """
    model.eval()
    input_ids = input_ids.unsqueeze(0).to(device)  # [1, seq_len]
    input_values = input_values.unsqueeze(0).to(device)  # [1, seq_len]
    
    # Create perturbation flags (all zeros for now)
    pert_flags = torch.zeros_like(input_ids).long()
    src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool)
    
    # Enable gradients for input values
    input_values.requires_grad_(True)
    
    # Forward pass
    output_dict = model(
        input_ids,
        input_values,
        pert_flags,
        src_key_padding_mask=src_key_padding_mask,
        CLS=False, CCE=False, MVC=False, ECS=False,
    )
    
    # Get output for target gene
    target_output = output_dict["mlm_output"][0, target_gene_idx]
    
    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=target_output,
        inputs=input_values,
        create_graph=False,
        retain_graph=False
    )[0]
    
    # Return saliency (gradient magnitude)
    saliency = torch.abs(gradients).squeeze(0).detach().cpu().numpy()
    
    return saliency


In [None]:
def prepare_sample_data(pert_data, condition, n_samples=20):
    """
    Prepare sample data for saliency analysis.
    
    Args:
        pert_data: PertData object
        condition: perturbation condition ("ctrl" for control)
        n_samples: number of samples to average
    
    Returns:
        dictionary with input_ids, input_values, and gene_names
    """
    # Get data for the condition
    condition_data = pert_data.adata[pert_data.adata.obs["condition"] == condition]
    
    # Sample cells
    n_cells = min(n_samples, condition_data.shape[0])
    sample_indices = np.random.choice(condition_data.shape[0], n_cells, replace=False)
    sample_data = condition_data[sample_indices, :]
    
    # Get expression values
    expr_values = sample_data.X.toarray()  # [n_cells, n_genes]
    
    # Get gene IDs
    input_ids = torch.tensor(gene_ids, dtype=torch.long)  # [n_genes]
    
    # Take mean across samples
    mean_expr = np.mean(expr_values, axis=0)  # [n_genes]
    input_values = torch.tensor(mean_expr, dtype=torch.float32)
    
    return {
        'input_ids': input_ids,
        'input_values': input_values,
        'gene_names': genes,
        'condition': condition
    }


## 4. Compute Saliency Matrix for Gene Pairs


In [None]:
# Prepare control data
print("Preparing control sample data...")
ctrl_sample = prepare_sample_data(pert_data, "ctrl", n_samples=50)

print(f"Control sample shape: {ctrl_sample['input_values'].shape}")
print(f"Number of genes: {len(ctrl_sample['gene_names'])}")


In [None]:
# Compute saliency matrix for all genes
# This computes: for each target gene j, what is the saliency of each source gene i?
# Saliency[i, j] = how much does gene i influence the prediction of gene j

print("Computing saliency matrix...")
print("This may take a while...")

n_genes = len(genes)
saliency_matrix = np.zeros((n_genes, n_genes))  # [source_gene, target_gene]

# For computational efficiency, we can sample a subset of genes
# Set to n_genes to compute for all genes (may be slow)
n_genes_to_compute = min(500, n_genes)  # Adjust based on your computational resources
gene_indices_to_compute = np.random.choice(n_genes, n_genes_to_compute, replace=False)

for i, target_idx in enumerate(gene_indices_to_compute):
    if i % 50 == 0:
        print(f"Progress: {i}/{n_genes_to_compute} genes computed")
    
    try:
        saliency = compute_gradient_saliency(
            model,
            ctrl_sample['input_ids'],
            ctrl_sample['input_values'],
            target_idx
        )
        saliency_matrix[:, target_idx] = saliency
    except Exception as e:
        print(f"Error computing saliency for gene {target_idx}: {e}")
        continue

print(f"Saliency matrix shape: {saliency_matrix.shape}")
print(f"Saliency range: [{saliency_matrix.min():.6f}, {saliency_matrix.max():.6f}]")


## 5. Build Gene Interaction Network from STRING


In [None]:
def build_interaction_matrix(string_data, gene_list, gene_id_col1='gene1', gene_id_col2='gene2'):
    """
    Build a binary interaction matrix from STRING data.
    
    Args:
        string_data: DataFrame with gene interactions
        gene_list: List of gene names/IDs in the same order as saliency matrix
        gene_id_col1: Column name for first gene
        gene_id_col2: Column name for second gene
    
    Returns:
        interaction_matrix: Binary matrix [n_genes, n_genes]
    """
    n = len(gene_list)
    interaction_matrix = np.zeros((n, n))
    
    # Create gene to index mapping
    gene_to_idx = {gene: i for i, gene in enumerate(gene_list)}
    
    # Fill in interactions
    n_matched = 0
    for _, row in string_data.iterrows():
        gene1 = row[gene_id_col1]
        gene2 = row[gene_id_col2]
        
        # Check if both genes are in our gene list
        if gene1 in gene_to_idx and gene2 in gene_to_idx:
            idx1 = gene_to_idx[gene1]
            idx2 = gene_to_idx[gene2]
            interaction_matrix[idx1, idx2] = 1
            interaction_matrix[idx2, idx1] = 1  # Symmetric
            n_matched += 1
    
    print(f"Matched {n_matched} interactions to gene list")
    print(f"Interaction density: {interaction_matrix.sum() / (n * n):.4f}")
    
    return interaction_matrix


In [None]:
# Note: You may need to map between gene name formats
# The perturbation data uses gene symbols, STRING uses Ensembl IDs
# You might need to load a mapping file or use a package like mygene

# For demonstration, we'll try to build the matrix
# You may need to adjust this based on your gene ID formats

print("Building interaction matrix from STRING data...")

# Try to build interaction matrix
# This assumes gene names match - you may need to add ID conversion
interaction_matrix = build_interaction_matrix(
    STRING_gene_interaction_high_conf,
    genes,
    gene_id_col1='gene1',
    gene_id_col2='gene2'
)

print(f"Interaction matrix shape: {interaction_matrix.shape}")


## 6. Analyze Correlation Between Saliency and Gene Interactions


In [None]:
def evaluate_saliency_vs_network(saliency_matrix, interaction_matrix):
    """
    Evaluate how well saliency maps predict gene interactions.
    
    Args:
        saliency_matrix: Matrix of saliency scores [n_genes, n_genes]
        interaction_matrix: Binary matrix of gene interactions [n_genes, n_genes]
    
    Returns:
        Dictionary of evaluation metrics
    """
    # Remove diagonal (self-interactions)
    n = saliency_matrix.shape[0]
    mask = ~np.eye(n, dtype=bool)
    
    saliency_flat = saliency_matrix[mask]
    interaction_flat = interaction_matrix[mask]
    
    # Filter out pairs where we computed saliency
    valid_mask = saliency_flat != 0  # Only consider where we computed saliency
    saliency_flat = saliency_flat[valid_mask]
    interaction_flat = interaction_flat[valid_mask]
    
    print(f"Evaluating {len(saliency_flat)} gene pairs")
    print(f"Positive pairs (interacting): {interaction_flat.sum()}")
    print(f"Negative pairs (non-interacting): {(interaction_flat == 0).sum()}")
    
    # Compute correlation
    spearman_corr, spearman_pval = spearmanr(saliency_flat, interaction_flat)
    pearson_corr, pearson_pval = pearsonr(saliency_flat, interaction_flat)
    
    # Compute classification metrics
    # Use saliency as predictor for interaction
    if len(np.unique(interaction_flat)) > 1:  # Need both classes
        auc_roc = roc_auc_score(interaction_flat, saliency_flat)
        auc_pr = average_precision_score(interaction_flat, saliency_flat)
        
        # Binary prediction using median threshold
        threshold = np.median(saliency_flat)
        predictions = (saliency_flat > threshold).astype(int)
        
        accuracy = accuracy_score(interaction_flat, predictions)
        precision = precision_score(interaction_flat, predictions, zero_division=0)
        recall = recall_score(interaction_flat, predictions, zero_division=0)
        f1 = f1_score(interaction_flat, predictions, zero_division=0)
    else:
        auc_roc = auc_pr = accuracy = precision = recall = f1 = np.nan
    
    results = {
        'spearman_corr': spearman_corr,
        'spearman_pval': spearman_pval,
        'pearson_corr': pearson_corr,
        'pearson_pval': pearson_pval,
        'auc_roc': auc_roc,
        'auc_pr': auc_pr,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'n_pairs': len(saliency_flat),
        'n_positive': int(interaction_flat.sum()),
        'n_negative': int((interaction_flat == 0).sum())
    }
    
    return results, saliency_flat, interaction_flat


In [None]:
# Evaluate saliency vs network
print("Evaluating saliency maps against gene interaction network...")

results, saliency_flat, interaction_flat = evaluate_saliency_vs_network(
    saliency_matrix,
    interaction_matrix
)

print("\n" + "="*60)
print("RESULTS: Saliency Map vs Gene Interaction Network")
print("="*60)
print(f"\nCorrelation Metrics:")
print(f"  Spearman correlation: {results['spearman_corr']:.4f} (p={results['spearman_pval']:.2e})")
print(f"  Pearson correlation:  {results['pearson_corr']:.4f} (p={results['pearson_pval']:.2e})")
print(f"\nClassification Metrics:")
print(f"  AUC-ROC:  {results['auc_roc']:.4f}")
print(f"  AUC-PR:   {results['auc_pr']:.4f}")
print(f"  Accuracy: {results['accuracy']:.4f}")
print(f"  Precision: {results['precision']:.4f}")
print(f"  Recall:    {results['recall']:.4f}")
print(f"  F1 Score:  {results['f1_score']:.4f}")
print(f"\nDataset Statistics:")
print(f"  Total gene pairs evaluated: {results['n_pairs']}")
print(f"  Interacting pairs: {results['n_positive']}")
print(f"  Non-interacting pairs: {results['n_negative']}")
print(f"  Class balance: {results['n_positive']/results['n_pairs']:.4f}")
print("="*60)


## 7. Visualization


In [None]:
# Plot saliency distribution for interacting vs non-interacting gene pairs
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram
ax = axes[0]
interacting_saliency = saliency_flat[interaction_flat == 1]
non_interacting_saliency = saliency_flat[interaction_flat == 0]

ax.hist(non_interacting_saliency, bins=50, alpha=0.6, label='Non-interacting', density=True, color='blue')
ax.hist(interacting_saliency, bins=50, alpha=0.6, label='Interacting (STRING)', density=True, color='red')
ax.set_xlabel('Saliency Score')
ax.set_ylabel('Density')
ax.set_title('Saliency Distribution: Interacting vs Non-Interacting Pairs')
ax.legend()
ax.set_yscale('log')

# Box plot
ax = axes[1]
data_to_plot = [non_interacting_saliency, interacting_saliency]
ax.boxplot(data_to_plot, labels=['Non-interacting', 'Interacting'])
ax.set_ylabel('Saliency Score')
ax.set_title('Saliency Score Comparison')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nMean saliency for interacting pairs: {interacting_saliency.mean():.6f}")
print(f"Mean saliency for non-interacting pairs: {non_interacting_saliency.mean():.6f}")
print(f"Ratio: {interacting_saliency.mean() / non_interacting_saliency.mean():.2f}x")


In [None]:
# Plot ROC and PR curves
from sklearn.metrics import roc_curve, precision_recall_curve

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# ROC curve
fpr, tpr, _ = roc_curve(interaction_flat, saliency_flat)
ax = axes[0]
ax.plot(fpr, tpr, linewidth=2, label=f'AUC = {results["auc_roc"]:.3f}')
ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve: Saliency Predicting Gene Interactions')
ax.legend()
ax.grid(True, alpha=0.3)

# Precision-Recall curve
precision_curve, recall_curve, _ = precision_recall_curve(interaction_flat, saliency_flat)
ax = axes[1]
ax.plot(recall_curve, precision_curve, linewidth=2, label=f'AP = {results["auc_pr"]:.3f}')
baseline = results['n_positive'] / results['n_pairs']
ax.axhline(y=baseline, color='k', linestyle='--', linewidth=1, label=f'Baseline = {baseline:.3f}')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curve')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
# Visualize a subset of the saliency and interaction matrices
n_genes_viz = min(100, n_genes)
subset_indices = np.random.choice(n_genes, n_genes_viz, replace=False)

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Saliency matrix
ax = axes[0]
im1 = ax.imshow(saliency_matrix[np.ix_(subset_indices, subset_indices)], 
                cmap='YlOrRd', aspect='auto', interpolation='nearest')
ax.set_title(f'Saliency Matrix ({n_genes_viz} genes)')
ax.set_xlabel('Target Gene')
ax.set_ylabel('Source Gene')
plt.colorbar(im1, ax=ax, label='Saliency Score')

# Interaction matrix
ax = axes[1]
im2 = ax.imshow(interaction_matrix[np.ix_(subset_indices, subset_indices)], 
                cmap='binary', aspect='auto', interpolation='nearest')
ax.set_title(f'STRING Interaction Matrix ({n_genes_viz} genes)')
ax.set_xlabel('Gene 2')
ax.set_ylabel('Gene 1')
plt.colorbar(im2, ax=ax, label='Interaction (0/1)')

plt.tight_layout()
plt.show()


## 8. Case Study: Specific Gene Interactions


In [None]:
def analyze_specific_gene(gene_name, saliency_matrix, interaction_matrix, gene_list, top_n=20):
    """
    Analyze saliency vs interactions for a specific gene.
    """
    if gene_name not in gene_list:
        print(f"Gene {gene_name} not found in gene list")
        return
    
    gene_idx = gene_list.index(gene_name)
    
    # Get saliency for this gene as target
    gene_saliency = saliency_matrix[:, gene_idx]
    gene_interactions = interaction_matrix[:, gene_idx]
    
    # Get top genes by saliency
    top_saliency_indices = np.argsort(gene_saliency)[-top_n:][::-1]
    top_saliency_genes = [gene_list[i] for i in top_saliency_indices]
    top_saliency_scores = gene_saliency[top_saliency_indices]
    top_saliency_interactions = gene_interactions[top_saliency_indices]
    
    # Get true interacting genes
    true_interacting_indices = np.where(gene_interactions == 1)[0]
    true_interacting_genes = [gene_list[i] for i in true_interacting_indices]
    
    print(f"\nAnalysis for gene: {gene_name}")
    print(f"Number of known interactions (STRING): {len(true_interacting_genes)}")
    print(f"\nTop {top_n} genes by saliency:")
    
    overlap = 0
    for i, (gene, score, is_interacting) in enumerate(zip(top_saliency_genes, top_saliency_scores, top_saliency_interactions)):
        marker = "✓" if is_interacting else " "
        print(f"{i+1:2d}. {gene:15s} | Saliency: {score:.6f} | Interacts: {marker}")
        if is_interacting:
            overlap += 1
    
    print(f"\nOverlap: {overlap}/{top_n} ({100*overlap/top_n:.1f}%) of top saliency genes are known interactors")
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = ['red' if i else 'blue' for i in top_saliency_interactions]
    ax.barh(range(top_n), top_saliency_scores, color=colors, alpha=0.7)
    ax.set_yticks(range(top_n))
    ax.set_yticklabels(top_saliency_genes)
    ax.set_xlabel('Saliency Score')
    ax.set_title(f'Top {top_n} Genes by Saliency for {gene_name}\n(Red = Known Interactor from STRING)')
    ax.invert_yaxis()
    plt.tight_layout()
    plt.show()
    
    return {
        'top_genes': top_saliency_genes,
        'top_scores': top_saliency_scores,
        'overlap': overlap,
        'n_true_interactions': len(true_interacting_genes)
    }


In [None]:
# Analyze specific genes
# Choose some well-known genes or genes from your perturbation data

example_genes = genes[:5]  # Adjust to genes of interest
print(f"Analyzing example genes: {example_genes}")

for gene in example_genes:
    try:
        analyze_specific_gene(gene, saliency_matrix, interaction_matrix, genes, top_n=15)
    except Exception as e:
        print(f"Error analyzing {gene}: {e}")
