In [29]:


import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import pickle
from tqdm import tqdm
import os
from sklearn.model_selection import KFold
from sklearn.metrics import precision_recall_fscore_support
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import roc_curve, auc
import os

def ensure_output_dir(output_dir):
    """Create output directory if it doesn't exist"""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

def plot_roc_curves(y_true, y_pred, labels, output_dir):
    """Plot ROC curves for each class"""
    plt.figure(figsize=(10, 8))
    
    for i, label in enumerate(labels):
        fpr, tpr, _ = roc_curve(y_true[:, i], y_pred[:, i])
        roc_auc = auc(fpr, tpr)
        
        plt.plot(
            fpr, 
            tpr, 
            label=f'{label} (AUC = {roc_auc:.2f})'
        )
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curves')
    plt.legend(loc="lower right")
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'roc_curves.png'))
    plt.close()
    
class MetricTracker:
    def __init__(self):
        self.metrics = defaultdict(list)
    
    def update(self, metrics_dict):
        for key, value in metrics_dict.items():
            self.metrics[key].append(value)
    
    def get_metric(self, metric_name):
        return self.metrics[metric_name]

def plot_training_history(tracker, fold, output_dir):
    """Plot training and validation metrics"""
    plt.figure(figsize=(15, 10))
    
    # Plot losses
    plt.subplot(2, 2, 1)
    plt.plot(tracker.get_metric('train_loss'), label='Train Loss')
    plt.plot(tracker.get_metric('val_loss'), label='Validation Loss')
    plt.title(f'Loss History - Fold {fold}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot accuracies
    plt.subplot(2, 2, 2)
    plt.plot(tracker.get_metric('exact_match_accuracy'), label='Exact Match')
    plt.plot(tracker.get_metric('hamming_accuracy'), label='Hamming')
    plt.title(f'Accuracy History - Fold {fold}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Plot F1, Precision, Recall
    plt.subplot(2, 2, 3)
    plt.plot(tracker.get_metric('f1'), label='F1')
    plt.plot(tracker.get_metric('precision'), label='Precision')
    plt.plot(tracker.get_metric('recall'), label='Recall')
    plt.title(f'Metrics History - Fold {fold}')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'training_history_fold_{fold}.png'))
    plt.close()
    
def plot_confusion_matrices(y_true, y_pred, labels, output_dir):
    """Plot confusion matrix for each class"""
    y_pred_binary = (y_pred > 0.5).astype(int)
    
    n_classes = len(labels)
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.ravel()
    
    for idx, label in enumerate(labels):
        cm = confusion_matrix(y_true[:, idx], y_pred_binary[:, idx])
        sns.heatmap(cm, annot=True, fmt='d', ax=axes[idx])
        axes[idx].set_title(f'Confusion Matrix - {label}')
        axes[idx].set_xlabel('Predicted')
        axes[idx].set_ylabel('True')
    
    if len(labels) < len(axes):
        for idx in range(len(labels), len(axes)):
            fig.delaxes(axes[idx])
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'confusion_matrices.png'))
    plt.close()

def plot_label_distribution(train_labels, test_labels, labels, output_dir):
    """Plot label distribution in train and test sets"""
    train_dist = train_labels.sum(axis=0)
    test_dist = test_labels.sum(axis=0)
    
    plt.figure(figsize=(12, 6))
    x = np.arange(len(labels))
    width = 0.35
    
    plt.bar(x - width/2, train_dist, width, label='Train')
    plt.bar(x + width/2, test_dist, width, label='Test')
    
    plt.xlabel('Labels')
    plt.ylabel('Count')
    plt.title('Label Distribution in Train and Test Sets')
    plt.xticks(x, labels, rotation=45)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'label_distribution.png'))
    plt.close()

def create_performance_tables(y_true, y_pred, labels, output_dir):
    """Create and save detailed performance tables"""
    y_pred_binary = (y_pred > 0.5).astype(int)
    
    metrics_dict = {
        'Precision': [],
        'Recall': [],
        'F1-Score': [],
        'Support': []
    }
    
    for i in range(len(labels)):
        precision, recall, f1, support = precision_recall_fscore_support(
            y_true[:, i], y_pred_binary[:, i], average='binary'
        )
        metrics_dict['Precision'].append(precision)
        metrics_dict['Recall'].append(recall)
        metrics_dict['F1-Score'].append(f1)
        metrics_dict['Support'].append(support)
    
    df_metrics = pd.DataFrame(metrics_dict, index=labels)
    df_metrics.to_csv(os.path.join(output_dir, 'class_performance_metrics.csv'))
    
    corr_matrix = np.corrcoef(y_pred_binary.T)
    df_corr = pd.DataFrame(corr_matrix, index=labels, columns=labels)
    df_corr.to_csv(os.path.join(output_dir, 'prediction_correlations.csv'))
    
    return df_metrics, df_corr

