# Forward Pass Embedding Distribution Analysis for scGPT Fine-tuning

This notebook analyzes the distribution differences in forward pass embeddings between training and testing data to investigate potential representation shifts that might affect model generalization.

## Overview
- **Goal**: Analyze if there are distribution differences in model embeddings between train/test data
- **Models**: Compare pretrained vs finetuned scGPT models
- **Analysis**: Statistical and visual comparison of embedding patterns across different model layers
- **Focus**: Understanding if fine-tuning changes how the model represents train vs test data differently


In [2]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks/scGPT_finetune')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Setup and Installation for Google Colab
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/notebooks/scGPT_finetune')

# Install required packages
%pip install -r ./requirements.txt
%pip install scgpt "flash-attn<1.0.5"
%pip install seaborn plotly scipy umap-learn


In [None]:
# Import libraries
import json
import os
import sys
import time
import copy
from pathlib import Path
from typing import Iterable, List, Tuple, Dict, Union, Optional
import warnings

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import pairwise_distances
import umap
from torch import nn
from torch.nn import functional as F
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)
from torch_geometric.loader import DataLoader
from gears import PertData, GEARS
from gears.inference import compute_metrics, deeper_analysis, non_dropout_analysis
from gears.utils import create_cell_graph_dataset_for_prediction

sys.path.insert(0, "../")

import scgpt as scg
from scgpt.model import TransformerGenerator
from scgpt.loss import (
    masked_mse_loss,
    criterion_neg_log_bernoulli,
    masked_relative_error,
)
from scgpt.tokenizer import tokenize_batch, pad_batch, tokenize_and_pad_batch
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed, map_raw_id_to_vocab_id, compute_perturbation_metrics

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
matplotlib.rcParams["savefig.transparent"] = False
warnings.filterwarnings("ignore")

set_seed(42)
print("Libraries imported successfully!")


In [None]:
# Load and prepare data
print("Loading perturbation data...")

# Settings for data processing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
pad_value = 0
pert_pad_id = 0
include_zero_gene = "all"
max_seq_len = 1536

# Dataset settings
data_name = "adamson"
split = "simulation"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load perturbation data
pert_data = PertData("./data")
pert_data.load(data_name=data_name)
pert_data.prepare_split(split=split, seed=1)
pert_data.get_dataloader(batch_size=64, test_batch_size=64)

print(f"Data loaded successfully!")
print(f"Dataset: {data_name}")
print(f"Split: {split}")
print(f"Device: {device}")

# Get basic info about the dataset
adata = pert_data.adata
print(f"\nDataset info:")
print(f"Total cells: {adata.n_obs}")
print(f"Total genes: {adata.n_vars}")
print(f"Conditions: {adata.obs['condition'].unique()}")

# Extract train/test splits
def extract_split_data(adata, split_info, split_name):
    """Extract data for a specific split"""
    if split_name == "train":
        split_cells = split_info["train_idx"]
    elif split_name == "test":
        split_cells = split_info["test_idx"]
    elif split_name == "val":
        split_cells = split_info["val_idx"]
    else:
        raise ValueError(f"Unknown split: {split_name}")

    return adata[split_cells].copy()

train_adata = extract_split_data(adata, pert_data.split, "train")
test_adata = extract_split_data(adata, pert_data.split, "test")
val_adata = extract_split_data(adata, pert_data.split, "val")

print(f"\nSplit sizes:")
print(f"Train: {train_adata.n_obs} cells")
print(f"Test: {test_adata.n_obs} cells")
print(f"Val: {val_adata.n_obs} cells")


In [None]:
# Load pretrained and finetuned models
print("Loading models...")

# Model settings
load_model = "./save/scGPT_human"
load_param_prefixs = [
    "encoder",
    "value_encoder",
    "transformer_encoder",
]

# Load model configuration
model_dir = Path("./save/scGPT_human")
model_config_file = model_dir / "args.json"
model_file = model_dir / "best_model.pt"
vocab_file = model_dir / "vocab.json"

vocab = GeneVocab.from_file(vocab_file)
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

pert_data.adata.var["id_in_vocab"] = [
    1 if gene in vocab else -1 for gene in pert_data.adata.var["gene_name"]
]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
genes = pert_data.adata.var["gene_name"].tolist()

# Load model configuration
with open(model_config_file, "r") as f:
    model_configs = json.load(f)

embsize = model_configs["embsize"]
nhead = model_configs["nheads"]
d_hid = model_configs["d_hid"]
nlayers = model_configs["nlayers"]
n_layers_cls = model_configs["n_layers_cls"]

vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array(
    [vocab[gene] if gene in vocab else vocab["<pad>"] for gene in genes], dtype=int
)
n_genes = len(genes)
ntokens = len(vocab)

print(f"Model configuration loaded:")
print(f"  Vocabulary size: {ntokens}")
print(f"  Embedding size: {embsize}")
print(f"  Number of layers: {nlayers}")
print(f"  Genes in vocab: {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)}")


In [None]:
# Create and load pretrained model
print("Loading pretrained model...")
model_pretrain = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Load pretrained weights
model_dict = model_pretrain.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
    k: v for k, v in pretrained_dict.items()
    if any([k.startswith(prefix) for prefix in load_param_prefixs])
}
for k, v in pretrained_dict.items():
    print(f"Loading pretrained param {k} with shape {v.shape}")
model_dict.update(pretrained_dict)
model_pretrain.load_state_dict(model_dict)
model_pretrain.to(device)
model_pretrain.eval()

print("Pretrained model loaded successfully!")

# Load finetuned model
print("Loading finetuned model...")
model_finetune = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=0,
    pad_token=pad_token,
    pad_value=pad_value,
    pert_pad_id=pert_pad_id,
    use_fast_transformer=True,
)

# Try to load finetuned weights
finetuned_model_dir = Path("./save/scGPT_human_finetuned_adamson")
finetuned_model_file = finetuned_model_dir / "best_model.pt"

if finetuned_model_file.exists():
    try:
        model_finetune.load_state_dict(torch.load(finetuned_model_file))
        print("Finetuned model loaded successfully!")
    except Exception as e:
        print(f"Error loading finetuned model: {e}")
        print("Using pretrained model for both comparisons...")
        model_finetune = copy.deepcopy(model_pretrain)
else:
    print("Finetuned model not found. Using pretrained model for both comparisons...")
    model_finetune = copy.deepcopy(model_pretrain)

model_finetune.to(device)
model_finetune.eval()

print("Models ready for embedding extraction!")


In [None]:
# Define embedding extraction functions
def extract_embeddings_from_model(model, adata, split_name, max_cells=1000, layer_indices=None):
    """
    Extract embeddings from different layers of the model

    Args:
        model: The scGPT model
        adata: AnnData object with cell data
        split_name: Name of the split for logging
        max_cells: Maximum number of cells to process (for memory efficiency)
        layer_indices: List of layer indices to extract embeddings from (None for all layers)

    Returns:
        Dictionary with embeddings from different layers
    """
    model.eval()
    device = next(model.parameters()).device

    # Sample cells if needed
    if adata.n_obs > max_cells:
        indices = np.random.choice(adata.n_obs, max_cells, replace=False)
        adata_sample = adata[indices].copy()
    else:
        adata_sample = adata.copy()

    print(f"Extracting embeddings from {adata_sample.n_obs} cells for {split_name}...")

    # Convert to dense if sparse
    if hasattr(adata_sample.X, 'toarray'):
        X = adata_sample.X.toarray()
    else:
        X = adata_sample.X

    embeddings = {}

    with torch.no_grad():
        # Process in batches
        batch_size = 32
        n_batches = (len(X) + batch_size - 1) // batch_size

        all_embeddings = []

        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(X))
            batch_X = X[start_idx:end_idx]

            # Prepare input
            input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
            mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
            mapped_input_gene_ids = mapped_input_gene_ids.unsqueeze(0).repeat(len(batch_X), 1)

            input_values = torch.from_numpy(batch_X).to(device=device, dtype=torch.float32)
            input_pert_flags = torch.zeros(len(batch_X), n_genes, dtype=torch.long, device=device)
            src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device)

            # Forward pass to get embeddings
            try:
                # Get embeddings from different layers
                tr_out = model._encode(
                    mapped_input_gene_ids,
                    input_values,
                    input_pert_flags,
                    src_key_padding_mask,
                )

                # Extract cell embeddings (mean pooling over genes)
                cell_emb = model._get_cell_emb_from_layer(tr_out, input_values)
                all_embeddings.append(cell_emb.cpu().numpy())

            except Exception as e:
                print(f"Error processing batch {batch_idx}: {e}")
                continue

    if all_embeddings:
        embeddings['cell_embeddings'] = np.vstack(all_embeddings)
        print(f"Extracted cell embeddings: {embeddings['cell_embeddings'].shape}")
    else:
        print(f"No embeddings extracted for {split_name}")
        embeddings['cell_embeddings'] = np.array([])

    return embeddings

