In [1]:
# @title Cell 1: ViT Transfer Learning Infrastructure Configuration (Fixed & Optimized)

# File: 05_01_ViT_RAF-DB_CASME2-AF.ipynb - Cell 1
# Location: experiments/05_01_ViT_RAF-DB_CASME2-AF.ipynb
# Purpose: Transfer learning from RAF-DB macro-expressions to CASME2 micro-expressions (Apex Frame)
#          with optimized loss functions and class weights

from google.colab import drive
print("=" * 70)
print("VIT TRANSFER LEARNING: RAF-DB → CASME2 APEX FRAME (OPTIMIZED)")
print("=" * 70)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("    Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import timm
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Project paths configuration
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"

print("\n" + "=" * 70)
print("MODEL VARIANT CONFIGURATION")
print("=" * 70)

# ViT Model Variant Selection
VIT_MODEL_VARIANT = 'patch32'

if VIT_MODEL_VARIANT == 'patch16':
    VIT_MODEL_NAME = 'google/vit-base-patch16-224-in21k'
    PATCH_SIZE = 16
    print("\nSelected: ViT-Base Patch16")
    print("  Characteristics: Fine-grained feature extraction")
    print("  Tokens at 384px: 576 tokens (24×24 grid)")
    print("  Best for: Subtle micro-expression details")
elif VIT_MODEL_VARIANT == 'patch32':
    VIT_MODEL_NAME = 'google/vit-base-patch32-224-in21k'
    PATCH_SIZE = 32
    print("\nSelected: ViT-Base Patch32")
    print("  Characteristics: Efficient feature extraction")
    print("  Tokens at 384px: 144 tokens (12×12 grid)")
    print("  Best for: Balanced performance and speed")
else:
    raise ValueError(f"Invalid VIT_MODEL_VARIANT: {VIT_MODEL_VARIANT}. Use 'patch16' or 'patch32'")

print("\n" + "=" * 70)
print("TRANSFER LEARNING STAGE CONFIGURATION")
print("=" * 70)

# Training stage toggle
TRAINING_STAGE = 'finetune'

print(f"\nCurrent training stage: {TRAINING_STAGE.upper()}")
print("  'pretrain' = RAF-DB macro-expression feature extraction")
print("  'finetune' = CASME2 micro-expression specialization")

# FIXED CLASS ORDERING FOR TRANSFER LEARNING (BY FREQUENCY)
# Ordered from most to least frequent in CASME2 training set
TRANSFER_CLASSES = ['disgust', 'happy', 'surprise', 'sad', 'fear']
NUM_CLASSES = 5

print(f"\nFixed transfer learning classes: {NUM_CLASSES}")
print(f"  Ordering: {TRANSFER_CLASSES}")
print(f"  Order rationale: Frequency-based (most to least common)")
print(f"  Index mapping:")
for idx, cls in enumerate(TRANSFER_CLASSES):
    print(f"    {idx} = {cls}")

# CASME2 to RAF-DB naming convention mapping
CASME2_TO_RAF_MAPPING = {
    'disgust': 'disgust',
    'fear': 'fear',
    'happiness': 'happy',
    'sadness': 'sad',
    'surprise': 'surprise',
    'repression': None,
    'others': None
}

print(f"\nCASME2 naming convention mapping:")
print(f"  happiness → happy")
print(f"  sadness → sad")
print(f"  repression → excluded")
print(f"  others → excluded")

# OPTIMIZED CLASS WEIGHTS CONFIGURATION
# CASME2 train distribution after filtering: [50, 25, 20, 5, 1] for [disgust, happy, surprise, sad, fear]
print(f"\n" + "=" * 70)
print("CLASS IMBALANCE HANDLING CONFIGURATION")
print("=" * 70)

print(f"\nCASME2 Training Distribution (5 classes):")
print(f"  disgust:  50 samples (49.5%)")
print(f"  happy:    25 samples (24.8%)")
print(f"  surprise: 20 samples (19.8%)")
print(f"  sad:       5 samples (5.0%)")
print(f"  fear:      1 sample  (1.0%) - CRITICAL: Very minor!")

# CrossEntropy Loss - Inverse Square Root Frequency Weights
# Formula: weight[i] = 1/sqrt(freq[i]) normalized to smallest class = 1.0
CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.42, 1.59, 3.17, 7.09]

print(f"\nCrossEntropy Class Weights (inverse sqrt frequency):")
print(f"  Values: {CROSSENTROPY_CLASS_WEIGHTS}")
print(f"  Rationale: Sqrt smoothing prevents extreme weights")
print(f"  Fear weight (7.09): Strong but not excessive")

# Focal Loss - Smoothed Normalized Alpha Weights (sum = 1.0)
# Formula: alpha[i] = 1/sqrt(proportion[i]) normalized to sum=1.0
FOCAL_LOSS_ALPHA_WEIGHTS = [0.071, 0.100, 0.112, 0.222, 0.497]
FOCAL_LOSS_GAMMA = 2.0

print(f"\nFocal Loss Configuration:")
print(f"  Alpha weights: {FOCAL_LOSS_ALPHA_WEIGHTS}")
print(f"  Alpha sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f} (must be 1.0)")
print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
print(f"  Rationale: Smoothed alpha prevents over-emphasis on minorities")

# LOSS FUNCTION SELECTION TOGGLE
USE_FOCAL_LOSS = False

print(f"\n" + "=" * 70)
print("LOSS FUNCTION CONFIGURATION")
print("=" * 70)

if USE_FOCAL_LOSS:
    print(f"\nSelected: Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Alpha weights: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Strategy: Focus on hard examples + minority class weighting")
    print(f"  Best for: Extreme imbalance (1-50 ratio)")
else:
    print(f"\nSelected: CrossEntropy Loss")
    print(f"  Class weights: {CROSSENTROPY_CLASS_WEIGHTS}")
    print(f"  Strategy: Direct minority class weighting")
    print(f"  Best for: Moderate imbalance")

# Unified output structure
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/05_01_transfer_learning"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/05_01_transfer_learning"

# Image resolution configuration
INPUT_SIZE = 384
print(f"\n" + "=" * 70)
print("IMAGE PROCESSING CONFIGURATION")
print("=" * 70)
print(f"\nImage resolution: {INPUT_SIZE}×{INPUT_SIZE}px")
print(f"  RAF-DB images: Upscaled from 100×100 to {INPUT_SIZE}×{INPUT_SIZE}")
print(f"  CASME2 images: Native resolution maintained at {INPUT_SIZE}×{INPUT_SIZE}")
print(f"  Resize method: LANCZOS (high quality)")

# Stage-dependent configuration
if TRAINING_STAGE == 'pretrain':
    print("\n" + "=" * 70)
    print("STAGE 1: RAF-DB PRE-TRAINING CONFIGURATION")
    print("=" * 70)

    DATASET_NAME = 'RAF-DB Balanced (Transfer Set)'
    DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_raf"
    METADATA_PATH = f"{PROJECT_ROOT}/datasets/metadata/rafdb_metadata.csv"

    BATCH_SIZE = 256
    NUM_EPOCHS = 30
    LEARNING_RATE = 1e-5
    WEIGHT_DECAY = 1e-5
    DROPOUT_RATE = 0.2
    GRADIENT_CLIP = 1.0

    AUGMENTATION_STRENGTH = 'moderate'
    ROTATION_RANGE = 10
    BRIGHTNESS_FACTOR = 0.1
    CONTRAST_FACTOR = 0.0
    HORIZONTAL_FLIP = False

    PRETRAINED_CHECKPOINT = None
    FREEZE_ENCODER = False

    CHECKPOINT_FILENAME = 'raf_pretrain_best_f1.pth'
    LOGS_SUBDIR = 'pretrain_logs'

    print(f"\nDataset: {DATASET_NAME}")
    print(f"  Root: {DATASET_ROOT}")
    print(f"  Classes: {NUM_CLASSES} (fixed ordering by frequency)")
    print(f"  Expected samples: ~29,880 images (5 classes × ~5,956 per class)")
    print(f"  Original resolution: 100×100px (aligned faces)")
    print(f"  Training resolution: {INPUT_SIZE}×{INPUT_SIZE}px (upscaled)")

    print(f"\nTraining configuration:")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Epochs: {NUM_EPOCHS}")
    print(f"  Learning rate: {LEARNING_RATE}")
    print(f"  Weight decay: {WEIGHT_DECAY}")
    print(f"  Dropout: {DROPOUT_RATE}")
    print(f"  Augmentation: {AUGMENTATION_STRENGTH}")

    print(f"\nLoss function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
    if USE_FOCAL_LOSS:
        print(f"  Note: Focal Loss less critical for balanced RAF-DB")
        print(f"  Gamma: {FOCAL_LOSS_GAMMA}")

elif TRAINING_STAGE == 'finetune':
    print("\n" + "=" * 70)
    print("STAGE 2: CASME2 FINE-TUNING CONFIGURATION (OPTIMIZED)")
    print("=" * 70)

    DATASET_NAME = 'CASME2 Apex Frame (Phase 1)'
    DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/data_split"
    METADATA_PATH = f"{DATASET_ROOT}/split_metadata.json"

    # OPTIMIZED HYPERPARAMETERS FOR SMALL IMBALANCED DATASET
    BATCH_SIZE = 8          # Smaller batch for limited data (was 16)
    NUM_EPOCHS = 50
    LEARNING_RATE = 3e-6    # Slightly higher than before (was 1e-6)
    WEIGHT_DECAY = 1e-4
    DROPOUT_RATE = 0.3      # Lower dropout (was 0.5)
    GRADIENT_CLIP = 0.5

    AUGMENTATION_STRENGTH = 'enhanced'
    ROTATION_RANGE = 15
    BRIGHTNESS_FACTOR = 0.2
    CONTRAST_FACTOR = 0.2
    HORIZONTAL_FLIP = True

    CHECKPOINT_FILENAME = 'casme2_finetune_best_f1.pth'
    LOGS_SUBDIR = 'finetune_logs'
    PRETRAINED_CHECKPOINT = f"{CHECKPOINT_ROOT}/raf_pretrain_best_f1.pth"

    if not os.path.exists(PRETRAINED_CHECKPOINT):
        raise FileNotFoundError(
            f"Pre-trained checkpoint not found: {PRETRAINED_CHECKPOINT}\n"
            f"Please run Stage 1 (pretrain) first before fine-tuning."
        )

    # OPTIMIZED TRANSFER STRATEGY
    FINETUNE_STRATEGY = 'frozen_encoder'  # Conservative approach for limited data
    FREEZE_ENCODER = (FINETUNE_STRATEGY == 'frozen_encoder')

    print(f"\nDataset: {DATASET_NAME}")
    print(f"  Root: {DATASET_ROOT}")
    print(f"  Classes: {NUM_CLASSES} (fixed ordering, fear=1 sample)")
    print(f"  Expected samples: ~150-160 apex images")
    print(f"  Native resolution: {INPUT_SIZE}×{INPUT_SIZE}px")

    print(f"\nTransfer learning configuration:")
    print(f"  Pre-trained checkpoint: {os.path.basename(PRETRAINED_CHECKPOINT)}")
    print(f"  Checkpoint verified: Found")
    print(f"  Strategy: {FINETUNE_STRATEGY}")
    print(f"  Encoder frozen: {FREEZE_ENCODER}")
    print(f"  Rationale: Preserve pretrained features, only adapt classifier")

    print(f"\nOptimized fine-tuning hyperparameters:")
    print(f"  Batch size: {BATCH_SIZE} (reduced from 16 for better gradient)")
    print(f"  Epochs: {NUM_EPOCHS}")
    print(f"  Learning rate: {LEARNING_RATE} (increased from 1e-6)")
    print(f"  Weight decay: {WEIGHT_DECAY}")
    print(f"  Dropout: {DROPOUT_RATE} (reduced from 0.5 to prevent over-regularization)")
    print(f"  Augmentation: {AUGMENTATION_STRENGTH}")

    print(f"\nLoss function: {'Focal Loss (CRITICAL)' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
    if USE_FOCAL_LOSS:
        print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
        print(f"  Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
        print(f"  Rationale: Handle extreme 1-50 class imbalance")

else:
    raise ValueError(f"Invalid TRAINING_STAGE: {TRAINING_STAGE}. Use 'pretrain' or 'finetune'")

# GPU configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\n[3] Hardware configuration")
print(f"    Device: {device}")
print(f"    GPU: {gpu_name} ({gpu_memory:.1f} GB)")

# Hardware-optimized worker configuration
if 'A100' in gpu_name:
    NUM_WORKERS = 32
    torch.backends.cudnn.benchmark = True
    print("    Optimization: A100 detected, enabled cudnn benchmark")
elif 'L4' in gpu_name:
    NUM_WORKERS = 16
    torch.backends.cudnn.benchmark = True
    print("    Optimization: L4 detected, enabled cudnn benchmark")
else:
    NUM_WORKERS = 8
    print("    Optimization: Default GPU configuration")

# RAM preloading workers
RAM_PRELOAD_WORKERS = 128
print(f"    DataLoader workers: {NUM_WORKERS} (batch preparation)")
print(f"    RAM preload workers: {RAM_PRELOAD_WORKERS} (parallel image loading)")

print(f"\n[4] Loss function configuration")
print(f"    Function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
if USE_FOCAL_LOSS:
    print(f"    Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"    Alpha weights: {FOCAL_LOSS_ALPHA_WEIGHTS}")
else:
    print(f"    Class weights: {CROSSENTROPY_CLASS_WEIGHTS}")

# ViT Architecture for Transfer Learning
class ViTTransferLearning(nn.Module):
    """ViT architecture with transfer learning support for macro to micro expression recognition"""

    def __init__(self, num_classes, dropout_rate=0.2, pretrained_checkpoint=None, freeze_encoder=False):
        super(ViTTransferLearning, self).__init__()

        from transformers import ViTModel

        self.vit = ViTModel.from_pretrained(
            VIT_MODEL_NAME,
            add_pooling_layer=False
        )

        if freeze_encoder:
            for param in self.vit.parameters():
                param.requires_grad = False
            print("ViT encoder frozen for fine-tuning")
        else:
            for param in self.vit.parameters():
                param.requires_grad = True
            print("ViT encoder trainable")

        self.vit_feature_dim = self.vit.config.hidden_size

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.vit_feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
        )

        self.classifier = nn.Linear(128, num_classes)

        print(f"ViT Transfer Learning: {self.vit_feature_dim} -> 512 -> 128 -> {num_classes}")

        if pretrained_checkpoint and os.path.exists(pretrained_checkpoint):
            self.load_pretrained_weights(pretrained_checkpoint)

    def load_pretrained_weights(self, checkpoint_path):
        """Load pre-trained weights from RAF-DB pre-training stage"""
        print(f"\nLoading pre-trained weights from: {os.path.basename(checkpoint_path)}")

        checkpoint = torch.load(checkpoint_path, map_location='cpu')

        vit_state = {k.replace('vit.', ''): v
                    for k, v in checkpoint['model_state_dict'].items()
                    if k.startswith('vit.')}

        missing_keys, unexpected_keys = self.vit.load_state_dict(vit_state, strict=False)

        print(f"  ViT encoder weights loaded")
        print(f"  Missing keys: {len(missing_keys)}")
        print(f"  Unexpected keys: {len(unexpected_keys)}")

        classifier_state = {k.replace('classifier_layers.', ''): v
                          for k, v in checkpoint['model_state_dict'].items()
                          if k.startswith('classifier_layers.')}

        if classifier_state:
            try:
                self.classifier_layers.load_state_dict(classifier_state, strict=False)
                print(f"  Classifier layers loaded")
            except:
                print(f"  Classifier layers not loaded (dimension mismatch expected)")

        print(f"  Transfer learning initialization complete")

    def forward(self, pixel_values):
        vit_outputs = self.vit(
            pixel_values=pixel_values,
            interpolate_pos_encoding=True
        )

        vit_features = vit_outputs.last_hidden_state[:, 0]

        processed_features = self.classifier_layers(vit_features)
        output = self.classifier(processed_features)

        return output

# Optimized Focal Loss Implementation
class OptimizedFocalLoss(nn.Module):
    """
    Advanced Focal Loss with per-class alpha support
    Paper: Focal Loss for Dense Object Detection (Lin et al., 2017)

    Formula: FL(p_t) = -alpha_t * (1-p_t)^gamma * log(p_t)
    """

    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()

        if alpha is not None:
            if isinstance(alpha, list):
                self.alpha = torch.tensor(alpha, dtype=torch.float32)
            else:
                self.alpha = alpha

            alpha_sum = self.alpha.sum().item()
            if abs(alpha_sum - 1.0) > 0.01:
                print(f"Warning: Alpha weights sum to {alpha_sum:.3f}, expected 1.0")
        else:
            self.alpha = None

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)

        if self.alpha is not None:
            if self.alpha.device != targets.device:
                self.alpha = self.alpha.to(targets.device)
            alpha_t = self.alpha.gather(0, targets)
        else:
            alpha_t = 1.0

        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Loss function factory
def create_criterion(use_focal_loss=False, gamma=2.0, alpha_weights=None, class_weights=None):
    """Create loss function based on configuration"""
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"  Alpha weights: {alpha_weights}")
            print(f"  Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with class weights")
        if class_weights is not None:
            class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
            print(f"  Class weights: {class_weights}")
            return nn.CrossEntropyLoss(weight=class_weights_tensor)
        else:
            return nn.CrossEntropyLoss()

# Optimizer and scheduler factory
def create_optimizer_scheduler(model, learning_rate, weight_decay):
    """Create optimizer and scheduler for training"""

    optimizer = optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay,
        betas=(0.9, 0.999)
    )

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=5,
        min_lr=1e-7
    )

    print(f"Optimizer: AdamW (lr={learning_rate}, wd={weight_decay})")
    print(f"Scheduler: ReduceLROnPlateau (monitor=val_f1, patience=5)")

    return optimizer, scheduler

