In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import transformers
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             roc_auc_score, confusion_matrix, ConfusionMatrixDisplay,
                             cohen_kappa_score)
import matplotlib.pyplot as plt
import copy
import gc
import time

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA")
else:
    device = torch.device("cpu")
    print("Using CPU")

ESM2_MODEL_NAMES = [
    "facebook/esm2_t6_8M_UR50D",
    "facebook/esm2_t12_35M_UR50D",
    "facebook/esm2_t30_150M_UR50D",
    "facebook/esm2_t33_650M_UR50D",
]

try:
    df1 = pd.read_excel('pnas.1616408114.sd01.xlsx')
    df2 = pd.read_excel('pnas.1616408114.sd02.xlsx')
    df3 = pd.read_excel('pnas.1616408114.sd03.xlsx')
    merged_df = df1.merge(df2, on='Name', how='outer').merge(df3, on='Name', how='outer')
    df = merged_df[['VH', 'VL', 'Poly-Specificity Reagent (PSR) SMP Score (0-1)']].copy()
    df = df.rename(columns={'Poly-Specificity Reagent (PSR) SMP Score (0-1)': 'psr'})
except FileNotFoundError:
    print("Warning: One or more Excel files not found. Using a placeholder empty DataFrame.")
    print("Please ensure 'pnas.1616408114.sd01.xlsx', 'pnas.1616408114.sd02.xlsx', 'pnas.1616408114.sd03.xlsx' are available.")
    data = {'VH': ['SEQVHONE', 'SEQVHTWO', 'SEQVHTHREE'] * 10,
            'VL': ['SEQVLONE', 'SEQVLTWO', 'SEQVLTHREE'] * 10,
            'psr': np.random.rand(30) * 0.5
           }
    df = pd.DataFrame(data)


NUM_CLASSES = 3

def psr_to_label(psr_value):
        if psr_value < 0.10: return 0
        elif 0.10 <= psr_value <= 0.33: return 1
        else: return 2

df['label'] = df['psr'].apply(psr_to_label)
print(df.head())

label_counts = df['label'].value_counts().sort_index()
print(f"\nLabel distribution:\n{label_counts}")

weights_values = np.zeros(NUM_CLASSES)
if not label_counts.empty:
    for i in range(NUM_CLASSES):
        if i in label_counts.index:
            weights_values[i] = 1.0 / np.sqrt(label_counts[i])
        else:
            weights_values[i] = 1.0
            print(f"Warning: Class {i} not found in data for weight calculation. Using default weight 1.0.")
else:
    print("Warning: Label counts are empty. Using default weights of 1.0 for all classes.")
    weights_values = np.ones(NUM_CLASSES)


class_weights_tensor = torch.tensor(weights_values, dtype=torch.float)
print(f"Calculated class weights: {class_weights_tensor.tolist()}")

MAX_LENGTH = 256
BATCH_SIZE = 8
NUM_FOLDS = 5
SEED = 42
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
EPOCHS = 10
EARLY_STOPPING_PATIENCE = 3
LR_SCHEDULER_PATIENCE = 1

class AntibodyPsrDataset(Dataset):
    def __init__(self, vh_sequences, vl_sequences, targets, tokenizer, max_len):
        self.vh_sequences = vh_sequences
        self.vl_sequences = vl_sequences
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.vh_sequences)

    def __getitem__(self, idx):
        vh_seq = str(self.vh_sequences[idx])
        vl_seq = str(self.vl_sequences[idx])
        target = int(self.targets[idx])

        combined_sequence = vh_seq + 'X' + vl_seq

        encoding = self.tokenizer.encode_plus(
            combined_sequence, add_special_tokens=True, max_length=self.max_len,
            return_token_type_ids=False, padding='max_length', truncation=True,
            return_attention_mask=True, return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(target, dtype=torch.long)
        }

def create_simple_classifier(input_size, num_classes):
    return nn.Linear(input_size, num_classes)

