In [None]:
# Part 1: Imports and Model Architecture
import wandb
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
from sklearn.metrics import mean_absolute_error, r2_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedGroupKFold, StratifiedShuffleSplit
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from collections import Counter
import torch.nn.functional as F
from typing import Dict, Tuple, List
from gemelli.preprocessing import matrix_rclr
import json
import os
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.decomposition import PCA

class NormalizedTransformerBlock(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super(NormalizedTransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, dropout=0, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GeLU(),
            nn.Linear(hidden_dim, input_dim),
        )
        self.alphaA = nn.Parameter(torch.tensor(1.0))
        self.alphaM = nn.Parameter(torch.tensor(1.0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Normalize input
        x = F.normalize(x, p=2, dim=-1)

        # Attention block
        hA, _ = self.attention(x, x, x)
        hA = F.normalize(hA, p=2, dim=-1)
        x = F.normalize(x + self.alphaA * (hA - x), p=2, dim=-1)

        # MLP block
        hM = self.mlp(x)
        hM = F.normalize(hM, p=2, dim=-1)
        x = F.normalize(x + self.alphaM * (hM - x), p=2, dim=-1)

        return x

class MTLNormalizedTransformer(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, num_countries: int, projection_dim: int = 1):
        super(MTLNormalizedTransformer, self).__init__()
        self.projection_dim = projection_dim

        # Shared backbone
        self.pca_projection = nn.Linear(input_dim, hidden_dim)
        self.view_generator = nn.Sequential(
            nn.Linear(hidden_dim, projection_dim * hidden_dim),
            nn.LayerNorm(projection_dim * hidden_dim)
        )

        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            NormalizedTransformerBlock(hidden_dim, hidden_dim * 2)
            for _ in range(num_layers)
        ])

        # Task-specific heads
        self.regression_head = nn.Linear(hidden_dim, 1)
        self.classification_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_countries)
        )

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        batch_size = x.shape[0]

        # Shared processing
        x = self.pca_projection(x)
        x = F.normalize(x, p=2, dim=-1)
        x = self.view_generator(x)
        x = x.view(batch_size, self.projection_dim, -1)
        x = F.normalize(x, p=2, dim=-1)

        # Transformer processing
        for block in self.transformer_blocks:
            x = block(x)

        # Global pooling
        x = x.mean(dim=1)

        # Task-specific outputs
        regression_output = self.regression_head(x)
        classification_logits = self.classification_head(x)

        return {
            'regression_output': regression_output,
            'classification_output': classification_logits
        }

def calculate_sparsity(model: nn.Module, threshold: float = 1e-5) -> float:
    """Calculate model sparsity."""
    def count_sparse_elements(tensor: torch.Tensor, threshold: float) -> float:
        return (torch.abs(tensor) < threshold).float().mean().item()

    params = [p for p in model.parameters() if p.dim() > 0]
    if not params:
        return 0.0

    sparsities = [count_sparse_elements(p, threshold) for p in params]
    return sum(sparsities) / len(sparsities)

def calculate_weight_entropy(model: nn.Module, epsilon: float = 1e-10) -> Tuple[float, Dict[str, float]]:
    """Calculate the absolute weight entropy."""
    def compute_entropy(tensor: torch.Tensor) -> float:
        abs_weights = torch.abs(tensor.flatten())
        normalized_weights = abs_weights / (torch.sum(abs_weights) + epsilon)
        entropy = -torch.sum(normalized_weights * torch.log(normalized_weights + epsilon)).item()
        return entropy

    total_entropy = 0.0
    layer_entropies = {}

    for name, param in model.named_parameters():
        if param.dim() > 0:
            layer_entropy = compute_entropy(param)
            layer_entropies[name] = layer_entropy
            total_entropy += layer_entropy

    return total_entropy, layer_entropies

class UncertaintyLoss(nn.Module):
    def __init__(self, num_tasks=2):
        super().__init__()
        self.log_vars = nn.Parameter(torch.zeros(num_tasks))

    def forward(self, losses):
        precision1 = torch.exp(-self.log_vars[0])
        precision2 = torch.exp(-self.log_vars[1])
        return precision1 * losses[0] + precision2 * losses[1] + self.log_vars.sum()


