# Model Performance Comparison: Pre-trained vs Fine-tuned scGPT

This notebook provides a comprehensive comparison of the pre-trained and fine-tuned scGPT models' performance on both training and test data to assess the effectiveness of fine-tuning.

## Overview
- **Goal**: Compare performance between pre-trained and fine-tuned models on training and test data
- **Dataset**: Adamson perturbation data with simulation split
- **Analysis**: Multiple evaluation metrics including perturbation prediction accuracy, gene expression reconstruction, and downstream task performance
- **Context**: Previous analysis showed OOD issues - this notebook quantifies how well fine-tuning addresses them


In [None]:
# Import libraries
import json
import os
import sys
import time
import copy
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Union, Optional
import warnings

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.decomposition import PCA
from torch import nn
from torch.nn import functional as F
from torchtext.vocab import Vocab
from torchtext._torchtext import Vocab as VocabPybind
from torch_geometric.loader import DataLoader
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction

sys.path.insert(0, "../")

import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.loss import (
    masked_mse_loss,
    criterion_neg_log_bernoulli,
    masked_relative_error,
)
from scgpt.tokenizer import tokenize_batch, pad_batch, tokenize_and_pad_batch
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id, compute_perturbation_metrics

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
warnings.filterwarnings("ignore")

set_seed(42)
print("Libraries imported successfully!")


In [None]:
# Load and prepare data
print("Loading perturbation data...")

# Settings for data processing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0
pert_pad_id = 0
include_zero_gene = "all"
max_seq_len = 1536

# Dataset settings
data_name = "adamson"
split = "simulation"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load perturbation data
pert_data = PertData("./data")
pert_data.load(data_name=data_name)
pert_data.prepare_split(split=split, seed=1)
pert_data.get_dataloader(batch_size=64, test_batch_size=64)

print(f"Data loaded successfully!")
print(f"Dataset: {data_name}")
print(f"Split: {split}")
print(f"Device: {device}")

# Get basic info about the dataset
adata = pert_data.adata
print(f"\nDataset info:")
print(f"Total cells: {adata.n_obs}")
print(f"Total genes: {adata.n_vars}")
print(f"Conditions: {len(adata.obs['condition'].unique())} unique conditions")

# Extract train/test splits
def extract_split_data_by_conditions(adata, set2conditions, split_name):
    """Extract data for a specific split based on conditions"""
    if split_name not in set2conditions:
        raise ValueError(f"Unknown split: {split_name}")

    # Get conditions for this split
    split_conditions = set2conditions[split_name]

    # Create boolean mask for cells in this split
    split_mask = adata.obs['condition'].isin(split_conditions)

    return adata[split_mask].copy()

train_adata = extract_split_data_by_conditions(adata, pert_data.set2conditions, "train")
test_adata = extract_split_data_by_conditions(adata, pert_data.set2conditions, "test")
val_adata = extract_split_data_by_conditions(adata, pert_data.set2conditions, "val")

print(f"\nSplit sizes:")
print(f"Train: {train_adata.n_obs} cells")
print(f"Test: {test_adata.n_obs} cells")
print(f"Val: {val_adata.n_obs} cells")


In [None]:
# Load pretrained and finetuned models
print("Loading models...")

# Model settings
load_model = "./save/scGPT_human"
load_param_prefixs = [
    "encoder",
    "value_encoder", 
    "transformer_encoder",
]

# Load model configuration
model_dir = Path("./save/scGPT_human")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

pert_data.adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
genes = pert_data.adata.var["gene_name"].tolist()

# Load model configuration
with open(model_config_file, "r") as f:
    model_configs = json.load(f)

embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)
ntokens = len(vocab)

print(f"Model configuration loaded:")
print(f"  Vocabulary size: {ntokens}")
print(f"  Embedding size: {embsize}")
print(f"  Number of layers: {nlayers}")
print(f"  Genes in vocab: {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)}")


In [None]:
# Create and load pretrained model
print("Loading pretrained model...")
model_pretrain = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Load pretrained weights
model_dict = model_pretrain.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
    k: v for k, v in pretrained_dict.items()
    if any([k.startswith(prefix) for prefix in load_param_prefixs])
}
for k, v in pretrained_dict.items():
    print(f"Loading pretrained param {k} with shape {v.shape}")