def plot_metric_comparison(fold_metrics):
    """Plot comparison of metrics across folds"""
    metrics = ['exact_match_accuracy', 'hamming_accuracy', 'f1', 'precision', 'recall']
    
    plt.figure(figsize=(12, 6))
    data = []
    for metric in metrics:
        data.append([metrics_dict[metric] for metrics_dict in fold_metrics])
    
    plt.boxplot(data, labels=metrics)
    plt.title('Metric Distribution Across Folds')
    plt.ylabel('Score')
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig('fold_metrics_comparison.png')
    plt.close()

class DocumentProcessor:
    def __init__(self, model_name):
        self.model = SentenceTransformer(model_name)
        self.labels = ['Treatment', 'Prevention', 'Diagnosis', 'Mechanism', 
                      'Transmission', 'Epidemic Forecasting', 'Case Report']
        
    def clean_text(self, text):
        if pd.isna(text):
            return ""
        text = str(text).lower()
        text = ' '.join(text.split())
        return text
    
    def process_labels(self, label_text):
        label_list = label_text.split(';')
        label_array = np.zeros(len(self.labels))
        for label in label_list:
            if label in self.labels:
                label_array[self.labels.index(label)] = 1
        return label_array
    
    def generate_embeddings(self, texts, batch_size=32, cache_file=None):
        if cache_file and os.path.exists(cache_file):
            print(f"Loading cached embeddings from {cache_file}")
            with open(cache_file, 'rb') as f:
                return pickle.load(f)
        
        print("Generating new embeddings...")
        embeddings = []
        for i in tqdm(range(0, len(texts), batch_size)):
            batch_texts = texts[i:i+batch_size]
            batch_embeddings = self.model.encode(batch_texts)
            embeddings.extend(batch_embeddings)
        
        embeddings = np.array(embeddings)
        
        if cache_file:
            print(f"Caching embeddings to {cache_file}")
            with open(cache_file, 'wb') as f:
                pickle.dump(embeddings, f)
        
        return embeddings

class COVIDDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = torch.FloatTensor(embeddings)
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

class TopicClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim1=512, hidden_dim2=256, num_classes=7):
        super(TopicClassifier, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim1)
        self.layer2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.layer3 = nn.Linear(hidden_dim2, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.dropout(x)
        x = self.relu(self.layer2(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.layer3(x))
        return x

# Modify train_fold function to use MetricTracker
def train_fold(model, train_loader, val_loader, criterion, optimizer, device, fold, output_dir):
    model = model.to(device)
    best_val_loss = float('inf')
    patience = 3
    patience_counter = 0
    
    tracker = MetricTracker()
    
    for epoch in range(20):
        # Training phase
        model.train()
        train_loss = 0
        for batch_embeddings, batch_labels in tqdm(train_loader, desc=f'Fold {fold}, Epoch {epoch+1} - Training'):
            batch_embeddings = batch_embeddings.to(device)
            batch_labels = batch_labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_embeddings)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation phase
        val_loss, val_metrics = evaluate_fold(model, val_loader, criterion, device, fold)
        
        avg_train_loss = train_loss/len(train_loader)
        
        # Track metrics
        tracker.update({
            'train_loss': avg_train_loss,
            'val_loss': val_loss,
            **val_metrics
        })
        
        print(f'Fold {fold}, Epoch {epoch+1}')
        print(f'Training Loss: {avg_train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print('Validation Metrics:')
        print(f'  Exact Match Accuracy: {val_metrics["exact_match_accuracy"]:.4f}')
        print(f'  Hamming Accuracy: {val_metrics["hamming_accuracy"]:.4f}')
        print(f'  F1: {val_metrics["f1"]:.4f}')
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(output_dir, f'best_model_fold_{fold}.pt'))

            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break
    
    # Plot training history for this fold
    plot_training_history(tracker, fold, output_dir)
    return best_val_loss, val_metrics


def evaluate_model(model, test_loader, criterion, device, labels):
    model.eval()
    test_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch_embeddings, batch_labels in tqdm(test_loader, desc='Testing'):
            batch_embeddings = batch_embeddings.to(device)
            batch_labels = batch_labels.to(device)
            
            outputs = model(batch_embeddings)
            loss = criterion(outputs, batch_labels)
            test_loss += loss.item()
            
            predictions = outputs.cpu().numpy()
            all_predictions.extend(predictions)
            all_labels.extend(batch_labels.cpu().numpy())
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    # Calculate and print all metrics
    overall_metrics, per_category_metrics = calculate_overall_metrics(
        all_labels, all_predictions, labels
    )
    
    return test_loss/len(test_loader), all_predictions, all_labels, overall_metrics, per_category_metrics


def plot_metrics_heatmap(metrics_dict, labels, output_dir):
    """Create a heatmap of metrics for each category"""
    metrics_df = pd.DataFrame({
        'Precision': metrics_dict['Precision'],
        'Recall': metrics_dict['Recall'],
        'F1-Score': metrics_dict['F1-Score']
    }, index=labels)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(metrics_df, annot=True, cmap='YlOrRd', fmt='.3f')
    plt.title('Performance Metrics by Category')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'metrics_heatmap.png'))
    plt.close()
    
