In [None]:
import wandb
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.model_selection import KFold, StratifiedKFold, StratifiedShuffleSplit, StratifiedGroupKFold
from sklearn.ensemble import RandomForestRegressor
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import wandb
from itertools import product
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from gemelli.preprocessing import matrix_rclr
from sklearn.metrics import mean_absolute_error, r2_score
import math
from functools import partial
from biom import load_table
from scipy import stats

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### version where pc vectors are projected into multiple views in a higher dim space.
class NormalizedTransformerBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(NormalizedTransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=4, dropout=0, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, input_dim),
        )
        self.alphaA = nn.Parameter(torch.tensor(1.0))  # Learnable scaling for attention updates
        self.alphaM = nn.Parameter(torch.tensor(1.0))  # Learnable scaling for MLP updates

    def forward(self, x):
        # Normalize input
        x = F.normalize(x, p=2, dim=-1)

        # Attention block
        hA, _ = self.attention(x, x, x)
        hA = F.normalize(hA, p=2, dim=-1)
        x = F.normalize(x + self.alphaA * (hA - x), p=2, dim=-1)

        # MLP block
        hM = self.mlp(x)
        hM = F.normalize(hM, p=2, dim=-1)
        x = F.normalize(x + self.alphaM * (hM - x), p=2, dim=-1)

        return x

class NormalizedTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim, projection_dim=4):
        super(NormalizedTransformer, self).__init__()
        self.projection_dim = projection_dim

        # Project PCA vector to hidden_dim
        self.pca_projection = nn.Linear(input_dim, hidden_dim)

        # Generate different "views" of the projected PCA vector
        self.view_generator = nn.Sequential(
            nn.Linear(hidden_dim, projection_dim * hidden_dim),
            nn.LayerNorm(projection_dim * hidden_dim)
        )

        # Transformer blocks remain the same
        self.transformer_blocks = nn.ModuleList(
            [NormalizedTransformerBlock(hidden_dim, hidden_dim * 2) for _ in range(num_layers)]
        )
        self.regression_head = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x shape: [batch, pca_dim]
        batch_size = x.shape[0]

        # Project PCA vector to hidden dimension
        x = self.pca_projection(x)  # Shape: [batch, hidden_dim]
        x = F.normalize(x, p=2, dim=-1)

        # Generate multiple views of the projected vector
        x = self.view_generator(x)  # Shape: [batch, projection_dim * hidden_dim]

        # Reshape to [batch, projection_dim, hidden_dim]
        x = x.view(batch_size, self.projection_dim, -1)
        x = F.normalize(x, p=2, dim=-1)

        # Pass through transformer blocks
        for block in self.transformer_blocks:
            x = block(x)

        # Global average pooling over projection dimensions
        x = x.mean(dim=1)  # Shape: [batch, hidden_dim]

        # Regression head
        output = self.regression_head(x)
        outputs = {'regression_output': output}
        return outputs


In [None]:
def calculate_sparsity(model, threshold=1e-5):
    """
    Calculate model sparsity using the formula:
    S = (1/D) * sum_{i=1}^D (1/n * sum_{j=1}^n I(a_{i,j} < τ))

    where:
    - D is the number of layers
    - n is the number of parameters in each layer
    - a_{i,j} is the j-th parameter in the i-th layer
    - τ (tau) is the threshold below which parameters are considered sparse
    - I() is the indicator function

    Args:
        model: PyTorch model
        threshold: float, threshold below which parameters are considered sparse

    Returns:
        float: sparsity score between 0 and 1
    """

    def count_sparse_elements(tensor, threshold):
        """Helper function to count elements below threshold"""
        return (torch.abs(tensor) < threshold).float().mean().item()

    # Get all parameter tensors
    params = list(model.parameters())

    # Calculate sparsity for each layer
    layer_sparsities = []
    for param in params:
        if param.dim() > 0:  # Skip scalar parameters
            sparsity = count_sparse_elements(param, threshold)
            layer_sparsities.append(sparsity)

    # Calculate average sparsity across all layers
    if layer_sparsities:
        total_sparsity = sum(layer_sparsities) / len(layer_sparsities)
        return total_sparsity
    else:
        return 0.0

