"""
Multi-Task Deep Gaussian Process - Training and Evaluation Pipeline
====================================================================

This module provides utilities for:
1. Data preprocessing and standardization
2. Model training with early stopping
3. Model evaluation and visualization
4. Multi-model comparison

The pipeline handles partial observations (missing data for some tasks)
and supports prior-guided learning where available.

Usage Example
-------------
# Define input and output variables
input_vars = ['W', 'V', 'Ti', 'Cr', 'Re', 'Fe', 'Ta', 'Zr']
output_vars = ['Property1', 'Property2', ...]

# Prepare and standardize data
input_scaled, output_scaled, yvar_scaled, input_scalers, output_scalers = \
    prepare_and_standardize_data(df, input_vars, output_vars)

# Train model
model_results, models = train_model(
    input_scaled, output_scaled, input_scalers, output_scalers,
    input_vars, output_vars, num_obj=5, reductions=[4, 5, 6]
)

# Visualize results
plot_comparison(model_results)
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from scipy.stats import kendalltau, spearmanr
from gpytorch.mlls import DeepApproximateMLL, VariationalELBO
from typing import Dict, List, Optional, Tuple

# Import model definitions (assumes deep_gp_model_definitions.py is available)
from deep_gp_model_definitions import MultiTaskDeepGP


# ============================================================================
# DATA PREPROCESSING AND STANDARDIZATION
# ============================================================================

def prepare_and_standardize_data(
    df: pd.DataFrame,
    input_vars: List[str],
    output_vars: List[str],
    input_scalers: Optional[Dict] = None,
    yvar_cols: Optional[List[Optional[str]]] = None,
    verbose: bool = True
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], Dict, Dict]:
    """
    Prepare and standardize data with support for partial task observations.
    
    Applies column-wise standardization (zero mean, unit variance) to inputs
    and outputs separately. Handles missing values gracefully and can either
    fit new scalers or apply existing ones.
    
    Parameters
    ----------
    df : pd.DataFrame
        Raw data containing input and output variables
    input_vars : List[str]
        Column names of input features
    output_vars : List[str]
        Column names of output properties/tasks
    input_scalers : Dict, optional
        Pre-fitted input scalers (one per column). If None, fits new scalers.
    yvar_cols : List[Optional[str]], optional
        Column names for observation noise variance (one per output_var).
        Use None for outputs without variance information.
    verbose : bool, default=True
        If True, prints information about valid observations per variable
        
    Returns
    -------
    input_scaled : np.ndarray, shape (n, d)
        Standardized input features
    output_scaled : np.ndarray, shape (n, m)
        Standardized output properties
    yvar_scaled : np.ndarray, shape (n, m) or None
        Scaled observation noise variances (if yvar_cols provided)
    input_scalers : Dict
        Dictionary mapping input variable names to fitted StandardScaler objects
    output_scalers : Dict
        Dictionary mapping output variable names to fitted StandardScaler objects
        
    Notes
    -----
    - Rows with any missing input values are dropped
    - Outputs can have missing values (for partial observations)
    - Each variable is standardized independently
    - Yvar is scaled using the squared scale of the corresponding output
    """
    
    # ========================================================================
    # PROCESS INPUT FEATURES
    # ========================================================================
    
    if input_scalers is None:
        if verbose:
            print("Fitting new input scalers...")
        
        # Coerce inputs to numeric (converts non-numeric to NaN)
        df[input_vars] = df[input_vars].apply(pd.to_numeric, errors='coerce')
        
        # Drop rows with any missing inputs
        df = df.dropna(subset=input_vars).reset_index(drop=True)
        
        # Initialize scaled input array
        input_scaled = np.full_like(df[input_vars].values, np.nan, dtype=np.float64)
        input_scalers = {}

        # Standardize each input column independently
        for j, col in enumerate(input_vars):
            mask = ~np.isnan(df[input_vars].values[:, j])

            if mask.any():
                scaler = StandardScaler()
                input_scaled[mask, j] = scaler.fit_transform(
                    df[input_vars].values[mask, j].reshape(-1, 1)
                ).ravel()
                input_scalers[col] = scaler

                if verbose:
                    print(f"Input {col}: {np.sum(mask)}/{len(mask)} valid observations")
    else:
        if verbose:
            print("Using provided input scalers...")
        
        # Apply existing scalers
        input_scaled = np.empty((len(df), len(input_vars)))
        for j, col in enumerate(input_vars):
            mask = ~df[col].isna()
            if mask.any():
                scaler = input_scalers[col]
                input_scaled[mask, j] = scaler.transform(
                    df.loc[mask, col].values.reshape(-1, 1)
                ).ravel()
                if verbose:
                    print(f"Input {col}: {np.sum(mask)}/{len(mask)} valid observations")

    # ========================================================================
    # PROCESS OUTPUT PROPERTIES
    # ========================================================================
    
    # Initialize scaled output array
    output_scaled = np.full_like(df[output_vars].values, np.nan, dtype=np.float64)
    output_scalers = {}

    # Standardize each output task independently
    for j, col in enumerate(output_vars):
        mask = ~np.isnan(df[output_vars].values[:, j])

        if mask.any():
            scaler = StandardScaler()
            output_scaled[mask, j] = scaler.fit_transform(
                df[output_vars].values[mask, j].reshape(-1, 1)
            ).ravel()
            output_scalers[col] = scaler

            if verbose:
                print(f"Output {col}: {np.sum(mask)}/{len(mask)} valid observations")

    # ========================================================================
    # PROCESS OBSERVATION NOISE VARIANCES
    # ========================================================================
    
    yvar_scaled = None
    if yvar_cols is not None:
        # Initialize with small default noise
        yvar_scaled = np.full_like(df[output_vars].values, 1e-6, dtype=np.float64)

        for j, (output_col, yvar_col) in enumerate(zip(output_vars, yvar_cols)):
            if yvar_col is not None and yvar_col in df.columns:
                mask = ~np.isnan(df[yvar_col].values)

                if mask.any():
                    # Scale variance: Var(aX) = a^2 * Var(X)
                    output_scaler = output_scalers[output_col]
                    yvar_scaled[mask, j] = (df[yvar_col].values[mask] ** 2) / (output_scaler.scale_ ** 2)

                    if verbose:
                        print(f"Yvar for {output_col}: {np.sum(mask)}/{len(mask)} valid observations")

    return input_scaled, output_scaled, yvar_scaled, input_scalers, output_scalers


# ============================================================================
# TRAINING DATA PREPARATION
# ============================================================================

def prepare_training_pairs(
    input_scaled: np.ndarray,
    output_scaled: np.ndarray,
    yvar_scaled: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
    """
    Convert multi-task data to flattened format with task indices.
    
    Transforms data from wide format (one row per sample, multiple output columns)
    to long format (one row per sample-task pair). This format is required for
    multi-task GP models.
    
    Parameters
    ----------
    input_scaled : np.ndarray, shape (n, d)
        Scaled input features
    output_scaled : np.ndarray, shape (n, m)
        Scaled output properties (can contain NaN for missing observations)
    yvar_scaled : np.ndarray, shape (n, m), optional
        Scaled observation noise variances
        
    Returns
    -------
    train_x : np.ndarray, shape (k, d+1)
        Input features with task index appended as last column
        k = number of non-NaN entries in output_scaled
    train_y : np.ndarray, shape (k,)
        Flattened output values
    train_yvar : np.ndarray, shape (k,) or None
        Flattened observation noise variances
        
    Example
    -------
    Input (wide format):
        inputs: [[x1_1, x1_2], [x2_1, x2_2]]
        outputs: [[y1_t0, y1_t1], [NaN, y2_t1]]
    
    Output (long format):
        train_x: [[x1_1, x1_2, 0], [x1_1, x1_2, 1], [x2_1, x2_2, 1]]
        train_y: [y1_t0, y1_t1, y2_t1]
    """
    n_samples, n_tasks = output_scaled.shape
    n_features = input_scaled.shape[1]

    # Count valid (non-NaN) observations
    total_pairs = np.sum(~np.isnan(output_scaled))
    
    # Initialize flattened arrays
    train_x = np.zeros((total_pairs, n_features + 1))
    train_y = np.zeros(total_pairs)
    train_yvar = np.zeros(total_pairs) if yvar_scaled is not None else None

    # Fill arrays with valid observations
    idx = 0
    for i in range(n_samples):
        for task in range(n_tasks):
            if not np.isnan(output_scaled[i, task]):
                train_x[idx, :-1] = input_scaled[i]  # Features
                train_x[idx, -1] = task  # Task index
                train_y[idx] = output_scaled[i, task]
                if yvar_scaled is not None:
                    train_yvar[idx] = yvar_scaled[i, task]
                idx += 1

    return train_x, train_y, train_yvar


def prepare_training_pairs_with_indices(
    input_scaled: np.ndarray,
    output_scaled: np.ndarray,
    yvar_scaled: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], np.ndarray, np.ndarray]:
    """
    Convert multi-task data to flattened format while tracking original indices.
    
    Extended version of prepare_training_pairs that maintains mapping between
    flattened data and original sample/task indices. Useful for evaluation and
    adding back prior values.
    
    Parameters
    ----------
    input_scaled : np.ndarray, shape (n, d)
        Scaled input features
    output_scaled : np.ndarray, shape (n, m)
        Scaled output properties
    yvar_scaled : np.ndarray, shape (n, m), optional
        Scaled observation noise variances
        
    Returns
    -------
    train_x : np.ndarray, shape (k, d+1)
        Input features with task index
    train_y : np.ndarray, shape (k,)
        Flattened output values
    train_yvar : np.ndarray, shape (k,) or None
        Flattened observation noise variances
    original_indices : np.ndarray, shape (k,)
        Index of original sample for each flattened observation
    task_indices : np.ndarray, shape (k,)
        Task index for each flattened observation
    """
    n_samples, n_tasks = output_scaled.shape
    n_features = input_scaled.shape[1]

    total_pairs = np.sum(~np.isnan(output_scaled))

    train_x = np.zeros((total_pairs, n_features + 1))
    train_y = np.zeros(total_pairs)
    train_yvar = np.zeros(total_pairs) if yvar_scaled is not None else None
    original_indices = np.zeros(total_pairs, dtype=int)
    task_indices = np.zeros(total_pairs, dtype=int)

    idx = 0
    for i in range(n_samples):
        for task in range(n_tasks):
            if not np.isnan(output_scaled[i, task]):
                train_x[idx, :-1] = input_scaled[i]
                train_x[idx, -1] = task
                train_y[idx] = output_scaled[i, task]
                if yvar_scaled is not None:
                    train_yvar[idx] = yvar_scaled[i, task]
                original_indices[idx] = i  # Track original sample index
                task_indices[idx] = task  # Track task index
                idx += 1

    return train_x, train_y, train_yvar, original_indices, task_indices


# ============================================================================
# MODEL TRAINING
# ============================================================================

def train_model_wo_validation(
    train_x: np.ndarray,
    train_y: np.ndarray,
    val_x: np.ndarray,
    val_y: np.ndarray,
    num_tasks: int,
    train_yvar: Optional[np.ndarray] = None,
    test_yvar: Optional[np.ndarray] = None,
    num_obj: int = 5,
    num_epochs: int = 3000,
    reduction: int = 8
) -> MultiTaskDeepGP:
    """
    Train Deep GP model without validation set (or with optional validation monitoring).
    
    Uses variational inference with early stopping based on training loss (or
    validation loss if validation data provided). Optimizes ELBO (Evidence Lower
    Bound) which balances data fit with KL divergence regularization.
    
    Parameters
    ----------
    train_x : np.ndarray
        Training inputs with task indices
    train_y : np.ndarray
        Training outputs
    val_x : np.ndarray
        Validation inputs (can be empty for no validation)
    val_y : np.ndarray
        Validation outputs (can be empty for no validation)
    num_tasks : int
        Total number of output tasks
    train_yvar : np.ndarray, optional
        Training observation noise variances
    test_yvar : np.ndarray, optional
        Validation observation noise variances (not used in current implementation)
    num_obj : int, default=5
        Number of output tasks to model (can be subset of num_tasks)
    num_epochs : int, default=3000
        Maximum number of training epochs
    reduction : int, default=8
        Hidden layer dimension reduction parameter
        
    Returns
    -------
    MultiTaskDeepGP
        Trained model
        
    Notes
    -----
    Early Stopping Criteria:
    - With validation: stops if validation loss doesn't improve for 50 epochs
    - Without validation: stops if training loss becomes negative or doesn't improve for 50 epochs
    
    Training is performed on GPU if available, otherwise CPU.
    """
    # Determine device (GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Convert numpy arrays to PyTorch tensors
    train_x = torch.tensor(train_x, dtype=torch.float64).to(device)
    train_y = torch.tensor(train_y, dtype=torch.float64).to(device)
    val_x = torch.tensor(val_x, dtype=torch.float64).to(device)
    val_y = torch.tensor(val_y, dtype=torch.float64).to(device)

    # Convert observation noise if provided
    if train_yvar is not None:
        train_yvar = torch.tensor(train_yvar, dtype=torch.float64).to(device)
    if test_yvar is not None:
        test_yvar = torch.tensor(test_yvar, dtype=torch.float64).to(device)
    
    # Initialize model
    model = MultiTaskDeepGP(
        train_X=train_x,
        train_Y=train_y.unsqueeze(-1),
        task_feature=-1,  # Task index is last column
        output_tasks=list(range(num_obj)),
        train_Yvar=train_yvar.unsqueeze(-1) if train_yvar is not None else None,
        reduction=reduction
    ).to(device)

    model = model.double()

    # Configure optimizer
    optimizer = torch.optim.Adam([
        {'params': model.parameters()},
    ], lr=0.01)

    # Configure marginal log likelihood (loss function)
    # ELBO = log p(y|f) - KL(q(f)||p(f))
    # beta weights the KL term (0.5 = less regularization)
    mll = DeepApproximateMLL(
        VariationalELBO(
            model.likelihood,
            model,
            num_data=train_x.shape[0],
            beta=0.5
        )
    )

    # Training loop variables
    train_losses = []
    val_losses = []
    best_loss = float('inf')
    patience = 50
    patience_counter = 0

    for i in range(num_epochs):
        # ====================================================================
        # TRAINING STEP
        # ====================================================================
        model.train()
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)  # Negative ELBO (we minimize this)
        train_losses.append(loss.item())

        loss.backward()
        optimizer.step()

        # ====================================================================
        # VALIDATION STEP (if validation data provided)
        # ====================================================================
        if val_y.shape[0] != 0:
            model.eval()
            with torch.no_grad():
                val_output = model(val_x)
                val_loss = -mll(val_output, val_y)
                val_losses.append(val_loss.item())

            # Early stopping based on validation loss
            if val_loss.item() < best_loss:
                best_loss = val_loss.item()
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {i+1}")
                break

            if i % 50 == 0:
                print(f'Epoch {i+1}/{num_epochs} - Train Loss: {loss.item():.3f} - Val Loss: {val_loss.item():.3f}')
        else:
            # Early stopping based on training loss (no validation set)
            model.eval()

            if loss.item() < best_loss and loss.item() > 0:
                best_loss = loss.item()
                patience_counter = 0
            elif loss.item() < 0:
                # Negative loss indicates numerical issues
                print(f"Stopping due to negative loss at epoch {i+1}")
                break
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {i+1}")
                break

            if i % 50 == 0:
                print(f'Epoch {i+1}/{num_epochs} - Train Loss: {loss.item():.3f}')

    # Optional: attach standard GP for hybrid posterior (currently set to None)
    model.gp_model = None

    return model


# ============================================================================
# MODEL EVALUATION
# ============================================================================

def evaluate_model(
    model: MultiTaskDeepGP,
    train_x: np.ndarray,
    train_y: np.ndarray,
    test_x: np.ndarray,
    test_y: np.ndarray,
    task_names: List[str],
    output_scalers: Dict,
    df_prior: Optional[pd.DataFrame] = None,
    train_orig_idx: Optional[np.ndarray] = None,
    test_orig_idx: Optional[np.ndarray] = None
) -> List[Dict]:
    """
    Evaluate model performance with comprehensive metrics and visualizations.
    
    Computes predictions for train and test sets, descales to original units,
    optionally adds back prior values, and generates parity plots with
    uncertainty quantification.
    
    Parameters
    ----------
    model : MultiTaskDeepGP
        Trained model
    train_x : np.ndarray
        Training inputs with task indices
    train_y : np.ndarray
        Training outputs (scaled)
    test_x : np.ndarray
        Test inputs with task indices (can be empty array if no test set)
    test_y : np.ndarray
        Test outputs (scaled, can be empty array if no test set)
    task_names : List[str]
        Names of output tasks
    output_scalers : Dict
        Scalers for inverse transformation to original units
    df_prior : pd.DataFrame, optional
        Prior values to add back after descaling (for residual learning)
    train_orig_idx : np.ndarray, optional
        Original sample indices for training data (for prior lookup)
    test_orig_idx : np.ndarray, optional
        Original sample indices for test data (for prior lookup)
        
    Returns
    -------
    List[Dict]
        List of metric dictionaries, one per task, containing:
        - task: task name
        - test_mae, test_rmse, test_r2: standard regression metrics
        - test_kendall, test_spearman: rank correlation metrics
        - gmae: geometric mean absolute error
        - smape: symmetric mean absolute percentage error
        - mase: mean absolute scaled error
        - rmspe: root mean squared percentage error
        - n_train_samples, n_test_samples: sample counts
        
    Notes
    -----
    Generates two plots per task:
    1. Parity plot: true vs predicted with 95% confidence intervals
    2. Distribution plot: histograms of true and predicted values
    
    If test_y is empty, uses training data for test metrics (self-evaluation).
    """
    model.eval()
    device = next(model.parameters()).device

    # ========================================================================
    # GET MODEL PREDICTIONS
    # ========================================================================
    with torch.no_grad():
        # Training predictions
        train_posterior = model.posterior(torch.tensor(train_x).to(device))
        train_mean = train_posterior.mean.squeeze().cpu().numpy()
        train_std = train_posterior.variance.sqrt().squeeze().cpu().numpy()
        
        # Test predictions (if test set exists)
        try:
            test_posterior = model.posterior(torch.tensor(test_x).to(device))
            test_mean = test_posterior.mean.squeeze().cpu().numpy()
            test_std = test_posterior.variance.sqrt().squeeze().cpu().numpy()
        except:
            print("No test data available - using training data for evaluation")

    # ========================================================================
    # EVALUATE EACH TASK
    # ========================================================================
    metrics = []
    for task_idx, task_name in enumerate(task_names):
        # Get masks for current task
        train_task_mask = (train_x[:, -1] == task_idx)
        if test_y.shape[0] != 0:
            test_task_mask = (test_x[:, -1] == task_idx)

        if np.sum(train_task_mask) > 0:
            # Extract predictions for this task
            train_preds_scaled = train_mean[train_task_mask]
            train_uncertainties_scaled = train_std[train_task_mask]
            train_true_scaled = train_y[train_task_mask]
            
            if test_y.shape[0] != 0:
                test_preds_scaled = test_mean[test_task_mask]
                test_uncertainties_scaled = test_std[test_task_mask]
                test_true_scaled = test_y[test_task_mask]

            # Inverse transform to original scale
            scaler = output_scalers[task_name]
            train_preds = scaler.inverse_transform(train_preds_scaled.reshape(-1, 1)).ravel()
            train_true = scaler.inverse_transform(train_true_scaled.reshape(-1, 1)).ravel()
            train_uncertainties = train_uncertainties_scaled * np.sqrt(scaler.var_)[0]
            
            if test_y.shape[0] != 0:
                test_preds = scaler.inverse_transform(test_preds_scaled.reshape(-1, 1)).ravel()
                test_true = scaler.inverse_transform(test_true_scaled.reshape(-1, 1)).ravel()
                test_uncertainties = test_uncertainties_scaled * np.sqrt(scaler.var_)[0]

            # ================================================================
            # ADD BACK PRIOR VALUES (for residual learning)
            # ================================================================
            if df_prior is not None and task_name in df_prior.columns:
                train_task_indices = np.where(train_task_mask)[0]
                train_data_indices = train_orig_idx[train_task_indices]
                train_prior = df_prior[task_name].iloc[train_data_indices].values

                if test_y.shape[0] != 0:
                    test_task_indices = np.where(test_task_mask)[0]
                    test_data_indices = test_orig_idx[test_task_indices]
                    test_prior = df_prior[task_name].iloc[test_data_indices].values

                print(f"\nTask: {task_name}")
                print(f"Train prior range: {train_prior.min():.2f} to {train_prior.max():.2f}")
                if test_y.shape[0] != 0:
                    print(f"Test prior range: {test_prior.min():.2f} to {test_prior.max():.2f}")

                train_preds += train_prior
                train_true += train_prior

                if test_y.shape[0] != 0:
                    test_preds += test_prior
                    test_true += test_prior

            # ================================================================
            # COMPUTE METRICS
            # ================================================================
            
            # Use test set if available, otherwise use training set
            if test_y.shape[0] != 0:
                eval_preds = test_preds
                eval_true = test_true
                eval_mask = test_task_mask
            else:
                eval_preds = train_preds
                eval_true = train_true
                eval_mask = train_task_mask

            # Standard regression metrics
            test_mae = mean_absolute_error(eval_true, eval_preds)
            test_rmse = np.sqrt(mean_squared_error(eval_true, eval_preds))
            test_r2 = r2_score(eval_true, eval_preds)
            
            # Rank correlation metrics
            test_kendall = kendalltau(eval_true, eval_preds)[0]
            test_spearman = spearmanr(eval_true, eval_preds)[0]

            # Geometric Mean Absolute Error (for multiplicative errors)
            eval_log_safe = np.where(eval_true <= 0, 1e-10, eval_true)
            preds_log_safe = np.where(eval_preds <= 0, 1e-10, eval_preds)
            gmae = np.mean(np.abs(np.log(eval_log_safe) - np.log(preds_log_safe)))

            # Symmetric Mean Absolute Percentage Error
            smape = np.mean(2 * np.abs(eval_preds - eval_true) /
                          (np.abs(eval_preds) + np.abs(eval_true))) * 100

            # Mean Absolute Scaled Error (scaled by naive forecast)
            scaling_factor = np.mean(np.abs(np.diff(eval_true)))
            mase = np.nan if scaling_factor == 0 else test_mae / scaling_factor

            # Root Mean Squared Percentage Error
            rmspe = np.sqrt(np.mean(np.square((eval_true - eval_preds) / eval_true)))

            metrics.append({
                'task': task_name,
                'test_mae': test_mae,
                'test_rmse': test_rmse,
                'test_r2': test_r2,
                'test_kendall': test_kendall,
                'test_spearman': test_spearman,
                'gmae': gmae,
                'smape': smape,
                'mase': mase,
                'rmspe': rmspe,
                'n_train_samples': np.sum(train_task_mask),
                'n_test_samples': np.sum(eval_mask)
            })

            # ================================================================
            # GENERATE VISUALIZATIONS
            # ================================================================
            plt.figure(figsize=(15, 5))

            # Parity plot (True vs Predicted)
            plt.subplot(121)
            
            # Plot training data with error bars (2 sigma = 95% confidence)
            plt.errorbar(train_true, train_preds, yerr=2*train_uncertainties,
                        fmt='o', alpha=0.3, capsize=3, markersize=4,
                        elinewidth=1, label='Training', color='blue')

            # Determine plot range
            if test_y.shape[0] != 0:
                all_true = np.concatenate([train_true, eval_true])
            else:
                all_true = train_true
            
            # Perfect prediction line
            plt.plot([min(all_true), max(all_true)],
                    [min(all_true), max(all_true)],
                    'k--', lw=2, label='Perfect prediction')

            # Plot test data if available
            if test_y.shape[0] != 0:
                plt.errorbar(eval_true, eval_preds, yerr=2*test_uncertainties,
                          fmt='o', alpha=0.3, capsize=3, markersize=4,
                          elinewidth=1, label='Test', color='red')

            plt.xlabel(f'True {task_name}')
            plt.ylabel(f'Predicted {task_name}')
            
            # Add metrics text box
            metrics_text = (f'R² = {test_r2:.3f}\n'
                          f'RMSE = {test_rmse:.3f}\n'
                          f'GMAE = {gmae:.3f}\n'
                          f'SMAPE = {smape:.1f}%\n'
                          f'MASE = {mase:.3f}\n'
                          f'RMSPE = {rmspe:.3f}')

            plt.text(0.98, 0.02, metrics_text,
                    transform=plt.gca().transAxes,
                    fontsize=10,
                    bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray'),
                    verticalalignment='bottom',
                    horizontalalignment='right')

            plt.title(f'{task_name}\nTest Spearman: {test_spearman:.3f}')
            plt.legend()

            # Distribution comparison plot
            plt.subplot(122)
            plt.hist(train_true, bins=20, alpha=0.5, label='Train True', color='blue')
            plt.hist(train_preds, bins=20, alpha=0.5, label='Train Predicted', color='lightblue')
            
            if test_y.shape[0] != 0:
                plt.hist(eval_true, bins=20, alpha=0.5, label='Test True', color='red')
                plt.hist(eval_preds, bins=20, alpha=0.5, label='Test Predicted', color='lightcoral')

            plt.xlabel(f'{task_name}')
            plt.ylabel('Frequency')
            plt.legend()

            plt.tight_layout()
            plt.show()

    return metrics


# ============================================================================
# MULTI-MODEL TRAINING AND COMPARISON
# ============================================================================

def train_model(
    input_scaled: np.ndarray,
    output_scaled: np.ndarray,
    input_scalers: Dict,
    output_scalers: Dict,
    input_vars: List[str],
    output_vars: List[str],
    yvar_scaled: Optional[np.ndarray] = None,
    df_prior: Optional[pd.DataFrame] = None,
    num_obj: int = 5,
    reductions: List[int] = [1, 3, 5, 7, 8, 9]
) -> Tuple[Dict, Dict]:
    """
    Train models with different reduction parameters and compare performance.
    
    This function enables hyperparameter search over the reduction parameter,
    which controls the hidden layer dimensionality. Training uses all available
    data (no train-test split).
    
    Parameters
    ----------
    input_scaled : np.ndarray
        Scaled input features
    output_scaled : np.ndarray
        Scaled output properties
    input_scalers : Dict
        Input variable scalers
    output_scalers : Dict
        Output variable scalers
    input_vars : List[str]
        Input variable names
    output_vars : List[str]
        Output variable names
    yvar_scaled : np.ndarray, optional
        Scaled observation noise variances
    df_prior : pd.DataFrame, optional
        Prior values for residual learning
    num_obj : int, default=5
        Number of output objectives to model
    reductions : List[int], default=[1, 3, 5, 7, 8, 9]
        List of reduction parameter values to try
        
    Returns
    -------
    model_results : Dict
        Nested dictionary: {reduction: [metrics_dict_per_task]}
    models : Dict
        Dictionary mapping reduction values to trained models
        
    Notes
    -----
    Hidden layer dimension = num_tasks - reduction
    Larger reduction → smaller hidden layer → faster but less expressive
    Smaller reduction → larger hidden layer → slower but more expressive
    """
    model_results = {}
    models = {}
    
    for reduction in reductions:
        print(f"\n{'='*70}")
        print(f"Training model with reduction={reduction}")
        print(f"Hidden layer dimension: {output_scaled.shape[1] - reduction}")
        print(f"{'='*70}")

        # Prepare training data with index tracking
        train_x, train_y, train_yvar, original_indices, task_indices = \
            prepare_training_pairs_with_indices(input_scaled, output_scaled, yvar_scaled)

        # Use full dataset for training (no validation split)
        # Set empty arrays for validation
        test_x = np.array([])
        test_y = np.array([])
        
        print(f"Training samples: {train_x.shape[0]}")
        print(f"Number of tasks: {output_scaled.shape[1]}")
        
        # Train model
        model = train_model_wo_validation(
            train_x,
            train_y,
            test_x,
            test_y,
            num_tasks=output_scaled.shape[1],
            train_yvar=train_yvar,
            test_yvar=None,
            num_obj=num_obj,
            num_epochs=3000,
            reduction=reduction
        )

        # Evaluate model performance
        metrics = evaluate_model(
            model,
            train_x,
            train_y,
            test_x,
            test_y,
            output_vars,
            output_scalers,
            df_prior=df_prior,
            train_orig_idx=original_indices,
            test_orig_idx=None
        )

        model_results[reduction] = metrics
        models[reduction] = model

    return model_results, models


# ============================================================================
# VISUALIZATION AND REPORTING
# ============================================================================

def plot_comparison(model_results: Dict) -> None:
    """
    Plot comparison of models with different reduction parameters.
    
    Generates two subplots showing mean performance across all tasks:
    1. RMSE vs reduction parameter
    2. Spearman correlation vs reduction parameter
    
    Error bars represent standard deviation across tasks.
    
    Parameters
    ----------
    model_results : Dict
        Nested dictionary from train_model: {reduction: [metrics_dict_per_task]}
    """
    reductions = list(model_results.keys())
    metrics = ['test_rmse', 'test_spearman']

    fig, axes = plt.subplots(2, 1, figsize=(12, 10))

    for idx, metric in enumerate(metrics):
        mean_scores = []
        std_scores = []

        for reduction in reductions:
            # Calculate mean and std across all tasks
            task_scores = [task_metric[metric] for task_metric in model_results[reduction]]
            mean_scores.append(np.mean(task_scores))
            std_scores.append(np.std(task_scores))

        # Plot with error bars
        axes[idx].errorbar(reductions, mean_scores, yerr=std_scores, 
                          fmt='o-', capsize=5, linewidth=2, markersize=8)
        axes[idx].set_xlabel('Reduction Parameter', fontsize=12)
        axes[idx].set_ylabel(f'Mean {metric.upper()}', fontsize=12)
        axes[idx].set_title(f'{metric.upper()} vs Reduction Parameter', fontsize=14)
        axes[idx].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


def print_model_comparison(model_results: Dict) -> None:
    """
    Print detailed comparison of models with different reduction parameters.
    
    Displays:
    1. Average metrics across all tasks for each reduction value
    2. Per-task metrics for each reduction value
    
    Parameters
    ----------
    model_results : Dict
        Nested dictionary from train_model: {reduction: [metrics_dict_per_task]}
    """
    print("\n" + "="*80)
    print("DETAILED MODEL COMPARISON")
    print("="*80)

    for reduction, metrics in model_results.items():
        print(f"\n{'-'*80}")
        print(f"REDUCTION PARAMETER: {reduction}")
        print(f"Hidden layer dimension: {metrics[0]['n_train_samples']}")  # Approximate
        print(f"{'-'*80}")

        # Calculate average metrics across all tasks
        avg_test_rmse = np.mean([m['test_rmse'] for m in metrics])
        avg_test_spearman = np.mean([m['test_spearman'] for m in metrics])
        avg_test_r2 = np.mean([m['test_r2'] for m in metrics])

        print(f"\nAverage Performance Across All Tasks:")
        print(f"  RMSE:         {avg_test_rmse:.4f}")
        print(f"  R²:           {avg_test_r2:.4f}")
        print(f"  Spearman ρ:   {avg_test_spearman:.4f}")

        # Print per-task metrics
        print(f"\nPer-Task Performance:")
        for metric in metrics:
            print(f"\n  {metric['task']}:")
            print(f"    RMSE:      {metric['test_rmse']:.4f}")
            print(f"    R²:        {metric['test_r2']:.4f}")
            print(f"    Spearman:  {metric['test_spearman']:.4f}")
            print(f"    MAE:       {metric['test_mae']:.4f}")
            print(f"    Samples:   {metric['n_train_samples']} train, {metric['n_test_samples']} test")

    print("\n" + "="*80)


# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    """
    Example usage demonstrating the complete pipeline.
    
    This example shows how to:
    1. Load and prepare data
    2. Handle prior-guided learning
    3. Train models with different architectures
    4. Compare and visualize results
    """
    
    # Define input and output variables
    input_vars = []

    output_vars = []    
    
    # Example: Load your data
    # df_actual = pd.read_excel('your_data.xlsx')
    
    
    
    # Example: Prepare data with standardization
    # input_scaled, output_scaled, yvar_scaled, input_scalers, output_scalers = \
    #     prepare_and_standardize_data(df_actual, input_vars, output_vars)
    
    # Example: Train models with different reduction parameters
    # model_results, models = train_model(
    #     input_scaled, output_scaled, input_scalers, output_scalers,
    #     input_vars, output_vars, num_obj=10, reductions=[4, 5, 6, 7, 8]
    # )
    
    # Example: Visualize comparison
    # plot_comparison(model_results)
    # print_model_comparison(model_results)
    
    print("Usage example loaded. Uncomment and modify for your specific data.")