def create_medium_mlp_classifier(input_size, num_classes, hidden_dim=256):
    return nn.Sequential(
        nn.Linear(input_size, hidden_dim),
        nn.GELU(),
        nn.BatchNorm1d(hidden_dim),
        nn.Dropout(0.30),
        nn.Linear(hidden_dim, num_classes)
    )

def create_deep_mlp_classifier(input_size, num_classes, hidden_dims=[512, 256, 128]):
    layers = []
    current_dim = input_size
    for h_dim in hidden_dims:
        layers.append(nn.Linear(current_dim, h_dim))
        layers.append(nn.GELU())
        layers.append(nn.BatchNorm1d(h_dim))
        layers.append(nn.Dropout(0.30))
        current_dim = h_dim
    layers.append(nn.Linear(current_dim, num_classes))
    return nn.Sequential(*layers)

class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attention_scorer = nn.Linear(hidden_size, 1)

    def forward(self, last_hidden_state, attention_mask):
        scores = self.attention_scorer(last_hidden_state)
        scores = scores.squeeze(-1)
        
        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, -float('inf'))
            
        weights = nn.functional.softmax(scores, dim=1)
        weights = weights.unsqueeze(-1)
        
        pooled_output = torch.sum(weights * last_hidden_state, dim=1)
        return pooled_output

class EsmForAntibodyPsrClassification(nn.Module):
    def __init__(self, model_name, num_classes, head_type='medium'):
        super().__init__()
        print(f"Loading base ESM model: {model_name}")
        self.esm_model = AutoModel.from_pretrained(model_name)
        hidden_size = self.esm_model.config.hidden_size
        
        print("Initializing Attention Pooling layer...")
        self.pooler = AttentionPooling(hidden_size)

        print("Freezing base ESM model parameters...")
        for param in self.esm_model.parameters():
            param.requires_grad = False
        
        classifier_input_size = hidden_size
        print(f"Creating classification head of type: {head_type} with input size {classifier_input_size}")
        if head_type == 'simple':
            self.classifier = create_simple_classifier(classifier_input_size, num_classes)
        elif head_type == 'medium':
            self.classifier = create_medium_mlp_classifier(classifier_input_size, num_classes)
        elif head_type == 'deep':
            self.classifier = create_deep_mlp_classifier(classifier_input_size, num_classes)
        else:
            raise ValueError("Invalid head_type. Choose 'simple', 'medium', or 'deep'.")
        print("Model initialization complete.")

    def forward(self, input_ids, attention_mask):
        outputs = self.esm_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        last_hidden_state = outputs.last_hidden_state
        
        pooled_output = self.pooler(last_hidden_state, attention_mask)
        
        logits = self.classifier(pooled_output)
        return logits

def initialize_model_and_optimizer(model_name, num_classes, head_type, learning_rate, weight_decay, class_weights_tensor):
    model = EsmForAntibodyPsrClassification(model_name, num_classes, head_type=head_type)
    model.to(device)
    
    trainable_params = list(model.pooler.parameters()) + list(model.classifier.parameters())

    optimizer = optim.AdamW(
        trainable_params,
        lr=learning_rate,
        weight_decay=weight_decay
    )
    print("\nOptimizer initialized. Training parameters (attention pooler + classification head):")
    num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {num_trainable}")
    
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor.to(device) if class_weights_tensor is not None else None)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=LR_SCHEDULER_PATIENCE, verbose=True)
    
    return model, optimizer, criterion, scheduler

