In [None]:
# !pip install mofapy2 
# !pip install mofax

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

from mofapy2.run.entry_point import entry_point
import mofax as mfx
from typing import Dict, List, Tuple, Optional
import warnings

## Phase 4 — Evaluation

1. MOFA+ metrics (already computed)
2. Stability analysis: LOOCV (multi-seed) → evaluate variance explained, factor replication rate, weight stability across runs  
3. Sample-size sensitivity → variance explained and factor clustering as a function of sample size

In [None]:
# ==================== Helper tools ====================
class MOFAValidator:
    """
    Comprehensive validation: LOOCV multi-seed and sample size sensitivity
    """
    def __init__(
        self,
        pro: pd.DataFrame,
        lipid: pd.DataFrame,
        meta: pd.DataFrame,
        views_names: List[str] = None,
    ):
        # Transpose: features × samples → samples × features
        self.pro = pro.T
        self.lipid = lipid.T
        self.meta = meta.T
        
        self.views_names = ['proteomics', 'lipidomics', 'metabolomics']
        
        # Extract metadata
        self.sample_names = self.pro.index.tolist()
        self.n_samples = len(self.sample_names)
        
        self.pro_features = [f"prot_{str(f)}" for f in self.pro.columns.tolist()]
        self.lipid_features = [f"lipid_{str(f)}" for f in self.lipid.columns.tolist()]
        self.meta_features = [f"meta_{str(f)}" for f in self.meta.columns.tolist()]
        
        print(f"Initialized validator:")
        print(f"  Samples: {self.n_samples}")
        print(f"  Proteomics features: {len(self.pro_features)}")
        print(f"  Lipidomics features: {len(self.lipid_features)}")
        print(f"  Metabolomics features: {len(self.meta_features)}")

    def loocv_multiseed(
        self,
        n_runs: int = 10,
        base_seed: int = 2026,
        factors: int = 10,
        iterations: int = 1000,
        save_dir: Optional[str] = None,
        verbose: bool = False
    ) -> Dict:
        """
        Leave-One-Out Cross-Validation with multiple seeds
        """
        models = []
        variance_results = {view: [] for view in self.views_names}
        all_factors = []
        all_weights = {view: [] for view in self.views_names}
        
        if save_dir:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)
        
        features_names = [
            self.pro_features,
            self.lipid_features,
            self.meta_features
        ]
        
        for run_idx in range(n_runs):
            seed = base_seed + run_idx
            
            if verbose:
                print(f"\n{'='*60}")
                print(f"LOOCV Run {run_idx + 1}/{n_runs} (seed={seed})")
                print(f"{'='*60}")
            
            data_nested = [
                [self.pro.values],
                [self.lipid.values],
                [self.meta.values]
            ]
            
            # Train model
            ent = entry_point()
            ent.set_data_options(scale_views=True, scale_groups=False)

            ent.set_data_matrix(
                data=data_nested,
                views_names=self.views_names,
                groups_names=['group1'],
                samples_names=[self.sample_names],
                features_names=features_names
            )
            
            ent.set_model_options(
                factors=factors,
                spikeslab_weights=True,
                ard_factors=True,
                ard_weights=True
            )
            
            ent.set_train_options(
                iter=iterations, 
                convergence_mode='medium', 
                dropR2=0.001, 
                seed=seed, 
                verbose=False
            )
            
            ent.build()
            ent.run()
            
            # Save model
            if save_dir:
                save_path = save_dir / f"loocv_run_{run_idx + 1:02d}.hdf5"
                ent.save(str(save_path))
            
            # Extract results
            variance = ent.model.calculate_variance_explained()
            
            # Extract the array (first element of the list)
            variance_array = variance[0]  # Shape: (n_views, n_factors)
            
            # Sum across factors to get total R² per view
            for view_idx, view in enumerate(self.views_names):
                r2_total = variance_array[view_idx, :].sum()  # Sum across all factors
                variance_results[view].append(r2_total)
                
                if verbose:
                    print(f"  {view}: {r2_total:.2%}")
            
            # Extract factors and weights
            factors_loocv = ent.model.nodes['Z'].getExpectation()
            weights_loocv = ent.model.nodes['W'].getExpectation()
            
            all_factors.append(factors_loocv)
            for v_idx, view in enumerate(self.views_names):
                all_weights[view].append(weights_loocv[v_idx])
                        
            models.append(ent)
        
        
        # Calculate factor replication rate
        replication_rate = self._calculate_factor_replication(all_factors)
        
        # Calculate weight stability
        weight_stability = self._calculate_weight_stability(all_weights)

        # Calculate stability
        loocv_results = {
            'models': models,
            'variance_explained': variance_results,
            'variance_summary': {},
            'factors': all_factors,
            'weights': all_weights,
            'replication_rate': replication_rate,
            'weight_stability': weight_stability
        }
        
        for view in self.views_names:
            loocv_results['variance_summary'][view] = {
                'mean': np.mean(variance_results[view]),
                'std': np.std(variance_results[view]),
                'cv': np.std(variance_results[view]) / np.mean(variance_results[view]) * 100
            }
        
        return loocv_results

    def _calculate_factor_replication(self, all_factors: List[np.ndarray], threshold: float = 0.7) -> np.ndarray:
        """
        Calculate % of factors replicated across LOOCV Multi-seed samples
        """
        ref_factors = all_factors[0]
        n_factors = ref_factors.shape[1]
        replication_counts = np.zeros(n_factors)
        
        for factors in all_factors[1:]:
            corr_matrix = np.corrcoef(ref_factors.T, factors.T)[:n_factors, n_factors:]
            max_corr = np.max(np.abs(corr_matrix), axis=1)
            replication_counts += (max_corr > threshold).astype(int)
        
        replication_rate = replication_counts / (len(all_factors) - 1) * 100
        return replication_rate
    
    def _calculate_weight_stability(self, all_weights: Dict[str, List[np.ndarray]]) -> Dict[str, np.ndarray]:
        """Calculate weight correlation across LOOCV Multi-seed samples"""
        stability = {}
        
        for view, weights_list in all_weights.items():
            ref_weights = weights_list[0]
            n_factors = ref_weights.shape[1]
            stabilities = []
            
            for weights in weights_list[1:]:
                corr_matrix = np.corrcoef(ref_weights.T, weights.T)[:n_factors, n_factors:]
                max_corr = np.max(np.abs(corr_matrix), axis=1)
                stabilities.append(max_corr)
            
            stability[view] = np.mean(stabilities, axis=0)
        
        return stability
    
    def sample_size_sensitivity_simulated(
        self,
        target_sizes: List[int] = [20, 30, 40, 50],
        n_replicates: int = 10,
        base_seed: int = 2026,
        factors: int = 10,
        iterations: int = 1000,
        verbose: bool = False
    ) -> Dict:
        """
        Test MOFA+ with simulated larger datasets (this is feasible because the distribution has already been normalized and transformed)
        
        Generates synthetic samples by:
        1. Fitting distributions to the real data 
        2. Sampling from those distributions
        3. Testing MOFA+ stability with increasing n
        
        """
        ss_results = {
            'sample_sizes': target_sizes,
            'variance_explained': {view: {n: [] for n in target_sizes} for view in self.views_names}
        }
        
        # Estimate data distributions from real data
        pro_mean = self.pro.mean(axis=0)
        pro_std = self.pro.std(axis=0)
        
        lipid_mean = self.lipid.mean(axis=0)
        lipid_std = self.lipid.std(axis=0)
        
        meta_mean = self.meta.mean(axis=0)
        meta_std = self.meta.std(axis=0)
        
        features_names = [
            self.pro_features,
            self.lipid_features,
            self.meta_features
        ]
        
        for n_size in target_sizes:
            if verbose:
                print(f"\nSimulating datasets with n={n_size}")
            
            for rep in range(n_replicates):                
                np.random.seed(base_seed + n_size * 100 + rep)

                # Generate synthetic data from estimated distributions
                synth_pro = np.random.normal(pro_mean, pro_std, size=(n_size, len(pro_mean)))
                synth_lipid = np.random.normal(lipid_mean, lipid_std, size=(n_size, len(lipid_mean)))
                synth_meta = np.random.normal(meta_mean, meta_std, size=(n_size, len(meta_mean)))

                # Create unique sample names
                synth_sample_names = [f"synth_{n_size}_{rep}_{i}" for i in range(n_size)]

                # Train MOFA+ on synthetic data
                data_nested = [
                    [synth_pro],
                    [synth_lipid],
                    [synth_meta]
                ]
                
                ent = entry_point()
                ent.set_data_options(scale_views=True, scale_groups=False)

                ent.set_data_matrix(
                    data=data_nested,
                    views_names=self.views_names,
                    groups_names=['group1'],
                    samples_names=[synth_sample_names],
                    features_names=features_names
                )
                                
                ent.set_model_options(
                    factors=factors,
                    spikeslab_weights=True,
                    ard_factors=True,
                    ard_weights=True
                )
                                
                ent.set_train_options(
                    iter=iterations,
                    convergence_mode='fast',
                    dropR2=0.001,
                    seed=base_seed+rep,
                    verbose=False
                )
                
                try:
                    ent.build()
                    ent.run()
                    
                    variance = ent.model.calculate_variance_explained()
                    variance_array = variance[0]
                    
                    for view_idx, view in enumerate(self.views_names):
                        r2_total = variance_array[view_idx, :].sum()
                        ss_results['variance_explained'][view][n_size].append(r2_total)
                        
                except Exception as e:
                    if verbose:
                        print(f"  Error: {e}")
                    continue
        
        # Summaries
        ss_results['summary'] = {}
        for view in self.views_names:
            ss_results['summary'][view] = {
                n: {
                    'mean': np.mean(vals) if vals else np.nan,
                    'std': np.std(vals) if vals else np.nan
                }
                for n, vals in ss_results['variance_explained'][view].items()
            }
        
        return ss_results

