In [None]:
#!/usr/bin/env python3
"""
Gene-level Analysis of Perturbation Embedding Impact

This script implements a comprehensive gene-level analysis to evaluate
prediction improvement when including perturbation embeddings (P) in
linear models for gene expression prediction.

The analysis follows these steps:
1. Set up baseline linear models with and without perturbation embeddings
2. Train both models on the same training data
3. Evaluate gene-wise prediction performance on test data
4. Compute improvement Δ(g) = Error_noP(g) - Error_withP(g) for each gene
5. Analyze which genes benefit most from perturbation embeddings
6. Provide controls with random embeddings
"""

import argparse
import json
import pickle
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
import tempfile
import warnings

import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse
from scipy.stats import pearsonr
from sklearn.decomposition import PCA
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

class LinearPerturbationModel:
    """
    Linear model for perturbation prediction: Y ≈ G * W * P^T + b
    where G is gene embeddings, P is perturbation embeddings, W is learned weights
    """

    def __init__(self, gene_embedding: np.ndarray, pert_embedding: Optional[np.ndarray] = None,
                 ridge_penalty: float = 0.1, use_perturbations: bool = True):
        """
        Initialize the linear perturbation model.

        Args:
            gene_embedding: Gene embedding matrix (n_genes x embedding_dim)
            pert_embedding: Perturbation embedding matrix (embedding_dim x n_perturbations)
            ridge_penalty: Ridge regularization penalty
            use_perturbations: Whether to use perturbation embeddings or gene-only model
        """
        self.gene_embedding = gene_embedding
        self.pert_embedding = pert_embedding
        self.ridge_penalty = ridge_penalty
        self.use_perturbations = use_perturbations
        self.weights = None
        self.intercept = None
        self.baseline = None

    def fit(self, Y_train: np.ndarray, perturbation_indices: np.ndarray):
        """
        Fit the linear model Y ≈ G * W * P^T + b

        Args:
            Y_train: Training expression data (n_genes x n_train_samples)
            perturbation_indices: Indices mapping samples to perturbations
        """
        # Center the data
        self.baseline = np.mean(Y_train, axis=1, keepdims=True)
        Y_centered = Y_train - self.baseline

        if self.use_perturbations and self.pert_embedding is not None:
            # Use perturbation embeddings: solve Y = G * W * P^T
            P_train = self.pert_embedding[:, perturbation_indices]

            # Solve using ridge regression: W = (G^T G + λI)^-1 G^T Y P (P P^T + λI)^-1
            G = self.gene_embedding
            P = P_train

            # Ridge-regularized solution
            G_reg = G.T @ G + self.ridge_penalty * np.eye(G.shape[1])
            P_reg = P @ P.T + self.ridge_penalty * np.eye(P.shape[0])

            self.weights = np.linalg.solve(G_reg) @ G.T @ Y_centered @ P.T @ np.linalg.solve(P_reg)

        else:
            # Gene-only model: Y = G * W + b (ignore perturbation structure)
            # Average over all perturbations for each gene
            Y_avg = np.mean(Y_centered, axis=1, keepdims=True)
            G_reg = self.gene_embedding.T @ self.gene_embedding + self.ridge_penalty * np.eye(self.gene_embedding.shape[1])
            self.weights = np.linalg.solve(G_reg) @ self.gene_embedding.T @ Y_avg

    def predict(self, perturbation_indices: np.ndarray) -> np.ndarray:
        """
        Predict gene expression for given perturbations.

        Args:
            perturbation_indices: Indices of perturbations to predict

        Returns:
            Predicted expression matrix (n_genes x n_samples)
        """
        if self.weights is None:
            raise ValueError("Model must be fitted before prediction")

        if self.use_perturbations and self.pert_embedding is not None:
            P_test = self.pert_embedding[:, perturbation_indices]
            pred = self.gene_embedding @ self.weights @ P_test + self.baseline
        else:
            # Gene-only model: same prediction for all perturbations
            pred = self.gene_embedding @ self.weights + self.baseline
            pred = np.repeat(pred, len(perturbation_indices), axis=1)

        return pred