model_dict.update(pretrained_dict)
model_pretrain.load_state_dict(model_dict)
model_pretrain.to(device)
model_pretrain.eval()

print("Pretrained model loaded successfully!")

# Load finetuned model
print("Loading finetuned model...")
model_finetune = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Try to load finetuned weights
finetuned_model_dir = Path("./save/scGPT_human_finetuned_adamson")
finetuned_model_file = finetuned_model_dir / "best_model.pt"

if finetuned_model_file.exists():
    try:
        model_finetune.load_state_dict(torch.load(finetuned_model_file))
        print("Finetuned model loaded successfully!")
    except Exception as e:
        print(f"Error loading finetuned model: {e}")
        print("Using pretrained model for both comparisons...")
        model_finetune = copy.deepcopy(model_pretrain)
else:
    print("Finetuned model not found. Using pretrained model for both comparisons...")
    model_finetune = copy.deepcopy(model_pretrain)

model_finetune.to(device)
model_finetune.eval()

print("Models ready for evaluation!")


In [None]:
# Define evaluation functions
def evaluate_model_performance(model, adata, split_name, max_cells=1000, batch_size=32):
    """
    Evaluate model performance on a dataset split
    
    Args:
        model: The scGPT model to evaluate
        adata: AnnData object with cell data
        split_name: Name of the split for logging
        max_cells: Maximum number of cells to evaluate (for memory efficiency)
        batch_size: Batch size for evaluation
    
    Returns:
        Dictionary with evaluation metrics
    """
    model.eval()
    device = next(model.parameters()).device
    
    # Sample cells if needed
    if adata.n_obs > max_cells:
        indices = np.random.choice(adata.n_obs, max_cells, replace=False)
        adata_sample = adata[indices].copy()
    else:
        adata_sample = adata.copy()
    
    print(f"Evaluating {split_name} with {adata_sample.n_obs} cells...")
    
    # Convert to dense if sparse
    if hasattr(adata_sample.X, 'toarray'):
        X = adata_sample.X.toarray()
    else:
        X = adata_sample.X
    
    all_predictions = []
    all_targets = []
    all_losses = []
    
    with torch.no_grad():
        n_batches = (len(X) + batch_size - 1) // batch_size
        
        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(X))
            batch_X = X[start_idx:end_idx]
            
            # Prepare input
            input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
            mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
            mapped_input_gene_ids = mapped_input_gene_ids.unsqueeze(0).repeat(len(batch_X), 1)
            
            input_values = torch.from_numpy(batch_X).to(device=device, dtype=torch.float32)
            input_pert_flags = torch.zeros(len(batch_X), n_genes, dtype=torch.long, device=device)
            src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device)
            
            try:
                # Forward pass
                output = model(
                    mapped_input_gene_ids,
                    input_values,
                    input_pert_flags,
                    src_key_padding_mask=src_key_padding_mask,
                )
                
                # Calculate loss (MSE for reconstruction)
                loss = F.mse_loss(output, input_values, reduction='none')
                loss = loss.mean(dim=1)  # Average over genes per cell
                
                all_predictions.append(output.cpu().numpy())
                all_targets.append(batch_X)
                all_losses.append(loss.cpu().numpy())
                
            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue
    
    if not all_predictions:
        print(f"No valid predictions for {split_name}")
        return {}
    
    # Concatenate results
    predictions = np.vstack(all_predictions)
    targets = np.vstack(all_targets)
    losses = np.concatenate(all_losses)
    
    # Calculate metrics
    mse = mean_squared_error(targets.flatten(), predictions.flatten())
    mae = mean_absolute_error(targets.flatten(), predictions.flatten())
    r2 = r2_score(targets.flatten(), predictions.flatten())
    
    # Per-cell metrics
    cell_mse = np.mean((predictions - targets) ** 2, axis=1)
    cell_mae = np.mean(np.abs(predictions - targets), axis=1)
    
    # Correlation analysis
    cell_correlations = []
    for i in range(len(predictions)):
        corr = np.corrcoef(predictions[i], targets[i])[0, 1]
        if not np.isnan(corr):
            cell_correlations.append(corr)
    
    results = {
        'split_name': split_name,
        'n_cells': len(predictions),
        'n_genes': predictions.shape[1],
        'mse': mse,
        'mae': mae,
        'r2': r2,
        'mean_cell_mse': np.mean(cell_mse),
        'std_cell_mse': np.std(cell_mse),
        'mean_cell_mae': np.mean(cell_mae),
        'std_cell_mae': np.std(cell_mae),
        'mean_correlation': np.mean(cell_correlations) if cell_correlations else 0,
        'std_correlation': np.std(cell_correlations) if cell_correlations else 0,
        'predictions': predictions,
        'targets': targets,
        'losses': losses
    }
    
    print(f"  MSE: {mse:.6f}")
    print(f"  MAE: {mae:.6f}")
    print(f"  R²: {r2:.6f}")
    print(f"  Mean cell correlation: {results['mean_correlation']:.4f}")
    
    return results