def load_and_preprocess_data(params: Dict) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Load and preprocess the data files with proper error handling and type conversion.
    """
    # Load the control table
    table = pd.read_csv('data/control.csv', index_col=0, low_memory=False)

    # Load metadata with proper type handling
    age_metadata = pd.read_csv('data/sampleMetadata.csv', index_col='sample_id', low_memory=False)

    # Convert age to numeric, setting errors='coerce' will convert non-numeric values to NaN
    age_metadata['age'] = pd.to_numeric(age_metadata['age'], errors='coerce')

    # Filter out rows where age is NaN
    age_metadata = age_metadata[age_metadata['age'].notna()]

    # Filter by body site
    age_metadata = age_metadata[age_metadata['body_site'] == params['body_site']]

    # Ensure we have enough data
    if len(age_metadata) < 10:
        raise ValueError(f"Not enough samples for body site {params['body_site']}")

    print(f"Number of samples after filtering: {len(age_metadata)}")
    print(f"Age range: {age_metadata['age'].min():.1f} - {age_metadata['age'].max():.1f}")
    print(f"Number of countries: {age_metadata['country'].nunique()}")

    return table, age_metadata

def prepare_mtl_datasets(table: pd.DataFrame, age_metadata: pd.DataFrame,
                        params: Dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray, LabelEncoder, np.ndarray, np.ndarray]:
    """Prepare datasets for multi-task learning with stratification by age category and study."""
    try:
        print("\nInitial data shapes:")
        print(f"Table shape: {table.shape}")
        print(f"Metadata shape: {age_metadata.shape}")

        # Handle age processing
        age_metadata = age_metadata[age_metadata['age'].notna() & (age_metadata.age > 18)].copy()
        print(f"\nSamples after removing NaN ages: {len(age_metadata)}")
        print(f"Age range: {age_metadata['age'].min():.1f} - {age_metadata['age'].max():.1f}")

        # Create age categories
        age_metadata['age_category'] = pd.cut(age_metadata['age'],
                                             bins=[18, 30, 45, 60, 75, 100],
                                             labels=['18-30', '31-45', '46-60', '61-75', '76+'],
                                             right=False)

        # Handle country data
        if 'country' not in age_metadata.columns:
            raise ValueError("'country' column not found in metadata")

        # Remove rows with missing country data
        age_metadata = age_metadata[age_metadata['country'].notna()]

        # Get country distribution and select top 5
        country_counts = Counter(age_metadata['country'])
        top_5_countries = [country for country, _ in country_counts.most_common(5)]
        print(f"\nSelected top 5 countries: {', '.join(top_5_countries)}")

        # Filter for top 5 countries
        mask = age_metadata['country'].isin(top_5_countries)
        age_metadata = age_metadata[mask]

        # Get abundance data
        metadata_cols = ['study_name', 'study_condition', 'subject_id', 'body_site']
        abundance_cols = [col for col in table.columns if col not in metadata_cols]
        abundance_data = table[abundance_cols].copy()

        # Find shared indices
        shared_indices = abundance_data.index.intersection(age_metadata.index)
        print(f"\nShared samples before filtering: {len(shared_indices)}")

        if len(shared_indices) == 0:
            raise ValueError("No shared samples between abundance data and metadata")

        # Filter data using shared indices
        abundance_data = abundance_data.loc[shared_indices]
        age_metadata = age_metadata.loc[shared_indices]

        # Handle abundance data
        abundance_data = abundance_data.apply(pd.to_numeric, errors='coerce')
        valid_abundance = ~abundance_data.isna().any(axis=1)
        abundance_data = abundance_data[valid_abundance]
        age_metadata = age_metadata.loc[abundance_data.index]

        print(f"Samples after removing NaN abundances: {len(abundance_data)}")

        # Get balanced sample sizes per country
        country_counts = Counter(age_metadata['country'])
        min_samples = min(country_counts.values())
        print(f"\nBalancing to {min_samples} samples per country")

        # Perform balanced sampling
        balanced_indices = []
        np.random.seed(42)

        for country in top_5_countries:
            country_indices = age_metadata[age_metadata['country'] == country].index
            if len(country_indices) >= min_samples:
                sampled_indices = np.random.choice(country_indices, min_samples, replace=False)
                balanced_indices.extend(sampled_indices)

        # Apply balanced sampling
        age_metadata = age_metadata.loc[balanced_indices]
        abundance_data = abundance_data.loc[balanced_indices]

        # Create stratification labels
        age_metadata['stratify_label'] = (age_metadata['study_name'].astype(str) +
                                        "_" +
                                        age_metadata['age_category'].astype(str) +
                                        "_" +
                                        age_metadata['country'].astype(str))

        # Prepare final datasets
        X = abundance_data.values
        y_age = age_metadata['age'].values.reshape(-1, 1)
        subject_ids = age_metadata['subject_id'].values

        # Create label encoder for countries
        label_encoder = LabelEncoder()
        y_country = label_encoder.fit_transform(age_metadata['country'])

        # Print final dataset statistics
        print("\nFinal balanced dataset statistics:")
        print(f"Number of features: {X.shape[1]}")
        print(f"Number of samples: {X.shape[0]}")
        print(f"Number of countries: {len(label_encoder.classes_)}")
        print("Countries and sample counts:")
        for country in label_encoder.classes_:
            count = (age_metadata['country'] == country).sum()
            print(f"{country}: {count} samples")
        print(f"Age range: {y_age.min():.1f} - {y_age.max():.1f}")
        print(f"Number of stratification classes: {len(np.unique(age_metadata['stratify_label']))}")

        # Verify balance
        unique_counts = np.unique(y_country, return_counts=True)[1]
        if not np.all(unique_counts == unique_counts[0]):
            raise ValueError("Dataset is not properly balanced")

        return (X.astype(np.float32),
                y_age.astype(np.float32),
                y_country,
                label_encoder,
                age_metadata['stratify_label'].values,
                subject_ids)

    except Exception as e:
        print(f"Error in prepare_mtl_datasets: {str(e)}")
        print(f"Table shape: {table.shape}")
        print(f"Age metadata shape: {age_metadata.shape}")
        raise

def prepare_country_data(age_metadata: pd.DataFrame, min_samples: int = 100) -> Tuple[pd.DataFrame, LabelEncoder, np.ndarray]:
    """Filter and prepare balanced country data with proper index handling."""
    # Get top 5 countries
    country_counts = Counter(age_metadata['country'])

    top_countries = [country for country, count in country_counts.most_common()
                    if count >= min_samples][:5]


    # Filter metadata
    filtered_metadata = age_metadata[age_metadata['country'].isin(top_countries)].copy()

    # Balance dataset by undersampling with index preservation
    min_count = min(filtered_metadata['country'].value_counts())
    print(f"\nBalancing datasets to {min_count} samples per country")

    balanced_data = []
    for country in top_countries:
        country_data = filtered_metadata[filtered_metadata['country'] == country]
        # Use sample with index preservation
        sampled_data = country_data.sample(n=min_count, random_state=42)
        balanced_data.append(sampled_data)

    balanced_metadata = pd.concat(balanced_data)

    # Create label encoder
    label_encoder = LabelEncoder()
    country_labels = label_encoder.fit_transform(balanced_metadata['country'])


    return balanced_metadata, label_encoder, country_labels.astype(np.int64)

def compute_mtl_loss(outputs: Dict[str, torch.Tensor],
                    age_batch: torch.Tensor,
                    country_batch: torch.Tensor,
                    regression_criterion: nn.Module,
                    classification_criterion: nn.Module,
                    reg_weight: float = 1.0,
                    cls_weight: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute combined loss for multi-task learning."""
    regression_loss = regression_criterion(outputs['regression_output'], age_batch)
    classification_loss = classification_criterion(outputs['classification_output'], country_batch)

    uncertainty_loss = UncertaintyLoss()
    total_loss = uncertainty_loss([regression_loss, classification_loss])

    return total_loss, regression_loss, classification_loss