class GeneWiseAnalyzer:
    """Analyzer for gene-wise perturbation prediction improvements"""

    def __init__(self, expression_data: np.ndarray, gene_names: List[str],
                 perturbation_names: List[str], train_mask: np.ndarray):
        """
        Initialize the analyzer.

        Args:
            expression_data: Full expression matrix (n_genes x n_samples)
            gene_names: List of gene names
            perturbation_names: List of perturbation names for each sample
            train_mask: Boolean mask indicating training samples
        """
        self.expression_data = expression_data
        self.gene_names = np.array(gene_names)
        self.perturbation_names = np.array(perturbation_names)
        self.train_mask = train_mask
        self.test_mask = ~train_mask

        # Create perturbation index mapping
        unique_perts = np.unique(perturbation_names)
        self.pert_to_idx = {pert: i for i, pert in enumerate(unique_perts)}
        self.perturbation_indices = np.array([self.pert_to_idx[pert] for pert in perturbation_names])

    def create_embeddings(self, gene_embedding_type: str = "pca", pert_embedding_type: str = "pca",
                         embedding_dim: int = 50, random_seed: int = 42) -> Tuple[np.ndarray, np.ndarray]:
        """
        Create gene and perturbation embeddings.

        Args:
            gene_embedding_type: Type of gene embedding ("pca", "random", "identity")
            pert_embedding_type: Type of perturbation embedding ("pca", "random", "identity")
            embedding_dim: Dimensionality of embeddings
            random_seed: Random seed for reproducibility

        Returns:
            Tuple of (gene_embedding, pert_embedding)
        """
        np.random.seed(random_seed)

        # Gene embeddings
        if gene_embedding_type == "pca":
            # PCA on training expression data
            pca_gene = PCA(n_components=embedding_dim, random_state=random_seed)
            gene_embedding = pca_gene.fit_transform(self.expression_data[:, self.train_mask].T).T
        elif gene_embedding_type == "random":
            gene_embedding = np.random.randn(len(self.gene_names), embedding_dim)
        elif gene_embedding_type == "identity":
            gene_embedding = np.eye(len(self.gene_names))[:, :embedding_dim]
        else:
            raise ValueError(f"Unknown gene embedding type: {gene_embedding_type}")

        # Perturbation embeddings
        n_unique_perts = len(self.pert_to_idx)
        if pert_embedding_type == "pca":
            # PCA on training expression data (transposed)
            pca_pert = PCA(n_components=embedding_dim, random_state=random_seed)
            # Group by perturbation and take mean
            pert_means = np.zeros((len(self.gene_names), n_unique_perts))
            for i, pert in enumerate(self.pert_to_idx.keys()):
                pert_mask = (self.perturbation_names == pert) & self.train_mask
                if np.any(pert_mask):
                    pert_means[:, i] = np.mean(self.expression_data[:, pert_mask], axis=1)
            pert_embedding = pca_pert.fit_transform(pert_means.T).T
        elif pert_embedding_type == "random":
            pert_embedding = np.random.randn(embedding_dim, n_unique_perts)
        elif pert_embedding_type == "identity":
            pert_embedding = np.eye(n_unique_perts)[:embedding_dim, :]
        else:
            raise ValueError(f"Unknown perturbation embedding type: {pert_embedding_type}")

        return gene_embedding, pert_embedding

    def compute_gene_wise_errors(self, model_with_P: LinearPerturbationModel,
                               model_without_P: LinearPerturbationModel) -> pd.DataFrame:
        """
        Compute gene-wise prediction errors and improvements.

        Args:
            model_with_P: Model trained with perturbation embeddings
            model_without_P: Model trained without perturbation embeddings

        Returns:
            DataFrame with gene-wise error metrics
        """
        # Get test data
        Y_test = self.expression_data[:, self.test_mask]
        test_pert_indices = self.perturbation_indices[self.test_mask]

        # Predictions
        pred_with_P = model_with_P.predict(test_pert_indices)
        pred_without_P = model_without_P.predict(test_pert_indices)

        # Compute gene-wise errors
        results = []
        for g, gene_name in enumerate(self.gene_names):
            y_true = Y_test[g, :]
            y_pred_with_P = pred_with_P[g, :]
            y_pred_without_P = pred_without_P[g, :]

            # MSE errors
            error_with_P = mean_squared_error(y_true, y_pred_with_P)
            error_without_P = mean_squared_error(y_true, y_pred_without_P)

            # R² scores
            r2_with_P = r2_score(y_true, y_pred_with_P)
            r2_without_P = r2_score(y_true, y_pred_without_P)

            # Pearson correlations
            corr_with_P, _ = pearsonr(y_true, y_pred_with_P)
            corr_without_P, _ = pearsonr(y_true, y_pred_without_P)

            # Improvement metrics
            error_improvement = error_without_P - error_with_P  # Δ(g)
            r2_improvement = r2_with_P - r2_without_P
            corr_improvement = corr_with_P - corr_without_P

            results.append({
                'gene': gene_name,
                'gene_idx': g,
                'error_with_P': error_with_P,
                'error_without_P': error_without_P,
                'error_improvement': error_improvement,  # Δ(g)
                'r2_with_P': r2_with_P,
                'r2_without_P': r2_without_P,
                'r2_improvement': r2_improvement,
                'corr_with_P': corr_with_P,
                'corr_without_P': corr_without_P,
                'corr_improvement': corr_improvement,
                'relative_error_improvement': error_improvement / error_without_P if error_without_P > 0 else 0
            })

        return pd.DataFrame(results)

    def run_comprehensive_analysis(self, gene_embedding_types: List[str] = ["pca"],
                                 pert_embedding_types: List[str] = ["pca", "random", "identity"],
                                 embedding_dim: int = 50, ridge_penalty: float = 0.1,
                                 random_seed: int = 42) -> Dict:
        """
        Run comprehensive gene-wise analysis with multiple embedding types.

        Args:
            gene_embedding_types: List of gene embedding types to test
            pert_embedding_types: List of perturbation embedding types to test
            embedding_dim: Embedding dimensionality
            ridge_penalty: Ridge regularization penalty
            random_seed: Random seed

        Returns:
            Dictionary containing all analysis results
        """
        results = {}

        for gene_emb_type in gene_embedding_types:
            for pert_emb_type in pert_embedding_types:
                print(f"Analyzing: gene_emb={gene_emb_type}, pert_emb={pert_emb_type}")

                # Create embeddings
                gene_emb, pert_emb = self.create_embeddings(
                    gene_embedding_type=gene_emb_type,
                    pert_embedding_type=pert_emb_type,
                    embedding_dim=embedding_dim,
                    random_seed=random_seed
                )

                # Train models
                model_with_P = LinearPerturbationModel(
                    gene_emb, pert_emb, ridge_penalty, use_perturbations=True
                )
                model_without_P = LinearPerturbationModel(
                    gene_emb, None, ridge_penalty, use_perturbations=False
                )

                # Fit on training data
                Y_train = self.expression_data[:, self.train_mask]
                train_pert_indices = self.perturbation_indices[self.train_mask]

                model_with_P.fit(Y_train, train_pert_indices)
                model_without_P.fit(Y_train, train_pert_indices)

                # Compute gene-wise errors
                gene_results = self.compute_gene_wise_errors(model_with_P, model_without_P)

                # Store results
                key = f"{gene_emb_type}_gene_{pert_emb_type}_pert"
                results[key] = {
                    'gene_results': gene_results,
                    'model_with_P': model_with_P,
                    'model_without_P': model_without_P,
                    'gene_embedding': gene_emb,
                    'pert_embedding': pert_emb
                }

        return results

    def summarize_improvements(self, results: Dict) -> pd.DataFrame:
        """
        Summarize improvements across different embedding types.

        Args:
            results: Results from run_comprehensive_analysis

        Returns:
            Summary DataFrame
        """
        summary_data = []

        for key, result in results.items():
            gene_results = result['gene_results']

            summary_data.append({
                'embedding_type': key,
                'mean_error_improvement': gene_results['error_improvement'].mean(),
                'median_error_improvement': gene_results['error_improvement'].median(),
                'std_error_improvement': gene_results['error_improvement'].std(),
                'mean_r2_improvement': gene_results['r2_improvement'].mean(),
                'median_r2_improvement': gene_results['r2_improvement'].median(),
                'mean_corr_improvement': gene_results['corr_improvement'].mean(),
                'median_corr_improvement': gene_results['corr_improvement'].median(),
                'n_genes_improved': (gene_results['error_improvement'] > 0).sum(),
                'pct_genes_improved': (gene_results['error_improvement'] > 0).mean() * 100,
                'n_genes_worsened': (gene_results['error_improvement'] < 0).sum(),
                'pct_genes_worsened': (gene_results['error_improvement'] < 0).mean() * 100
            })

        return pd.DataFrame(summary_data)

    def plot_improvement_distributions(self, results: Dict, save_path: Optional[str] = None):
        """Plot distributions of gene-wise improvements."""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        for i, (key, result) in enumerate(results.items()):
            if i >= 4:  # Limit to first 4 results for visualization
                break

            gene_results = result['gene_results']
            row, col = i // 2, i % 2

            # Error improvement distribution
            axes[row, col].hist(gene_results['error_improvement'], bins=50, alpha=0.7,
                              edgecolor='black', linewidth=0.5)
            axes[row, col].axvline(0, color='red', linestyle='--', alpha=0.8)
            axes[row, col].set_xlabel('Error Improvement Δ(g)')
            axes[row, col].set_ylabel('Number of Genes')
            axes[row, col].set_title(f'{key}\nMean Δ(g) = {gene_results["error_improvement"].mean():.4f}')
            axes[row, col].grid(True, alpha=0.3)

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

    def identify_top_improved_genes(self, results: Dict, top_k: int = 20) -> Dict:
        """
        Identify genes that benefit most from perturbation embeddings.

        Args:
            results: Results from comprehensive analysis
            top_k: Number of top genes to return

        Returns:
            Dictionary with top improved genes for each embedding type
        """
        top_genes = {}

        for key, result in results.items():
            gene_results = result['gene_results']

            # Sort by error improvement (Δ(g))
            top_improved = gene_results.nlargest(top_k, 'error_improvement')
            top_worsened = gene_results.nsmallest(top_k, 'error_improvement')

            top_genes[key] = {
                'top_improved': top_improved[['gene', 'error_improvement', 'r2_improvement', 'corr_improvement']],
                'top_worsened': top_worsened[['gene', 'error_improvement', 'r2_improvement', 'corr_improvement']]
            }

        return top_genes