def extract_layer_wise_embeddings(model, adata, split_name, max_cells=500):
    """
    Extract embeddings from different transformer layers
    """
    model.eval()
    device = next(model.parameters()).device

    # Sample cells if needed
    if adata.n_obs > max_cells:
        indices = np.random.choice(adata.n_obs, max_cells, replace=False)
        adata_sample = adata[indices].copy()
    else:
        adata_sample = adata.copy()

    print(f"Extracting layer-wise embeddings from {adata_sample.n_obs} cells for {split_name}...")

    # Convert to dense if sparse
    if hasattr(adata_sample.X, 'toarray'):
        X = adata_sample.X.toarray()
    else:
        X = adata_sample.X

    layer_embeddings = {}

    with torch.no_grad():
        # Process a smaller batch for layer-wise analysis
        batch_size = 16
        n_batches = min(4, (len(X) + batch_size - 1) // batch_size)  # Limit to 4 batches

        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(X))
            batch_X = X[start_idx:end_idx]

            # Prepare input
            input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
            mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
            mapped_input_gene_ids = mapped_input_gene_ids.unsqueeze(0).repeat(len(batch_X), 1)

            input_values = torch.from_numpy(batch_X).to(device=device, dtype=torch.float32)
            input_pert_flags = torch.zeros(len(batch_X), n_genes, dtype=torch.long, device=device)
            src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device)

            try:
                # Get embeddings from encoder
                encoder_out = model.encoder(mapped_input_gene_ids, input_values, input_pert_flags, src_key_padding_mask)

                # Extract from different transformer layers
                for layer_idx in range(min(6, len(model.transformer_encoder.layers))):  # First 6 layers
                    layer_output = model.transformer_encoder.layers[layer_idx](
                        encoder_out, src_key_padding_mask=src_key_padding_mask
                    )

                    # Mean pool over genes to get cell-level embeddings
                    cell_emb = layer_output.mean(dim=1)  # [batch_size, hidden_dim]

                    if f'layer_{layer_idx}' not in layer_embeddings:
                        layer_embeddings[f'layer_{layer_idx}'] = []
                    layer_embeddings[f'layer_{layer_idx}'].append(cell_emb.cpu().numpy())

            except Exception as e:
                print(f"Error processing batch {batch_idx} for layer analysis: {e}")
                continue

    # Concatenate all batches for each layer
    for layer_name in layer_embeddings:
        if layer_embeddings[layer_name]:
            layer_embeddings[layer_name] = np.vstack(layer_embeddings[layer_name])
            print(f"Layer {layer_name} embeddings: {layer_embeddings[layer_name].shape}")

    return layer_embeddings

print("Embedding extraction functions defined!")


In [None]:
# Extract embeddings from pretrained model
print("=== EXTRACTING EMBEDDINGS FROM PRETRAINED MODEL ===")

# Extract cell embeddings
train_emb_pretrain = extract_embeddings_from_model(model_pretrain, train_adata, "train", max_cells=1000)
test_emb_pretrain = extract_embeddings_from_model(model_pretrain, test_adata, "test", max_cells=1000)
val_emb_pretrain = extract_embeddings_from_model(model_pretrain, val_adata, "val", max_cells=1000)

# Extract layer-wise embeddings
print("\nExtracting layer-wise embeddings from pretrained model...")
train_layers_pretrain = extract_layer_wise_embeddings(model_pretrain, train_adata, "train", max_cells=500)
test_layers_pretrain = extract_layer_wise_embeddings(model_pretrain, test_adata, "test", max_cells=500)

print(f"\nPretrained model embeddings extracted:")
print(f"  Train cell embeddings: {train_emb_pretrain['cell_embeddings'].shape}")
print(f"  Test cell embeddings: {test_emb_pretrain['cell_embeddings'].shape}")
print(f"  Val cell embeddings: {val_emb_pretrain['cell_embeddings'].shape}")
print(f"  Train layer embeddings: {len(train_layers_pretrain)} layers")
print(f"  Test layer embeddings: {len(test_layers_pretrain)} layers")


In [None]:
# Extract embeddings from finetuned model
print("=== EXTRACTING EMBEDDINGS FROM FINETUNED MODEL ===")

# Extract cell embeddings
train_emb_finetune = extract_embeddings_from_model(model_finetune, train_adata, "train", max_cells=1000)
test_emb_finetune = extract_embeddings_from_model(model_finetune, test_adata, "test", max_cells=1000)
val_emb_finetune = extract_embeddings_from_model(model_finetune, val_adata, "val", max_cells=1000)