# Part 3: Visualization Functions
def create_confusion_matrix_plot(true_labels: np.ndarray,
                               pred_labels: np.ndarray,
                               label_encoder: LabelEncoder) -> wandb.Image:
    """Create and save confusion matrix visualization."""
    plt.figure(figsize=(10, 8))
    cm = confusion_matrix(true_labels, pred_labels)

    # Calculate percentages
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

    # Create heatmap
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=label_encoder.classes_,
                yticklabels=label_encoder.classes_)

    plt.title('Confusion Matrix for Country Classification')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')

    # Customize appearance
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)

    # Adjust layout to prevent label cutoff
    plt.tight_layout()

    # Save plot
    plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.close()

    return wandb.Image('confusion_matrix.png')

# Part 3 Continued: Complete Regression Plot Function
def create_regression_plot(true_ages: np.ndarray,
                         pred_ages: np.ndarray,
                         mae: float,
                         r2: float,
                         mae_std: float = None) -> wandb.Image:
    """Create and save regression plot."""
    plt.figure(figsize=(8, 8))

    # Scatter plot
    plt.scatter(true_ages, pred_ages, alpha=0.3, color='#4169E1',
               edgecolor='none', s=60, label='Test Predictions')

    # Best fit line
    slope, intercept, r_value, p_value, std_err = stats.linregress(true_ages, pred_ages)
    line_x = np.linspace(min(true_ages), max(true_ages), 100)
    line_y = slope * line_x + intercept
    plt.plot(line_x, line_y, color='#C4161C', linestyle='--',
             label=f'Best Fit (R² = {r2:.3f})')

    # Perfect prediction line
    plt.plot([min(true_ages), max(true_ages)],
             [min(true_ages), max(true_ages)],
             color='black', linestyle='-', alpha=0.3,
             label='Perfect Prediction')

    # Labels and title
    plt.xlabel('True Age (years)', fontsize=12, fontweight='bold')
    plt.ylabel('Predicted Age (years)', fontsize=12, fontweight='bold')
    if mae_std is not None:
        plt.title(f'MAE = {mae:.2f} ± {mae_std:.2f} years',
                 fontsize=14, fontweight='bold', pad=15)
    else:
        plt.title(f'MAE = {mae:.2f} years',
                 fontsize=14, fontweight='bold', pad=15)

    # Customize appearance
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.legend(frameon=True, facecolor='white', framealpha=1,
              edgecolor='none', loc='upper left')
    plt.axis('equal')

    # Customize spines
    for spine in plt.gca().spines.values():
        spine.set_linewidth(1.5)

    plt.tight_layout()
    plt.savefig('regression_plot.png', dpi=300, bbox_inches='tight')
    plt.close()

    return wandb.Image('regression_plot.png')