# ViT Image Processor setup
from transformers import ViTImageProcessor

print(f"\n[5] Setting up ViT Image Processor for {INPUT_SIZE}px input...")

vit_processor = ViTImageProcessor.from_pretrained(
    VIT_MODEL_NAME,
    do_resize=True,
    size={'height': INPUT_SIZE, 'width': INPUT_SIZE},
    do_normalize=True,
    do_rescale=True,
    do_center_crop=False
)

print(f"    ViT Image Processor configured for {INPUT_SIZE}px with interpolation")
print(f"    Variant: {VIT_MODEL_VARIANT.upper()}")
print(f"    Expected tokens: {(INPUT_SIZE // PATCH_SIZE) ** 2}")

# Transform functions with stage-dependent augmentation
def get_transforms(stage):
    """Get transforms based on training stage and augmentation strength"""

    if stage == 'train':
        transform_list = []

        if ROTATION_RANGE > 0:
            transform_list.append(
                transforms.RandomRotation(degrees=ROTATION_RANGE)
            )

        if BRIGHTNESS_FACTOR > 0 or CONTRAST_FACTOR > 0:
            transform_list.append(
                transforms.ColorJitter(
                    brightness=BRIGHTNESS_FACTOR,
                    contrast=CONTRAST_FACTOR
                )
            )

        if HORIZONTAL_FLIP:
            transform_list.append(
                transforms.RandomHorizontalFlip(p=0.5)
            )

        if transform_list:
            augmentation = transforms.Compose(transform_list)

            def train_transform(image):
                image = augmentation(image)
                inputs = vit_processor(image, return_tensors="pt")
                return inputs['pixel_values'].squeeze(0)

            return train_transform
        else:
            def train_transform(image):
                inputs = vit_processor(image, return_tensors="pt")
                return inputs['pixel_values'].squeeze(0)

            return train_transform

    else:
        def val_transform(image):
            inputs = vit_processor(image, return_tensors="pt")
            return inputs['pixel_values'].squeeze(0)

        return val_transform

# Custom Dataset class for RAF-DB with fixed class ordering
class RAFDBDataset(Dataset):
    """Dataset class for RAF-DB with fixed transfer learning class ordering"""

    def __init__(self, metadata_df, dataset_root, transform=None, transfer_classes=None):
        if transfer_classes:
            self.metadata = metadata_df[metadata_df['emotion_label'].isin(transfer_classes)].copy()
            print(f"Filtered to {len(self.metadata)} samples from {len(metadata_df)} total")
        else:
            self.metadata = metadata_df.copy()

        self.dataset_root = dataset_root
        self.transform = transform

        self.class_to_idx = {cls: idx for idx, cls in enumerate(transfer_classes)}

        print(f"Loaded {len(self.metadata)} samples")
        print(f"Fixed class ordering: {list(self.class_to_idx.keys())}")

        unique_labels = self.metadata['emotion_label'].unique()
        for label in unique_labels:
            if label not in self.class_to_idx:
                raise ValueError(f"Label '{label}' not found in transfer_classes: {transfer_classes}")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]

        image_path = os.path.join(self.dataset_root, row['filepath'].replace('datasets/processed_raf/', ''))
        image = Image.open(image_path).convert('RGB')

        if image.size != (INPUT_SIZE, INPUT_SIZE):
            image = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.LANCZOS)

        if self.transform:
            image = self.transform(image)

        label = self.class_to_idx[row['emotion_label']]

        return image, label