def calculate_weight_entropy(model, epsilon=1e-10):
    """
    Calculate the absolute weight entropy using the formula:
    H(W) = -sum_{i=1}^m sum_{j=1}^n |w_{ij}| log|w_{ij}|

    Args:
        model: PyTorch model
        epsilon: small constant to avoid log(0)

    Returns:
        float: total weight entropy
        dict: layer-wise entropies
    """
    def compute_entropy(tensor):
        """Helper function to compute entropy for a single tensor"""
        # Flatten the tensor and take absolute values
        abs_weights = torch.abs(tensor.flatten())

        # Normalize weights to sum to 1 (treating them as probabilities)
        normalized_weights = abs_weights / (torch.sum(abs_weights) + epsilon)

        # Calculate entropy
        entropy = -torch.sum(
            normalized_weights * torch.log(normalized_weights + epsilon)
        ).item()

        return entropy

    total_entropy = 0.0
    layer_entropies = {}

    # Calculate entropy for each layer
    for name, param in model.named_parameters():
        if param.dim() > 0:  # Skip scalar parameters
            layer_entropy = compute_entropy(param)
            layer_entropies[name] = layer_entropy
            total_entropy += layer_entropy

    return total_entropy, layer_entropies

In [None]:
def train_with_test_loss(model, dataloaders, criterion, optimizer, run, num_epochs=20, device='cuda', scaler_y=None):
    train_loader, val_loader, test_loader = dataloaders

    # Initialize the cosine annealing scheduler with warm restarts
    # T_0 is the number of epochs before first restart
    # T_mult is the factor by which T_i increases after each restart
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=500,  # First restart occurs after 100 epochs
        T_mult=1,  # Each restart interval is twice as long as the previous one
        eta_min=0.0005
    )

    best_val_mae = float('inf')
    best_val_loss = float('inf')
    best_model_state = None

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        train_loss = 0.0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(x_batch)
            loss = criterion(outputs['regression_output'], y_batch) 
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Step the scheduler after each batch
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]

        # Validation Phase
        val_loss = 0.0
        y_true_val = []
        y_pred_val = []

        model.eval()
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                outputs = model(x_batch)
                loss = criterion(outputs['regression_output'], y_batch) 
                val_loss += loss.item()

                y_true_val.append(y_batch.cpu().numpy())
                y_pred_val.append(outputs['regression_output'].cpu().numpy())

            # Test Phase (Monitoring Test Dataset)
            test_loss = 0.0
            y_true_test = []
            y_pred_test = []
            for x_batch, y_batch in test_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                outputs = model(x_batch)
                loss = criterion(outputs['regression_output'], y_batch)
                test_loss += loss.item()

                y_true_test.append(y_batch.cpu().numpy())
                y_pred_test.append(outputs['regression_output'].cpu().numpy())

        y_true_val = np.concatenate(y_true_val)
        y_pred_val = np.concatenate(y_pred_val)
        y_true_test = np.concatenate(y_true_test)
        y_pred_test = np.concatenate(y_pred_test)

        if scaler_y is not None:
            y_true_val_original = scaler_y.inverse_transform(y_true_val)
            y_pred_val_original = scaler_y.inverse_transform(y_pred_val)
            y_true_test_original = scaler_y.inverse_transform(y_true_test)
            y_pred_test_original = scaler_y.inverse_transform(y_pred_test)
        else:
            y_true_val_original = y_true_val
            y_pred_val_original = y_pred_val
            y_true_test_original = y_true_test
            y_pred_test_original = y_pred_test

        val_mae = mean_absolute_error(y_true_val_original, y_pred_val_original)
        val_r2 = r2_score(y_true_val_original, y_pred_val_original)
        test_mae = mean_absolute_error(y_true_test_original, y_pred_test_original)
        test_r2 = r2_score(y_true_test_original, y_pred_test_original)

        if val_mae < best_val_mae:
            best_val_mae = val_mae
            best_model_state = model.state_dict().copy()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        test_loss /= len(test_loader)

        # Calculate sparsity and entropy metrics
        sparsity = calculate_sparsity(model)
        abs_weight_entropy, layer_entropies = calculate_weight_entropy(model)

        # Log all metrics, including test metrics and learning rate
        run.log({
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'test_loss': test_loss,
            'val_mae_original_scale': val_mae,
            'val_r2_original_scale': val_r2,
            'test_mae_original_scale': test_mae,
            'test_r2_original_scale': test_r2,
            'sparsity': sparsity,
            'absolute_weight_entropy': abs_weight_entropy,
            'learning_rate': current_lr
        })

    # After training, load the best model state
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