# Extract layer-wise embeddings
print("\nExtracting layer-wise embeddings from finetuned model...")
train_layers_finetune = extract_layer_wise_embeddings(model_finetune, train_adata, "train", max_cells=500)
test_layers_finetune = extract_layer_wise_embeddings(model_finetune, test_adata, "test", max_cells=500)

print(f"\nFinetuned model embeddings extracted:")
print(f"  Train cell embeddings: {train_emb_finetune['cell_embeddings'].shape}")
print(f"  Test cell embeddings: {test_emb_finetune['cell_embeddings'].shape}")
print(f"  Val cell embeddings: {val_emb_finetune['cell_embeddings'].shape}")
print(f"  Train layer embeddings: {len(train_layers_finetune)} layers")
print(f"  Test layer embeddings: {len(test_layers_finetune)} layers")


In [None]:
# Statistical analysis of embedding distributions
print("=== STATISTICAL ANALYSIS OF EMBEDDING DISTRIBUTIONS ===")

def analyze_embedding_distributions(train_emb, test_emb, model_name):
    """Analyze distribution differences in embeddings"""

    if train_emb['cell_embeddings'].size == 0 or test_emb['cell_embeddings'].size == 0:
        print(f"No embeddings available for {model_name}")
        return {}

    train_embeddings = train_emb['cell_embeddings']
    test_embeddings = test_emb['cell_embeddings']

    print(f"\nAnalyzing {model_name} embeddings:")
    print(f"  Train shape: {train_embeddings.shape}")
    print(f"  Test shape: {test_embeddings.shape}")

    # Calculate statistics
    train_mean = np.mean(train_embeddings, axis=0)
    test_mean = np.mean(test_embeddings, axis=0)
    train_std = np.std(train_embeddings, axis=0)
    test_std = np.std(test_embeddings, axis=0)

    # Calculate per-cell statistics
    train_cell_norms = np.linalg.norm(train_embeddings, axis=1)
    test_cell_norms = np.linalg.norm(test_embeddings, axis=1)
    train_cell_means = np.mean(train_embeddings, axis=1)
    test_cell_means = np.mean(test_embeddings, axis=1)

    # Statistical tests
    # 1. Mean embedding comparison (per dimension)
    mean_diffs = np.abs(train_mean - test_mean)
    mean_diff_pct = np.mean(mean_diffs) / np.mean(np.abs(train_mean)) * 100

    # 2. Variance comparison (per dimension)
    var_diffs = np.abs(train_std - test_std)
    var_diff_pct = np.mean(var_diffs) / np.mean(train_std) * 100

    # 3. Cell-level statistics
    ks_norm, ks_norm_p = stats.ks_2samp(train_cell_norms, test_cell_norms)
    ks_mean, ks_mean_p = stats.ks_2samp(train_cell_means, test_cell_means)

    # 4. Cosine similarity between mean embeddings
    cosine_sim = np.dot(train_mean, test_mean) / (np.linalg.norm(train_mean) * np.linalg.norm(test_mean))

    # 5. Centroid distance
    centroid_distance = np.linalg.norm(train_mean - test_mean)

    results = {
        'mean_diff_pct': mean_diff_pct,
        'var_diff_pct': var_diff_pct,
        'ks_norm_stat': ks_norm,
        'ks_norm_pvalue': ks_norm_p,
        'ks_mean_stat': ks_mean,
        'ks_mean_pvalue': ks_mean_p,
        'cosine_similarity': cosine_sim,
        'centroid_distance': centroid_distance,
        'train_mean_norm': np.mean(train_cell_norms),
        'test_mean_norm': np.mean(test_cell_norms),
        'train_std_norm': np.std(train_cell_norms),
        'test_std_norm': np.std(test_cell_norms)
    }

    print(f"  Mean difference: {mean_diff_pct:.2f}%")
    print(f"  Variance difference: {var_diff_pct:.2f}%")
    print(f"  Cosine similarity: {cosine_sim:.4f}")
    print(f"  Centroid distance: {centroid_distance:.4f}")
    print(f"  KS test (norms): stat={ks_norm:.4f}, p={ks_norm_p:.4f}")
    print(f"  KS test (means): stat={ks_mean:.4f}, p={ks_mean_p:.4f}")

    if ks_norm_p < 0.05 or ks_mean_p < 0.05:
        print(f"  *** SIGNIFICANT DIFFERENCE in embedding distributions ***")

    return results