# Custom Dataset class for CASME2 with naming convention mapping
class CASME2Dataset(Dataset):
    """Dataset class for CASME2 with CASME2->RAF naming mapping and fixed class ordering"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='train',
                 transfer_classes=None, casme2_mapping=None):

        samples = split_metadata[split]['samples']

        if casme2_mapping:
            mapped_samples = []
            for s in samples:
                casme2_emotion = s['emotion']
                raf_emotion = casme2_mapping.get(casme2_emotion)

                if raf_emotion in transfer_classes:
                    mapped_sample = s.copy()
                    mapped_sample['emotion'] = raf_emotion
                    mapped_samples.append(mapped_sample)

            print(f"Mapped {len(mapped_samples)} samples from {len(samples)} total")
            self.samples = mapped_samples
        else:
            self.samples = samples

        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split

        self.class_to_idx = {cls: idx for idx, cls in enumerate(transfer_classes)}

        print(f"Loaded {len(self.samples)} samples for {split} split")
        print(f"Fixed class ordering: {list(self.class_to_idx.keys())}")

        actual_classes = {}
        for s in self.samples:
            emotion = s['emotion']
            actual_classes[emotion] = actual_classes.get(emotion, 0) + 1

        print(f"Class distribution:")
        for cls in transfer_classes:
            count = actual_classes.get(cls, 0)
            idx = self.class_to_idx[cls]
            status = "NO SAMPLES" if count == 0 else f"{count} samples"
            print(f"  [{idx}] {cls}: {status}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        image_path = os.path.join(self.dataset_root, self.split, sample['image_filename'])
        image = Image.open(image_path).convert('RGB')

        if image.size != (INPUT_SIZE, INPUT_SIZE):
            image = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.LANCZOS)

        if self.transform:
            image = self.transform(image)

        label = self.class_to_idx[sample['emotion']]

        return image, label, sample['sample_id']

# Create unified directory structure
print(f"\n[6] Creating unified output directory structure...")

os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(RESULTS_ROOT, exist_ok=True)

logs_path = f"{RESULTS_ROOT}/{LOGS_SUBDIR}"
os.makedirs(logs_path, exist_ok=True)
os.makedirs(f"{logs_path}/training_logs", exist_ok=True)
os.makedirs(f"{logs_path}/evaluation_results", exist_ok=True)

print(f"    Unified structure created:")
print(f"    Checkpoints: {CHECKPOINT_ROOT}/")
print(f"      - {CHECKPOINT_FILENAME}")
print(f"    Results: {RESULTS_ROOT}/")
print(f"      - {LOGS_SUBDIR}/")

# Architecture validation
print(f"\n[7] Validating ViT Transfer Learning architecture...")

try:
    test_model = ViTTransferLearning(
        num_classes=NUM_CLASSES,
        dropout_rate=DROPOUT_RATE,
        pretrained_checkpoint=PRETRAINED_CHECKPOINT if TRAINING_STAGE == 'finetune' else None,
        freeze_encoder=FREEZE_ENCODER
    ).to(device)

    test_input = torch.randn(1, 3, INPUT_SIZE, INPUT_SIZE).to(device)
    test_output = test_model(test_input)

    expected_tokens = (INPUT_SIZE // PATCH_SIZE) ** 2

    print(f"    Validation successful")
    print(f"    Output shape: {test_output.shape}")
    print(f"    Expected tokens: {expected_tokens}")
    print(f"    Variant: {VIT_MODEL_VARIANT} ({PATCH_SIZE}×{PATCH_SIZE} patches)")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"    Validation failed: {e}")
    raise

# Global configuration
GLOBAL_CONFIG = {
    'training_stage': TRAINING_STAGE,
    'dataset_name': DATASET_NAME,
    'dataset_root': DATASET_ROOT,
    'metadata_path': METADATA_PATH,
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'ram_preload_workers': RAM_PRELOAD_WORKERS,
    'num_classes': NUM_CLASSES,
    'transfer_classes': TRANSFER_CLASSES,
    'class_to_idx': {cls: idx for idx, cls in enumerate(TRANSFER_CLASSES)},
    'casme2_mapping': CASME2_TO_RAF_MAPPING,
    'transform_train': get_transforms('train'),
    'transform_val': get_transforms('val'),
    'vit_model': VIT_MODEL_NAME,
    'vit_variant': VIT_MODEL_VARIANT,
    'patch_size': PATCH_SIZE,
    'input_size': INPUT_SIZE,
    'num_epochs': NUM_EPOCHS,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'dropout_rate': DROPOUT_RATE,
    'gradient_clip': GRADIENT_CLIP,
    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_weights': CROSSENTROPY_CLASS_WEIGHTS,
    'pretrained_checkpoint': PRETRAINED_CHECKPOINT,
    'freeze_encoder': FREEZE_ENCODER,
    'finetune_strategy': FINETUNE_STRATEGY if TRAINING_STAGE == 'finetune' else None,
    'augmentation_strength': AUGMENTATION_STRENGTH,
    'checkpoint_root': CHECKPOINT_ROOT,
    'checkpoint_filename': CHECKPOINT_FILENAME,
    'checkpoint_path': f"{CHECKPOINT_ROOT}/{CHECKPOINT_FILENAME}",
    'results_root': RESULTS_ROOT,
    'logs_subdir': LOGS_SUBDIR,
    'logs_path': logs_path,
    'criterion_factory': create_criterion,
    'optimizer_scheduler_factory': create_optimizer_scheduler
}

# Configuration summary
print("\n" + "=" * 70)
print("TRANSFER LEARNING INFRASTRUCTURE CONFIGURATION COMPLETE")
print("=" * 70)

print(f"\nModel Configuration:")
print(f"  Variant: {VIT_MODEL_VARIANT.upper()} (ViT-Base-{VIT_MODEL_VARIANT.capitalize()})")
print(f"  Patch size: {PATCH_SIZE}×{PATCH_SIZE}px")
print(f"  Input resolution: {INPUT_SIZE}×{INPUT_SIZE}px")
print(f"  Expected tokens: {(INPUT_SIZE // PATCH_SIZE) ** 2}")
print(f"  Dropout: {DROPOUT_RATE}")

print(f"\nFixed Transfer Learning Classes:")
print(f"  Total: {NUM_CLASSES} classes")
print(f"  Ordering: {TRANSFER_CLASSES} (by frequency)")
print(f"  Class indices:")
for idx, cls in enumerate(TRANSFER_CLASSES):
    print(f"    [{idx}] = {cls}")

print(f"\nLoss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
else:
    print(f"  Function: CrossEntropy")
    print(f"  Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nTraining Stage: {TRAINING_STAGE.upper()}")
print(f"  Dataset: {DATASET_NAME}")

print(f"\nUnified Output Structure:")
print(f"  Checkpoint: {CHECKPOINT_ROOT}/{CHECKPOINT_FILENAME}")
print(f"  Results: {RESULTS_ROOT}/{LOGS_SUBDIR}/")

if TRAINING_STAGE == 'pretrain':
    print(f"\nPre-training Details:")
    print(f"  Expected samples: ~29,880 images")
    print(f"  Strategy: Extract macro-expression features")
    print(f"  Image upscaling: 100×100 → {INPUT_SIZE}×{INPUT_SIZE}")
    print(f"  Output checkpoint: {CHECKPOINT_FILENAME}")
elif TRAINING_STAGE == 'finetune':
    print(f"\nFine-tuning Details:")
    print(f"  Expected samples: ~150-160 apex images")
    print(f"  Strategy: {FINETUNE_STRATEGY}")
    print(f"  Load from: raf_pretrain_best_f1.pth")
    print(f"  Output checkpoint: {CHECKPOINT_FILENAME}")
    print(f"  Encoder frozen: {FREEZE_ENCODER}")
    print(f"  Critical: Fear class has ONLY 1 sample")

print(f"\nTraining Hyperparameters:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"  Augmentation: {AUGMENTATION_STRENGTH}")

print(f"\nHardware Optimization:")
print(f"  DataLoader workers: {NUM_WORKERS}")
print(f"  RAM preload workers: {RAM_PRELOAD_WORKERS}")
print(f"  GPU optimization: {'Enabled' if torch.backends.cudnn.benchmark else 'Default'}")

print(f"\n[8] Optimization Summary")
print(f"    Class reordering: By frequency (disgust→happy→surprise→sad→fear)")
print(f"    Loss function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy'} with optimized weights")
print(f"    Transfer strategy: {'Frozen encoder' if TRAINING_STAGE == 'finetune' and FREEZE_ENCODER else 'Full training'}")
print(f"    Status: Ready for optimized transfer learning")

print(f"\nNext: Cell 2 - Dataset Loading and Training Pipeline")
print("=" * 70)

VIT TRANSFER LEARNING: RAF-DB → CASME2 APEX FRAME (OPTIMIZED)

[1] Mounting Google Drive...
Mounted at /content/drive
    Google Drive mounted successfully

[2] Importing required libraries...

MODEL VARIANT CONFIGURATION

Selected: ViT-Base Patch32
  Characteristics: Efficient feature extraction
  Tokens at 384px: 144 tokens (12×12 grid)
  Best for: Balanced performance and speed

TRANSFER LEARNING STAGE CONFIGURATION

Current training stage: FINETUNE
  'pretrain' = RAF-DB macro-expression feature extraction
  'finetune' = CASME2 micro-expression specialization

Fixed transfer learning classes: 5
  Ordering: ['disgust', 'happy', 'surprise', 'sad', 'fear']
  Order rationale: Frequency-based (most to least common)
  Index mapping:
    0 = disgust
    1 = happy
    2 = surprise
    3 = sad
    4 = fear

CASME2 naming convention mapping:
  happiness → happy
  sadness → sad
  repression → excluded
  others → excluded

CLASS IMBALANCE HANDLING CONFIGURATION

CASME2 Training Distribution (5 

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

    ViT Image Processor configured for 384px with interpolation
    Variant: PATCH32
    Expected tokens: 144

[6] Creating unified output directory structure...
    Unified structure created:
    Checkpoints: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/05_01_transfer_learning/
      - casme2_finetune_best_f1.pth
    Results: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/05_01_transfer_learning/
      - finetune_logs/

[7] Validating ViT Transfer Learning architecture...


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/352M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch32-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ViT encoder frozen for fine-tuning
ViT Transfer Learning: 768 -> 512 -> 128 -> 5

Loading pre-trained weights from: raf_pretrain_best_f1.pth
  ViT encoder weights loaded
  Missing keys: 0
  Unexpected keys: 0
  Classifier layers loaded
  Transfer learning initialization complete
    Validation successful
    Output shape: torch.Size([1, 5])
    Expected tokens: 144
    Variant: patch32 (32×32 patches)

TRANSFER LEARNING INFRASTRUCTURE CONFIGURATION COMPLETE

Model Configuration:
  Variant: PATCH32 (ViT-Base-Patch32)
  Patch size: 32×32px
  Input resolution: 384×384px
  Expected tokens: 144
  Dropout: 0.3

Fixed Transfer Learning Classes:
  Total: 5 classes
  Ordering: ['disgust', 'happy', 'surprise', 'sad', 'fear'] (by frequency)
  Class indices:
    [0] = disgust
    [1] = happy
    [2] = surprise
    [3] = sad
    [4] = fear

Loss Configuration:
  Function: CrossEntropy
  Weights: [1.0, 1.42, 1.59, 3.17, 7.09]

Training Stage: FINETUNE
  Dataset: CASME2 Apex Frame (Phase 1)

Unified 

In [2]:
# @title Cell 2: ViT Transfer Learning Training Pipeline

# File: 05_01_ViT_RAF-DB_CASME2-AF.ipynb - Cell 2
# Location: experiments/05_01_ViT_RAF-DB_CASME2-AF.ipynb
# Purpose: Dual-stage training pipeline for RAF-DB pre-training and CASME2 fine-tuning

from google.colab import drive
drive.mount('/content/drive')

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import shutil
import tempfile

print("=" * 70)
print(f"VIT TRANSFER LEARNING TRAINING PIPELINE - STAGE: {GLOBAL_CONFIG['training_stage'].upper()}")
print("=" * 70)

# Display stage-specific configuration
if GLOBAL_CONFIG['training_stage'] == 'pretrain':
    print(f"\nPre-training Configuration:")
    print(f"  Dataset: {GLOBAL_CONFIG['dataset_name']}")
    print(f"  Classes: {GLOBAL_CONFIG['num_classes']} (fixed ordering)")
    print(f"  Expected samples: ~29,880 images")
    print(f"  Batch size: {GLOBAL_CONFIG['batch_size']}")
    print(f"  Epochs: {GLOBAL_CONFIG['num_epochs']}")
    print(f"  Learning rate: {GLOBAL_CONFIG['learning_rate']}")
    print(f"  Augmentation: {GLOBAL_CONFIG['augmentation_strength']}")
    print(f"  Output: {GLOBAL_CONFIG['checkpoint_filename']}")
else:
    print(f"\nFine-tuning Configuration:")
    print(f"  Dataset: {GLOBAL_CONFIG['dataset_name']}")
    print(f"  Classes: {GLOBAL_CONFIG['num_classes']} (fixed ordering, fear=placeholder)")
    print(f"  Expected samples: ~150-160 apex images")
    print(f"  Pre-trained checkpoint: {os.path.basename(GLOBAL_CONFIG['pretrained_checkpoint'])}")
    print(f"  Strategy: {GLOBAL_CONFIG['finetune_strategy']}")
    print(f"  Encoder frozen: {GLOBAL_CONFIG['freeze_encoder']}")
    print(f"  Batch size: {GLOBAL_CONFIG['batch_size']}")
    print(f"  Epochs: {GLOBAL_CONFIG['num_epochs']}")
    print(f"  Learning rate: {GLOBAL_CONFIG['learning_rate']}")
    print(f"  Augmentation: {GLOBAL_CONFIG['augmentation_strength']}")
    print(f"  Output: {GLOBAL_CONFIG['checkpoint_filename']}")

print(f"\nLoss function: {'Focal Loss' if GLOBAL_CONFIG['use_focal_loss'] else 'CrossEntropy Loss'}")
if GLOBAL_CONFIG['use_focal_loss']:
    print(f"  Gamma: {GLOBAL_CONFIG['focal_loss_gamma']}")

# RAM preloading configuration
RAM_PRELOAD_WORKERS = GLOBAL_CONFIG['ram_preload_workers']
print(f"\nRAM preload configuration:")
print(f"  Workers: {RAM_PRELOAD_WORKERS} (parallel image loading)")
print(f"  Method: ThreadPoolExecutor with concurrent futures")
print(f"  Target: Load all images to RAM before training starts")

# Enhanced Dataset class with RAM caching for RAF-DB
class RAFDBDatasetTraining(Dataset):
    """Enhanced RAF-DB dataset with RAM caching for efficient training"""

    def __init__(self, metadata_df, dataset_root, transform=None, transfer_classes=None, use_ram_cache=True):
        if transfer_classes:
            self.metadata = metadata_df[metadata_df['emotion_label'].isin(transfer_classes)].copy()
        else:
            self.metadata = metadata_df.copy()

        self.dataset_root = dataset_root
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.cached_images = []

        # Use FIXED class ordering from transfer_classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(transfer_classes)}

        print(f"Loading RAF-DB dataset for training...")
        print(f"Fixed class ordering: {list(self.class_to_idx.keys())}")

        # Process metadata
        for _, row in self.metadata.iterrows():
            image_path = os.path.join(dataset_root, row['filepath'].replace('datasets/processed_raf/', ''))
            self.images.append(image_path)
            self.labels.append(self.class_to_idx[row['emotion_label']])

        print(f"Loaded {len(self.images)} RAF-DB samples")
        self._print_distribution()

        if self.use_ram_cache:
            self._preload_to_ram()

    def _print_distribution(self):
        """Print class distribution with indices"""
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        class_names = {v: k for k, v in self.class_to_idx.items()}

        print(f"Class distribution:")
        for idx in sorted(label_counts.keys()):
            class_name = class_names[idx]
            count = label_counts[idx]
            percentage = (count / len(self.labels)) * 100
            print(f"  [{idx}] {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        """RAM preloading with parallel loading using ThreadPoolExecutor"""
        print(f"Preloading {len(self.images)} images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            """Load single image with error handling"""
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (INPUT_SIZE, INPUT_SIZE):
                    image = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (INPUT_SIZE, INPUT_SIZE), (128, 128, 128)), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc="Loading to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * INPUT_SIZE * INPUT_SIZE * 3 * 4 / 1e9
        print(f"RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (INPUT_SIZE, INPUT_SIZE):
                    image = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (INPUT_SIZE, INPUT_SIZE), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx]

# Enhanced Dataset class with RAM caching for CASME2
class CASME2DatasetTraining(Dataset):
    """Enhanced CASME2 dataset with RAM caching for efficient fine-tuning"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='train',
                 transfer_classes=None, casme2_mapping=None, use_ram_cache=True):
        samples = split_metadata[split]['samples']

        # Apply CASME2 to RAF-DB naming convention mapping first
        if casme2_mapping:
            mapped_samples = []
            for s in samples:
                casme2_emotion = s['emotion']
                raf_emotion = casme2_mapping.get(casme2_emotion)

                if raf_emotion in transfer_classes:
                    mapped_sample = s.copy()
                    mapped_sample['emotion'] = raf_emotion
                    mapped_samples.append(mapped_sample)

            print(f"Mapped {len(mapped_samples)} samples from {len(samples)} total")
            self.samples = mapped_samples
        else:
            self.samples = samples

        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.cached_images = []

        # Use FIXED class ordering from transfer_classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(transfer_classes)}

        print(f"Loading CASME2 {split} dataset for fine-tuning...")
        print(f"Fixed class ordering: {list(self.class_to_idx.keys())}")

        # Process samples
        for sample in self.samples:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(self.class_to_idx[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])

        print(f"Loaded {len(self.images)} CASME2 {split} samples")
        self._print_distribution()

        if self.use_ram_cache:
            self._preload_to_ram()

    def _print_distribution(self):
        """Print class distribution with indices"""
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        class_names = {v: k for k, v in self.class_to_idx.items()}

        print(f"Class distribution:")
        for idx, class_name in enumerate(self.class_to_idx.keys()):
            count = label_counts.get(idx, 0)
            if count > 0:
                percentage = (count / len(self.labels)) * 100
                print(f"  [{idx}] {class_name}: {count} samples ({percentage:.1f}%)")
            else:
                print(f"  [{idx}] {class_name}: NO SAMPLES (placeholder)")

    def _preload_to_ram(self):
        """RAM preloading with parallel loading using ThreadPoolExecutor"""
        print(f"Preloading {len(self.images)} {self.split} images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            """Load single image with error handling"""
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (INPUT_SIZE, INPUT_SIZE):
                    image = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (INPUT_SIZE, INPUT_SIZE), (128, 128, 128)), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc=f"Loading {self.split} to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * INPUT_SIZE * INPUT_SIZE * 3 * 4 / 1e9
        print(f"{self.split.upper()} RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (INPUT_SIZE, INPUT_SIZE):
                    image = image.resize((INPUT_SIZE, INPUT_SIZE), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (INPUT_SIZE, INPUT_SIZE), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.sample_ids[idx]

# Enhanced metrics calculation
def calculate_metrics_robust(outputs, labels, num_classes, average='macro'):
    """Calculate metrics with enhanced error handling and missing class support"""
    try:
        if outputs.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: outputs {outputs.size(0)} vs labels {labels.size(0)}")

        if isinstance(outputs, torch.Tensor):
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
        else:
            predictions = np.array(outputs)

        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()
        else:
            labels = np.array(labels)

        # Check which classes actually present in labels
        present_classes = np.unique(labels)

        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions,
            average=average,
            zero_division=0,
            labels=list(range(num_classes))
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1),
            'present_classes': present_classes.tolist()
        }
    except Exception as e:
        print(f"Warning: Metrics calculation error: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0,
            'present_classes': []
        }

# Training epoch function
def train_epoch(model, dataloader, criterion, optimizer, device, epoch, total_epochs, num_classes):
    """Training epoch with robust error handling"""
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch+1}/{total_epochs}")

    for batch_idx, batch_data in enumerate(progress_bar):
        if len(batch_data) == 3:
            images, labels, _ = batch_data
        else:
            images, labels = batch_data

        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)

        if outputs.dim() != 2 or outputs.size(1) != num_classes:
            raise ValueError(f"Invalid output shape: {outputs.shape}, expected [batch_size, {num_classes}]")

        loss = criterion(outputs, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), GLOBAL_CONFIG['gradient_clip'])

        optimizer.step()
        running_loss += loss.item()

        all_outputs.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())

        if batch_idx % 5 == 0:
            avg_loss = running_loss / (batch_idx + 1)
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'Loss': f'{avg_loss:.4f}',
                'LR': f'{current_lr:.2e}'
            })

    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_robust(epoch_outputs, epoch_labels, num_classes, average='macro')
    except Exception as e:
        print(f"Warning: Training metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0, 'present_classes': []}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics

# Validation epoch function
def validate_epoch(model, dataloader, criterion, device, epoch, total_epochs, num_classes):
    """Validation epoch with robust error handling"""
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f"Validation Epoch {epoch+1}/{total_epochs}")

        for batch_idx, batch_data in enumerate(progress_bar):
            if len(batch_data) == 3:
                images, labels, _ = batch_data
            else:
                images, labels = batch_data

            images, labels = images.to(device), labels.to(device)

            outputs = model(images)

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            all_outputs.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())

            if batch_idx % 3 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                progress_bar.set_postfix({'Val Loss': f'{avg_loss:.4f}'})

    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_robust(epoch_outputs, epoch_labels, num_classes, average='macro')
    except Exception as e:
        print(f"Warning: Validation metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0, 'present_classes': []}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics

# Atomic checkpoint saving with validation
def save_checkpoint_robust(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                          best_metrics, config, max_retries=3):
    """Hardened checkpoint saving with atomic write and validation"""

    def make_serializable(obj):
        if isinstance(obj, torch.Tensor):
            cpu_obj = obj.detach().cpu()
            return cpu_obj.item() if cpu_obj.numel() == 1 else cpu_obj.tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.integer, np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, dict):
            return {k: make_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [make_serializable(item) for item in obj]
        else:
            return obj

    def filter_serializable_config(config):
        """Filter config to only include serializable items"""
        excluded_keys = [
            'transform_train',
            'transform_val',
            'criterion_factory',
            'optimizer_scheduler_factory',
            'device'
        ]
        return {k: v for k, v in config.items() if k not in excluded_keys}

    serializable_config = filter_serializable_config(config)

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_metrics': make_serializable(train_metrics),
        'val_metrics': make_serializable(val_metrics),
        'config': make_serializable(serializable_config),
        'best_f1': float(best_metrics['f1']),
        'best_loss': float(best_metrics['loss']),
        'best_acc': float(best_metrics['accuracy']),
        'transfer_classes': config['transfer_classes'],
        'num_classes': config['num_classes'],
        'training_stage': config['training_stage']
    }

    final_path = config['checkpoint_path']
    checkpoint_dir = config['checkpoint_root']

    for attempt in range(max_retries):
        try:
            temp_fd, temp_path = tempfile.mkstemp(dir=checkpoint_dir, suffix='.pth.tmp')
            os.close(temp_fd)

            torch.save(checkpoint, temp_path)

            validation_checkpoint = torch.load(temp_path, map_location='cpu')

            required_keys = ['model_state_dict', 'epoch', 'best_f1', 'num_classes']
            for key in required_keys:
                if key not in validation_checkpoint:
                    raise ValueError(f"Checkpoint validation failed: missing key '{key}'")

            if validation_checkpoint['epoch'] != epoch:
                raise ValueError(f"Checkpoint epoch mismatch: saved {epoch}, loaded {validation_checkpoint['epoch']}")

            shutil.move(temp_path, final_path)

            return final_path

        except Exception as e:
            if os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except:
                    pass

            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                time.sleep(wait_time)
            else:
                print(f"All {max_retries} checkpoint save attempts failed: {e}")
                return None

    return None

# Safe JSON serialization
def safe_json_serialize(obj):
    """Convert objects to JSON-serializable format"""
    if isinstance(obj, torch.Tensor):
        return obj.cpu().item() if obj.numel() == 1 else obj.cpu().numpy().tolist()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: safe_json_serialize(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [safe_json_serialize(item) for item in obj]
    else:
        try:
            return float(obj) if isinstance(obj, (int, float)) else str(obj)
        except:
            return str(obj)

# Create datasets based on training stage
print("\n" + "=" * 70)
print("DATASET LOADING")
print("=" * 70)

if GLOBAL_CONFIG['training_stage'] == 'pretrain':
    print("\nLoading RAF-DB datasets for pre-training...")

    metadata_df = pd.read_csv(GLOBAL_CONFIG['metadata_path'])

    train_metadata = metadata_df[metadata_df['split'] == 'train']
    val_metadata = metadata_df[metadata_df['split'] == 'val']

    train_dataset = RAFDBDatasetTraining(
        metadata_df=train_metadata,
        dataset_root=GLOBAL_CONFIG['dataset_root'],
        transform=GLOBAL_CONFIG['transform_train'],
        transfer_classes=GLOBAL_CONFIG['transfer_classes'],
        use_ram_cache=True
    )

    val_dataset = RAFDBDatasetTraining(
        metadata_df=val_metadata,
        dataset_root=GLOBAL_CONFIG['dataset_root'],
        transform=GLOBAL_CONFIG['transform_val'],
        transfer_classes=GLOBAL_CONFIG['transfer_classes'],
        use_ram_cache=True
    )

else:
    print("\nLoading CASME2 datasets for fine-tuning...")

    with open(GLOBAL_CONFIG['metadata_path'], 'r') as f:
        split_metadata = json.load(f)

    train_dataset = CASME2DatasetTraining(
        split_metadata=split_metadata,
        dataset_root=GLOBAL_CONFIG['dataset_root'],
        transform=GLOBAL_CONFIG['transform_train'],
        split='train',
        transfer_classes=GLOBAL_CONFIG['transfer_classes'],
        casme2_mapping=GLOBAL_CONFIG['casme2_mapping'],
        use_ram_cache=True
    )

    val_dataset = CASME2DatasetTraining(
        split_metadata=split_metadata,
        dataset_root=GLOBAL_CONFIG['dataset_root'],
        transform=GLOBAL_CONFIG['transform_val'],
        split='val',
        transfer_classes=GLOBAL_CONFIG['transfer_classes'],
        casme2_mapping=GLOBAL_CONFIG['casme2_mapping'],
        use_ram_cache=True
    )

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=GLOBAL_CONFIG['batch_size'],
    shuffle=True,
    num_workers=GLOBAL_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=GLOBAL_CONFIG['batch_size'],
    shuffle=False,
    num_workers=GLOBAL_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

print(f"\nDataLoader configuration:")
print(f"  Training batches: {len(train_loader)} (samples: {len(train_dataset)})")
print(f"  Validation batches: {len(val_loader)} (samples: {len(val_dataset)})")
print(f"  Batch size: {GLOBAL_CONFIG['batch_size']}")
print(f"  Num workers: {GLOBAL_CONFIG['num_workers']}")

# Initialize model
print("\n" + "=" * 70)
print("MODEL INITIALIZATION")
print("=" * 70)

model = ViTTransferLearning(
    num_classes=GLOBAL_CONFIG['num_classes'],
    dropout_rate=GLOBAL_CONFIG['dropout_rate'],
    pretrained_checkpoint=GLOBAL_CONFIG['pretrained_checkpoint'] if GLOBAL_CONFIG['training_stage'] == 'finetune' else None,
    freeze_encoder=GLOBAL_CONFIG['freeze_encoder']
).to(GLOBAL_CONFIG['device'])

# Create criterion and optimizer
criterion = GLOBAL_CONFIG['criterion_factory'](
    use_focal_loss=GLOBAL_CONFIG['use_focal_loss'],
    gamma=GLOBAL_CONFIG['focal_loss_gamma']
)

optimizer, scheduler = GLOBAL_CONFIG['optimizer_scheduler_factory'](
    model,
    GLOBAL_CONFIG['learning_rate'],
    GLOBAL_CONFIG['weight_decay']
)

print(f"\nModel: ViT Transfer Learning ({GLOBAL_CONFIG['training_stage']} stage)")
print(f"Optimizer: AdamW (LR={GLOBAL_CONFIG['learning_rate']})")
print(f"Scheduler: ReduceLROnPlateau (patience=5)")
print(f"Criterion: {'Focal Loss' if GLOBAL_CONFIG['use_focal_loss'] else 'CrossEntropy Loss'}")

# Training history tracking
training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

# Main training loop
print("\n" + "=" * 70)
print("TRAINING")
print("=" * 70)
print(f"Training configuration: {GLOBAL_CONFIG['num_epochs']} epochs")

start_time = time.time()

for epoch in range(GLOBAL_CONFIG['num_epochs']):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{GLOBAL_CONFIG['num_epochs']}")

    train_loss, train_metrics = train_epoch(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG['device'], epoch, GLOBAL_CONFIG['num_epochs'],
        GLOBAL_CONFIG['num_classes']
    )

    val_loss, val_metrics = validate_epoch(
        model, val_loader, criterion,
        GLOBAL_CONFIG['device'], epoch, GLOBAL_CONFIG['num_epochs'],
        GLOBAL_CONFIG['num_classes']
    )

    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    save_model = False
    improvement_reason = ""

    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        best_model_path = save_checkpoint_robust(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics,
            best_metrics, GLOBAL_CONFIG
        )

        if best_model_path:
            print(f"Checkpoint saved: {improvement_reason} - F1: {best_metrics['f1']:.4f}")

    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * GLOBAL_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / GLOBAL_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min")

# Training completion
total_time = time.time() - start_time

print("\n" + "=" * 70)
print(f"TRAINING COMPLETED - {GLOBAL_CONFIG['training_stage'].upper()} STAGE")
print("=" * 70)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {GLOBAL_CONFIG['num_epochs']}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")
print(f"Checkpoint saved: {GLOBAL_CONFIG['checkpoint_filename']}")

# Export training documentation
print("\n" + "=" * 70)
print("EXPORTING TRAINING DOCUMENTATION")
print("=" * 70)

training_history_path = f"{GLOBAL_CONFIG['logs_path']}/training_logs/training_history.json"

try:
    def filter_json_config(config):
        """Filter config to only include JSON-serializable items"""
        excluded_keys = [
            'transform_train',
            'transform_val',
            'criterion_factory',
            'optimizer_scheduler_factory',
            'device'
        ]
        filtered = {k: v for k, v in config.items() if k not in excluded_keys}
        if 'device' in config:
            filtered['device_name'] = str(config['device'])
        return filtered

    training_summary = {
        'experiment_type': f'ViT_Transfer_Learning_{GLOBAL_CONFIG["training_stage"].upper()}',
        'training_stage': GLOBAL_CONFIG['training_stage'],
        'dataset_name': GLOBAL_CONFIG['dataset_name'],
        'training_history': safe_json_serialize(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(GLOBAL_CONFIG['num_epochs']),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'checkpoint_filename': GLOBAL_CONFIG['checkpoint_filename'],
        'dataset_info': {
            'name': GLOBAL_CONFIG['dataset_name'],
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': GLOBAL_CONFIG['num_classes'],
            'transfer_classes': GLOBAL_CONFIG['transfer_classes']
        },
        'architecture_info': {
            'model_type': 'ViTTransferLearning',
            'backbone': GLOBAL_CONFIG['vit_model'],
            'variant': GLOBAL_CONFIG['vit_variant'],
            'patch_size': GLOBAL_CONFIG['patch_size'],
            'input_size': f"{GLOBAL_CONFIG['input_size']}x{GLOBAL_CONFIG['input_size']}",
            'dropout': GLOBAL_CONFIG['dropout_rate'],
            'classification_head': '768->512->128->5'
        },
        'training_config': {
            'batch_size': GLOBAL_CONFIG['batch_size'],
            'learning_rate': GLOBAL_CONFIG['learning_rate'],
            'weight_decay': GLOBAL_CONFIG['weight_decay'],
            'gradient_clip': GLOBAL_CONFIG['gradient_clip'],
            'augmentation': GLOBAL_CONFIG['augmentation_strength'],
            'loss_function': 'Focal Loss' if GLOBAL_CONFIG['use_focal_loss'] else 'CrossEntropy',
            'ram_preload_workers': GLOBAL_CONFIG['ram_preload_workers']
        }
    }

    if GLOBAL_CONFIG['training_stage'] == 'finetune':
        training_summary['transfer_learning_config'] = {
            'pretrained_checkpoint': os.path.basename(GLOBAL_CONFIG['pretrained_checkpoint']),
            'finetune_strategy': GLOBAL_CONFIG['finetune_strategy'],
            'encoder_frozen': GLOBAL_CONFIG['freeze_encoder']
        }

    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Training documentation saved: training_history.json")
    print(f"Location: {GLOBAL_CONFIG['logs_subdir']}/training_logs/")

except Exception as e:
    print(f"Warning: Could not save training documentation: {e}")

# Memory cleanup
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\n" + "=" * 70)
print(f"Next: Cell 3 - {GLOBAL_CONFIG['training_stage'].upper()} Stage Evaluation")
print("=" * 70)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
VIT TRANSFER LEARNING TRAINING PIPELINE - STAGE: FINETUNE

Fine-tuning Configuration:
  Dataset: CASME2 Apex Frame (Phase 1)
  Classes: 5 (fixed ordering, fear=placeholder)
  Expected samples: ~150-160 apex images
  Pre-trained checkpoint: raf_pretrain_best_f1.pth
  Strategy: frozen_encoder
  Encoder frozen: True
  Batch size: 8
  Epochs: 50
  Learning rate: 3e-06
  Augmentation: enhanced
  Output: casme2_finetune_best_f1.pth

Loss function: CrossEntropy Loss

RAM preload configuration:
  Workers: 128 (parallel image loading)
  Method: ThreadPoolExecutor with concurrent futures
  Target: Load all images to RAM before training starts

DATASET LOADING

Loading CASME2 datasets for fine-tuning...
Mapped 101 samples from 201 total
Loading CASME2 train dataset for fine-tuning...
Fixed class ordering: ['disgust', 'happy', 'surprise', 'sad', 'fear']
Loaded 101 CASME2

Loading train to RAM: 100%|██████████| 101/101 [00:04<00:00, 21.05it/s]


TRAIN RAM caching completed: 101/101 images, ~0.18GB
Mapped 13 samples from 26 total
Loading CASME2 val dataset for fine-tuning...
Fixed class ordering: ['disgust', 'happy', 'surprise', 'sad', 'fear']
Loaded 13 CASME2 val samples
Class distribution:
  [0] disgust: 6 samples (46.2%)
  [1] happy: 3 samples (23.1%)
  [2] surprise: 2 samples (15.4%)
  [3] sad: 1 samples (7.7%)
  [4] fear: 1 samples (7.7%)
Preloading 13 val images to RAM with 128 workers...


Loading val to RAM: 100%|██████████| 13/13 [00:01<00:00,  8.32it/s]


VAL RAM caching completed: 13/13 images, ~0.02GB

DataLoader configuration:
  Training batches: 13 (samples: 101)
  Validation batches: 2 (samples: 13)
  Batch size: 8
  Num workers: 8

MODEL INITIALIZATION


Some weights of the model checkpoint at google/vit-base-patch32-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ViT encoder frozen for fine-tuning
ViT Transfer Learning: 768 -> 512 -> 128 -> 5

Loading pre-trained weights from: raf_pretrain_best_f1.pth
  ViT encoder weights loaded
  Missing keys: 0
  Unexpected keys: 0
  Classifier layers loaded
  Transfer learning initialization complete
Using CrossEntropy Loss with class weights
Optimizer: AdamW (lr=3e-06, wd=0.0001)
Scheduler: ReduceLROnPlateau (monitor=val_f1, patience=5)

Model: ViT Transfer Learning (finetune stage)
Optimizer: AdamW (LR=3e-06)
Scheduler: ReduceLROnPlateau (patience=5)
Criterion: CrossEntropy Loss

TRAINING
Training configuration: 50 epochs

Epoch 1/50


Training Epoch 1/50: 100%|██████████| 13/13 [00:02<00:00,  4.96it/s, Loss=1.5900, LR=3.00e-06]
Validation Epoch 1/50: 100%|██████████| 2/2 [00:00<00:00,  3.91it/s, Val Loss=1.3261]


Train - Loss: 1.5775, F1: 0.1696, Acc: 0.3069
Val   - Loss: 1.4998, F1: 0.1176, Acc: 0.3846
Time  - Epoch: 3.2s, LR: 3.00e-06
Checkpoint saved: Higher F1 - F1: 0.1176
Progress: 2.0% | Best F1: 0.1176 | ETA: 3.8min

Epoch 2/50


Training Epoch 2/50: 100%|██████████| 13/13 [00:02<00:00,  6.06it/s, Loss=1.5910, LR=3.00e-06]
Validation Epoch 2/50: 100%|██████████| 2/2 [00:00<00:00,  2.33it/s, Val Loss=1.2839]


Train - Loss: 1.5634, F1: 0.1655, Acc: 0.2871
Val   - Loss: 1.4824, F1: 0.1111, Acc: 0.3846
Time  - Epoch: 3.0s, LR: 3.00e-06
Progress: 4.0% | Best F1: 0.1176 | ETA: 3.1min

Epoch 3/50


Training Epoch 3/50: 100%|██████████| 13/13 [00:02<00:00,  4.58it/s, Loss=1.5774, LR=3.00e-06]
Validation Epoch 3/50: 100%|██████████| 2/2 [00:00<00:00,  3.48it/s, Val Loss=1.2501]


Train - Loss: 1.5796, F1: 0.1338, Acc: 0.3366
Val   - Loss: 1.4678, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.4s, LR: 3.00e-06
Checkpoint saved: Higher F1 - F1: 0.1263
Progress: 6.0% | Best F1: 0.1263 | ETA: 3.3min

Epoch 4/50


Training Epoch 4/50: 100%|██████████| 13/13 [00:02<00:00,  6.09it/s, Loss=1.5128, LR=3.00e-06]
Validation Epoch 4/50: 100%|██████████| 2/2 [00:00<00:00,  3.52it/s, Val Loss=1.2183]


Train - Loss: 1.5057, F1: 0.2000, Acc: 0.4158
Val   - Loss: 1.4552, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 3.00e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 8.0% | Best F1: 0.1263 | ETA: 3.3min

Epoch 5/50


Training Epoch 5/50: 100%|██████████| 13/13 [00:01<00:00,  6.53it/s, Loss=1.5390, LR=3.00e-06]
Validation Epoch 5/50: 100%|██████████| 2/2 [00:00<00:00,  3.61it/s, Val Loss=1.1897]


Train - Loss: 1.5031, F1: 0.2041, Acc: 0.4257
Val   - Loss: 1.4450, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.6s, LR: 3.00e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 10.0% | Best F1: 0.1263 | ETA: 3.3min

Epoch 6/50


Training Epoch 6/50: 100%|██████████| 13/13 [00:02<00:00,  4.66it/s, Loss=1.4904, LR=3.00e-06]
Validation Epoch 6/50: 100%|██████████| 2/2 [00:00<00:00,  3.55it/s, Val Loss=1.1608]


Train - Loss: 1.4862, F1: 0.2132, Acc: 0.4455
Val   - Loss: 1.4353, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.4s, LR: 3.00e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 12.0% | Best F1: 0.1263 | ETA: 3.2min

Epoch 7/50


Training Epoch 7/50: 100%|██████████| 13/13 [00:02<00:00,  6.04it/s, Loss=1.4980, LR=3.00e-06]
Validation Epoch 7/50: 100%|██████████| 2/2 [00:00<00:00,  3.54it/s, Val Loss=1.1345]


Train - Loss: 1.5013, F1: 0.1806, Acc: 0.4257
Val   - Loss: 1.4279, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 3.00e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 14.0% | Best F1: 0.1263 | ETA: 3.1min

Epoch 8/50


Training Epoch 8/50: 100%|██████████| 13/13 [00:02<00:00,  6.14it/s, Loss=1.4849, LR=3.00e-06]
Validation Epoch 8/50: 100%|██████████| 2/2 [00:00<00:00,  3.62it/s, Val Loss=1.1115]


Train - Loss: 1.5109, F1: 0.1710, Acc: 0.3762
Val   - Loss: 1.4220, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 3.00e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 16.0% | Best F1: 0.1263 | ETA: 3.2min

Epoch 9/50


Training Epoch 9/50: 100%|██████████| 13/13 [00:02<00:00,  6.33it/s, Loss=1.4697, LR=3.00e-06]
Validation Epoch 9/50: 100%|██████████| 2/2 [00:00<00:00,  3.57it/s, Val Loss=1.0904]


Train - Loss: 1.4586, F1: 0.1665, Acc: 0.4158
Val   - Loss: 1.4165, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.6s, LR: 1.50e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 18.0% | Best F1: 0.1263 | ETA: 3.3min

Epoch 10/50


Training Epoch 10/50: 100%|██████████| 13/13 [00:02<00:00,  5.80it/s, Loss=1.3735, LR=1.50e-06]
Validation Epoch 10/50: 100%|██████████| 2/2 [00:00<00:00,  3.06it/s, Val Loss=1.0816]


Train - Loss: 1.3763, F1: 0.1871, Acc: 0.4653
Val   - Loss: 1.4143, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.9s, LR: 1.50e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 20.0% | Best F1: 0.1263 | ETA: 3.2min

Epoch 11/50


Training Epoch 11/50: 100%|██████████| 13/13 [00:03<00:00,  3.80it/s, Loss=1.4036, LR=1.50e-06]
Validation Epoch 11/50: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s, Val Loss=1.0734]


Train - Loss: 1.4289, F1: 0.1792, Acc: 0.4554
Val   - Loss: 1.4120, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 4.5s, LR: 1.50e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 22.0% | Best F1: 0.1263 | ETA: 3.2min

Epoch 12/50


Training Epoch 12/50: 100%|██████████| 13/13 [00:02<00:00,  6.38it/s, Loss=1.4261, LR=1.50e-06]
Validation Epoch 12/50: 100%|██████████| 2/2 [00:00<00:00,  2.69it/s, Val Loss=1.0643]


Train - Loss: 1.4135, F1: 0.2273, Acc: 0.4752
Val   - Loss: 1.4103, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.8s, LR: 1.50e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 24.0% | Best F1: 0.1263 | ETA: 3.1min

Epoch 13/50


Training Epoch 13/50: 100%|██████████| 13/13 [00:02<00:00,  6.35it/s, Loss=1.3943, LR=1.50e-06]
Validation Epoch 13/50: 100%|██████████| 2/2 [00:00<00:00,  3.50it/s, Val Loss=1.0561]


Train - Loss: 1.3910, F1: 0.1913, Acc: 0.4851
Val   - Loss: 1.4090, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.6s, LR: 1.50e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 26.0% | Best F1: 0.1263 | ETA: 3.0min

Epoch 14/50


Training Epoch 14/50: 100%|██████████| 13/13 [00:02<00:00,  4.46it/s, Loss=1.4109, LR=1.50e-06]
Validation Epoch 14/50: 100%|██████████| 2/2 [00:00<00:00,  3.44it/s, Val Loss=1.0475]


Train - Loss: 1.4086, F1: 0.1906, Acc: 0.4752
Val   - Loss: 1.4078, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.5s, LR: 1.50e-06
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 28.0% | Best F1: 0.1263 | ETA: 3.0min

Epoch 15/50


Training Epoch 15/50: 100%|██████████| 13/13 [00:02<00:00,  6.16it/s, Loss=1.4521, LR=1.50e-06]
Validation Epoch 15/50: 100%|██████████| 2/2 [00:00<00:00,  2.84it/s, Val Loss=1.0385]


Train - Loss: 1.4016, F1: 0.2329, Acc: 0.4851
Val   - Loss: 1.4061, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.8s, LR: 7.50e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 30.0% | Best F1: 0.1263 | ETA: 3.0min

Epoch 16/50


Training Epoch 16/50: 100%|██████████| 13/13 [00:03<00:00,  4.24it/s, Loss=1.3971, LR=7.50e-07]
Validation Epoch 16/50: 100%|██████████| 2/2 [00:00<00:00,  3.44it/s, Val Loss=1.0342]


Train - Loss: 1.4027, F1: 0.1802, Acc: 0.4554
Val   - Loss: 1.4057, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.7s, LR: 7.50e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 32.0% | Best F1: 0.1263 | ETA: 2.9min

Epoch 17/50


Training Epoch 17/50: 100%|██████████| 13/13 [00:02<00:00,  6.15it/s, Loss=1.3913, LR=7.50e-07]
Validation Epoch 17/50: 100%|██████████| 2/2 [00:00<00:00,  3.31it/s, Val Loss=1.0306]


Train - Loss: 1.3866, F1: 0.1776, Acc: 0.4653
Val   - Loss: 1.4055, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 7.50e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 34.0% | Best F1: 0.1263 | ETA: 2.8min

Epoch 18/50


Training Epoch 18/50: 100%|██████████| 13/13 [00:02<00:00,  6.17it/s, Loss=1.4171, LR=7.50e-07]
Validation Epoch 18/50: 100%|██████████| 2/2 [00:00<00:00,  3.19it/s, Val Loss=1.0270]


Train - Loss: 1.4004, F1: 0.1540, Acc: 0.4257
Val   - Loss: 1.4049, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.8s, LR: 7.50e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 36.0% | Best F1: 0.1263 | ETA: 2.7min

Epoch 19/50


Training Epoch 19/50: 100%|██████████| 13/13 [00:02<00:00,  4.93it/s, Loss=1.4449, LR=7.50e-07]
Validation Epoch 19/50: 100%|██████████| 2/2 [00:00<00:00,  2.55it/s, Val Loss=1.0239]


Train - Loss: 1.4338, F1: 0.2335, Acc: 0.4653
Val   - Loss: 1.4043, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.4s, LR: 7.50e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 38.0% | Best F1: 0.1263 | ETA: 2.7min

Epoch 20/50


Training Epoch 20/50: 100%|██████████| 13/13 [00:02<00:00,  6.32it/s, Loss=1.3733, LR=7.50e-07]
Validation Epoch 20/50: 100%|██████████| 2/2 [00:00<00:00,  3.41it/s, Val Loss=1.0206]


Train - Loss: 1.3657, F1: 0.1872, Acc: 0.4257
Val   - Loss: 1.4038, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 7.50e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 40.0% | Best F1: 0.1263 | ETA: 2.6min

Epoch 21/50


Training Epoch 21/50: 100%|██████████| 13/13 [00:02<00:00,  4.39it/s, Loss=1.3651, LR=7.50e-07]
Validation Epoch 21/50: 100%|██████████| 2/2 [00:00<00:00,  2.47it/s, Val Loss=1.0174]


Train - Loss: 1.3834, F1: 0.1597, Acc: 0.4257
Val   - Loss: 1.4035, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.8s, LR: 3.75e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 42.0% | Best F1: 0.1263 | ETA: 2.5min

Epoch 22/50


Training Epoch 22/50: 100%|██████████| 13/13 [00:02<00:00,  5.63it/s, Loss=1.4022, LR=3.75e-07]
Validation Epoch 22/50: 100%|██████████| 2/2 [00:00<00:00,  3.33it/s, Val Loss=1.0160]


Train - Loss: 1.3939, F1: 0.1419, Acc: 0.4257
Val   - Loss: 1.4033, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.9s, LR: 3.75e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 44.0% | Best F1: 0.1263 | ETA: 2.4min

Epoch 23/50


Training Epoch 23/50: 100%|██████████| 13/13 [00:02<00:00,  5.47it/s, Loss=1.3841, LR=3.75e-07]
Validation Epoch 23/50: 100%|██████████| 2/2 [00:00<00:00,  3.40it/s, Val Loss=1.0142]


Train - Loss: 1.3605, F1: 0.2094, Acc: 0.4851
Val   - Loss: 1.4031, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.0s, LR: 3.75e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 46.0% | Best F1: 0.1263 | ETA: 2.3min

Epoch 24/50


Training Epoch 24/50: 100%|██████████| 13/13 [00:02<00:00,  4.46it/s, Loss=1.3983, LR=3.75e-07]
Validation Epoch 24/50: 100%|██████████| 2/2 [00:00<00:00,  3.30it/s, Val Loss=1.0123]


Train - Loss: 1.3885, F1: 0.1611, Acc: 0.4356
Val   - Loss: 1.4030, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.5s, LR: 3.75e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 48.0% | Best F1: 0.1263 | ETA: 2.2min

Epoch 25/50


Training Epoch 25/50: 100%|██████████| 13/13 [00:02<00:00,  5.81it/s, Loss=1.3446, LR=3.75e-07]
Validation Epoch 25/50: 100%|██████████| 2/2 [00:00<00:00,  3.26it/s, Val Loss=1.0105]


Train - Loss: 1.3522, F1: 0.2153, Acc: 0.4554
Val   - Loss: 1.4028, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.9s, LR: 3.75e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 50.0% | Best F1: 0.1263 | ETA: 2.2min

Epoch 26/50


Training Epoch 26/50: 100%|██████████| 13/13 [00:02<00:00,  5.12it/s, Loss=1.3673, LR=3.75e-07]
Validation Epoch 26/50: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s, Val Loss=1.0089]


Train - Loss: 1.3750, F1: 0.1654, Acc: 0.4554
Val   - Loss: 1.4026, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.7s, LR: 3.75e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 52.0% | Best F1: 0.1263 | ETA: 2.1min

Epoch 27/50


Training Epoch 27/50: 100%|██████████| 13/13 [00:02<00:00,  5.82it/s, Loss=1.3463, LR=3.75e-07]
Validation Epoch 27/50: 100%|██████████| 2/2 [00:00<00:00,  3.31it/s, Val Loss=1.0071]


Train - Loss: 1.3533, F1: 0.1877, Acc: 0.4752
Val   - Loss: 1.4026, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.9s, LR: 1.88e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 54.0% | Best F1: 0.1263 | ETA: 2.0min

Epoch 28/50


Training Epoch 28/50: 100%|██████████| 13/13 [00:02<00:00,  5.81it/s, Loss=1.3664, LR=1.88e-07]
Validation Epoch 28/50: 100%|██████████| 2/2 [00:00<00:00,  2.95it/s, Val Loss=1.0062]


Train - Loss: 1.3504, F1: 0.2009, Acc: 0.4851
Val   - Loss: 1.4025, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.9s, LR: 1.88e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 56.0% | Best F1: 0.1263 | ETA: 1.9min

Epoch 29/50


Training Epoch 29/50: 100%|██████████| 13/13 [00:02<00:00,  5.03it/s, Loss=1.3849, LR=1.88e-07]
Validation Epoch 29/50: 100%|██████████| 2/2 [00:00<00:00,  3.19it/s, Val Loss=1.0055]


Train - Loss: 1.3917, F1: 0.1798, Acc: 0.4554
Val   - Loss: 1.4024, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.2s, LR: 1.88e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 58.0% | Best F1: 0.1263 | ETA: 1.8min

Epoch 30/50


Training Epoch 30/50: 100%|██████████| 13/13 [00:02<00:00,  6.09it/s, Loss=1.4424, LR=1.88e-07]
Validation Epoch 30/50: 100%|██████████| 2/2 [00:00<00:00,  3.24it/s, Val Loss=1.0046]


Train - Loss: 1.4248, F1: 0.1331, Acc: 0.3960
Val   - Loss: 1.4023, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.8s, LR: 1.88e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 60.0% | Best F1: 0.1263 | ETA: 1.7min

Epoch 31/50


Training Epoch 31/50: 100%|██████████| 13/13 [00:02<00:00,  5.71it/s, Loss=1.3619, LR=1.88e-07]
Validation Epoch 31/50: 100%|██████████| 2/2 [00:00<00:00,  2.12it/s, Val Loss=1.0038]


Train - Loss: 1.3693, F1: 0.2257, Acc: 0.4950
Val   - Loss: 1.4023, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.2s, LR: 1.88e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 62.0% | Best F1: 0.1263 | ETA: 1.6min

Epoch 32/50


Training Epoch 32/50: 100%|██████████| 13/13 [00:02<00:00,  6.15it/s, Loss=1.3894, LR=1.88e-07]
Validation Epoch 32/50: 100%|██████████| 2/2 [00:00<00:00,  2.81it/s, Val Loss=1.0030]


Train - Loss: 1.3563, F1: 0.1999, Acc: 0.4653
Val   - Loss: 1.4022, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.9s, LR: 1.88e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 64.0% | Best F1: 0.1263 | ETA: 1.5min

Epoch 33/50


Training Epoch 33/50: 100%|██████████| 13/13 [00:02<00:00,  5.80it/s, Loss=1.3718, LR=1.88e-07]
Validation Epoch 33/50: 100%|██████████| 2/2 [00:00<00:00,  3.30it/s, Val Loss=1.0021]


Train - Loss: 1.3420, F1: 0.1964, Acc: 0.4752
Val   - Loss: 1.4022, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.9s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 66.0% | Best F1: 0.1263 | ETA: 1.5min

Epoch 34/50


Training Epoch 34/50: 100%|██████████| 13/13 [00:02<00:00,  4.66it/s, Loss=1.3258, LR=1.00e-07]
Validation Epoch 34/50: 100%|██████████| 2/2 [00:00<00:00,  2.20it/s, Val Loss=1.0017]


Train - Loss: 1.3371, F1: 0.1940, Acc: 0.4950
Val   - Loss: 1.4021, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.7s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 68.0% | Best F1: 0.1263 | ETA: 1.4min

Epoch 35/50


Training Epoch 35/50: 100%|██████████| 13/13 [00:02<00:00,  6.16it/s, Loss=1.3525, LR=1.00e-07]
Validation Epoch 35/50: 100%|██████████| 2/2 [00:00<00:00,  3.37it/s, Val Loss=1.0012]


Train - Loss: 1.3620, F1: 0.1889, Acc: 0.4950
Val   - Loss: 1.4021, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 70.0% | Best F1: 0.1263 | ETA: 1.3min

Epoch 36/50


Training Epoch 36/50: 100%|██████████| 13/13 [00:02<00:00,  6.20it/s, Loss=1.3600, LR=1.00e-07]
Validation Epoch 36/50: 100%|██████████| 2/2 [00:00<00:00,  3.42it/s, Val Loss=1.0008]


Train - Loss: 1.3631, F1: 0.1694, Acc: 0.4455
Val   - Loss: 1.4021, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 72.0% | Best F1: 0.1263 | ETA: 1.2min

Epoch 37/50


Training Epoch 37/50: 100%|██████████| 13/13 [00:02<00:00,  5.17it/s, Loss=1.3388, LR=1.00e-07]
Validation Epoch 37/50: 100%|██████████| 2/2 [00:00<00:00,  2.34it/s, Val Loss=1.0004]


Train - Loss: 1.3537, F1: 0.2267, Acc: 0.5050
Val   - Loss: 1.4020, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.4s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 74.0% | Best F1: 0.1263 | ETA: 1.1min

Epoch 38/50


Training Epoch 38/50: 100%|██████████| 13/13 [00:02<00:00,  5.96it/s, Loss=1.3654, LR=1.00e-07]
Validation Epoch 38/50: 100%|██████████| 2/2 [00:00<00:00,  2.66it/s, Val Loss=1.0000]


Train - Loss: 1.3715, F1: 0.1879, Acc: 0.4554
Val   - Loss: 1.4020, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.0s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 76.0% | Best F1: 0.1263 | ETA: 1.0min

Epoch 39/50


Training Epoch 39/50: 100%|██████████| 13/13 [00:02<00:00,  5.81it/s, Loss=1.4438, LR=1.00e-07]
Validation Epoch 39/50: 100%|██████████| 2/2 [00:00<00:00,  3.40it/s, Val Loss=0.9996]


Train - Loss: 1.4132, F1: 0.1813, Acc: 0.4752
Val   - Loss: 1.4020, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.8s, LR: 1.00e-07
Progress: 78.0% | Best F1: 0.1263 | ETA: 0.9min

Epoch 40/50


Training Epoch 40/50: 100%|██████████| 13/13 [00:02<00:00,  5.94it/s, Loss=1.3973, LR=1.00e-07]
Validation Epoch 40/50: 100%|██████████| 2/2 [00:00<00:00,  2.31it/s, Val Loss=0.9991]


Train - Loss: 1.3853, F1: 0.2018, Acc: 0.4554
Val   - Loss: 1.4020, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.1s, LR: 1.00e-07
Progress: 80.0% | Best F1: 0.1263 | ETA: 0.8min

Epoch 41/50


Training Epoch 41/50: 100%|██████████| 13/13 [00:02<00:00,  4.59it/s, Loss=1.3662, LR=1.00e-07]
Validation Epoch 41/50: 100%|██████████| 2/2 [00:00<00:00,  2.97it/s, Val Loss=0.9986]


Train - Loss: 1.3964, F1: 0.1708, Acc: 0.4752
Val   - Loss: 1.4020, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.5s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 82.0% | Best F1: 0.1263 | ETA: 0.7min

Epoch 42/50


Training Epoch 42/50: 100%|██████████| 13/13 [00:02<00:00,  6.01it/s, Loss=1.3392, LR=1.00e-07]
Validation Epoch 42/50: 100%|██████████| 2/2 [00:00<00:00,  2.81it/s, Val Loss=0.9983]


Train - Loss: 1.3250, F1: 0.1686, Acc: 0.4653
Val   - Loss: 1.4020, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.9s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 84.0% | Best F1: 0.1263 | ETA: 0.7min

Epoch 43/50


Training Epoch 43/50: 100%|██████████| 13/13 [00:02<00:00,  5.74it/s, Loss=1.4068, LR=1.00e-07]
Validation Epoch 43/50: 100%|██████████| 2/2 [00:00<00:00,  2.42it/s, Val Loss=0.9978]


Train - Loss: 1.3662, F1: 0.1716, Acc: 0.4752
Val   - Loss: 1.4019, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.1s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 86.0% | Best F1: 0.1263 | ETA: 0.6min

Epoch 44/50


Training Epoch 44/50: 100%|██████████| 13/13 [00:02<00:00,  6.33it/s, Loss=1.3107, LR=1.00e-07]
Validation Epoch 44/50: 100%|██████████| 2/2 [00:00<00:00,  3.45it/s, Val Loss=0.9974]


Train - Loss: 1.3240, F1: 0.2080, Acc: 0.5149
Val   - Loss: 1.4019, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 88.0% | Best F1: 0.1263 | ETA: 0.5min

Epoch 45/50


Training Epoch 45/50: 100%|██████████| 13/13 [00:02<00:00,  5.89it/s, Loss=1.3581, LR=1.00e-07]
Validation Epoch 45/50: 100%|██████████| 2/2 [00:00<00:00,  3.48it/s, Val Loss=0.9970]


Train - Loss: 1.3490, F1: 0.2111, Acc: 0.4851
Val   - Loss: 1.4019, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.8s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 90.0% | Best F1: 0.1263 | ETA: 0.4min

Epoch 46/50


Training Epoch 46/50: 100%|██████████| 13/13 [00:02<00:00,  5.87it/s, Loss=1.3613, LR=1.00e-07]
Validation Epoch 46/50: 100%|██████████| 2/2 [00:00<00:00,  2.35it/s, Val Loss=0.9965]


Train - Loss: 1.3565, F1: 0.2184, Acc: 0.5050
Val   - Loss: 1.4018, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.1s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 92.0% | Best F1: 0.1263 | ETA: 0.3min

Epoch 47/50


Training Epoch 47/50: 100%|██████████| 13/13 [00:02<00:00,  6.02it/s, Loss=1.3714, LR=1.00e-07]
Validation Epoch 47/50: 100%|██████████| 2/2 [00:00<00:00,  3.42it/s, Val Loss=0.9962]


Train - Loss: 1.3410, F1: 0.1333, Acc: 0.4752
Val   - Loss: 1.4018, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.8s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 94.0% | Best F1: 0.1263 | ETA: 0.2min

Epoch 48/50


Training Epoch 48/50: 100%|██████████| 13/13 [00:02<00:00,  6.20it/s, Loss=1.2837, LR=1.00e-07]
Validation Epoch 48/50: 100%|██████████| 2/2 [00:00<00:00,  3.38it/s, Val Loss=0.9958]


Train - Loss: 1.3205, F1: 0.1702, Acc: 0.4950
Val   - Loss: 1.4018, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 2.7s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 96.0% | Best F1: 0.1263 | ETA: 0.2min

Epoch 49/50


Training Epoch 49/50: 100%|██████████| 13/13 [00:02<00:00,  4.55it/s, Loss=1.3946, LR=1.00e-07]
Validation Epoch 49/50: 100%|██████████| 2/2 [00:00<00:00,  3.31it/s, Val Loss=0.9955]


Train - Loss: 1.3653, F1: 0.1317, Acc: 0.4158
Val   - Loss: 1.4017, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.5s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 98.0% | Best F1: 0.1263 | ETA: 0.1min

Epoch 50/50


Training Epoch 50/50: 100%|██████████| 13/13 [00:02<00:00,  5.87it/s, Loss=1.2953, LR=1.00e-07]
Validation Epoch 50/50: 100%|██████████| 2/2 [00:00<00:00,  2.71it/s, Val Loss=0.9951]


Train - Loss: 1.3490, F1: 0.1720, Acc: 0.4752
Val   - Loss: 1.4017, F1: 0.1263, Acc: 0.4615
Time  - Epoch: 3.0s, LR: 1.00e-07
Checkpoint saved: Same F1, Lower Loss - F1: 0.1263
Progress: 100.0% | Best F1: 0.1263 | ETA: 0.0min

TRAINING COMPLETED - FINETUNE STAGE
Training time: 4.2 minutes
Epochs completed: 50
Best validation F1: 0.1263 (epoch 50)
Final train F1: 0.1720
Final validation F1: 0.1263
Checkpoint saved: casme2_finetune_best_f1.pth

EXPORTING TRAINING DOCUMENTATION
Training documentation saved: training_history.json
Location: finetune_logs/training_logs/

Next: Cell 3 - FINETUNE Stage Evaluation


In [3]:
# @title Cell 3: ViT Transfer Learning Evaluation

# File: 05_01_ViT_RAF-DB_CASME2-AF.ipynb - Cell 3
# Location: experiments/05_01_ViT_RAF-DB_CASME2-AF.ipynb
# Purpose: Stage-aware evaluation for pre-training and fine-tuning with fixed class ordering

from google.colab import drive
drive.mount('/content/drive')

import os
import json
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.metrics import (
    f1_score,
    precision_recall_fscore_support,
    accuracy_score,
    confusion_matrix,
    classification_report,
    balanced_accuracy_score
)

print("=" * 70)
print(f"VIT TRANSFER LEARNING EVALUATION - STAGE: {GLOBAL_CONFIG['training_stage'].upper()}")
print("=" * 70)

# Stage-aware evaluation configuration
if GLOBAL_CONFIG['training_stage'] == 'pretrain':
    print("\nPre-training Evaluation Configuration:")
    print(f"  Dataset: RAF-DB Test Set")
    print(f"  Classes: {GLOBAL_CONFIG['num_classes']} (fixed ordering)")
    print(f"  Expected samples: ~3,280 test images")
    print(f"  Metrics: Accuracy, F1-macro, Precision, Recall")
    print(f"  Checkpoint: {GLOBAL_CONFIG['checkpoint_filename']}")
else:
    print("\nFine-tuning Evaluation Configuration:")
    print(f"  Dataset: CASME2 Apex Test Set")
    print(f"  Classes: {GLOBAL_CONFIG['num_classes']} (fixed ordering, fear=placeholder)")
    print(f"  Expected samples: ~22 test images")
    print(f"  Metrics: UAR, UF1, Accuracy, F1-macro, Balanced Accuracy")
    print(f"  Checkpoint: {GLOBAL_CONFIG['checkpoint_filename']}")
    print(f"  Note: Fear class has NO samples in test set")

# Load best checkpoint
checkpoint_path = GLOBAL_CONFIG['checkpoint_path']

print(f"\n[1] Loading checkpoint: {os.path.basename(checkpoint_path)}")

if not os.path.exists(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

checkpoint = torch.load(checkpoint_path, map_location='cpu')

print(f"    Checkpoint loaded successfully")
print(f"    Trained epochs: {checkpoint['epoch'] + 1}")
print(f"    Best validation F1: {checkpoint['best_f1']:.4f}")
print(f"    Training stage: {checkpoint.get('training_stage', 'unknown')}")

# Initialize model with same architecture
print(f"\n[2] Initializing model for evaluation...")

model = ViTTransferLearning(
    num_classes=GLOBAL_CONFIG['num_classes'],
    dropout_rate=GLOBAL_CONFIG['dropout_rate'],
    pretrained_checkpoint=None,
    freeze_encoder=False
).to(GLOBAL_CONFIG['device'])

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"    Model loaded and set to evaluation mode")
print(f"    Architecture: ViT-{GLOBAL_CONFIG['vit_variant']} Transfer Learning")
print(f"    Patch size: {GLOBAL_CONFIG['patch_size']}px")

# Load test dataset based on stage
print(f"\n[3] Loading test dataset...")

if GLOBAL_CONFIG['training_stage'] == 'pretrain':
    # Load RAF-DB test set
    metadata_df = pd.read_csv(GLOBAL_CONFIG['metadata_path'])
    test_metadata = metadata_df[metadata_df['split'] == 'test']

    test_dataset = RAFDBDataset(
        metadata_df=test_metadata,
        dataset_root=GLOBAL_CONFIG['dataset_root'],
        transform=GLOBAL_CONFIG['transform_val'],
        transfer_classes=GLOBAL_CONFIG['transfer_classes']
    )

    print(f"    RAF-DB test set loaded: {len(test_dataset)} samples")

else:
    # Load CASME2 test set
    with open(GLOBAL_CONFIG['metadata_path'], 'r') as f:
        split_metadata = json.load(f)

    test_dataset = CASME2Dataset(
        split_metadata=split_metadata,
        dataset_root=GLOBAL_CONFIG['dataset_root'],
        transform=GLOBAL_CONFIG['transform_val'],
        split='test',
        transfer_classes=GLOBAL_CONFIG['transfer_classes'],
        casme2_mapping=GLOBAL_CONFIG['casme2_mapping']
    )

    print(f"    CASME2 test set loaded: {len(test_dataset)} samples")

# Create test dataloader
test_loader = DataLoader(
    test_dataset,
    batch_size=GLOBAL_CONFIG['batch_size'],
    shuffle=False,
    num_workers=GLOBAL_CONFIG['num_workers'],
    pin_memory=True
)

print(f"    Test batches: {len(test_loader)}")

# Evaluation function with missing class handling
def evaluate_model(model, dataloader, device, transfer_classes):
    """Comprehensive model evaluation with missing class support"""
    model.eval()

    all_predictions = []
    all_labels = []
    all_sample_ids = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating")

        for batch_data in progress_bar:
            if len(batch_data) == 3:
                images, labels, sample_ids = batch_data
                all_sample_ids.extend(sample_ids)
            else:
                images, labels = batch_data

            images = images.to(device)
            labels = labels.cpu().numpy()

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()

            all_predictions.extend(predictions)
            all_labels.extend(labels)

    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)

    # Check which classes actually present in test set
    present_classes = np.unique(all_labels)
    missing_classes = [i for i in range(len(transfer_classes)) if i not in present_classes]

    print(f"\n    Test set class presence:")
    for idx, cls in enumerate(transfer_classes):
        if idx in present_classes:
            count = np.sum(all_labels == idx)
            print(f"      [{idx}] {cls}: {count} samples")
        else:
            print(f"      [{idx}] {cls}: NO SAMPLES (missing)")

    # Calculate overall metrics
    accuracy = accuracy_score(all_labels, all_predictions)

    # Balanced accuracy (handles imbalanced classes)
    balanced_acc = balanced_accuracy_score(all_labels, all_predictions)

    # Per-class metrics (only for present classes)
    precision_per_class, recall_per_class, f1_per_class, support_per_class = precision_recall_fscore_support(
        all_labels, all_predictions,
        labels=list(range(len(transfer_classes))),
        average=None,
        zero_division=0
    )

    # Macro-averaged metrics (average across all classes including missing)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_predictions,
        labels=list(range(len(transfer_classes))),
        average='macro',
        zero_division=0
    )

    # Weighted-averaged metrics (weighted by support)
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        all_labels, all_predictions,
        labels=list(range(len(transfer_classes))),
        average='weighted',
        zero_division=0
    )

    # UAR (Unweighted Average Recall) - same as recall_macro
    uar = recall_macro

    # UF1 (Unweighted F1) - same as f1_macro
    uf1 = f1_macro

    # Confusion matrix (all classes)
    cm = confusion_matrix(all_labels, all_predictions, labels=list(range(len(transfer_classes))))

    # Per-class results
    per_class_results = []
    for i, class_name in enumerate(transfer_classes):
        per_class_results.append({
            'class': class_name,
            'class_index': i,
            'precision': float(precision_per_class[i]),
            'recall': float(recall_per_class[i]),
            'f1_score': float(f1_per_class[i]),
            'support': int(support_per_class[i]),
            'present_in_test': i in present_classes
        })

    results = {
        'accuracy': float(accuracy),
        'balanced_accuracy': float(balanced_acc),
        'precision_macro': float(precision_macro),
        'recall_macro': float(recall_macro),
        'f1_macro': float(f1_macro),
        'precision_weighted': float(precision_weighted),
        'recall_weighted': float(recall_weighted),
        'f1_weighted': float(f1_weighted),
        'uar': float(uar),
        'uf1': float(uf1),
        'per_class_results': per_class_results,
        'confusion_matrix': cm.tolist(),
        'predictions': all_predictions.tolist(),
        'labels': all_labels.tolist(),
        'sample_ids': all_sample_ids if all_sample_ids else None,
        'present_classes': present_classes.tolist(),
        'missing_classes': missing_classes
    }

    return results

# Run evaluation
print(f"\n[4] Running evaluation on test set...")

eval_results = evaluate_model(
    model, test_loader, GLOBAL_CONFIG['device'],
    GLOBAL_CONFIG['transfer_classes']
)

# Display results
print("\n" + "=" * 70)
print("EVALUATION RESULTS")
print("=" * 70)

print(f"\nOverall Metrics:")
print(f"  Accuracy: {eval_results['accuracy']:.4f}")
print(f"  Balanced Accuracy: {eval_results['balanced_accuracy']:.4f}")
print(f"  Precision (macro): {eval_results['precision_macro']:.4f}")
print(f"  Recall (macro): {eval_results['recall_macro']:.4f}")
print(f"  F1-Score (macro): {eval_results['f1_macro']:.4f}")
print(f"  Precision (weighted): {eval_results['precision_weighted']:.4f}")
print(f"  Recall (weighted): {eval_results['recall_weighted']:.4f}")
print(f"  F1-Score (weighted): {eval_results['f1_weighted']:.4f}")

if GLOBAL_CONFIG['training_stage'] == 'finetune':
    print(f"\nMicro-Expression Metrics:")
    print(f"  UAR (Unweighted Average Recall): {eval_results['uar']:.4f}")
    print(f"  UF1 (Unweighted F1-Score): {eval_results['uf1']:.4f}")

print(f"\nPer-Class Performance:")
print("-" * 70)
print(f"{'Class':<15} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
print("-" * 70)

for result in eval_results['per_class_results']:
    if result['present_in_test']:
        print(f"{result['class']:<15} "
              f"{result['precision']:<12.4f} "
              f"{result['recall']:<12.4f} "
              f"{result['f1_score']:<12.4f} "
              f"{result['support']:<10}")
    else:
        print(f"{result['class']:<15} "
              f"{'N/A':<12} "
              f"{'N/A':<12} "
              f"{'N/A':<12} "
              f"{'0':<10} (missing)")

print("-" * 70)

# Confusion matrix display
print(f"\nConfusion Matrix:")
cm = np.array(eval_results['confusion_matrix'])
class_names = GLOBAL_CONFIG['transfer_classes']

print("\nPredicted →")
print(f"{'True ↓':<15}", end="")
for name in class_names:
    print(f"{name[:8]:<10}", end="")
print()

for i, true_class in enumerate(class_names):
    if i in eval_results['present_classes']:
        print(f"{true_class:<15}", end="")
        for j in range(len(class_names)):
            print(f"{cm[i][j]:<10}", end="")
        print()
    else:
        print(f"{true_class:<15}", end="")
        for j in range(len(class_names)):
            print(f"{'N/A':<10}", end="")
        print(" (missing)")

# Transfer learning gain analysis (finetune only)
if GLOBAL_CONFIG['training_stage'] == 'finetune':
    print(f"\n[5] Transfer Learning Analysis...")

    # Try to load baseline results for comparison
    baseline_path = f"{PROJECT_ROOT}/results/04_01_vit_casme2_baseline/evaluation_results/test_evaluation.json"

    if os.path.exists(baseline_path):
        print(f"    Loading baseline results from: 04_01_vit_casme2_baseline")

        with open(baseline_path, 'r') as f:
            baseline_results = json.load(f)

        baseline_f1 = baseline_results.get('f1_macro', baseline_results.get('uf1', 0.0))
        baseline_uar = baseline_results.get('uar', baseline_results.get('recall_macro', 0.0))
        baseline_acc = baseline_results.get('accuracy', 0.0)

        transfer_f1 = eval_results['f1_macro']
        transfer_uar = eval_results['uar']
        transfer_acc = eval_results['accuracy']

        f1_gain = transfer_f1 - baseline_f1
        uar_gain = transfer_uar - baseline_uar
        acc_gain = transfer_acc - baseline_acc

        f1_improvement_pct = (f1_gain / baseline_f1 * 100) if baseline_f1 > 0 else 0
        uar_improvement_pct = (uar_gain / baseline_uar * 100) if baseline_uar > 0 else 0
        acc_improvement_pct = (acc_gain / baseline_acc * 100) if baseline_acc > 0 else 0

        print(f"\n" + "=" * 70)
        print("TRANSFER LEARNING GAIN ANALYSIS")
        print("=" * 70)

        print(f"\nF1-Score Comparison:")
        print(f"  Baseline (no transfer): {baseline_f1:.4f}")
        print(f"  Transfer learning: {transfer_f1:.4f}")
        print(f"  Absolute gain: {f1_gain:+.4f}")
        print(f"  Relative improvement: {f1_improvement_pct:+.2f}%")

        print(f"\nUAR Comparison:")
        print(f"  Baseline (no transfer): {baseline_uar:.4f}")
        print(f"  Transfer learning: {transfer_uar:.4f}")
        print(f"  Absolute gain: {uar_gain:+.4f}")
        print(f"  Relative improvement: {uar_improvement_pct:+.2f}%")

        print(f"\nAccuracy Comparison:")
        print(f"  Baseline (no transfer): {baseline_acc:.4f}")
        print(f"  Transfer learning: {transfer_acc:.4f}")
        print(f"  Absolute gain: {acc_gain:+.4f}")
        print(f"  Relative improvement: {acc_improvement_pct:+.2f}%")

        transfer_analysis = {
            'baseline_f1': float(baseline_f1),
            'transfer_f1': float(transfer_f1),
            'f1_gain': float(f1_gain),
            'f1_improvement_pct': float(f1_improvement_pct),
            'baseline_uar': float(baseline_uar),
            'transfer_uar': float(transfer_uar),
            'uar_gain': float(uar_gain),
            'uar_improvement_pct': float(uar_improvement_pct),
            'baseline_acc': float(baseline_acc),
            'transfer_acc': float(transfer_acc),
            'acc_gain': float(acc_gain),
            'acc_improvement_pct': float(acc_improvement_pct),
            'baseline_source': '04_01_vit_casme2_baseline'
        }

        eval_results['transfer_learning_analysis'] = transfer_analysis

        if f1_gain > 0 and uar_gain > 0:
            print(f"\nConclusion: Transfer learning provides positive gain")
        elif f1_gain > 0 or uar_gain > 0:
            print(f"\nConclusion: Transfer learning shows mixed results")
        else:
            print(f"\nConclusion: Transfer learning shows no improvement over baseline")

    else:
        print(f"    Baseline results not found at: {baseline_path}")
        print(f"    Skipping transfer learning gain analysis")
        eval_results['transfer_learning_analysis'] = None

# Save evaluation results
print(f"\n[6] Saving evaluation results...")

eval_results_path = f"{GLOBAL_CONFIG['logs_path']}/evaluation_results/test_evaluation.json"

# Prepare comprehensive evaluation report
evaluation_report = {
    'experiment_type': f'ViT_Transfer_Learning_{GLOBAL_CONFIG["training_stage"].upper()}',
    'training_stage': GLOBAL_CONFIG['training_stage'],
    'dataset_name': GLOBAL_CONFIG['dataset_name'],
    'test_samples': len(test_dataset),
    'checkpoint_used': GLOBAL_CONFIG['checkpoint_filename'],
    'checkpoint_epoch': checkpoint['epoch'] + 1,
    'checkpoint_val_f1': float(checkpoint['best_f1']),
    'model_info': {
        'backbone': GLOBAL_CONFIG['vit_model'],
        'variant': GLOBAL_CONFIG['vit_variant'],
        'patch_size': GLOBAL_CONFIG['patch_size'],
        'input_size': GLOBAL_CONFIG['input_size'],
        'num_classes': GLOBAL_CONFIG['num_classes'],
        'transfer_classes': GLOBAL_CONFIG['transfer_classes']
    },
    'test_results': {
        'accuracy': eval_results['accuracy'],
        'balanced_accuracy': eval_results['balanced_accuracy'],
        'precision_macro': eval_results['precision_macro'],
        'recall_macro': eval_results['recall_macro'],
        'f1_macro': eval_results['f1_macro'],
        'precision_weighted': eval_results['precision_weighted'],
        'recall_weighted': eval_results['recall_weighted'],
        'f1_weighted': eval_results['f1_weighted'],
        'uar': eval_results['uar'],
        'uf1': eval_results['uf1'],
        'per_class_results': eval_results['per_class_results'],
        'confusion_matrix': eval_results['confusion_matrix'],
        'present_classes': eval_results['present_classes'],
        'missing_classes': eval_results['missing_classes']
    }
}

if GLOBAL_CONFIG['training_stage'] == 'finetune':
    evaluation_report['transfer_learning_analysis'] = eval_results.get('transfer_learning_analysis')

# Save to JSON
with open(eval_results_path, 'w') as f:
    json.dump(evaluation_report, f, indent=2)

print(f"    Evaluation results saved: test_evaluation.json")
print(f"    Location: {GLOBAL_CONFIG['logs_subdir']}/evaluation_results/")

# Summary statistics
print("\n" + "=" * 70)
print(f"EVALUATION COMPLETE - {GLOBAL_CONFIG['training_stage'].upper()} STAGE")
print("=" * 70)

print(f"\nTest Set Performance:")
print(f"  Samples evaluated: {len(test_dataset)}")
print(f"  Overall accuracy: {eval_results['accuracy']:.4f}")
print(f"  Balanced accuracy: {eval_results['balanced_accuracy']:.4f}")
print(f"  Macro F1-Score: {eval_results['f1_macro']:.4f}")
print(f"  Weighted F1-Score: {eval_results['f1_weighted']:.4f}")

if GLOBAL_CONFIG['training_stage'] == 'finetune':
    print(f"  UAR: {eval_results['uar']:.4f}")
    print(f"  UF1: {eval_results['uf1']:.4f}")

    if eval_results.get('transfer_learning_analysis'):
        tl_analysis = eval_results['transfer_learning_analysis']
        print(f"\nTransfer Learning Impact:")
        print(f"  F1 improvement: {tl_analysis['f1_improvement_pct']:+.2f}%")
        print(f"  UAR improvement: {tl_analysis['uar_improvement_pct']:+.2f}%")
        print(f"  Accuracy improvement: {tl_analysis['acc_improvement_pct']:+.2f}%")

print(f"\nMissing Classes in Test Set:")
if eval_results['missing_classes']:
    for idx in eval_results['missing_classes']:
        class_name = GLOBAL_CONFIG['transfer_classes'][idx]
        print(f"  [{idx}] {class_name}: NO SAMPLES")
else:
    print(f"  None (all classes present)")

print(f"\nResults saved to: {GLOBAL_CONFIG['logs_subdir']}/evaluation_results/")

if GLOBAL_CONFIG['training_stage'] == 'pretrain':
    print("\nNext: Run fine-tuning stage (set TRAINING_STAGE='finetune' in Cell 1)")
else:
    print("\nNext: Cell 4 - Confusion Matrix Visualization")

print("=" * 70)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
VIT TRANSFER LEARNING EVALUATION - STAGE: FINETUNE

Fine-tuning Evaluation Configuration:
  Dataset: CASME2 Apex Test Set
  Classes: 5 (fixed ordering, fear=placeholder)
  Expected samples: ~22 test images
  Metrics: UAR, UF1, Accuracy, F1-macro, Balanced Accuracy
  Checkpoint: casme2_finetune_best_f1.pth
  Note: Fear class has NO samples in test set

[1] Loading checkpoint: casme2_finetune_best_f1.pth
    Checkpoint loaded successfully
    Trained epochs: 50
    Best validation F1: 0.1263
    Training stage: finetune

[2] Initializing model for evaluation...


Some weights of the model checkpoint at google/vit-base-patch32-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ViT encoder trainable
ViT Transfer Learning: 768 -> 512 -> 128 -> 5
    Model loaded and set to evaluation mode
    Architecture: ViT-patch32 Transfer Learning
    Patch size: 32px

[3] Loading test dataset...
Mapped 15 samples from 28 total
Loaded 15 samples for test split
Fixed class ordering: ['disgust', 'happy', 'surprise', 'sad', 'fear']
Class distribution:
  [0] disgust: 7 samples
  [1] happy: 4 samples
  [2] surprise: 3 samples
  [3] sad: 1 samples
  [4] fear: NO SAMPLES
    CASME2 test set loaded: 15 samples
    Test batches: 2

[4] Running evaluation on test set...


Evaluating: 100%|██████████| 2/2 [00:09<00:00,  4.76s/it]


    Test set class presence:
      [0] disgust: 7 samples
      [1] happy: 4 samples
      [2] surprise: 3 samples
      [3] sad: 1 samples
      [4] fear: NO SAMPLES (missing)

EVALUATION RESULTS

Overall Metrics:
  Accuracy: 0.4667
  Balanced Accuracy: 0.2500
  Precision (macro): 0.0933
  Recall (macro): 0.2000
  F1-Score (macro): 0.1273
  Precision (weighted): 0.2178
  Recall (weighted): 0.4667
  F1-Score (weighted): 0.2970

Micro-Expression Metrics:
  UAR (Unweighted Average Recall): 0.2000
  UF1 (Unweighted F1-Score): 0.1273

Per-Class Performance:
----------------------------------------------------------------------
Class           Precision    Recall       F1-Score     Support   
----------------------------------------------------------------------
disgust         0.4667       1.0000       0.6364       7         
happy           0.0000       0.0000       0.0000       4         
surprise        0.0000       0.0000       0.0000       3         
sad             0.0000       0.00




In [4]:
# @title Cell 4: ViT Transfer Learning Confusion Matrix Generation

# File: 05_01_ViT_RAF-DB_CASME2-AF.ipynb - Cell 4
# Location: experiments/05_01_ViT_RAF-DB_CASME2-AF.ipynb
# Purpose: Generate confusion matrix for pretrain and finetune stages with missing class handling

from google.colab import drive
drive.mount('/content/drive')

import json
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime

print("=" * 70)
print("VIT TRANSFER LEARNING CONFUSION MATRIX GENERATION")
print("=" * 70)

# Unified output directory
OUTPUT_DIR = f"{PROJECT_ROOT}/results/05_01_transfer_learning/visualization"
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

print(f"\nOutput directory: results/05_01_transfer_learning/visualization/")

# Helper functions
def calculate_weighted_f1(per_class_results):
    """Calculate weighted F1 score"""
    total_support = sum([result['support'] for result in per_class_results])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0
    for result in per_class_results:
        weight = result['support'] / total_support
        weighted_f1 += result['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy(confusion_matrix, present_classes):
    """Calculate balanced accuracy for present classes only"""
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_acc = []

    for i in present_classes:
        if cm[i, :].sum() > 0:
            tp = cm[i, i]
            fn = cm[i, :].sum() - tp
            fp = cm[:, i].sum() - tp
            tn = cm.sum() - tp - fn - fp

            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

            class_balanced_acc = (sensitivity + specificity) / 2
            per_class_acc.append(class_balanced_acc)

    balanced_acc = np.mean(per_class_acc) if per_class_acc else 0.0

    return balanced_acc

def determine_text_color(color_value, threshold=0.5):
    """Determine optimal text color based on background intensity"""
    return 'white' if color_value > threshold else 'black'

def create_confusion_matrix(eval_data, output_path, stage):
    """Create confusion matrix visualization with missing class handling"""

    print(f"\n[{stage.upper()}] Creating confusion matrix...")

    # Extract data
    transfer_classes = eval_data['model_info']['transfer_classes']
    cm = np.array(eval_data['test_results']['confusion_matrix'], dtype=int)
    present_classes = eval_data['test_results']['present_classes']
    missing_classes = eval_data['test_results']['missing_classes']

    # Calculate 4 key metrics
    accuracy = eval_data['test_results']['accuracy']
    macro_f1 = eval_data['test_results']['f1_macro']
    weighted_f1 = eval_data['test_results']['f1_weighted']
    balanced_acc = eval_data['test_results']['balanced_accuracy']

    print(f"    Classes: {transfer_classes}")
    print(f"    Matrix shape: {cm.shape}")
    print(f"    Present classes: {present_classes}")
    print(f"    Missing classes: {missing_classes}")
    print(f"    Accuracy: {accuracy:.4f}")
    print(f"    Macro F1: {macro_f1:.4f}")
    print(f"    Weighted F1: {weighted_f1:.4f}")
    print(f"    Balanced Acc: {balanced_acc:.4f}")

    # Create figure
    fig, ax = plt.subplots(figsize=(10, 8))

    # Create base heatmap for present classes only
    cm_display = np.full_like(cm, np.nan, dtype=float)

    for i in present_classes:
        row_sum = cm[i, :].sum()
        if row_sum > 0:
            cm_display[i, :] = cm[i, :] / row_sum
        else:
            cm_display[i, :] = 0.0

    # Create heatmap with masked values for missing classes
    masked_cm = np.ma.masked_where(np.isnan(cm_display), cm_display)
    im = ax.imshow(masked_cm, interpolation='nearest', cmap='Blues', vmin=0.0, vmax=0.8)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    # Annotate cells
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            if i in present_classes:
                count = cm[i, j]
                row_sum = cm[i, :].sum()

                if row_sum > 0:
                    percentage = (count / row_sum) * 100
                    text = f"{count}\n{percentage:.1f}%"
                else:
                    text = f"{count}\nN/A"

                cell_value = cm_display[i, j] if not np.isnan(cm_display[i, j]) else 0.0
                text_color = determine_text_color(cell_value, threshold=0.4)

                ax.text(j, i, text, ha="center", va="center",
                       color=text_color, fontsize=9, fontweight='bold')
            else:
                # Missing class row
                ax.text(j, i, "N/A", ha="center", va="center",
                       color='gray', fontsize=9, fontweight='bold', style='italic')

    # Mark missing class rows with gray background
    for i in missing_classes:
        ax.axhspan(i - 0.5, i + 0.5, facecolor='lightgray', alpha=0.3, zorder=0)

    # Configure axes
    ax.set_xticks(np.arange(len(transfer_classes)))
    ax.set_yticks(np.arange(len(transfer_classes)))
    ax.set_xticklabels(transfer_classes, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(transfer_classes, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    # Add note for missing classes
    if missing_classes:
        missing_note = "Note: Missing classes in test set: " + ", ".join([transfer_classes[i] for i in missing_classes])
        ax.text(0.02, 0.02, missing_note, transform=ax.transAxes, fontsize=8,
                verticalalignment='bottom', bbox=dict(boxstyle='round',
                facecolor='lightyellow', alpha=0.8))

    # Add stage information
    dataset_label = "RAF-DB Macro" if stage == 'pretrain' else "CASME2 Micro"
    stage_text = f"Stage: {stage.upper()} ({dataset_label})"

    ax.text(0.02, 0.98, stage_text, transform=ax.transAxes, fontsize=9,
            verticalalignment='top', bbox=dict(boxstyle='round',
            facecolor='lightyellow', alpha=0.8))

    # Title with 4 key metrics
    title = f"ViT Transfer Learning Confusion Matrix - {stage.upper()}\n"
    title += f"Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  "
    title += f"Acc: {accuracy:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=20, fontweight='bold')

    # Save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"    Saved: {os.path.basename(output_path)}")

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'balanced_acc': balanced_acc
    }

# Process both stages
print("\n[1] Loading evaluation results...")

stages_to_process = []
results_summary = {}

# Check pretrain results
pretrain_json = f"{PROJECT_ROOT}/results/05_01_transfer_learning/pretrain_logs/evaluation_results/test_evaluation.json"
if os.path.exists(pretrain_json):
    stages_to_process.append(('pretrain', pretrain_json))
    print(f"    Found pretrain results: test_evaluation.json")
else:
    print(f"    Pretrain results not found (skip if not yet trained)")

# Check finetune results
finetune_json = f"{PROJECT_ROOT}/results/05_01_transfer_learning/finetune_logs/evaluation_results/test_evaluation.json"
if os.path.exists(finetune_json):
    stages_to_process.append(('finetune', finetune_json))
    print(f"    Found finetune results: test_evaluation.json")
else:
    print(f"    Finetune results not found (skip if not yet trained)")

if not stages_to_process:
    print("\nERROR: No evaluation results found!")
    print("Please run Cell 3 (evaluation) first for at least one stage.")
else:
    print(f"\n[2] Generating confusion matrices for {len(stages_to_process)} stage(s)...")

    generated_files = []

    for stage, json_path in stages_to_process:
        # Load evaluation data
        with open(json_path, 'r') as f:
            eval_data = json.load(f)

        # Generate confusion matrix
        output_path = os.path.join(OUTPUT_DIR, f"confusion_matrix_{stage}.png")
        metrics = create_confusion_matrix(eval_data, output_path, stage)

        results_summary[stage] = metrics
        generated_files.append(output_path)

    # Final summary
    print("\n" + "=" * 70)
    print("CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 70)

    print(f"\nGenerated Files:")
    for file_path in generated_files:
        print(f"  {os.path.basename(file_path)}")

    print(f"\nOutput Location:")
    print(f"  results/05_01_transfer_learning/visualization/")

    print(f"\nPerformance Summary:")
    for stage, metrics in results_summary.items():
        dataset_type = "RAF-DB Macro" if stage == 'pretrain' else "CASME2 Micro"
        print(f"\n{stage.upper()} ({dataset_type}):")
        print(f"  Macro F1: {metrics['macro_f1']:.4f}")
        print(f"  Weighted F1: {metrics['weighted_f1']:.4f}")
        print(f"  Accuracy: {metrics['accuracy']:.4f}")
        print(f"  Balanced Acc: {metrics['balanced_acc']:.4f}")

    # Performance analysis
    if 'pretrain' in results_summary and 'finetune' in results_summary:
        print(f"\n" + "=" * 70)
        print("TRANSFER LEARNING PERFORMANCE ANALYSIS")
        print("=" * 70)

        pretrain_f1 = results_summary['pretrain']['macro_f1']
        finetune_f1 = results_summary['finetune']['macro_f1']

        pretrain_acc = results_summary['pretrain']['accuracy']
        finetune_acc = results_summary['finetune']['accuracy']

        f1_drop = pretrain_f1 - finetune_f1
        f1_drop_pct = (f1_drop / pretrain_f1 * 100) if pretrain_f1 > 0 else 0

        acc_drop = pretrain_acc - finetune_acc
        acc_drop_pct = (acc_drop / pretrain_acc * 100) if pretrain_acc > 0 else 0

        print(f"\nPretrain (Macro) Performance:")
        print(f"  Macro F1: {pretrain_f1:.4f}")
        print(f"  Accuracy: {pretrain_acc:.4f}")

        print(f"\nFinetune (Micro) Performance:")
        print(f"  Macro F1: {finetune_f1:.4f}")
        print(f"  Accuracy: {finetune_acc:.4f}")

        print(f"\nPerformance Gap Analysis:")
        print(f"  F1 drop: {f1_drop:.4f} ({f1_drop_pct:.1f}%)")
        print(f"  Accuracy drop: {acc_drop:.4f} ({acc_drop_pct:.1f}%)")

        print(f"\nInterpretation:")
        if f1_drop_pct > 50:
            print(f"  Significant performance drop from macro to micro expressions")
            print(f"  This is expected due to:")
            print(f"    - Limited micro-expression training samples (~150 vs ~30k)")
            print(f"    - Higher difficulty of micro-expression recognition")
            print(f"    - Class imbalance in CASME2 dataset")
        else:
            print(f"  Transfer learning maintained reasonable performance")

    print(f"\nGeneration completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    # Next steps guidance
    if 'pretrain' in results_summary and 'finetune' not in results_summary:
        print("\nNext: Run fine-tuning stage (set TRAINING_STAGE='finetune' in Cell 1)")
    elif 'pretrain' in results_summary and 'finetune' in results_summary:
        print("\nTransfer Learning Experiment Complete!")
        print("Both stages finished: Pre-training and Fine-tuning")
        print("\nKey Findings:")
        print(f"  1. Pretrain Macro F1: {results_summary['pretrain']['macro_f1']:.4f}")
        print(f"  2. Finetune Macro F1: {results_summary['finetune']['macro_f1']:.4f}")
        print(f"  3. Performance preserved: {(results_summary['finetune']['macro_f1']/results_summary['pretrain']['macro_f1']*100):.1f}%")

    print("=" * 70)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
VIT TRANSFER LEARNING CONFUSION MATRIX GENERATION

Output directory: results/05_01_transfer_learning/visualization/

[1] Loading evaluation results...
    Found pretrain results: test_evaluation.json
    Found finetune results: test_evaluation.json

[2] Generating confusion matrices for 2 stage(s)...

[PRETRAIN] Creating confusion matrix...
    Classes: ['disgust', 'happy', 'surprise', 'sad', 'fear']
    Matrix shape: (5, 5)
    Present classes: [0, 1, 2, 3, 4]
    Missing classes: []
    Accuracy: 0.9640
    Macro F1: 0.9640
    Weighted F1: 0.9640
    Balanced Acc: 0.9640
    Saved: confusion_matrix_pretrain.png

[FINETUNE] Creating confusion matrix...
    Classes: ['disgust', 'happy', 'surprise', 'sad', 'fear']
    Matrix shape: (5, 5)
    Present classes: [0, 1, 2, 3]
    Missing classes: [4]
    Accuracy: 0.4667
    Macro F1: 0.1273
    Weighted F1: 0.29