def calculate_metrics(y_true, y_pred):
    """Calculate various metrics for multi-label classification"""
    # Convert predictions to binary
    y_pred_binary = (y_pred > 0.5).astype(int)
    
    # Exact match accuracy (all labels must match)
    exact_match_accuracy = np.mean(np.all(y_pred_binary == y_true, axis=1))
    
    # Per-class accuracy
    per_class_accuracy = np.mean(y_pred_binary == y_true, axis=0)
    
    # Hamming accuracy (proportion of correct predictions)
    hamming_accuracy = np.mean(y_pred_binary == y_true)
    
    # Calculate precision, recall, f1
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred_binary, average='samples'
    )
    
    return {
        'exact_match_accuracy': exact_match_accuracy,
        'hamming_accuracy': hamming_accuracy,
        'per_class_accuracy': per_class_accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def calculate_overall_metrics(y_true, y_pred, labels):
    """
    Calculate both overall and per-category metrics
    """
    # Convert predictions to binary
    y_pred_binary = (y_pred > 0.5).astype(int)
    
    # Per-category metrics
    per_category_metrics = {
        'Precision': [],
        'Recall': [],
        'F1-Score': [],
        'Support': []
    }
    
    print("\nPer-category Metrics:")
    print("--------------------")
    for i, label in enumerate(labels):
        precision, recall, f1, support = precision_recall_fscore_support(
            y_true[:, i], y_pred_binary[:, i], average='binary'
        )
        per_category_metrics['Precision'].append(precision)
        per_category_metrics['Recall'].append(recall)
        per_category_metrics['F1-Score'].append(f1)
        per_category_metrics['Support'].append(support)
        
        print(f"\n{label}:")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-Score: {f1:.4f}")
        print(f"Support: {support}")
    
    # Overall metrics (micro average)
    micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
        y_true, y_pred_binary, average='micro'
    )
    
    # Overall metrics (macro average)
    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
        y_true, y_pred_binary, average='macro'
    )
    
    # Overall metrics (weighted average)
    weighted_precision, weighted_recall, weighted_f1, _ = precision_recall_fscore_support(
        y_true, y_pred_binary, average='weighted'
    )
    
    # Exact match ratio (perfect predictions across all categories)
    exact_match = np.mean(np.all(y_pred_binary == y_true, axis=1))
    
    # Hamming accuracy (percentage of correct labels)
    hamming_accuracy = np.mean(y_pred_binary == y_true)
    
    # Create summary dictionary
    overall_metrics = {
        'Micro-average': {
            'Precision': micro_precision,
            'Recall': micro_recall,
            'F1-Score': micro_f1
        },
        'Macro-average': {
            'Precision': macro_precision,
            'Recall': macro_recall,
            'F1-Score': macro_f1
        },
        'Weighted-average': {
            'Precision': weighted_precision,
            'Recall': weighted_recall,
            'F1-Score': weighted_f1
        },
        'Exact Match Ratio': exact_match,
        'Hamming Accuracy': hamming_accuracy
    }
    
    # Create and display summary DataFrame
    df_overall = pd.DataFrame({
        'Metric': ['Precision', 'Recall', 'F1-Score'],
        'Micro-avg': [micro_precision, micro_recall, micro_f1],
        'Macro-avg': [macro_precision, macro_recall, macro_f1],
        'Weighted-avg': [weighted_precision, weighted_recall, weighted_f1]
    }).set_index('Metric')
    
    print("\nOverall Metrics:")
    print("--------------")
    print(f"\nExact Match Ratio: {exact_match:.4f}")
    print(f"Hamming Accuracy: {hamming_accuracy:.4f}")
    print("\nAveraged Metrics:")
    print(df_overall)
    
    # Save metrics to CSV
    df_overall.to_csv('overall_metrics.csv')
    df_categories = pd.DataFrame(per_category_metrics, index=labels)
    df_categories.to_csv('per_category_metrics.csv')
    
    return overall_metrics, per_category_metrics