def create_training_plots(history: Dict[str, List[float]]) -> Tuple[wandb.Image, wandb.Image]:
    """Create training history plots."""
    # Loss plot
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_reg_loss'], label='Train Regression Loss')
    plt.plot(history['val_reg_loss'], label='Val Regression Loss')
    plt.plot(history['train_cls_loss'], label='Train Classification Loss')
    plt.plot(history['val_cls_loss'], label='Val Classification Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.savefig('loss_plot.png', dpi=300)
    plt.close()

    # Metrics plot
    plt.figure(figsize=(10, 5))
    plt.plot(history['val_mae'], label='Validation MAE')
    plt.plot(history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Metric Value')
    plt.title('Validation Metrics')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.savefig('metrics_plot.png', dpi=300)
    plt.close()

    return wandb.Image('loss_plot.png'), wandb.Image('metrics_plot.png')

def create_combined_results_figure(confusion_matrix_img: wandb.Image,
                                 regression_plot_img: wandb.Image,
                                 save_path: str = 'combined_results.png') -> wandb.Image:
    """Create a combined figure with both confusion matrix and regression plot."""
    plt.figure(figsize=(16, 8))

    # Add confusion matrix
    plt.subplot(1, 2, 1)
    plt.imshow(plt.imread('confusion_matrix.png'))
    plt.axis('off')

    # Add regression plot
    plt.subplot(1, 2, 2)
    plt.imshow(plt.imread('regression_plot.png'))
    plt.axis('off')

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

    return wandb.Image(save_path)

# Part 4: Training Functions
from typing import Dict, Tuple, List, Any

def train_mtl_model(model: nn.Module,
                   dataloaders: Tuple[DataLoader, DataLoader, DataLoader],
                   criteria: Tuple[nn.Module, nn.Module],
                   optimizer: optim.Optimizer,
                   run: Any,
                   num_epochs: int = 20,
                   device: str = 'cuda',
                   scaler_y: StandardScaler = None,
                   label_encoder: LabelEncoder = None,
                   early_stopping_patience: int = 5000) -> Tuple[nn.Module, Dict[str, List[float]], Dict]:
    """Train the multi-task learning model with proper best model tracking."""
    train_loader, val_loader, test_loader = dataloaders
    regression_criterion, classification_criterion = criteria

    # Initialize best model tracking
    best_val_loss = float('inf')
    best_model_state = None
    best_epoch = 0
    patience_counter = 0
    best_metrics = {}

    # Initialize history tracking
    history = {
        'train_reg_loss': [], 'train_cls_loss': [],
        'val_reg_loss': [], 'val_cls_loss': [],
        'val_mae': [], 'val_accuracy': []
    }

    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=100,  # Number of epochs before first restart
        T_mult=2,  # Multiply T_0 by this factor after each restart
        eta_min=1e-6  # Minimum learning rate
    )


    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_reg_loss = 0.0
        train_cls_loss = 0.0

        for x_batch, y_age_batch, y_country_batch in train_loader:
            x_batch = x_batch.to(device)
            y_age_batch = y_age_batch.to(device)
            y_country_batch = y_country_batch.to(device)

            optimizer.zero_grad()
            outputs = model(x_batch)

            total_loss, reg_loss, cls_loss = compute_mtl_loss(
                outputs, y_age_batch, y_country_batch,
                regression_criterion, classification_criterion
            )

            total_loss.backward()
            optimizer.step()

            train_reg_loss += reg_loss.item()
            train_cls_loss += cls_loss.item()

        # Validation phase
        model.eval()
        val_reg_loss = 0.0
        val_cls_loss = 0.0
        val_preds_age = []
        val_true_age = []
        val_preds_country = []
        val_true_country = []

        with torch.no_grad():
            for x_batch, y_age_batch, y_country_batch in val_loader:
                x_batch = x_batch.to(device)
                y_age_batch = y_age_batch.to(device)
                y_country_batch = y_country_batch.to(device)

                outputs = model(x_batch)

                _, reg_loss, cls_loss = compute_mtl_loss(
                    outputs, y_age_batch, y_country_batch,
                    regression_criterion, classification_criterion
                )

                val_reg_loss += reg_loss.item()
                val_cls_loss += cls_loss.item()

                val_preds_age.extend(outputs['regression_output'].cpu().numpy())
                val_true_age.extend(y_age_batch.cpu().numpy())
                val_preds_country.extend(outputs['classification_output'].argmax(1).cpu().numpy())
                val_true_country.extend(y_country_batch.cpu().numpy())

        # Calculate validation metrics
        val_preds_age = np.array(val_preds_age)
        val_true_age = np.array(val_true_age)
        if scaler_y is not None:
            val_preds_age = scaler_y.inverse_transform(val_preds_age)
            val_true_age = scaler_y.inverse_transform(val_true_age)

        val_mae = mean_absolute_error(val_true_age, val_preds_age)
        val_accuracy = (np.array(val_preds_country) == np.array(val_true_country)).mean()

        # Update history
        val_total_loss = val_reg_loss + val_cls_loss
        history['train_reg_loss'].append(train_reg_loss / len(train_loader))
        history['train_cls_loss'].append(train_cls_loss / len(train_loader))
        history['val_reg_loss'].append(val_reg_loss / len(val_loader))
        history['val_cls_loss'].append(val_cls_loss / len(val_loader))
        history['val_mae'].append(val_mae)
        history['val_accuracy'].append(val_accuracy)


        scheduler.step()
        # Log metrics
        run.log({
            'epoch': epoch,
            'train_regression_loss': history['train_reg_loss'][-1],
            'train_classification_loss': history['train_cls_loss'][-1],
            'val_regression_loss': history['val_reg_loss'][-1],
            'val_classification_loss': history['val_cls_loss'][-1],
            'val_mae': val_mae,
            'val_accuracy': val_accuracy
        })

        # Save best model
        if val_total_loss < best_val_loss:
            best_val_loss = val_total_loss
            best_model_state = model.state_dict().copy()
            best_epoch = epoch
            patience_counter = 0

            # Store best metrics
            best_metrics = {
                'val_mae': val_mae,
                'val_accuracy': val_accuracy,
                'val_reg_loss': val_reg_loss / len(val_loader),
                'val_cls_loss': val_cls_loss / len(val_loader),
                'epoch': epoch
            }
        else:
            patience_counter += 1

        # Early stopping
        if patience_counter >= early_stopping_patience:
            print(f"\nEarly stopping triggered after {epoch + 1} epochs")
            break

    # Load best model for final evaluation
    print(f"\nLoading best model from epoch {best_epoch}")
    model.load_state_dict(best_model_state)

    return model, history, best_metrics

def evaluate_model(model: nn.Module,
                  test_loader: DataLoader,
                  regression_criterion: nn.Module,
                  classification_criterion: nn.Module,
                  device: str,
                  scaler_y: StandardScaler = None) -> Dict[str, float]:
    """Evaluate the model on test data using the best validation model."""
    model.eval()
    test_reg_loss = 0.0
    test_cls_loss = 0.0
    test_preds_age = []
    test_true_age = []
    test_preds_country = []
    test_true_country = []

    with torch.no_grad():
        for x_batch, y_age_batch, y_country_batch in test_loader:
            x_batch = x_batch.to(device)
            y_age_batch = y_age_batch.to(device)
            y_country_batch = y_country_batch.to(device)

            outputs = model(x_batch)

            _, reg_loss, cls_loss = compute_mtl_loss(
                outputs, y_age_batch, y_country_batch,
                regression_criterion, classification_criterion
            )

            test_reg_loss += reg_loss.item()
            test_cls_loss += cls_loss.item()

            test_preds_age.extend(outputs['regression_output'].cpu().numpy())
            test_true_age.extend(y_age_batch.cpu().numpy())
            test_preds_country.extend(outputs['classification_output'].argmax(1).cpu().numpy())
            test_true_country.extend(y_country_batch.cpu().numpy())

    # Calculate test metrics
    test_preds_age = np.array(test_preds_age)
    test_true_age = np.array(test_true_age)
    if scaler_y is not None:
        test_preds_age = scaler_y.inverse_transform(test_preds_age)
        test_true_age = scaler_y.inverse_transform(test_true_age)

    test_mae = mean_absolute_error(test_true_age, test_preds_age)
    test_r2 = r2_score(test_true_age, test_preds_age)
    test_accuracy = (np.array(test_preds_country) == np.array(test_true_country)).mean()

    return {
        'test_mae': test_mae,
        'test_r2': test_r2,
        'test_accuracy': test_accuracy,
        'test_reg_loss': test_reg_loss / len(test_loader),
        'test_cls_loss': test_cls_loss / len(test_loader),
        'test_preds_age': test_preds_age,
        'test_true_age': test_true_age,
        'test_preds_country': test_preds_country,
        'test_true_country': test_true_country
    }

def calculate_overall_metrics(true_ages, pred_ages, true_countries, pred_countries, fold_results):
    """Calculate overall metrics with standard deviations."""
    try:
        # Ensure arrays are properly shaped
        true_ages = true_ages.flatten()
        pred_ages = pred_ages.flatten()
        true_countries = true_countries.flatten()
        pred_countries = pred_countries.flatten()

        # Calculate overall metrics
        overall_mae = mean_absolute_error(true_ages, pred_ages)
        overall_r2 = r2_score(true_ages, pred_ages)
        overall_accuracy = np.mean(true_countries == pred_countries)

        # Compute standard deviations across folds
        mae_std = np.std([fold['test_mae'] for fold in fold_results])
        r2_std = np.std([fold['test_r2'] for fold in fold_results])
        accuracy_std = np.std([fold['test_accuracy'] for fold in fold_results])

        return {
            'mae': overall_mae,
            'mae_std': mae_std,
            'r2': overall_r2,
            'r2_std': r2_std,
            'accuracy': overall_accuracy,
            'accuracy_std': accuracy_std
        }
    except Exception as e:
        print(f"Error in calculate_overall_metrics: {str(e)}")
        print(f"Original shapes:")
        print(f"true_ages: {true_ages.shape}, pred_ages: {pred_ages.shape}")
        print(f"true_countries: {true_countries.shape}, pred_countries: {pred_countries.shape}")
        print("Number of folds:", len(fold_results))
        raise

def create_confusion_matrix_plot(true_labels: np.ndarray,
                               pred_labels: np.ndarray,
                               label_encoder: LabelEncoder) -> wandb.Image:
    """Create and save confusion matrix visualization with larger text."""
    try:
        # Set the global font size
        plt.rcParams.update({'font.size': 14})  # Increase base font size

        plt.figure(figsize=(12, 10))  # Made figure larger
        cm = confusion_matrix(true_labels, pred_labels)

        # Calculate percentages
        cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

        # Create heatmap with larger annotations
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=label_encoder.classes_,
                   yticklabels=label_encoder.classes_,
                   annot_kws={'size': 16})  # Larger numbers in cells

        # Set larger fonts for title and labels
        plt.title('Confusion Matrix for Country Classification',
                 fontsize=20, pad=20)  # Larger title
        plt.xlabel('Predicted Labels', fontsize=16, labelpad=10)  # Larger x-label
        plt.ylabel('True Labels', fontsize=16, labelpad=10)  # Larger y-label

        # Customize appearance with larger tick labels
        plt.xticks(rotation=45, ha='right', fontsize=14)  # Larger x-tick labels
        plt.yticks(rotation=0, fontsize=14)  # Larger y-tick labels

        # Adjust layout to prevent label cutoff
        plt.tight_layout()

        # Save plot with high DPI for clarity
        plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
        plt.close()

        # Reset the global font size to default
        plt.rcParams.update({'font.size': plt.rcParamsDefault['font.size']})

        return 'confusion_matrix.png'
    except Exception as e:
        print(f"Error in create_confusion_matrix_plot: {str(e)}")
        print(f"Shapes - true_labels: {true_labels.shape}, pred_labels: {pred_labels.shape}")
        raise