def compare_model_performance(model_pretrain, model_finetune, train_adata, test_adata, val_adata):
    """Compare performance between pretrained and finetuned models"""
    
    print("=== MODEL PERFORMANCE COMPARISON ===")
    
    results = {}
    
    # Evaluate pretrained model
    print("\n--- PRETRAINED MODEL EVALUATION ---")
    results['pretrain'] = {}
    results['pretrain']['train'] = evaluate_model_performance(model_pretrain, train_adata, "train")
    results['pretrain']['test'] = evaluate_model_performance(model_pretrain, test_adata, "test")
    results['pretrain']['val'] = evaluate_model_performance(model_pretrain, val_adata, "val")
    
    # Evaluate finetuned model
    print("\n--- FINETUNED MODEL EVALUATION ---")
    results['finetune'] = {}
    results['finetune']['train'] = evaluate_model_performance(model_finetune, train_adata, "train")
    results['finetune']['test'] = evaluate_model_performance(model_finetune, test_adata, "test")
    results['finetune']['val'] = evaluate_model_performance(model_finetune, val_adata, "val")
    
    return results

print("Evaluation functions defined!")


In [None]:
# Run performance comparison
print("Starting model performance evaluation...")
results = compare_model_performance(model_pretrain, model_finetune, train_adata, test_adata, val_adata)

print("\n=== EVALUATION COMPLETE ===")
print("Results summary:")
for model_type in ['pretrain', 'finetune']:
    print(f"\n{model_type.upper()} MODEL:")
    for split in ['train', 'test', 'val']:
        if results[model_type][split]:
            r = results[model_type][split]
            print(f"  {split}: MSE={r['mse']:.6f}, R²={r['r2']:.6f}, Corr={r['mean_correlation']:.4f}")


In [None]:
# Create comprehensive performance comparison plots
def create_performance_comparison_plots(results):
    """Create comprehensive visualization of model performance comparison"""
    
    # Extract metrics for plotting
    metrics = ['mse', 'mae', 'r2', 'mean_correlation']
    metric_labels = ['MSE (Lower is Better)', 'MAE (Lower is Better)', 'R² (Higher is Better)', 'Mean Correlation (Higher is Better)']
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Model Performance Comparison: Pre-trained vs Fine-tuned', fontsize=16)
    
    # Data for plotting
    model_types = ['Pretrained', 'Finetuned']
    splits = ['Train', 'Test', 'Val']
    
    for idx, (metric, label) in enumerate(zip(metrics, metric_labels)):
        row, col = idx // 2, idx % 2
        ax = axes[row, col]
        
        # Extract data
        pretrain_values = []
        finetune_values = []
        
        for split in ['train', 'test', 'val']:
            if results['pretrain'][split]:
                pretrain_values.append(results['pretrain'][split][metric])
            if results['finetune'][split]:
                finetune_values.append(results['finetune'][split][metric])
        
        # Create grouped bar plot
        x = np.arange(len(splits))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, pretrain_values, width, label='Pretrained', alpha=0.8)
        bars2 = ax.bar(x + width/2, finetune_values, width, label='Finetuned', alpha=0.8)
        
        ax.set_xlabel('Dataset Split')
        ax.set_ylabel(label)
        ax.set_title(f'{label}')
        ax.set_xticks(x)
        ax.set_xticklabels(splits)
        ax.legend()
        
        # Add value labels on bars
        for bars in [bars1, bars2]:
            for bar in bars:
                height = bar.get_height()
                ax.annotate(f'{height:.3f}',
                           xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3),  # 3 points vertical offset
                           textcoords="offset points",
                           ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.show()
    
    # Create performance improvement analysis
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle('Fine-tuning Performance Improvement Analysis', fontsize=16)
    
    for idx, split in enumerate(['train', 'test', 'val']):
        ax = axes[idx]
        
        if results['pretrain'][split] and results['finetune'][split]:
            pretrain_r = results['pretrain'][split]
            finetune_r = results['finetune'][split]
            
            # Calculate improvements
            mse_improvement = (pretrain_r['mse'] - finetune_r['mse']) / pretrain_r['mse'] * 100
            r2_improvement = (finetune_r['r2'] - pretrain_r['r2']) / abs(pretrain_r['r2']) * 100
            corr_improvement = (finetune_r['mean_correlation'] - pretrain_r['mean_correlation']) / abs(pretrain_r['mean_correlation']) * 100
            
            improvements = [mse_improvement, r2_improvement, corr_improvement]
            labels = ['MSE Improvement (%)', 'R² Improvement (%)', 'Correlation Improvement (%)']
            colors = ['red' if x < 0 else 'green' for x in improvements]
            
            bars = ax.bar(labels, improvements, color=colors, alpha=0.7)
            ax.set_title(f'{split.title()} Split')
            ax.set_ylabel('Improvement (%)')
            ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
            
            # Add value labels
            for bar, value in zip(bars, improvements):
                height = bar.get_height()
                ax.annotate(f'{value:.1f}%',
                           xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3 if height >= 0 else -15),
                           textcoords="offset points",
                           ha='center', va='bottom' if height >= 0 else 'top', fontsize=10)
    
    plt.tight_layout()
    plt.show()