# Analyze pretrained model
pretrain_results = analyze_embedding_distributions(train_emb_pretrain, test_emb_pretrain, "Pretrained")

# Analyze finetuned model
finetune_results = analyze_embedding_distributions(train_emb_finetune, test_emb_finetune, "Finetuned")


In [None]:
# Visualize embedding distributions
print("=== VISUALIZING EMBEDDING DISTRIBUTIONS ===")

def plot_embedding_distributions(train_emb, test_emb, model_name, max_cells=1000):
    """Create visualizations of embedding distributions"""

    if train_emb['cell_embeddings'].size == 0 or test_emb['cell_embeddings'].size == 0:
        print(f"No embeddings available for {model_name} visualization")
        return

    train_embeddings = train_emb['cell_embeddings']
    test_embeddings = test_emb['cell_embeddings']

    # Sample for visualization if too many cells
    if len(train_embeddings) > max_cells:
        train_indices = np.random.choice(len(train_embeddings), max_cells, replace=False)
        train_embeddings = train_embeddings[train_indices]

    if len(test_embeddings) > max_cells:
        test_indices = np.random.choice(len(test_embeddings), max_cells, replace=False)
        test_embeddings = test_embeddings[test_indices]

    # Calculate per-cell statistics
    train_norms = np.linalg.norm(train_embeddings, axis=1)
    test_norms = np.linalg.norm(test_embeddings, axis=1)
    train_means = np.mean(train_embeddings, axis=1)
    test_means = np.mean(test_embeddings, axis=1)

    # Create subplots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle(f'Embedding Distribution Analysis: {model_name}', fontsize=16)

    # 1. Embedding norms distribution
    axes[0, 0].hist(train_norms, bins=50, alpha=0.7, label='Train', density=True)
    axes[0, 0].hist(test_norms, bins=50, alpha=0.7, label='Test', density=True)
    axes[0, 0].set_xlabel('Embedding Norm')
    axes[0, 0].set_ylabel('Density')
    axes[0, 0].set_title('Embedding Norm Distribution')
    axes[0, 0].legend()

    # 2. Embedding means distribution
    axes[0, 1].hist(train_means, bins=50, alpha=0.7, label='Train', density=True)
    axes[0, 1].hist(test_means, bins=50, alpha=0.7, label='Test', density=True)
    axes[0, 1].set_xlabel('Mean Embedding Value')
    axes[0, 1].set_ylabel('Density')
    axes[0, 1].set_title('Mean Embedding Distribution')
    axes[0, 1].legend()

    # 3. PCA visualization
    try:
        # Combine embeddings for PCA
        combined_embeddings = np.vstack([train_embeddings, test_embeddings])
        pca = PCA(n_components=2, random_state=42)
        pca_embeddings = pca.fit_transform(combined_embeddings)

        train_pca = pca_embeddings[:len(train_embeddings)]
        test_pca = pca_embeddings[len(train_embeddings):]

        axes[0, 2].scatter(train_pca[:, 0], train_pca[:, 1], alpha=0.6, label='Train', s=20)
        axes[0, 2].scatter(test_pca[:, 0], test_pca[:, 1], alpha=0.6, label='Test', s=20)
        axes[0, 2].set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%})')
        axes[0, 2].set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%})')
        axes[0, 2].set_title('PCA Visualization')
        axes[0, 2].legend()
    except Exception as e:
        axes[0, 2].text(0.5, 0.5, f'PCA Error: {str(e)}', ha='center', va='center')
        axes[0, 2].set_title('PCA Visualization (Error)')

    # 4. Box plot comparison
    data_to_plot = [train_norms, test_norms]
    axes[1, 0].boxplot(data_to_plot, labels=['Train', 'Test'])
    axes[1, 0].set_ylabel('Embedding Norm')
    axes[1, 0].set_title('Embedding Norm Comparison')

    # 5. UMAP visualization (if available)
    try:
        import umap
        reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
        umap_embeddings = reducer.fit_transform(combined_embeddings)

        train_umap = umap_embeddings[:len(train_embeddings)]
        test_umap = umap_embeddings[len(train_embeddings):]

        axes[1, 1].scatter(train_umap[:, 0], train_umap[:, 1], alpha=0.6, label='Train', s=20)
        axes[1, 1].scatter(test_umap[:, 0], test_umap[:, 1], alpha=0.6, label='Test', s=20)
        axes[1, 1].set_xlabel('UMAP 1')
        axes[1, 1].set_ylabel('UMAP 2')
        axes[1, 1].set_title('UMAP Visualization')
        axes[1, 1].legend()
    except Exception as e:
        axes[1, 1].text(0.5, 0.5, f'UMAP Error: {str(e)}', ha='center', va='center')
        axes[1, 1].set_title('UMAP Visualization (Error)')

    # 6. Distance analysis
    try:
        # Calculate pairwise distances
        train_distances = pairwise_distances(train_embeddings, metric='cosine')
        test_distances = pairwise_distances(test_embeddings, metric='cosine')
        between_distances = pairwise_distances(train_embeddings, test_embeddings, metric='cosine')

        train_within = train_distances[np.triu_indices_from(train_distances, k=1)]
        test_within = test_distances[np.triu_indices_from(test_distances, k=1)]
        between = between_distances.flatten()

        axes[1, 2].hist(train_within, bins=30, alpha=0.7, label='Train within', density=True)
        axes[1, 2].hist(test_within, bins=30, alpha=0.7, label='Test within', density=True)
        axes[1, 2].hist(between, bins=30, alpha=0.7, label='Train-Test between', density=True)
        axes[1, 2].set_xlabel('Cosine Distance')
        axes[1, 2].set_ylabel('Density')
        axes[1, 2].set_title('Distance Distributions')
        axes[1, 2].legend()
    except Exception as e:
        axes[1, 2].text(0.5, 0.5, f'Distance Error: {str(e)}', ha='center', va='center')
        axes[1, 2].set_title('Distance Analysis (Error)')

    plt.tight_layout()
    plt.show()