def create_regression_plot(true_ages: np.ndarray,
                         pred_ages: np.ndarray,
                         mae: float,
                         r2: float,
                         mae_std: float = None) -> str:
    """Create and save regression plot."""
    try:
        plt.figure(figsize=(8, 8))

        # Flatten arrays if they're 2D
        true_ages = true_ages.flatten()
        pred_ages = pred_ages.flatten()

        # Scatter plot
        plt.scatter(true_ages, pred_ages, alpha=0.3, color='#4169E1',
                   edgecolor='none', s=60, label='Test Predictions')

        # Best fit line
        slope, intercept, r_value, p_value, std_err = stats.linregress(true_ages, pred_ages)
        line_x = np.linspace(min(true_ages), max(true_ages), 100)
        line_y = slope * line_x + intercept
        plt.plot(line_x, line_y, color='#C4161C', linestyle='--',
                 label=f'Best Fit (R² = {r2:.3f})')

        # Perfect prediction line
        plt.plot([min(true_ages), max(true_ages)],
                 [min(true_ages), max(true_ages)],
                 color='black', linestyle='-', alpha=0.3,
                 label='Perfect Prediction')

        # Labels and title
        plt.xlabel('True Age (years)', fontsize=12, fontweight='bold')
        plt.ylabel('Predicted Age (years)', fontsize=12, fontweight='bold')
        if mae_std is not None:
            plt.title(f'MAE = {mae:.2f} ± {mae_std:.2f} years',
                     fontsize=14, fontweight='bold', pad=15)
        else:
            plt.title(f'MAE = {mae:.2f} years',
                     fontsize=14, fontweight='bold', pad=15)

        plt.grid(True, linestyle='--', alpha=0.3)
        plt.legend(frameon=True, facecolor='white', framealpha=1,
                  edgecolor='none', loc='upper left')
        plt.axis('equal')

        # Customize spines
        for spine in plt.gca().spines.values():
            spine.set_linewidth(1.5)

        plt.tight_layout()
        plt.savefig('regression_plot.png', dpi=300, bbox_inches='tight')
        plt.close()

        return 'regression_plot.png'  # Return the path instead of wandb.Image
    except Exception as e:
        print(f"Error in create_regression_plot: {str(e)}")
        print(f"Shapes - true_ages: {true_ages.shape}, pred_ages: {pred_ages.shape}")
        raise