def load_perturbation_data(data_path: str, dataset_name: str) -> Tuple[np.ndarray, List[str], List[str], np.ndarray]:
    """
    Load perturbation data from the benchmark format.

    Args:
        data_path: Path to data directory
        dataset_name: Name of dataset (e.g., "norman", "adamson")

    Returns:
        Tuple of (expression_data, gene_names, perturbation_names, train_mask)
    """
    # Load the processed data
    data_file = Path(data_path) / "gears_pert_data" / dataset_name / "perturb_processed.h5ad"
    adata = sc.read_h5ad(data_file)

    # Extract expression data
    if sparse.issparse(adata.X):
        expression_data = adata.X.toarray().T  # Transpose to genes x cells
    else:
        expression_data = adata.X.T

    gene_names = adata.var['gene_name'].tolist()
    perturbation_names = adata.obs['condition'].tolist()

    # Create a simple train/test split (80/20)
    np.random.seed(42)
    n_samples = len(perturbation_names)
    train_indices = np.random.choice(n_samples, size=int(0.8 * n_samples), replace=False)
    train_mask = np.zeros(n_samples, dtype=bool)
    train_mask[train_indices] = True

    return expression_data, gene_names, perturbation_names, train_mask

def main():
    parser = argparse.ArgumentParser(description='Gene-level perturbation embedding analysis')
    parser.add_argument('--data_path', type=str, required=True,
                       help='Path to data directory containing gears_pert_data')
    parser.add_argument('--dataset_name', type=str, default='norman',
                       help='Dataset name (norman, adamson, etc.)')
    parser.add_argument('--output_dir', type=str, default='./gene_analysis_results',
                       help='Output directory for results')
    parser.add_argument('--embedding_dim', type=int, default=50,
                       help='Embedding dimensionality')
    parser.add_argument('--ridge_penalty', type=float, default=0.1,
                       help='Ridge regularization penalty')
    parser.add_argument('--random_seed', type=int, default=42,
                       help='Random seed for reproducibility')
    parser.add_argument('--gene_embedding_types', nargs='+', default=['pca'],
                       choices=['pca', 'random', 'identity'],
                       help='Types of gene embeddings to test')
    parser.add_argument('--pert_embedding_types', nargs='+', default=['pca', 'random', 'identity'],
                       choices=['pca', 'random', 'identity'],
                       help='Types of perturbation embeddings to test')

    args = parser.parse_args()

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"Loading data from {args.data_path}/{args.dataset_name}...")

    # Load data
    expression_data, gene_names, perturbation_names, train_mask = load_perturbation_data(
        args.data_path, args.dataset_name
    )

    print(f"Loaded data: {len(gene_names)} genes, {len(perturbation_names)} samples")
    print(f"Training samples: {train_mask.sum()}, Test samples: {(~train_mask).sum()}")
    print(f"Unique perturbations: {len(set(perturbation_names))}")

    # Initialize analyzer
    analyzer = GeneWiseAnalyzer(expression_data, gene_names, perturbation_names, train_mask)

    # Run comprehensive analysis
    print("\nRunning comprehensive gene-wise analysis...")
    results = analyzer.run_comprehensive_analysis(
        gene_embedding_types=args.gene_embedding_types,
        pert_embedding_types=args.pert_embedding_types,
        embedding_dim=args.embedding_dim,
        ridge_penalty=args.ridge_penalty,
        random_seed=args.random_seed
    )

    # Summarize improvements
    print("\nSummarizing improvements...")
    summary = analyzer.summarize_improvements(results)
    print("\nSummary of Improvements:")
    print(summary.to_string(index=False))

    # Save summary
    summary.to_csv(output_dir / 'improvement_summary.csv', index=False)

    # Identify top improved genes
    print("\nIdentifying top improved genes...")
    top_genes = analyzer.identify_top_improved_genes(results, top_k=20)

    # Save detailed results
    for key, result in results.items():
        result['gene_results'].to_csv(output_dir / f'gene_results_{key}.csv', index=False)

    # Save top genes
    with open(output_dir / 'top_genes.json', 'w') as f:
        # Convert DataFrames to dictionaries for JSON serialization
        top_genes_serializable = {}
        for key, value in top_genes.items():
            top_genes_serializable[key] = {
                'top_improved': value['top_improved'].to_dict('records'),
                'top_worsened': value['top_worsened'].to_dict('records')
            }
        json.dump(top_genes_serializable, f, indent=2)

    # Plot improvement distributions
    print("\nGenerating plots...")
    analyzer.plot_improvement_distributions(results, save_path=output_dir / 'improvement_distributions.png')

    # Generate detailed report
    report_path = output_dir / 'analysis_report.txt'
    with open(report_path, 'w') as f:
        f.write("Gene-level Perturbation Embedding Analysis Report\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Dataset: {args.dataset_name}\n")
        f.write(f"Number of genes: {len(gene_names)}\n")
        f.write(f"Number of samples: {len(perturbation_names)}\n")
        f.write(f"Training samples: {train_mask.sum()}\n")
        f.write(f"Test samples: {(~train_mask).sum()}\n")
        f.write(f"Unique perturbations: {len(set(perturbation_names))}\n")
        f.write(f"Embedding dimension: {args.embedding_dim}\n")
        f.write(f"Ridge penalty: {args.ridge_penalty}\n\n")

        f.write("Summary of Improvements:\n")
        f.write("-" * 30 + "\n")
        f.write(summary.to_string(index=False))
        f.write("\n\n")

        f.write("Top 10 Most Improved Genes (by error reduction):\n")
        f.write("-" * 50 + "\n")
        for key, genes_data in top_genes.items():
            f.write(f"\n{key}:\n")
            f.write(genes_data['top_improved'].head(10).to_string(index=False))
            f.write("\n")

    print(f"\nAnalysis complete! Results saved to {output_dir}")
    print(f"Summary: {summary.to_string(index=False)}")

    # Print key insights
    print("\nKey Insights:")
    print("-" * 20)
    for _, row in summary.iterrows():
        emb_type = row['embedding_type']
        mean_improvement = row['mean_error_improvement']
        pct_improved = row['pct_genes_improved']

        if 'random_pert' in emb_type:
            print(f"🔍 {emb_type}: {pct_improved:.1f}% genes improved (control)")
        elif mean_improvement > 0:
            print(f"✅ {emb_type}: {pct_improved:.1f}% genes improved, mean Δ(g)={mean_improvement:.4f}")
        else:
            print(f"❌ {emb_type}: {pct_improved:.1f}% genes improved, mean Δ(g)={mean_improvement:.4f}")

if __name__ == "__main__":
    main()

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/GitHub/Biological-Foundation-Model/Notebooks')

!pip install -r ../requirements.txt