# Visualize pretrained model
plot_embedding_distributions(train_emb_pretrain, test_emb_pretrain, "Pretrained Model")

# Visualize finetuned model
plot_embedding_distributions(train_emb_finetune, test_emb_finetune, "Finetuned Model")


In [None]:
# Layer-wise analysis
print("=== LAYER-WISE EMBEDDING ANALYSIS ===")

def analyze_layer_wise_distributions(train_layers, test_layers, model_name):
    """Analyze distribution differences across transformer layers"""

    if not train_layers or not test_layers:
        print(f"No layer-wise embeddings available for {model_name}")
        return {}

    layer_results = {}

    for layer_name in train_layers:
        if layer_name not in test_layers:
            continue

        train_layer_emb = train_layers[layer_name]
        test_layer_emb = test_layers[layer_name]

        if train_layer_emb.size == 0 or test_layer_emb.size == 0:
            continue

        print(f"\nAnalyzing {layer_name} for {model_name}:")
        print(f"  Train shape: {train_layer_emb.shape}")
        print(f"  Test shape: {test_layer_emb.shape}")

        # Calculate statistics
        train_norms = np.linalg.norm(train_layer_emb, axis=1)
        test_norms = np.linalg.norm(test_layer_emb, axis=1)
        train_means = np.mean(train_layer_emb, axis=1)
        test_means = np.mean(test_layer_emb, axis=1)

        # Statistical tests
        ks_norm, ks_norm_p = stats.ks_2samp(train_norms, test_norms)
        ks_mean, ks_mean_p = stats.ks_2samp(train_means, test_means)

        # Centroid distance
        train_centroid = np.mean(train_layer_emb, axis=0)
        test_centroid = np.mean(test_layer_emb, axis=0)
        centroid_distance = np.linalg.norm(train_centroid - test_centroid)

        # Cosine similarity
        cosine_sim = np.dot(train_centroid, test_centroid) / (np.linalg.norm(train_centroid) * np.linalg.norm(test_centroid))

        layer_results[layer_name] = {
            'ks_norm_stat': ks_norm,
            'ks_norm_pvalue': ks_norm_p,
            'ks_mean_stat': ks_mean,
            'ks_mean_pvalue': ks_mean_p,
            'centroid_distance': centroid_distance,
            'cosine_similarity': cosine_sim,
            'train_mean_norm': np.mean(train_norms),
            'test_mean_norm': np.mean(test_norms)
        }

        print(f"  Centroid distance: {centroid_distance:.4f}")
        print(f"  Cosine similarity: {cosine_sim:.4f}")
        print(f"  KS test (norms): stat={ks_norm:.4f}, p={ks_norm_p:.4f}")
        print(f"  KS test (means): stat={ks_mean:.4f}, p={ks_mean_p:.4f}")

        if ks_norm_p < 0.05 or ks_mean_p < 0.05:
            print(f"  *** SIGNIFICANT DIFFERENCE in {layer_name} ***")

    return layer_results