def create_combined_results_figure(confusion_matrix_path: str,
                                 regression_plot_path: str,
                                 save_path: str = 'combined_results.png') -> str:
    """Create a combined figure with both confusion matrix and regression plot."""
    try:
        plt.figure(figsize=(16, 8))

        # Add confusion matrix
        plt.subplot(1, 2, 1)
        plt.imshow(plt.imread(confusion_matrix_path))
        plt.axis('off')

        # Add regression plot
        plt.subplot(1, 2, 2)
        plt.imshow(plt.imread(regression_plot_path))
        plt.axis('off')

        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

        return save_path
    except Exception as e:
        print(f"Error in create_combined_results_figure: {str(e)}")
        raise

def create_and_log_visualizations(results, true_ages, pred_ages, true_countries, pred_countries, label_encoder, run):
    """Create and log all visualizations."""
    try:
        # Create confusion matrix
        cm_path = create_confusion_matrix_plot(
            true_countries, pred_countries, label_encoder
        )

        # Create regression plot
        reg_path = create_regression_plot(
            true_ages, pred_ages,
            results['overall_metrics']['mae'],
            results['overall_metrics']['r2'],
            results['overall_metrics']['mae_std']
        )

        # Create combined plot
        combined_path = create_combined_results_figure(
            cm_path,
            reg_path
        )

        # Log to wandb
        run.log({
            'confusion_matrix': wandb.Image(cm_path),
            'regression_plot': wandb.Image(reg_path),
            'combined_results': wandb.Image(combined_path),
            **results['overall_metrics']
        })
    except Exception as e:
        print(f"Error in create_and_log_visualizations: {str(e)}")
        print(f"Shapes - true_ages: {true_ages.shape}, pred_ages: {pred_ages.shape}")
        print(f"Shapes - true_countries: {true_countries.shape}, pred_countries: {pred_countries.shape}")
        raise

import optuna
from optuna.trial import TrialState
from functools import partial