# Create plots
create_performance_comparison_plots(results)


In [None]:
# Detailed statistical analysis
def perform_statistical_analysis(results):
    """Perform detailed statistical analysis of performance differences"""
    
    print("=== DETAILED STATISTICAL ANALYSIS ===")
    
    analysis_results = {}
    
    for split in ['train', 'test', 'val']:
        if results['pretrain'][split] and results['finetune'][split]:
            print(f"\n--- {split.upper()} SPLIT ANALYSIS ---")
            
            pretrain_r = results['pretrain'][split]
            finetune_r = results['finetune'][split]
            
            # Extract cell-level metrics for statistical testing
            pretrain_cell_mse = pretrain_r['losses']
            finetune_cell_mse = finetune_r['losses']
            
            # Statistical tests
            # 1. Paired t-test for MSE differences
            from scipy.stats import ttest_rel, wilcoxon
            
            try:
                t_stat, t_pvalue = ttest_rel(pretrain_cell_mse, finetune_cell_mse)
                w_stat, w_pvalue = wilcoxon(pretrain_cell_mse, finetune_cell_mse)
                
                print(f"Cell-level MSE comparison:")
                print(f"  Pretrained mean MSE: {np.mean(pretrain_cell_mse):.6f} ± {np.std(pretrain_cell_mse):.6f}")
                print(f"  Finetuned mean MSE:  {np.mean(finetune_cell_mse):.6f} ± {np.std(finetune_cell_mse):.6f}")
                print(f"  Paired t-test: t={t_stat:.4f}, p={t_pvalue:.4f}")
                print(f"  Wilcoxon test: W={w_stat:.4f}, p={w_pvalue:.4f}")
                
                if t_pvalue < 0.05:
                    improvement = (np.mean(pretrain_cell_mse) - np.mean(finetune_cell_mse)) / np.mean(pretrain_cell_mse) * 100
                    print(f"  *** SIGNIFICANT IMPROVEMENT: {improvement:.2f}% reduction in MSE ***")
                else:
                    print(f"  No significant difference in MSE")
                
                analysis_results[split] = {
                    't_stat': t_stat,
                    't_pvalue': t_pvalue,
                    'w_stat': w_stat,
                    'w_pvalue': w_pvalue,
                    'mse_improvement_pct': (np.mean(pretrain_cell_mse) - np.mean(finetune_cell_mse)) / np.mean(pretrain_cell_mse) * 100,
                    'significant_improvement': t_pvalue < 0.05
                }
                
            except Exception as e:
                print(f"Error in statistical analysis: {e}")
                analysis_results[split] = {}
    
    return analysis_results

# Run statistical analysis
statistical_results = perform_statistical_analysis(results)