# Analyze pretrained model layers
pretrain_layer_results = analyze_layer_wise_distributions(train_layers_pretrain, test_layers_pretrain, "Pretrained")

# Analyze finetuned model layers
finetune_layer_results = analyze_layer_wise_distributions(train_layers_finetune, test_layers_finetune, "Finetuned")


In [None]:
# Summary and conclusions
print("=== EMBEDDING DISTRIBUTION ANALYSIS SUMMARY ===")

def generate_embedding_summary(pretrain_results, finetune_results, pretrain_layer_results, finetune_layer_results):
    """Generate comprehensive summary of embedding analysis"""

    print("EMBEDDING DISTRIBUTION ANALYSIS SUMMARY")
    print("=" * 60)

    # Overall embedding analysis
    print("\n1. OVERALL EMBEDDING DISTRIBUTION DIFFERENCES:")

    if pretrain_results:
        print(f"\n   PRETRAINED MODEL:")
        print(f"   - Mean difference: {pretrain_results.get('mean_diff_pct', 0):.2f}%")
        print(f"   - Cosine similarity: {pretrain_results.get('cosine_similarity', 0):.4f}")
        print(f"   - Centroid distance: {pretrain_results.get('centroid_distance', 0):.4f}")
        if pretrain_results.get('ks_norm_pvalue', 1) < 0.05 or pretrain_results.get('ks_mean_pvalue', 1) < 0.05:
            print(f"   - *** SIGNIFICANT DIFFERENCE in pretrained embeddings ***")
        else:
            print(f"   - No significant difference in pretrained embeddings")

    if finetune_results:
        print(f"\n   FINETUNED MODEL:")
        print(f"   - Mean difference: {finetune_results.get('mean_diff_pct', 0):.2f}%")
        print(f"   - Cosine similarity: {finetune_results.get('cosine_similarity', 0):.4f}")
        print(f"   - Centroid distance: {finetune_results.get('centroid_distance', 0):.4f}")
        if finetune_results.get('ks_norm_pvalue', 1) < 0.05 or finetune_results.get('ks_mean_pvalue', 1) < 0.05:
            print(f"   - *** SIGNIFICANT DIFFERENCE in finetuned embeddings ***")
        else:
            print(f"   - No significant difference in finetuned embeddings")

    # Layer-wise analysis
    print(f"\n2. LAYER-WISE DISTRIBUTION DIFFERENCES:")

    if pretrain_layer_results:
        print(f"\n   PRETRAINED MODEL LAYERS:")
        significant_layers_pretrain = []
        for layer_name, results in pretrain_layer_results.items():
            if results.get('ks_norm_pvalue', 1) < 0.05 or results.get('ks_mean_pvalue', 1) < 0.05:
                significant_layers_pretrain.append(layer_name)
                print(f"   - {layer_name}: SIGNIFICANT (p<0.05)")
            else:
                print(f"   - {layer_name}: Not significant")

    if finetune_layer_results:
        print(f"\n   FINETUNED MODEL LAYERS:")
        significant_layers_finetune = []
        for layer_name, results in finetune_layer_results.items():
            if results.get('ks_norm_pvalue', 1) < 0.05 or results.get('ks_mean_pvalue', 1) < 0.05:
                significant_layers_finetune.append(layer_name)
                print(f"   - {layer_name}: SIGNIFICANT (p<0.05)")
            else:
                print(f"   - {layer_name}: Not significant")

    # Model comparison
    print(f"\n3. PRETRAINED vs FINETUNED COMPARISON:")

    if pretrain_results and finetune_results:
        pretrain_diff = pretrain_results.get('mean_diff_pct', 0)
        finetune_diff = finetune_results.get('mean_diff_pct', 0)

        if finetune_diff > pretrain_diff * 1.2:
            print(f"   - *** FINETUNING INCREASED distribution differences ***")
            print(f"   - Pretrained difference: {pretrain_diff:.2f}%")
            print(f"   - Finetuned difference: {finetune_diff:.2f}%")
        elif finetune_diff < pretrain_diff * 0.8:
            print(f"   - Finetuning reduced distribution differences")
            print(f"   - Pretrained difference: {pretrain_diff:.2f}%")
            print(f"   - Finetuned difference: {finetune_diff:.2f}%")
        else:
            print(f"   - Finetuning had minimal effect on distribution differences")
            print(f"   - Pretrained difference: {pretrain_diff:.2f}%")
            print(f"   - Finetuned difference: {finetune_diff:.2f}%")

    # Overall conclusion
    print(f"\n4. OVERALL ASSESSMENT:")

    embedding_ood_indicators = []

    if pretrain_results and (pretrain_results.get('ks_norm_pvalue', 1) < 0.05 or pretrain_results.get('ks_mean_pvalue', 1) < 0.05):
        embedding_ood_indicators.append("Pretrained model shows embedding distribution differences")

    if finetune_results and (finetune_results.get('ks_norm_pvalue', 1) < 0.05 or finetune_results.get('ks_mean_pvalue', 1) < 0.05):
        embedding_ood_indicators.append("Finetuned model shows embedding distribution differences")

    if pretrain_layer_results and any(results.get('ks_norm_pvalue', 1) < 0.05 or results.get('ks_mean_pvalue', 1) < 0.05 for results in pretrain_layer_results.values()):
        embedding_ood_indicators.append("Layer-wise differences in pretrained model")

    if finetune_layer_results and any(results.get('ks_norm_pvalue', 1) < 0.05 or results.get('ks_mean_pvalue', 1) < 0.05 for results in finetune_layer_results.values()):
        embedding_ood_indicators.append("Layer-wise differences in finetuned model")

    if embedding_ood_indicators:
        print(f"   *** EVIDENCE OF EMBEDDING DISTRIBUTION DIFFERENCES:")
        for indicator in embedding_ood_indicators:
            print(f"      - {indicator}")
        print(f"\n   *** CONCLUSION: Model representations differ between train/test data ***")
        print(f"   *** This may contribute to fine-tuning performance issues ***")
    else:
        print(f"   *** CONCLUSION: Limited evidence of embedding distribution differences ***")
        print(f"   *** Embedding differences may not be the primary cause of performance issues ***")

    return {
        'embedding_ood_indicators': embedding_ood_indicators,
        'has_embedding_ood': len(embedding_ood_indicators) > 0
    }