def train_epoch(model, data_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    start_time = time.time()
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    avg_loss = total_loss / len(data_loader)
    elapsed_time = time.time() - start_time
    print(f"  Train Epoch completed in {elapsed_time:.2f}s, Avg. Loss: {avg_loss:.4f}")
    return avg_loss

def evaluate_model(model, data_loader, criterion, device, num_classes):
    model.eval()
    total_loss = 0.0
    all_labels, all_predictions, all_probabilities = [], [], []
    start_time = time.time()
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            total_loss += loss.item()
            
            probabilities = torch.softmax(logits, dim=1)
            predictions = torch.argmax(probabilities, dim=1)
            
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
            
    avg_loss = total_loss / len(data_loader)
    all_labels, all_predictions, all_probabilities = np.array(all_labels), np.array(all_predictions), np.array(all_probabilities)
    
    accuracy = accuracy_score(all_labels, all_predictions)
    precision_w, recall_w, f1_w, _ = precision_recall_fscore_support(all_labels, all_predictions, average='weighted', zero_division=0)
    _, _, f1_macro, _ = precision_recall_fscore_support(all_labels, all_predictions, average='macro', zero_division=0)
    
    cohen_k = float('nan')
    if len(np.unique(all_labels)) >= 2:
        cohen_k = cohen_kappa_score(all_labels, all_predictions, weights='quadratic')
    else:
        print("  Info: Cohen's Kappa N/A (fewer than 2 classes in this fold's validation labels).")

    roc_auc = float('nan')
    unique_labels_in_fold = np.unique(all_labels)
    if len(unique_labels_in_fold) == num_classes and num_classes > 1:
        try:
            if num_classes == 2:
                 roc_auc = roc_auc_score(all_labels, all_probabilities[:, 1])
            else:
                 roc_auc = roc_auc_score(all_labels, all_probabilities, multi_class='ovr', average='weighted', labels=list(range(num_classes)))
        except ValueError as e: 
            print(f"  Warning: ROC AUC calculation failed: {e}")
            if len(unique_labels_in_fold) < num_classes:
                 print(f"  Info: ROC AUC N/A (only {len(unique_labels_in_fold)}/{num_classes} classes present in this fold's validation set).")
    elif num_classes == 1:
        print(" Info: ROC AUC N/A (only 1 class defined).")
    else:
        print(f"  Info: ROC AUC N/A (only {len(unique_labels_in_fold)}/{num_classes} classes present in this fold's validation set).")
        
    elapsed_time = time.time() - start_time
    print(f"  Evaluation completed in {elapsed_time:.2f}s")
    return avg_loss, accuracy, precision_w, recall_w, f1_w, f1_macro, cohen_k, roc_auc, all_labels, all_predictions

def run_training_fold(current_esm_model_name, model_obj, optimizer, criterion, scheduler, train_loader, val_loader, epochs, device, num_classes, fold_num):
    best_val_f1_weighted = -1.0
    epochs_no_improve = 0
    best_model_state = None
    train_losses, val_losses, metrics_history = [], [], []
    
    print(f"Starting training for Fold {fold_num} (ESM Model: {current_esm_model_name}), max {epochs} epochs...")
    total_fold_start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        print(f"\n-- Epoch {epoch+1}/{epochs} -- Current LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        avg_train_loss = train_epoch(model_obj, train_loader, criterion, optimizer, device)
        train_losses.append(avg_train_loss)
        
        avg_val_loss, val_accuracy, _, _, val_f1_w, val_f1_macro, val_cohen_k, val_roc_auc, _, _ = evaluate_model(
            model_obj, val_loader, criterion, device, num_classes
        )
        val_losses.append(avg_val_loss)
        
        metrics_history.append({
            'epoch': epoch + 1, 'train_loss': avg_train_loss, 'val_loss': avg_val_loss,
            'val_accuracy': val_accuracy, 'val_f1_weighted': val_f1_w, 
            'val_f1_macro': val_f1_macro, 'val_cohen_kappa': val_cohen_k, 'val_roc_auc': val_roc_auc
        })
        
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1} Summary | Time: {epoch_time:.2f}s")
        print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f} | Val F1 (w): {val_f1_w:.4f} | Val F1 (m): {val_f1_macro:.4f} | Val Kappa: {val_cohen_k:.4f} | Val AUC: {val_roc_auc:.4f}")
        
        scheduler.step(val_f1_w)

        if val_f1_w > best_val_f1_weighted:
            best_val_f1_weighted = val_f1_w
            best_model_state = copy.deepcopy(model_obj.state_dict())
            epochs_no_improve = 0
            print(f"  -> New best validation F1 (weighted): {best_val_f1_weighted:.4f}. Saving model state.")
        else:
            epochs_no_improve += 1
            print(f"  Validation F1 (weighted) did not improve for {epochs_no_improve} epoch(s).")

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break
            
        if device == torch.device("mps"): torch.mps.empty_cache()
        gc.collect()
        
    total_fold_time = time.time() - total_fold_start_time
    print(f"\nTraining finished for Fold {fold_num}. Total time: {total_fold_time:.2f}s. Best F1 (w): {best_val_f1_weighted:.4f}")
    return best_model_state, train_losses, val_losses, metrics_history, best_val_f1_weighted