def evaluate_fold(model, val_loader, criterion, device, fold):
    """Evaluate model on validation set during training"""
    model.eval()
    val_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch_embeddings, batch_labels in val_loader:
            batch_embeddings = batch_embeddings.to(device)
            batch_labels = batch_labels.to(device)
            
            outputs = model(batch_embeddings)
            loss = criterion(outputs, batch_labels)
            val_loss += loss.item()
            
            predictions = outputs.cpu().numpy()
            all_predictions.extend(predictions)
            all_labels.extend(batch_labels.cpu().numpy())
    
    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    # Calculate metrics with binary predictions
    metrics = calculate_metrics(all_labels, all_predictions)
    
    return val_loss/len(val_loader), metrics

def run(model_name='all-mpnet-base-v2'):
    output_dir = 'output/sbert'
    ensure_output_dir(output_dir)
    fold_metrics = []
    
    # Load datasets
    print("Loading datasets...")
    train_df = pd.read_csv('./dataset/BC7-LitCovid-Train.csv')
    test_df = pd.read_csv('./dataset/BC7-LitCovid-Dev.csv')  # Using dev set as test set
    
    # Initialize processor
    processor = DocumentProcessor(model_name)
    
    # Process all data
    print("Processing training data...")
    train_abstracts = train_df['abstract'].apply(processor.clean_text).values
    train_labels = np.array([processor.process_labels(label) for label in train_df['label']])
    train_embeddings = processor.generate_embeddings(train_abstracts, cache_file='train_embeddings_cache.pkl')
    
    # Get the actual embedding dimension
    embedding_dim = train_embeddings.shape[1]
    print(f"Training embeddings shape: {train_embeddings.shape}")
    print(f"Detected embedding dimension: {embedding_dim}")
    
    print("Processing test data...")
    test_abstracts = test_df['abstract'].apply(processor.clean_text).values
    test_labels = np.array([processor.process_labels(label) for label in test_df['label']])
    test_embeddings = processor.generate_embeddings(test_abstracts, cache_file='test_embeddings_cache.pkl')
    print(f"Test embeddings shape: {test_embeddings.shape}")
    
    # Verify dimensions match
    if train_embeddings.shape[1] != test_embeddings.shape[1]:
        raise ValueError(f"Embedding dimensions don't match! Train: {train_embeddings.shape[1]}, Test: {test_embeddings.shape[1]}")
    
    # Create full datasets
    train_dataset = COVIDDataset(train_embeddings, train_labels)
    test_dataset = COVIDDataset(test_embeddings, test_labels)
    
    # Let's verify the dimensions of our dataset outputs
    sample_batch = next(iter(DataLoader(train_dataset, batch_size=1)))
    print(f"Sample batch embedding shape: {sample_batch[0].shape}")
    
    test_loader = DataLoader(test_dataset, batch_size=32)
    
    # Setup for k-fold cross validation
    k_folds = 5
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    
    # For storing fold results
    fold_results = []  # Store validation losses
    fold_metrics = []  # Store other metrics
    best_val_loss = float('inf')
    best_fold = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # K-fold Cross Validation
    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'\nFOLD {fold+1}/{k_folds}')
        
        train_subsampler = SubsetRandomSampler(train_ids)
        val_subsampler = SubsetRandomSampler(val_ids)
        
        train_loader = DataLoader(train_dataset, batch_size=32, sampler=train_subsampler)
        val_loader = DataLoader(train_dataset, batch_size=32, sampler=val_subsampler)
        
        # Initialize model with correct input dimension
        model = TopicClassifier(input_dim=embedding_dim)
        if fold == 0:
            model_config = {
                'input_dim': embedding_dim,
                'state_dict_keys': list(model.state_dict().keys())
            }
            print(f"Model configuration: {model_config}")
        
        criterion = nn.BCELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        
        val_loss, final_metrics = train_fold(model, train_loader, val_loader, criterion, optimizer, device, fold+1, output_dir)
        fold_results.append(val_loss)
        fold_metrics.append(final_metrics)
        
        # Track best fold
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_fold = fold

    print("\nEvaluating on test set...")
    model = TopicClassifier(input_dim=embedding_dim)
    # Use the tracked best_fold instead of computing it
    model.load_state_dict(torch.load(f'best_model_fold_{best_fold+1}.pt'))
    model = model.to(device)
    
    
    criterion = nn.BCELoss()
    test_loss, test_predictions, test_labels, overall_metrics, per_category_metrics = evaluate_model(
        model, test_loader, criterion, device, processor.labels
    )
    
    # Create visualizations
    plot_confusion_matrices(test_labels, test_predictions, processor.labels, output_dir)
    plot_roc_curves(test_labels, test_predictions, processor.labels, output_dir)
    plot_metrics_heatmap(per_category_metrics, processor.labels, output_dir)
    
    # Save all metrics to a summary file
    with open(os.path.join(output_dir, 'metrics_summary.txt'), 'w') as f:
        f.write("Overall Metrics:\n")
        f.write("---------------\n")
        for metric_type, metrics in overall_metrics.items():
            f.write(f"\n{metric_type}:\n")
            if isinstance(metrics, dict):
                for name, value in metrics.items():
                    f.write(f"{name}: {value:.4f}\n")
            else:
                f.write(f"{metrics:.4f}\n")