# Generate summary
embedding_summary = generate_embedding_summary(
    pretrain_results, finetune_results,
    pretrain_layer_results, finetune_layer_results
)


In [None]:
# Save results
print("\n=== SAVING RESULTS ===")

# Create results directory
results_dir = Path("./embedding_analysis_results")
results_dir.mkdir(exist_ok=True)

# Save statistical results
with open(results_dir / "embedding_statistical_results.json", "w") as f:
    json_results = {
        'pretrain_results': {k: float(v) if isinstance(v, (np.integer, np.floating)) else v
                           for k, v in pretrain_results.items()},
        'finetune_results': {k: float(v) if isinstance(v, (np.integer, np.floating)) else v
                           for k, v in finetune_results.items()}
    }
    json.dump(json_results, f, indent=2)

# Save layer-wise results
with open(results_dir / "layer_wise_results.json", "w") as f:
    layer_json = {
        'pretrain_layers': {k: {kk: float(vv) if isinstance(vv, (np.integer, np.floating)) else vv
                              for kk, vv in v.items()}
                           for k, v in pretrain_layer_results.items()},
        'finetune_layers': {k: {kk: float(vv) if isinstance(vv, (np.integer, np.floating)) else vv
                              for kk, vv in v.items()}
                           for k, v in finetune_layer_results.items()}
    }
    json.dump(layer_json, f, indent=2)

# Save summary
with open(results_dir / "embedding_summary.json", "w") as f:
    json.dump(embedding_summary, f, indent=2)

# Save embeddings for further analysis (sample only)
if train_emb_pretrain['cell_embeddings'].size > 0:
    np.save(results_dir / "train_embeddings_pretrain.npy", train_emb_pretrain['cell_embeddings'])
if test_emb_pretrain['cell_embeddings'].size > 0:
    np.save(results_dir / "test_embeddings_pretrain.npy", test_emb_pretrain['cell_embeddings'])
if train_emb_finetune['cell_embeddings'].size > 0:
    np.save(results_dir / "train_embeddings_finetune.npy", train_emb_finetune['cell_embeddings'])
if test_emb_finetune['cell_embeddings'].size > 0:
    np.save(results_dir / "test_embeddings_finetune.npy", test_emb_finetune['cell_embeddings'])

print(f"Results saved to {results_dir}/")
print(f"Files created:")
print(f"  - embedding_statistical_results.json")
print(f"  - layer_wise_results.json")
print(f"  - embedding_summary.json")
print(f"  - train/test embeddings (numpy files)")

print(f"\nEmbedding distribution analysis complete!")
print(f"Check the results directory for detailed outputs and further analysis.")