# Modified run_mtl_experiment function with skip capability
def run_mtl_experiment(params: Dict, n_splits: int = 5, device: str = 'cuda',
                      mae_threshold: float = 20.0, accuracy_threshold: float = 0.5) -> Dict:
    """Run the complete MTL experiment with proper CV and data splits."""
    # Set random seeds
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    run = wandb.init(
        project=f"mtl_WGS_{params['body_site']}_final",
        config=params,
        reinit=True
    )

    try:
        # Load and prepare data
        print("Loading data...")
        table = pd.read_csv('data/control.csv', index_col=0, low_memory=False)
        age_metadata = pd.read_csv('data/sampleMetadata.csv', index_col='sample_id', low_memory=False)
        age_metadata = age_metadata.loc[(age_metadata.age.notna()) & (age_metadata.body_site == params['body_site'])]
        table = table.loc[table.index.isin(age_metadata.index)]
        table = table.drop_duplicates(subset='subject_id', keep='first')
        shared_index = table.index.intersection(age_metadata.index)
        age_metadata = age_metadata.loc[shared_index]

        # Get data with stratification labels
        X, y_age, y_country, label_encoder, stratify_labels, subject_ids = prepare_mtl_datasets(table, age_metadata, params)
        X = matrix_rclr(X)

        # Modified cross-validation setup
        cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)


        # Initialize results structure
        results = {
            'aborted': False,
            'fold_results': [],
            'overall_metrics': {},
            'best_validation_metrics': [],
            'predictions': {
                'true_ages': np.zeros(len(X)),
                'pred_ages': np.zeros(len(X)),
                'true_countries': np.zeros(len(X), dtype=int),
                'pred_countries': np.zeros(len(X), dtype=int),
                'fold_indices': np.zeros(len(X), dtype=int)
            }
        }

        for fold, (train_idx, test_idx) in enumerate(cv.split(X, stratify_labels, groups=subject_ids)):
            # [Fold processing remains unchanged until after test_metrics...]
            print(f"\nFold {fold + 1}/{n_splits}")
            # Create validation split from training data
            sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)
            train_idx_final, val_idx = next(sss.split(X[train_idx], y_country[train_idx]))

            # Map indices back to original data
            train_idx_final = train_idx[train_idx_final]
            val_idx = train_idx[val_idx]

            # Split data
            X_train = X[train_idx_final]
            X_val = X[val_idx]
            X_test = X[test_idx]

            y_age_train = y_age[train_idx_final]
            y_age_val = y_age[val_idx]
            y_age_test = y_age[test_idx]

            y_country_train = y_country[train_idx_final]
            y_country_val = y_country[val_idx]
            y_country_test = y_country[test_idx]

            # Scale features
            scaler_X = StandardScaler()
            X_train_scaled = X_train
            X_val_scaled = X_val
            X_test_scaled = X_test

            # Scale age values
            scaler_y = MinMaxScaler()
            y_age_train_scaled = scaler_y.fit_transform(y_age_train)
            y_age_val_scaled = scaler_y.transform(y_age_val)
            y_age_test_scaled = scaler_y.transform(y_age_test)

            # Create data loaders
            train_data = TensorDataset(
                torch.FloatTensor(X_train_scaled),
                torch.FloatTensor(y_age_train_scaled),
                torch.LongTensor(y_country_train)
            )
            val_data = TensorDataset(
                torch.FloatTensor(X_val_scaled),
                torch.FloatTensor(y_age_val_scaled),
                torch.LongTensor(y_country_val)
            )
            test_data = TensorDataset(
                torch.FloatTensor(X_test_scaled),
                torch.FloatTensor(y_age_test_scaled),
                torch.LongTensor(y_country_test)
            )

            train_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
            val_loader = DataLoader(val_data, batch_size=params['batch_size'])
            test_loader = DataLoader(test_data, batch_size=params['batch_size'])

            # Initialize model
            model = MTLNormalizedTransformer(
                input_dim=X.shape[1],
                hidden_dim=params['hidden_dim'],
                num_layers=params['num_layers'],
                num_countries=len(label_encoder.classes_)
            ).to(device)

            # Train model
            optimizer = params['optimizer'](
                model.parameters(),
                lr=params['learning_rate'],
                weight_decay=params['weight_decay']
            )
            regression_criterion = nn.MSELoss()
            classification_criterion = nn.CrossEntropyLoss()

            # Train with best validation model tracking
            model, history, best_val_metrics = train_mtl_model(
                model=model,
                dataloaders=(train_loader, val_loader, test_loader),
                criteria=(regression_criterion, classification_criterion),
                optimizer=optimizer,
                run=run,
                num_epochs=params['num_epochs'],
                device=device,
                scaler_y=scaler_y,
                label_encoder=label_encoder,
                early_stopping_patience=5000,
            )

            # Evaluate using best validation model
            test_metrics = evaluate_model(
                model=model,
                test_loader=test_loader,
                regression_criterion=regression_criterion,
                classification_criterion=classification_criterion,
                device=device,
                scaler_y=scaler_y
            )
            # Store results
            fold_results = {
                'fold': fold + 1,
                'test_mae': test_metrics['test_mae'],
                'test_r2': test_metrics['test_r2'],
                'test_accuracy': test_metrics['test_accuracy'],
                'best_validation_epoch': best_val_metrics['epoch'],
                'best_validation_mae': best_val_metrics['val_mae'],
                'best_validation_accuracy': best_val_metrics['val_accuracy'],
                'test_indices': test_idx
            }

            results['fold_results'].append(fold_results)
            results['best_validation_metrics'].append(best_val_metrics)

            # Check if current fold exceeds thresholds
            if (test_metrics['test_mae'] > mae_threshold or
                test_metrics['test_accuracy'] < accuracy_threshold):
                print(f"Fold {fold+1} failed thresholds (MAE > {mae_threshold} or "
                      f"Accuracy < {accuracy_threshold}). Aborting trial.")
                results['aborted'] = True
                break  # Exit fold loop early

            # Store predictions in their correct positions
            results['predictions']['true_ages'][test_idx] = test_metrics['test_true_age'].flatten()
            results['predictions']['pred_ages'][test_idx] = test_metrics['test_preds_age'].flatten()
            results['predictions']['true_countries'][test_idx] = test_metrics['test_true_country']
            results['predictions']['pred_countries'][test_idx] = test_metrics['test_preds_country']
            results['predictions']['fold_indices'][test_idx] = fold + 1

            # Log fold results
            run.log({
                f'fold_{fold+1}_test_mae': test_metrics['test_mae'],
                f'fold_{fold+1}_test_r2': test_metrics['test_r2'],
                f'fold_{fold+1}_test_accuracy': test_metrics['test_accuracy']
            })

            print(f"\nFold {fold + 1} Results:")
            print(f"Test MAE: {test_metrics['test_mae']:.3f}")
            print(f"Test R2: {test_metrics['test_r2']:.3f}")
            print(f"Test Accuracy: {test_metrics['test_accuracy']:.3f}")

        print("\nCalculating overall metrics...")
        print(f"Number of samples: {len(X)}")
        print(f"Number of folds: {len(results['fold_results'])}")

        # Verify each sample was used exactly once
        fold_counts = np.bincount(results['predictions']['fold_indices'].astype(int))[1:]
        print("\nSamples per fold:", fold_counts)

        # Only calculate metrics if trial wasn't aborted
        if not results['aborted']:
            try:
                # Calculate overall metrics
                results['overall_metrics'] = calculate_overall_metrics(
                    results['predictions']['true_ages'],
                    results['predictions']['pred_ages'],
                    results['predictions']['true_countries'],
                    results['predictions']['pred_countries'],
                    results['fold_results']
                )

                print("\nCreating visualizations...")
                # Create visualizations
                create_and_log_visualizations(
                    results,
                    results['predictions']['true_ages'],
                    results['predictions']['pred_ages'],
                    results['predictions']['true_countries'],
                    results['predictions']['pred_countries'],
                    label_encoder,
                    run
                )

                # Print final results
                print("\nFinal Results:")
                print(f"MAE: {results['overall_metrics']['mae']:.2f} ± {results['overall_metrics']['mae_std']:.2f}")
                print(f"R²: {results['overall_metrics']['r2']:.3f} ± {results['overall_metrics']['r2_std']:.3f}")
                print(f"Accuracy: {results['overall_metrics']['accuracy']:.3f} ± {results['overall_metrics']['accuracy_std']:.3f}")

            except Exception as e:
                print(f"Error in metrics calculation: {str(e)}")
                print("Prediction shapes:")
                for key, value in results['predictions'].items():
                    print(f"{key}: {value.shape}")
                raise
        else:
            print("Trial aborted due to failed thresholds. Skipping final metrics.")

    except Exception as e:
        error_msg = f"Experiment failed: {str(e)}"
        print(error_msg)
        run.log({'error': error_msg})
        raise

    finally:
        run.finish()
        import gc
        gc.collect()
        torch.cuda.empty_cache()
    return results