In [None]:
# Generate comprehensive summary report
def generate_performance_summary(results, statistical_results):
    """Generate comprehensive summary of fine-tuning effectiveness"""
    
    print("=" * 80)
    print("COMPREHENSIVE FINE-TUNING PERFORMANCE SUMMARY")
    print("=" * 80)
    
    # Overall performance comparison
    print("\n1. OVERALL PERFORMANCE COMPARISON:")
    print("-" * 50)
    
    for split in ['train', 'test', 'val']:
        if results['pretrain'][split] and results['finetune'][split]:
            pretrain_r = results['pretrain'][split]
            finetune_r = results['finetune'][split]
            
            print(f"\n{split.upper()} SPLIT:")
            print(f"  Pretrained - MSE: {pretrain_r['mse']:.6f}, R²: {pretrain_r['r2']:.4f}, Corr: {pretrain_r['mean_correlation']:.4f}")
            print(f"  Finetuned  - MSE: {finetune_r['mse']:.6f}, R²: {finetune_r['r2']:.4f}, Corr: {finetune_r['mean_correlation']:.4f}")
            
            # Calculate improvements
            mse_improvement = (pretrain_r['mse'] - finetune_r['mse']) / pretrain_r['mse'] * 100
            r2_improvement = (finetune_r['r2'] - pretrain_r['r2']) / abs(pretrain_r['r2']) * 100
            corr_improvement = (finetune_r['mean_correlation'] - pretrain_r['mean_correlation']) / abs(pretrain_r['mean_correlation']) * 100
            
            print(f"  Improvements - MSE: {mse_improvement:+.2f}%, R²: {r2_improvement:+.2f}%, Corr: {corr_improvement:+.2f}%")
    
    # Statistical significance analysis
    print("\n2. STATISTICAL SIGNIFICANCE ANALYSIS:")
    print("-" * 50)
    
    significant_improvements = []
    for split in ['train', 'test', 'val']:
        if split in statistical_results and statistical_results[split]:
            stats = statistical_results[split]
            if stats.get('significant_improvement', False):
                significant_improvements.append(split)
                print(f"  {split.upper()}: SIGNIFICANT improvement (p={stats['t_pvalue']:.4f}, {stats['mse_improvement_pct']:.2f}% MSE reduction)")
            else:
                print(f"  {split.upper()}: No significant improvement (p={stats['t_pvalue']:.4f})")
    
    # Train vs Test performance analysis
    print("\n3. TRAIN vs TEST PERFORMANCE ANALYSIS:")
    print("-" * 50)
    
    if results['pretrain']['train'] and results['pretrain']['test'] and results['finetune']['train'] and results['finetune']['test']:
        # Calculate generalization gap
        pretrain_gap = results['pretrain']['test']['mse'] - results['pretrain']['train']['mse']
        finetune_gap = results['finetune']['test']['mse'] - results['finetune']['train']['mse']
        
        print(f"  Pretrained generalization gap (Test MSE - Train MSE): {pretrain_gap:.6f}")
        print(f"  Finetuned generalization gap (Test MSE - Train MSE):  {finetune_gap:.6f}")
        
        gap_improvement = (pretrain_gap - finetune_gap) / pretrain_gap * 100 if pretrain_gap != 0 else 0
        print(f"  Generalization gap improvement: {gap_improvement:+.2f}%")
        
        if gap_improvement > 0:
            print(f"  *** FINE-TUNING REDUCED GENERALIZATION GAP ***")
        else:
            print(f"  Fine-tuning increased generalization gap")
    
    # Overall assessment
    print("\n4. OVERALL ASSESSMENT:")
    print("-" * 50)
    
    # Count improvements
    total_splits = len([s for s in ['train', 'test', 'val'] if results['pretrain'].get(s) and results['finetune'].get(s)])
    improvement_count = 0
    
    for split in ['train', 'test', 'val']:
        if results['pretrain'].get(split) and results['finetune'].get(split):
            pretrain_r = results['pretrain'][split]
            finetune_r = results['finetune'][split]
            if finetune_r['mse'] < pretrain_r['mse']:
                improvement_count += 1
    
    improvement_rate = improvement_count / total_splits * 100 if total_splits > 0 else 0
    
    print(f"  Splits with MSE improvement: {improvement_count}/{total_splits} ({improvement_rate:.1f}%)")
    print(f"  Splits with significant improvement: {len(significant_improvements)}/{total_splits}")
    
    # Final verdict
    print(f"\n5. FINAL VERDICT:")
    print("-" * 50)
    
    if len(significant_improvements) >= 2:
        print(f"  *** FINE-TUNING IS EFFECTIVE ***")
        print(f"  - Significant improvements on {len(significant_improvements)} splits")
        print(f"  - Training loss curve issues may be due to other factors")
    elif improvement_rate >= 66:
        print(f"  *** FINE-TUNING SHOWS MODERATE EFFECTIVENESS ***")
        print(f"  - MSE improved on {improvement_rate:.1f}% of splits")
        print(f"  - May need longer training or different hyperparameters")
    else:
        print(f"  *** FINE-TUNING EFFECTIVENESS IS LIMITED ***")
        print(f"  - Only {improvement_rate:.1f}% of splits show improvement")
        print(f"  - Training issues may indicate fundamental problems")
    
    # Context from previous analysis
    print(f"\n6. CONTEXT FROM DISTRIBUTION ANALYSIS:")
    print("-" * 50)
    print(f"  - Previous analysis showed significant OOD issues between train/test")
    print(f"  - 22 novel perturbations in test set not seen during training")
    print(f"  - Gene expression distribution differences between splits")
    print(f"  - These OOD issues may limit fine-tuning effectiveness")
    
    return {
        'improvement_rate': improvement_rate,
        'significant_improvements': significant_improvements,
        'generalization_gap_improvement': gap_improvement if 'gap_improvement' in locals() else 0,
        'overall_verdict': 'effective' if len(significant_improvements) >= 2 else 'moderate' if improvement_rate >= 66 else 'limited'
    }