class MOFAPlotter:
    """
    Comprehensive plotting utilities for validation results
    """
    @staticmethod
    def plot_loocv_var_stability(loocv_results: Dict, save_path: Optional[str] = None):
        """Plot variance explained stability across LOOCV runs"""
        views = list(loocv_results['variance_explained'].keys())
        n_runs = len(next(iter(loocv_results['variance_explained'].values())))
    
        fig, ax = plt.subplots(figsize=(8, 6))
    
        colors = {
            'proteomics': '#FFD700',
            'lipidomics': '#440154',
            'metabolomics': '#21918c'
        }
    
        x = np.arange(1, n_runs + 1)
    
        for view in views:
            y = loocv_results['variance_explained'][view]
            ax.plot(
                x,
                y,
                marker='o',
                linewidth=2,
                label=view,
                color=colors.get(view, None)
            )
    
            # Mean ± std band
            mean = np.mean(y)
            std = np.std(y)
            ax.axhline(mean, linestyle='--', alpha=0.6, color=colors.get(view, None))
            ax.fill_between(
                x,
                mean - std,
                mean + std,
                alpha=0.15,
                color=colors.get(view, None)
            )
    
        ax.set_xlabel('LOOCV Run', fontsize=12)
        ax.set_ylabel('Variance Explained (R²)', fontsize=12)
        ax.set_title('LOOCV Stability of Variance Explained per Omics View',
                     fontsize=14, fontweight='bold')
    
        ax.legend(title='View')
        ax.grid(alpha=0.3)
    
        plt.tight_layout()
    
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    @staticmethod
    def plot_loocv_factorrep(loocv_factorep_results: Dict, save_path: Optional[str] = None):
        """Plot LOOCV Multi-seed factor replication rates"""
        replication_rate = loocv_factorep_results['replication_rate']
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        
        # Factor replication rate
        n_factors = len(replication_rate)
        axes[0].bar(range(1, n_factors + 1), replication_rate, alpha=0.8, edgecolor='black')
        axes[0].axhline(70, color='red', linestyle='--', label='70% threshold')
        axes[0].set_xlabel('Factor', fontsize=12)
        axes[0].set_ylabel('Replication Rate (%)', fontsize=12)
        axes[0].set_title('LOOCV Multi-seed Factor Replication', fontsize=13, fontweight='bold')
        axes[0].set_xticks(range(1, n_factors + 1))
        axes[0].set_xticklabels([f'F{i}' for i in range(1, n_factors + 1)])
        axes[0].legend()
        axes[0].grid(alpha=0.3, axis='y')
        
        # Weight stability per view
        weight_stability = loocv_factorep_results['weight_stability']
        views = list(weight_stability.keys())
        
        x = np.arange(n_factors)
        width = 0.25
        colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
        
        for idx, view in enumerate(views):
            offset = (idx - 1) * width
            axes[1].bar(x + offset, weight_stability[view], width,
                       label=view, alpha=0.8, color=colors[idx])
        
        axes[1].set_xlabel('Factor', fontsize=12)
        axes[1].set_ylabel('Weight Stability (correlation)', fontsize=12)
        axes[1].set_title('Weight Stability per View', fontsize=13, fontweight='bold')
        axes[1].set_xticks(x)
        axes[1].set_xticklabels([f'F{i+1}' for i in range(n_factors)])
        axes[1].legend()
        axes[1].grid(alpha=0.3, axis='y')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    @staticmethod
    def plot_sample_size_sensitivity(sensitivity_results: Dict, save_path: Optional[str] = None):
        """Plot sample size sensitivity analysis"""
        sample_sizes = sensitivity_results['sample_sizes']
        views = list(sensitivity_results['summary'].keys())
        
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        
        # Variance explained vs sample size
        colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
        for idx, view in enumerate(views):
            means = [sensitivity_results['summary'][view][n]['mean'] for n in sample_sizes]
            stds = [sensitivity_results['summary'][view][n]['std'] for n in sample_sizes]
            
            ax.plot(sample_sizes, means, marker='o', linewidth=2,
                    label=view, color=colors[idx])
            ax.fill_between(sample_sizes,
                            np.array(means) - np.array(stds),
                            np.array(means) + np.array(stds),
                            alpha=0.2, color=colors[idx])
        
        ax.set_xlabel('Sample Size', fontsize=12)
        ax.set_ylabel('Variance Explained (fraction)', fontsize=12)
        ax.set_title('Variance Explained vs Sample Size', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()

In [None]:
# Prep the data

# Load datasets
proteins = pd.read_csv('data/proteins.csv')
lipids = pd.read_csv('data/lipids.csv')
metabolites = pd.read_csv('data/metabolites.csv')

# Split labels and keep it separately
pro_labels = proteins[proteins['sample'] == 'label'].iloc[0,1:].values
lipid_labels = lipids[lipids['sample'] == 'label'].iloc[0,1:].values
meta_labels = metabolites[metabolites['sample'] == 'label'].iloc[0,1:].values

# Now remove the label row
pro_nolabel = proteins[proteins['sample'] != 'label'].reset_index(drop=True)
lipid_nolabel = lipids[lipids['sample'] != 'label'].reset_index(drop=True)
meta_nolabel = metabolites[metabolites['sample'] != 'label'].reset_index(drop=True)

# Set id_cols as index
pro_nolabel = pro_nolabel.set_index('sample')
lipid_nolabel = lipid_nolabel.set_index('sample')
meta_nolabel = meta_nolabel.set_index('sample')

# Filter proteins
pro_nolabel = pro_nolabel.apply(pd.to_numeric, errors='coerce') 
pro_var = pro_nolabel.var(axis=1)
top_proteins = (pro_var.nlargest(min(2000, len(pro_var))).index)
pro_nolabel_filtered = pro_nolabel.loc[top_proteins]

# Force all values to be numeric in case they are't
lipid_nolabel = lipid_nolabel.apply(pd.to_numeric, errors='coerce') 
meta_nolabel = meta_nolabel.apply(pd.to_numeric, errors='coerce') 

In [None]:
# Stability Analysis: LOOCV Multi-seed 

# Initialize validator
validator = MOFAValidator(pro_nolabel_filtered, lipid_nolabel, meta_nolabel)

# LOOCV Multi-seed
print("\n--- LOOCV Multi-Seed Analysis ---")
loocv_results = validator.loocv_multiseed(
    n_runs=10,
    base_seed=2026,
    factors=10,
    iterations=1000,
    verbose=False
)

In [None]:
# Plots LOOCV results
output_dir = Path('plots/evaluation')
output_dir.mkdir(parents=True, exist_ok=True)

MOFAPlotter.plot_loocv_var_stability(loocv_results, save_path=str(output_dir / 'loocv_var_stability.png'))
MOFAPlotter.plot_loocv_factorrep(loocv_results, save_path=str(output_dir / 'loocv_factorrep_and_wstability.png'))

In [None]:
# Sample size sensitivity
sensitivity_results = validator.sample_size_sensitivity_simulated(
        target_sizes=[20, 30, 40, 50],
        n_replicates=10,
        base_seed=2026,
        factors=10,
        iterations=1000,
        verbose=True)
    
MOFAPlotter.plot_sample_size_sensitivity(sensitivity_results, save_path=str(output_dir / 'sample_size_sensitivity.png'))