# Modified objective function
def objective(trial: optuna.Trial, device: str = 'cuda') -> float:
    """Objective function with early trial termination."""
    params = {
        'body_site': 'stool',
        'hidden_dim': trial.suggest_categorical('hidden_dim', [256]),
        'num_layers': 1,
        'batch_size': trial.suggest_categorical('batch_size', [4096]),
        'learning_rate': trial.suggest_float('learning_rate', 1e-4, 7e-1, log=True),
        'weight_decay': trial.suggest_float('weight_decay', 1e-6, 1, log=True),
        'test_split': 0.2,
        'num_epochs': trial.suggest_categorical('num_epochs', [1000]),
        'normalize_X': False,
        'normalize_y': True,
    }

    # Optimizer selection
    optimizer_name = trial.suggest_categorical('optimizer', ['SGD'])#
    params['optimizer'] = optim.AdamW if optimizer_name == 'AdamW' else optim.SGD


    try:
        results = run_mtl_experiment(
            params,
            n_splits=10,
            device=device,
            mae_threshold=12.0,      # Set your MAE threshold
            accuracy_threshold=0.75   # Set your accuracy threshold
        )

        if results['aborted']:
            # Return large value to mark bad trial
            trial.set_user_attr('aborted_reason', 'Threshold exceeded')
            return 1000.0  # High value for minimization

        # Calculate combined score (lower is better)
        combined_score = results['overall_metrics']['mae'] * (1 - results['overall_metrics']['accuracy'])
        # Log additional metrics to Optuna
        trial.set_user_attr('accuracy', results['overall_metrics']['accuracy'])
        trial.set_user_attr('mae', results['overall_metrics']['mae'])
        trial.set_user_attr('r2', results['overall_metrics']['r2'])

        return combined_score

    except Exception as e:
        print(f"Trial failed: {str(e)}")
        raise optuna.exceptions.TrialPruned()


def run_optimization(n_trials: int = 200, device: str = 'cuda') -> optuna.Study:
    """Run hyperparameter optimization."""
    study = optuna.create_study(
        direction="minimize",
        study_name="microbiome_mtl_optimization",
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.MedianPruner(
            n_startup_trials=10,
            n_warmup_steps=40,
            interval_steps=20
        )
    )

    study.optimize(
        partial(objective, device=device),
        n_trials=n_trials,
        callbacks=[
            lambda study, trial: print(f"\nTrial {trial.number} finished with score: {trial.value}")
        ]
    )

    # Print optimization results
    print("\nStudy statistics: ")
    print(f"  Number of finished trials: {len(study.trials)}")
    print(f"  Number of pruned trials: {len(study.get_trials(states=[TrialState.PRUNED]))}")
    print(f"  Number of complete trials: {len(study.get_trials(states=[TrialState.COMPLETE]))}")

    print("\nBest trial:")
    trial = study.best_trial
    print(f"  Best combined score: {trial.value}")
    print("  Best parameters:")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")

    # Calculate parameter importances
    try:
        importances = optuna.importance.get_param_importances(study)
        print("\nParameter importances:")
        for param, importance in importances.items():
            print(f"    {param}: {importance:.3f}")
    except:
        print("Could not compute parameter importances")

    return study

if __name__ == "__main__":
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Run optimization
    study = run_optimization(n_trials=3, device=device)

    # Save study results
    study.trials_dataframe().to_csv("optuna_results.csv")

    # Get best parameters
    best_params = study.best_trial.params
    final_params = {
        'body_site': 'stool',
        'hidden_dim': best_params['hidden_dim'],
        'num_layers': best_params['num_layers'],
        'batch_size': best_params['batch_size'],
        'learning_rate': best_params['learning_rate'],
        'weight_decay': best_params['weight_decay'],
        'test_split': 0.2,
        'num_epochs': 1000,
        'normalize_X': False,
        'normalize_y': True,
        'optimizer': optim.AdamW if best_params['optimizer'] == 'AdamW' else optim.SGD
    }

    # Run final experiment with best parameters
    print("\nRunning final experiment with best parameters...")
    final_results = run_mtl_experiment(final_params, n_splits=5, device=device)

    # Save final results
    with open('final_results.json', 'w') as f:
        json.dump({
            'best_parameters': best_params,
            'final_metrics': final_results['overall_metrics']
        }, f, indent=2)

    print("\nResults have been saved to 'final_results.json'")