if df.empty or len(df) < NUM_FOLDS:
    print("DataFrame is empty or has insufficient data for K-Fold cross-validation. Exiting.")
else:
    all_vh_sequences = df['VH'].tolist()
    all_vl_sequences = df['VL'].tolist()
    all_labels_list = df['label'].tolist()
    overall_results_per_esm_model = []

    for esm_model_name_hf in ESM2_MODEL_NAMES:
        print(f"\n\n===== Processing ESM Model: {esm_model_name_hf} =====")
        
        current_head_type = 'medium'
        if "8M" in esm_model_name_hf: current_head_type = 'simple'
        elif "35M" in esm_model_name_hf or "150M" in esm_model_name_hf: current_head_type = 'medium'
        elif "650M" in esm_model_name_hf: current_head_type = 'deep'
        print(f"Using head type: {current_head_type} for {esm_model_name_hf}")

        try:
            tokenizer = AutoTokenizer.from_pretrained(esm_model_name_hf)
        except Exception as e:
            print(f"Could not load tokenizer for {esm_model_name_hf}, trying base. Error: {e}")
            try: tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
            except Exception as e_base:
                print(f"Could not load base ESM tokenizer. Error: {e_base}. Skipping this model.")
                overall_results_per_esm_model.append({'ESM Model': esm_model_name_hf, 'Error': f'Tokenizer load failed: {e_base}'})
                continue

        full_dataset = AntibodyPsrDataset(all_vh_sequences, all_vl_sequences, all_labels_list, tokenizer, MAX_LENGTH)
        if len(full_dataset) < NUM_FOLDS:
            print(f"Skipping {esm_model_name_hf}: Not enough samples ({len(full_dataset)}) for {NUM_FOLDS}-fold cross-validation.")
            overall_results_per_esm_model.append({'ESM Model': esm_model_name_hf, 'Error': 'Insufficient samples for CV'})
            continue

        kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
        fold_results_current_esm = []
        
        print(f"\n--- Starting {NUM_FOLDS}-Fold Cross-Validation for {esm_model_name_hf} ---")
        cv_start_time = time.time()

        for fold, (train_idx, val_idx) in enumerate(kf.split(all_vh_sequences)):
            fold_num = fold + 1
            print(f"\n==================== Fold {fold_num}/{NUM_FOLDS} ({esm_model_name_hf}) ====================")
            
            if len(val_idx) == 0 or len(train_idx) == 0:
                print(f"Skipping Fold {fold_num} due to empty train/validation set from KFold split.")
                continue

            train_sampler = SubsetRandomSampler(train_idx)
            val_sampler = SubsetRandomSampler(val_idx)
            
            train_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=0, pin_memory=True if device.type != 'mps' else False)
            val_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, sampler=val_sampler, num_workers=0, pin_memory=True if device.type != 'mps' else False)
            print(f"Fold {fold_num}: Train samples = {len(train_idx)}, Validation samples = {len(val_idx)}")

            current_batch_size_to_check = BATCH_SIZE
            if "650M" in esm_model_name_hf and current_batch_size_to_check > 1 and device == torch.device("mps"):
                print(f"Warning: BATCH_SIZE {current_batch_size_to_check} might be too large for {esm_model_name_hf} (650M) on MPS. Consider reducing to 1.")
            elif "150M" in esm_model_name_hf and current_batch_size_to_check > 2 and device == torch.device("mps"):
                print(f"Warning: BATCH_SIZE {current_batch_size_to_check} might be too large for {esm_model_name_hf} (150M) on MPS. Consider reducing to 1-2.")
            elif "35M" in esm_model_name_hf and current_batch_size_to_check > 4 and device == torch.device("mps"):
                print(f"Warning: BATCH_SIZE {current_batch_size_to_check} might be too large for {esm_model_name_hf} (35M) on MPS. Consider reducing.")

            model, optimizer, criterion, scheduler = initialize_model_and_optimizer(
                esm_model_name_hf, NUM_CLASSES, current_head_type, LEARNING_RATE, WEIGHT_DECAY, class_weights_tensor
            )
            best_model_state, train_losses, val_losses, metrics_history, fold_best_f1_w = run_training_fold(
                esm_model_name_hf, model, optimizer, criterion, scheduler, train_loader, val_loader, EPOCHS, device, NUM_CLASSES, fold_num
            )
            if best_model_state:
                model.load_state_dict(best_model_state)
                print("\nLoaded best model state from training for final evaluation on this fold.")
            else:
                print("\nWarning: No best model state saved (e.g., training too short or no improvement). Evaluating last state.")
            
            print(f"Performing final evaluation for Fold {fold_num} ({esm_model_name_hf}) using best F1 model...")
            final_val_loss, final_accuracy, final_prec_w, final_recall_w, final_f1_w, final_f1_macro, final_cohen_k, final_roc_auc, fold_labels, fold_preds = evaluate_model(
                model, val_loader, criterion, device, NUM_CLASSES
            )
            print(f"\nFold {fold_num} ({esm_model_name_hf}) Final Validation Metrics (Best F1 Model):")
            print(f"  Loss:           {final_val_loss:.4f}")
            print(f"  Accuracy:       {final_accuracy:.4f}")
            print(f"  Precision (w):  {final_prec_w:.4f}")
            print(f"  Recall (w):     {final_recall_w:.4f}")
            print(f"  F1 Score (w):   {final_f1_w:.4f}")
            print(f"  F1 Score (m):   {final_f1_macro:.4f}")
            print(f"  Cohen's Kappa:  {final_cohen_k:.4f}")
            print(f"  ROC AUC (w ovr):{final_roc_auc:.4f}")
            
            fold_results_current_esm.append({
                'fold': fold_num, 'accuracy': final_accuracy, 'precision_w': final_prec_w,
                'recall_w': final_recall_w, 'f1_w': final_f1_w, 'f1_macro': final_f1_macro,
                'cohen_kappa': final_cohen_k, 'roc_auc': final_roc_auc, 'best_val_metric_f1w': fold_best_f1_w,
            })
            
            if metrics_history:
                plt.figure(figsize=(12, 6))
                plt.subplot(1, 2, 1)
                plt.plot([m['epoch'] for m in metrics_history], [m['train_loss'] for m in metrics_history], label='Training Loss', marker='o')
                plt.plot([m['epoch'] for m in metrics_history], [m['val_loss'] for m in metrics_history], label='Validation Loss', marker='x')
                plt.xlabel('Epoch'); plt.ylabel('Loss (CrossEntropy)'); plt.title(f'Fold {fold_num} ({esm_model_name_hf}) - Loss')
                plt.legend(); plt.grid(True); plt.ylim(bottom=0)

                plt.subplot(1, 2, 2)
                plt.plot([m['epoch'] for m in metrics_history], [m['val_f1_weighted'] for m in metrics_history], label='Val F1 (w)', marker='s')
                plt.plot([m['epoch'] for m in metrics_history], [m['val_f1_macro'] for m in metrics_history], label='Val F1 (m)', marker='p')
                valid_roc_auc = [m['val_roc_auc'] for m in metrics_history if not np.isnan(m['val_roc_auc'])]
                if valid_roc_auc:
                    plt.plot([m['epoch'] for m in metrics_history if not np.isnan(m['val_roc_auc'])], valid_roc_auc, label='Val ROC AUC', marker='^')
                plt.plot([m['epoch'] for m in metrics_history], [m['val_accuracy'] for m in metrics_history], label='Val Accuracy', marker='.')
                plt.xlabel('Epoch'); plt.ylabel('Metric Value'); plt.title(f'Fold {fold_num} ({esm_model_name_hf}) - Metrics')
                plt.legend(); plt.grid(True); plt.ylim(0, 1.05)
                plt.tight_layout(); plt.show()
            else:
                print(f"No metrics history recorded for Fold {fold_num} to plot.")


            if len(fold_labels) > 0 and len(fold_preds) > 0:
                try:
                    cm = confusion_matrix(fold_labels, fold_preds, labels=list(range(NUM_CLASSES)))
                    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[f'Class {i}' for i in range(NUM_CLASSES)])
                    disp.plot(cmap=plt.cm.Blues)
                    plt.title(f'Fold {fold_num} ({esm_model_name_hf}) - Confusion Matrix')
                    plt.show()
                except Exception as e:
                    print(f"Could not display confusion matrix for Fold {fold_num}: {e}")
            else:
                print(f"Not enough data to generate confusion matrix for Fold {fold_num}.")


            print(f"Cleaning up Fold {fold_num} resources...")
            del model, optimizer, criterion, scheduler, train_loader, val_loader, train_sampler, val_sampler, best_model_state
            if device == torch.device("mps"): torch.mps.empty_cache()
            gc.collect()

        cv_end_time = time.time()
        print(f"\n--- Cross-Validation Finished for {esm_model_name_hf} --- Total Time: {cv_end_time - cv_start_time:.2f}s ---")
        
        results_df_current_esm = pd.DataFrame(fold_results_current_esm)
        avg_metrics_dict = {'ESM Model': esm_model_name_hf, 'Error': None}
        if not results_df_current_esm.empty:
            print(f"\n--- Cross-Validation Summary for {esm_model_name_hf} ---")
            metric_cols = ['fold', 'accuracy', 'f1_w', 'f1_macro', 'cohen_kappa', 'roc_auc', 'best_val_metric_f1w']
            print(results_df_current_esm[metric_cols].round(4).to_string(index=False))
            
            for metric in ['accuracy', 'f1_w', 'f1_macro', 'cohen_kappa', 'roc_auc']:
                avg_metrics_dict[f'Avg {metric}'] = results_df_current_esm[metric].mean()
                avg_metrics_dict[f'Std {metric}'] = results_df_current_esm[metric].std()
        else:
            avg_metrics_dict['Error'] = 'No fold results to aggregate (e.g. all folds skipped)'

        overall_results_per_esm_model.append(avg_metrics_dict)
        print("\nAverage Metrics Across Folds:")
        for key, val in avg_metrics_dict.items():
            if key not in ['ESM Model', 'Error'] and "Std" not in key:
                std_key = key.replace("Avg", "Std")
                if isinstance(val, (int, float)) and not np.isnan(val):
                     print(f"  {key:<20}: {val:.4f} +/- {avg_metrics_dict.get(std_key, 0.0):.4f}")
                else:
                     print(f"  {key:<20}: N/A")

            elif key == 'Error' and val is not None:
                print(f"  Error: {val}")

    print("\n\n===== Overall Summary Across All ESM Models =====")
    summary_df = pd.DataFrame(overall_results_per_esm_model)
    cols_to_display = ['ESM Model', 'Error']
    if any(r.get('Error') is None for r in overall_results_per_esm_model):
        first_valid_result = next((r for r in overall_results_per_esm_model if r.get('Error') is None), None)
        if first_valid_result:
            metric_avg_std_cols = [k for k in first_valid_result.keys() if k not in ['ESM Model', 'Error']]
            cols_to_display.extend(metric_avg_std_cols)

    summary_df_display = summary_df[[col for col in cols_to_display if col in summary_df.columns]]
    numeric_cols = summary_df_display.select_dtypes(include=np.number).columns
    summary_df_display[numeric_cols] = summary_df_display[numeric_cols].round(4)
    print(summary_df_display.to_string(index=False))