In [None]:
def run_cv_experiment(params, n_splits=5, device='cuda'):
    """
    Run cross-validation experiment with transformer model and save indexed predictions.

    Args:
        params (dict): Model and training parameters
        n_splits (int): Number of CV splits
        device (str): Computing device ('cuda' or 'cpu')
    """
    # Initialize wandb
    run = wandb.init(
        project=f"wgs_single_reviewer_{params['body_site']}",
        config=params,
        reinit=True
    )

    try:
        # Data preparation for WGS
        table = pd.read_csv('control.csv', index_col=0)
        age_metadata = pd.read_csv('sampleMetadata.csv', index_col='sample_id', dtype={'age': float})
        age_metadata = age_metadata.loc[(age_metadata.age.notna()) & (age_metadata.body_site == params['body_site'])]
        table = table.loc[table.index.isin(age_metadata.index)]
        table = table.drop_duplicates(subset='subject_id', keep='first')
        shared_index = table.index.intersection(age_metadata.index)
        table = table.loc[shared_index].drop(columns=['study_name', 'study_condition', 'subject_id'])
        age_metadata = age_metadata.loc[shared_index]

        # Remove columns with all zeros
        all_zero_columns = (table == 0).all(axis=0)
        table = table.loc[:, ~all_zero_columns]
        df = (table * 1e7).round().astype(int)

        # Data Preparation for 16S
        # table = load_table('data/skin_1975.biom').to_dataframe(dense=True).T.astype(int)
        # age_metadata = pd.read_csv('data/skin_1975_map.txt', sep='\t', index_col=0, dtype={'qiita_host_age': float})
        # # age_metadata = age_metadata.drop_duplicates(subset='host_subject_id')
        # table = table.loc[age_metadata.index]
        # columns_to_drop = table.columns[table.apply(lambda col: (col != 0).sum()) < 25]# drop columns with low prev
        # df = table.drop(columns=columns_to_drop).copy()
        # print(df.shape)
        # Prepare target variable
        y = age_metadata.age.values.reshape(-1, 1)
        arr = np.nan_to_num(matrix_rclr(df.values), nan=0.0)
        arr_reduced = np.nan_to_num(matrix_rclr(df.values), nan=0.0)
        if arr.ndim > 2:
            arr = arr.reshape(arr.shape[0], -1)

        # PCA reduction
        pca = PCA(n_components=256)
        arr_reduced = pca.fit_transform(arr)

        print(f"Original dimensions: {arr.shape}")
        print(f"Reduced dimensions: {arr_reduced.shape}")
        print(f"Number of components: {pca.n_components_}")
        print(f"Explained variance ratio: {pca.explained_variance_ratio_.sum():.3f}")

        X = torch.tensor(arr_reduced).float()
        y = torch.tensor(y).float()

        # Initialize scalers
        scaler_X = StandardScaler() if params.get('normalize_X', True) else None
        scaler_y = MinMaxScaler() if params.get('normalize_y', True) else None

        # Apply normalization
        X_np = scaler_X.fit_transform(X.numpy()) if scaler_X else X.numpy()
        y_np = scaler_y.fit_transform(y.numpy()) if scaler_y else y.numpy()

        # Create stratification bins
        n_bins = 5
        strata = pd.qcut(age_metadata.age, q=n_bins, labels=[f'age_bin_{i}' for i in range(n_bins)]).astype(str)+age_metadata.study_name.astype(str)+age_metadata.country.astype(str)
        groups = age_metadata.subject_id.astype(str)


        # Filter out samples from strata with less than 10 occurrences
        strata_counts = strata.value_counts()
        valid_strata = strata_counts[strata_counts >= 10].index

        # Create mask and apply to all relevant variables
        mask = strata.isin(valid_strata)
        X_filtered = X[mask]
        y_filtered = y[mask]
        X_np = X_np[mask]
        y_np = y_np[mask]
        age_metadata_filtered = age_metadata[mask]
        strata_filtered = strata[mask]
        groups_filtered = groups[mask]

        # Print filtering stats
        print(f"Original samples: {len(strata)}")
        print(f"Samples after filtering strata with <10 occurrences: {len(strata_filtered)}")
        print(f"Removed {len(strata) - len(strata_filtered)} samples")

        # Update variables to use filtered versions
        X = X_filtered
        y = y_filtered
        age_metadata = age_metadata_filtered
        strata = strata_filtered
        groups = groups_filtered

        kf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)

        # Initialize prediction tracking
        predictions_dict = {
            'sample_id': [],
            'true_age': [],
            'predicted_age': [],
            'fold': []
        }

        fold_results = []

        for fold, (train_index, test_index) in enumerate(kf.split(X_np, groups=groups, y=strata), 1):
            try:
                # Split data for current fold
                X_train_full, X_test = X_np[train_index], X_np[test_index]
                y_train_full, y_test = y_np[train_index], y_np[test_index]
                strata_train_full = strata.iloc[train_index]

                # Create validation split
                sss = StratifiedShuffleSplit(n_splits=1, test_size=params['test_split'], random_state=42)
                train_index_sub, val_index = next(sss.split(X_train_full, y=strata_train_full))

                # Final train/val split
                X_train = X_train_full[train_index_sub]
                y_train = y_train_full[train_index_sub]
                X_val = X_train_full[val_index]
                y_val = y_train_full[val_index]

                # Create data loaders
                train_data = TensorDataset(torch.tensor(X_train).float(), torch.tensor(y_train).float())
                val_data = TensorDataset(torch.tensor(X_val).float(), torch.tensor(y_val).float())
                test_data = TensorDataset(torch.tensor(X_test).float(), torch.tensor(y_test).float())

                train_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
                val_loader = DataLoader(val_data, batch_size=params['batch_size'])
                test_loader = DataLoader(test_data, batch_size=params['batch_size'])

                model = NormalizedTransformer(
                    input_dim=X_train.shape[1],
                    num_layers=params['num_layers'],
                    hidden_dim=params['hidden_dim'],
                    output_dim=1,
                ).to(device)

                # Initialize weights
                def init_weights(m):
                    if isinstance(m, nn.Linear):
                        torch.nn.init.xavier_uniform_(m.weight)
                        if m.bias is not None:
                            torch.nn.init.zeros_(m.bias)

                model.apply(init_weights)
                model = model.to(device)

                # Setup optimizer
                optimizer = params['optimizer'](
                    model.parameters(),
                    lr=params['learning_rate'],
                    weight_decay=params['weight_decay']
                )

                criterion = nn.MSELoss()#nn.HuberLoss(delta=10)#

                # Training phase
                try:
                    train_with_test_loss(
                        model,
                        (train_loader, val_loader, test_loader),
                        criterion,
                        optimizer,
                        run,
                        num_epochs=params['num_epochs'],
                        device=device,
                        scaler_y=scaler_y
                    )
                except RuntimeError as e:
                    if "nan" in str(e).lower():
                        run.log({
                            f'fold_{fold}_error': f'NaN loss detected during training: {str(e)}',
                            f'fold_{fold}_status': 'failed_nan_loss'
                        })
                        print(f"Fold {fold} failed due to NaN loss. Skipping to next fold.")
                        continue

                # Evaluation phase
                try:
                    model.eval()
                    with torch.no_grad():
                        batch_start = 0
                        for x_batch, y_batch in test_loader:
                            # Get indices for current batch
                            batch_size = len(x_batch)
                            batch_indices = test_index[batch_start:batch_start + batch_size]
                            batch_start += batch_size

                            x_batch = x_batch.to(device)
                            reg_outputs = model(x_batch)['regression_output'].cpu().numpy()

                            if np.any(np.isnan(reg_outputs)):
                                raise RuntimeError("NaN values detected in model predictions")

                            # Convert predictions back to original scale
                            y_true_batch = scaler_y.inverse_transform(y_batch.numpy()) if scaler_y else y_batch.numpy()
                            y_pred_batch = scaler_y.inverse_transform(reg_outputs) if scaler_y else reg_outputs

                            # Store predictions with corresponding indices
                            predictions_dict['sample_id'].extend(age_metadata.index[batch_indices])
                            predictions_dict['true_age'].extend(y_true_batch.flatten())
                            predictions_dict['predicted_age'].extend(y_pred_batch.flatten())
                            predictions_dict['fold'].extend([fold] * batch_size)

                    # Calculate metrics for this fold
                    fold_true = predictions_dict['true_age'][-len(test_index):]
                    fold_pred = predictions_dict['predicted_age'][-len(test_index):]

                    mae = mean_absolute_error(fold_true, fold_pred)
                    r2 = r2_score(fold_true, fold_pred)

                    # Log results
                    fold_results.append({
                        'fold': fold,
                        'mae': mae,
                        'r2': r2
                    })

                    run.log({
                        f'fold_{fold}_mae': mae,
                        f'fold_{fold}_r2': r2,
                        f'fold_{fold}_status': 'completed'
                    })

                except Exception as e:
                    run.log({
                        f'fold_{fold}_error': f'Error during evaluation: {str(e)}',
                        f'fold_{fold}_status': 'failed_evaluation'
                    })
                    print(f"Error during evaluation of fold {fold}: {str(e)}")

            except Exception as e:
                run.log({
                    f'fold_{fold}_error': f'Fold processing error: {str(e)}',
                    f'fold_{fold}_status': 'failed_processing'
                })
                print(f"Error processing fold {fold}: {str(e)}")

        # Calculate and log overall metrics
        if fold_results:
            overall_mae = np.mean([r['mae'] for r in fold_results])
            overall_r2 = np.mean([r['r2'] for r in fold_results])
            mae_std = np.std([r['mae'] for r in fold_results])
            r2_std = np.std([r['r2'] for r in fold_results])

            # Create predictions DataFrame
            predictions_df = pd.DataFrame(predictions_dict)

            # Save predictions locally and to wandb
            predictions_df.to_csv('predictions.csv')
            table = wandb.Table(dataframe=predictions_df)
            run.log({
                "predictions_table": table,
                'overall_mae': overall_mae,
                'overall_r2': overall_r2,
                'mae_std': mae_std,
                'r2_std': r2_std
            })

            # Create final publication-quality regression plot
            plt.figure(figsize=(8, 8))

            # Create scatter plot
            plt.scatter(predictions_df['true_age'], predictions_df['predicted_age'],
                       alpha=0.3, color='#4169E1',
                       edgecolor='none', s=60, label='Test Predictions')

            # Calculate and plot best fit line
            slope, intercept, r_value, p_value, std_err = stats.linregress(
                predictions_df['true_age'],
                predictions_df['predicted_age']
            )
            line_x = np.linspace(min(predictions_df['true_age']),
                                max(predictions_df['true_age']), 100)
            line_y = slope * line_x + intercept
            plt.plot(line_x, line_y, color='#C4161C', linestyle='--',
                     label=f'Best Fit (R² = {r_value**2:.3f})')

            # Add perfect prediction line (y=x)
            plt.plot([min(predictions_df['true_age']), max(predictions_df['true_age'])],
                     [min(predictions_df['true_age']), max(predictions_df['true_age'])],
                     color='black', linestyle='-', alpha=0.3, label='Perfect Prediction')

            # Set labels and title with metrics
            plt.xlabel("True Age (years)", fontsize=12, fontweight='bold')
            plt.ylabel("Predicted Age (years)", fontsize=12, fontweight='bold')
            plt.title(f"MAE = {overall_mae:.2f} ± {mae_std:.2f} years",
                      fontsize=14, fontweight='bold', pad=15)

            # Customize grid
            plt.grid(True, linestyle='--', alpha=0.3)

            # Add legend
            plt.legend(frameon=True, facecolor='white', framealpha=1,
                      edgecolor='none', loc='upper left')

            # Set equal aspect ratio
            plt.axis('equal')

            # Adjust layout
            plt.tight_layout()

            # Customize spines
            for spine in plt.gca().spines.values():
                spine.set_linewidth(1.5)

            # Save the plot
            plt.savefig("final_regression_plot.png", dpi=300, bbox_inches='tight')
            run.log({"final_regression_plot": wandb.Image("final_regression_plot.png")})
            plt.close()

    except Exception as e:
        run.log({
            'experiment_error': str(e),
            'experiment_status': 'failed'
        })
        print(f"Fatal error in experiment: {str(e)}")

    finally:
        run.finish()

    return {
        'overall_mae': overall_mae if 'overall_mae' in locals() else None,
        'overall_r2': overall_r2 if 'overall_r2' in locals() else None,
        'predictions_df': predictions_df if 'predictions_df' in locals() else None
    }