In [30]:
run('all-mpnet-base-v2')

Loading datasets...




Processing training data...
Generating new embeddings...


100%|██████████| 780/780 [01:28<00:00,  8.82it/s]


Caching embeddings to train_embeddings_cache.pkl
Training embeddings shape: (24960, 768)
Detected embedding dimension: 768
Processing test data...
Generating new embeddings...


100%|██████████| 195/195 [00:20<00:00,  9.74it/s]


Caching embeddings to test_embeddings_cache.pkl
Test embeddings shape: (6239, 768)
Sample batch embedding shape: torch.Size([1, 768])
Using device: cuda

FOLD 1/5
Model configuration: {'input_dim': 768, 'state_dict_keys': ['layer1.weight', 'layer1.bias', 'layer2.weight', 'layer2.bias', 'layer3.weight', 'layer3.bias']}


Fold 1, Epoch 1 - Training: 100%|██████████| 624/624 [00:05<00:00, 121.46it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 1
Training Loss: 0.2020
Validation Loss: 0.1491
Validation Metrics:
  Exact Match Accuracy: 0.7196
  Hamming Accuracy: 0.9424
  F1: 0.8587


Fold 1, Epoch 2 - Training: 100%|██████████| 624/624 [00:00<00:00, 785.58it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 2
Training Loss: 0.1448
Validation Loss: 0.1389
Validation Metrics:
  Exact Match Accuracy: 0.7388
  Hamming Accuracy: 0.9469
  F1: 0.8677


Fold 1, Epoch 3 - Training: 100%|██████████| 624/624 [00:00<00:00, 678.59it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 3
Training Loss: 0.1343
Validation Loss: 0.1298
Validation Metrics:
  Exact Match Accuracy: 0.7512
  Hamming Accuracy: 0.9500
  F1: 0.8763


Fold 1, Epoch 4 - Training: 100%|██████████| 624/624 [00:00<00:00, 735.34it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 4
Training Loss: 0.1274
Validation Loss: 0.1282
Validation Metrics:
  Exact Match Accuracy: 0.7584
  Hamming Accuracy: 0.9506
  F1: 0.8770


Fold 1, Epoch 5 - Training: 100%|██████████| 624/624 [00:00<00:00, 865.09it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 5
Training Loss: 0.1223
Validation Loss: 0.1292
Validation Metrics:
  Exact Match Accuracy: 0.7528
  Hamming Accuracy: 0.9508
  F1: 0.8803


Fold 1, Epoch 6 - Training: 100%|██████████| 624/624 [00:00<00:00, 1025.51it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 6
Training Loss: 0.1187
Validation Loss: 0.1262
Validation Metrics:
  Exact Match Accuracy: 0.7576
  Hamming Accuracy: 0.9515
  F1: 0.8820


Fold 1, Epoch 7 - Training: 100%|██████████| 624/624 [00:00<00:00, 869.67it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 7
Training Loss: 0.1133
Validation Loss: 0.1267
Validation Metrics:
  Exact Match Accuracy: 0.7560
  Hamming Accuracy: 0.9514
  F1: 0.8825


Fold 1, Epoch 8 - Training: 100%|██████████| 624/624 [00:00<00:00, 741.40it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 8
Training Loss: 0.1091
Validation Loss: 0.1289
Validation Metrics:
  Exact Match Accuracy: 0.7486
  Hamming Accuracy: 0.9493
  F1: 0.8804


Fold 1, Epoch 9 - Training: 100%|██████████| 624/624 [00:00<00:00, 960.18it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 1, Epoch 9
Training Loss: 0.1052
Validation Loss: 0.1310
Validation Metrics:
  Exact Match Accuracy: 0.7552
  Hamming Accuracy: 0.9508
  F1: 0.8811
Early stopping triggered at epoch 9

FOLD 2/5


Fold 2, Epoch 1 - Training: 100%|██████████| 624/624 [00:00<00:00, 853.44it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 1
Training Loss: 0.2024
Validation Loss: 0.1471
Validation Metrics:
  Exact Match Accuracy: 0.7188
  Hamming Accuracy: 0.9416
  F1: 0.8502


Fold 2, Epoch 2 - Training: 100%|██████████| 624/624 [00:00<00:00, 937.10it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 2
Training Loss: 0.1448
Validation Loss: 0.1366
Validation Metrics:
  Exact Match Accuracy: 0.7334
  Hamming Accuracy: 0.9449
  F1: 0.8643


Fold 2, Epoch 3 - Training: 100%|██████████| 624/624 [00:00<00:00, 820.74it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 3
Training Loss: 0.1342
Validation Loss: 0.1338
Validation Metrics:
  Exact Match Accuracy: 0.7372
  Hamming Accuracy: 0.9460
  F1: 0.8670


Fold 2, Epoch 4 - Training: 100%|██████████| 624/624 [00:00<00:00, 1015.67it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 4
Training Loss: 0.1270
Validation Loss: 0.1321
Validation Metrics:
  Exact Match Accuracy: 0.7374
  Hamming Accuracy: 0.9464
  F1: 0.8668


Fold 2, Epoch 5 - Training: 100%|██████████| 624/624 [00:00<00:00, 862.80it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 5
Training Loss: 0.1214
Validation Loss: 0.1305
Validation Metrics:
  Exact Match Accuracy: 0.7492
  Hamming Accuracy: 0.9479
  F1: 0.8706


Fold 2, Epoch 6 - Training: 100%|██████████| 624/624 [00:00<00:00, 822.67it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 6
Training Loss: 0.1170
Validation Loss: 0.1299
Validation Metrics:
  Exact Match Accuracy: 0.7436
  Hamming Accuracy: 0.9475
  F1: 0.8716


Fold 2, Epoch 7 - Training: 100%|██████████| 624/624 [00:00<00:00, 661.64it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 7
Training Loss: 0.1125
Validation Loss: 0.1303
Validation Metrics:
  Exact Match Accuracy: 0.7486
  Hamming Accuracy: 0.9485
  F1: 0.8761


Fold 2, Epoch 8 - Training: 100%|██████████| 624/624 [00:00<00:00, 992.31it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 8
Training Loss: 0.1088
Validation Loss: 0.1322
Validation Metrics:
  Exact Match Accuracy: 0.7452
  Hamming Accuracy: 0.9475
  F1: 0.8734


Fold 2, Epoch 9 - Training: 100%|██████████| 624/624 [00:00<00:00, 1010.11it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 2, Epoch 9
Training Loss: 0.1055
Validation Loss: 0.1312
Validation Metrics:
  Exact Match Accuracy: 0.7466
  Hamming Accuracy: 0.9488
  F1: 0.8773
Early stopping triggered at epoch 9

FOLD 3/5


Fold 3, Epoch 1 - Training: 100%|██████████| 624/624 [00:00<00:00, 854.52it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 1
Training Loss: 0.1992
Validation Loss: 0.1476
Validation Metrics:
  Exact Match Accuracy: 0.7159
  Hamming Accuracy: 0.9410
  F1: 0.8533


Fold 3, Epoch 2 - Training: 100%|██████████| 624/624 [00:00<00:00, 904.38it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 2
Training Loss: 0.1435
Validation Loss: 0.1401
Validation Metrics:
  Exact Match Accuracy: 0.7266
  Hamming Accuracy: 0.9433
  F1: 0.8605


Fold 3, Epoch 3 - Training: 100%|██████████| 624/624 [00:00<00:00, 873.28it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 3
Training Loss: 0.1339
Validation Loss: 0.1353
Validation Metrics:
  Exact Match Accuracy: 0.7428
  Hamming Accuracy: 0.9464
  F1: 0.8668


Fold 3, Epoch 4 - Training: 100%|██████████| 624/624 [00:00<00:00, 906.86it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 4
Training Loss: 0.1269
Validation Loss: 0.1355
Validation Metrics:
  Exact Match Accuracy: 0.7404
  Hamming Accuracy: 0.9459
  F1: 0.8656


Fold 3, Epoch 5 - Training: 100%|██████████| 624/624 [00:00<00:00, 822.04it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 5
Training Loss: 0.1220
Validation Loss: 0.1326
Validation Metrics:
  Exact Match Accuracy: 0.7400
  Hamming Accuracy: 0.9471
  F1: 0.8710


Fold 3, Epoch 6 - Training: 100%|██████████| 624/624 [00:00<00:00, 1015.36it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 6
Training Loss: 0.1164
Validation Loss: 0.1319
Validation Metrics:
  Exact Match Accuracy: 0.7430
  Hamming Accuracy: 0.9481
  F1: 0.8732


Fold 3, Epoch 7 - Training: 100%|██████████| 624/624 [00:00<00:00, 805.02it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 7
Training Loss: 0.1133
Validation Loss: 0.1304
Validation Metrics:
  Exact Match Accuracy: 0.7466
  Hamming Accuracy: 0.9486
  F1: 0.8761


Fold 3, Epoch 8 - Training: 100%|██████████| 624/624 [00:00<00:00, 1005.65it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 8
Training Loss: 0.1093
Validation Loss: 0.1305
Validation Metrics:
  Exact Match Accuracy: 0.7460
  Hamming Accuracy: 0.9487
  F1: 0.8752


Fold 3, Epoch 9 - Training: 100%|██████████| 624/624 [00:00<00:00, 936.13it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 9
Training Loss: 0.1058
Validation Loss: 0.1326
Validation Metrics:
  Exact Match Accuracy: 0.7416
  Hamming Accuracy: 0.9479
  F1: 0.8766


Fold 3, Epoch 10 - Training: 100%|██████████| 624/624 [00:00<00:00, 777.43it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 3, Epoch 10
Training Loss: 0.1020
Validation Loss: 0.1392
Validation Metrics:
  Exact Match Accuracy: 0.7396
  Hamming Accuracy: 0.9468
  F1: 0.8684
Early stopping triggered at epoch 10

FOLD 4/5


Fold 4, Epoch 1 - Training: 100%|██████████| 624/624 [00:00<00:00, 1033.00it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 1
Training Loss: 0.2030
Validation Loss: 0.1471
Validation Metrics:
  Exact Match Accuracy: 0.7115
  Hamming Accuracy: 0.9406
  F1: 0.8507


Fold 4, Epoch 2 - Training: 100%|██████████| 624/624 [00:00<00:00, 862.80it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 2
Training Loss: 0.1448
Validation Loss: 0.1372
Validation Metrics:
  Exact Match Accuracy: 0.7336
  Hamming Accuracy: 0.9452
  F1: 0.8643


Fold 4, Epoch 3 - Training: 100%|██████████| 624/624 [00:00<00:00, 668.28it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 3
Training Loss: 0.1334
Validation Loss: 0.1324
Validation Metrics:
  Exact Match Accuracy: 0.7438
  Hamming Accuracy: 0.9475
  F1: 0.8688


Fold 4, Epoch 4 - Training: 100%|██████████| 624/624 [00:00<00:00, 889.83it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 4
Training Loss: 0.1267
Validation Loss: 0.1313
Validation Metrics:
  Exact Match Accuracy: 0.7508
  Hamming Accuracy: 0.9486
  F1: 0.8744


Fold 4, Epoch 5 - Training: 100%|██████████| 624/624 [00:00<00:00, 835.42it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 5
Training Loss: 0.1220
Validation Loss: 0.1354
Validation Metrics:
  Exact Match Accuracy: 0.7468
  Hamming Accuracy: 0.9471
  F1: 0.8710


Fold 4, Epoch 6 - Training: 100%|██████████| 624/624 [00:00<00:00, 953.56it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 6
Training Loss: 0.1169
Validation Loss: 0.1285
Validation Metrics:
  Exact Match Accuracy: 0.7528
  Hamming Accuracy: 0.9500
  F1: 0.8773


Fold 4, Epoch 7 - Training: 100%|██████████| 624/624 [00:00<00:00, 800.87it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 7
Training Loss: 0.1129
Validation Loss: 0.1291
Validation Metrics:
  Exact Match Accuracy: 0.7510
  Hamming Accuracy: 0.9496
  F1: 0.8767


Fold 4, Epoch 8 - Training: 100%|██████████| 624/624 [00:00<00:00, 968.78it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 8
Training Loss: 0.1087
Validation Loss: 0.1296
Validation Metrics:
  Exact Match Accuracy: 0.7464
  Hamming Accuracy: 0.9494
  F1: 0.8789


Fold 4, Epoch 9 - Training: 100%|██████████| 624/624 [00:00<00:00, 1024.16it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 4, Epoch 9
Training Loss: 0.1054
Validation Loss: 0.1316
Validation Metrics:
  Exact Match Accuracy: 0.7524
  Hamming Accuracy: 0.9499
  F1: 0.8783
Early stopping triggered at epoch 9

FOLD 5/5


Fold 5, Epoch 1 - Training: 100%|██████████| 624/624 [00:00<00:00, 881.66it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 1
Training Loss: 0.2006
Validation Loss: 0.1481
Validation Metrics:
  Exact Match Accuracy: 0.7179
  Hamming Accuracy: 0.9404
  F1: 0.8504


Fold 5, Epoch 2 - Training: 100%|██████████| 624/624 [00:00<00:00, 920.19it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 2
Training Loss: 0.1436
Validation Loss: 0.1414
Validation Metrics:
  Exact Match Accuracy: 0.7310
  Hamming Accuracy: 0.9440
  F1: 0.8614


Fold 5, Epoch 3 - Training: 100%|██████████| 624/624 [00:00<00:00, 837.02it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 3
Training Loss: 0.1331
Validation Loss: 0.1439
Validation Metrics:
  Exact Match Accuracy: 0.7272
  Hamming Accuracy: 0.9425
  F1: 0.8583


Fold 5, Epoch 4 - Training: 100%|██████████| 624/624 [00:00<00:00, 977.07it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 4
Training Loss: 0.1264
Validation Loss: 0.1343
Validation Metrics:
  Exact Match Accuracy: 0.7334
  Hamming Accuracy: 0.9454
  F1: 0.8635


Fold 5, Epoch 5 - Training: 100%|██████████| 624/624 [00:00<00:00, 753.33it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 5
Training Loss: 0.1212
Validation Loss: 0.1338
Validation Metrics:
  Exact Match Accuracy: 0.7430
  Hamming Accuracy: 0.9472
  F1: 0.8700


Fold 5, Epoch 6 - Training: 100%|██████████| 624/624 [00:00<00:00, 865.91it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 6
Training Loss: 0.1170
Validation Loss: 0.1322
Validation Metrics:
  Exact Match Accuracy: 0.7448
  Hamming Accuracy: 0.9480
  F1: 0.8746


Fold 5, Epoch 7 - Training: 100%|██████████| 624/624 [00:00<00:00, 941.45it/s] 
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 7
Training Loss: 0.1122
Validation Loss: 0.1316
Validation Metrics:
  Exact Match Accuracy: 0.7456
  Hamming Accuracy: 0.9473
  F1: 0.8702


Fold 5, Epoch 8 - Training: 100%|██████████| 624/624 [00:00<00:00, 780.49it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 8
Training Loss: 0.1089
Validation Loss: 0.1324
Validation Metrics:
  Exact Match Accuracy: 0.7436
  Hamming Accuracy: 0.9475
  F1: 0.8731


Fold 5, Epoch 9 - Training: 100%|██████████| 624/624 [00:00<00:00, 631.60it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 9
Training Loss: 0.1048
Validation Loss: 0.1381
Validation Metrics:
  Exact Match Accuracy: 0.7402
  Hamming Accuracy: 0.9470
  F1: 0.8740


Fold 5, Epoch 10 - Training: 100%|██████████| 624/624 [00:00<00:00, 808.77it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Fold 5, Epoch 10
Training Loss: 0.1016
Validation Loss: 0.1351
Validation Metrics:
  Exact Match Accuracy: 0.7454
  Hamming Accuracy: 0.9483
  F1: 0.8760
Early stopping triggered at epoch 10

Evaluating on test set...


  model.load_state_dict(torch.load(f'best_model_fold_{best_fold+1}.pt'))


FileNotFoundError: [Errno 2] No such file or directory: 'best_model_fold_1.pt'