# Generate summary
summary = generate_performance_summary(results, statistical_results)


In [None]:
# Save comprehensive results
print("\n=== SAVING RESULTS ===")

# Create results directory
results_dir = Path("./performance_comparison_results")
results_dir.mkdir(exist_ok=True)

# Save detailed results
import json

# Convert numpy arrays to lists for JSON serialization
def convert_for_json(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: convert_for_json(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_for_json(item) for item in obj]
    else:
        return obj

# Save performance results (without large arrays)
performance_summary = {}
for model_type in ['pretrain', 'finetune']:
    performance_summary[model_type] = {}
    for split in ['train', 'test', 'val']:
        if results[model_type][split]:
            r = results[model_type][split]
            performance_summary[model_type][split] = {
                'split_name': r['split_name'],
                'n_cells': r['n_cells'],
                'n_genes': r['n_genes'],
                'mse': r['mse'],
                'mae': r['mae'],
                'r2': r['r2'],
                'mean_cell_mse': r['mean_cell_mse'],
                'std_cell_mse': r['std_cell_mse'],
                'mean_cell_mae': r['mean_cell_mae'],
                'std_cell_mae': r['std_cell_mae'],
                'mean_correlation': r['mean_correlation'],
                'std_correlation': r['std_correlation']
            }

with open(results_dir / "performance_comparison.json", "w") as f:
    json.dump(performance_summary, f, indent=2)

# Save statistical results
with open(results_dir / "statistical_analysis.json", "w") as f:
    json.dump(convert_for_json(statistical_results), f, indent=2)

# Save summary
with open(results_dir / "summary_report.json", "w") as f:
    json.dump(summary, f, indent=2)

# Save plots
plt.savefig(results_dir / "performance_comparison.png", dpi=300, bbox_inches='tight')

print(f"Results saved to {results_dir}/")
print(f"Files created:")
print(f"  - performance_comparison.json")
print(f"  - statistical_analysis.json") 
print(f"  - summary_report.json")
print(f"  - performance_comparison.png")

print(f"\n=== ANALYSIS COMPLETE ===")
print(f"Fine-tuning effectiveness verdict: {summary['overall_verdict'].upper()}")
print(f"Improvement rate: {summary['improvement_rate']:.1f}%")
print(f"Significant improvements: {len(summary['significant_improvements'])} splits")

print(f"\nThis comprehensive analysis provides quantitative evidence of fine-tuning effectiveness")
print(f"and helps understand whether training loss curve issues indicate actual performance problems.")