In [None]:
# from _typeshed import TraceFunction
if __name__ == "__main__":
    # Define parameters
    body_sites = ['skin']
    num_layers = [1]
    hidden_dims = [512]
    batch_sizes = [4096]
    learning_rates = [0.001]
    weight_decays = [0.001]
    test_splits = [0.2]
    optimizers = [optim.AdamW]
    n_splits=10

    # Device configuration
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Nested loops for parameter search
    for body_site in body_sites:
        for num_layer in num_layers:
            for hidden_dim in hidden_dims:
                for batch_size in batch_sizes:
                    for lr in learning_rates:
                        for wd in weight_decays:
                            for test_split in test_splits:
                                for optum in optimizers:
                                    print(f"\nTrying parameters: hidden_dim={hidden_dim}, batch_size={batch_size}, "
                                        f"learning_rate={lr}, weight_decay={wd}")

                                    current_params = {
                                        'hidden_dim': hidden_dim,
                                        'batch_size': batch_size,
                                        'num_layers': num_layer,
                                        'learning_rate': lr,
                                        'weight_decay': wd,
                                        'num_epochs': 1000,
                                        'optimizer': optum,
                                        'body_site': body_site,
                                        'test_split': test_split,
                                        'normalize_X': False,
                                        'normalize_y': False,
                                    }


                                    run_cv_experiment(current_params, n_splits=n_splits, device=device)