In [1]:
# @title Cell 1: CASME II Multi-Frame Sequence PoolFormer Infrastructure Configuration

# File: 07_03_PoolFormer_CASME2_MFS_Cell1_FIXED.py
# Location: experiments/07_03_PoolFormer_CASME2-MFS-PREP.ipynb
# Purpose: PoolFormer for CASME II micro-expression recognition with multi-frame sequence strategy and face-aware preprocessing
# Fix: Corrected forward function to handle PoolFormer 4D output format

from google.colab import drive
print("=" * 60)
print("CASME II MULTI-FRAME SEQUENCE POOLFORMER WITH FACE-AWARE PREPROCESSING")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Project paths configuration
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/preprocessed_v9"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/07_03_poolformer_casme2_mfs_prep"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/07_03_poolformer_casme2_mfs_prep"

# Load CASME II v9 preprocessing metadata
PREPROCESSING_SUMMARY = f"{DATASET_ROOT}/preprocessing_summary.json"

print("CASME II Multi-Frame Sequence PoolFormer - Face-Aware Preprocessing Infrastructure")
print("=" * 60)

# Validate preprocessing metadata exists
if not os.path.exists(PREPROCESSING_SUMMARY):
    raise FileNotFoundError(f"v9 preprocessing summary not found: {PREPROCESSING_SUMMARY}")

# Load v9 preprocessing metadata
print("Loading CASME II v9 preprocessing metadata...")
with open(PREPROCESSING_SUMMARY, 'r') as f:
    preprocessing_info = json.load(f)

print(f"Dataset variant: {preprocessing_info['variant']}")
print(f"Processing date: {preprocessing_info['processing_date']}")
print(f"Preprocessing method: {preprocessing_info['preprocessing_method']}")
print(f"Total images processed: {preprocessing_info['total_processed']}")
print(f"Face detection rate: {preprocessing_info['face_detection_stats']['detection_rate']:.2%}")

# Extract preprocessing parameters
preproc_params = preprocessing_info['preprocessing_parameters']
print(f"Target size: {preproc_params['target_size']}x{preproc_params['target_size']}px")
print(f"BBox expansion: {preproc_params['bbox_expansion']}px (all directions)")

# Display split information
print(f"\nDataset split information:")
print(f"  Train samples: {preprocessing_info['splits']['train']['total_images']}")
print(f"  Validation samples: {preprocessing_info['splits']['val']['total_images']}")
print(f"  Test samples: {preprocessing_info['splits']['test']['total_images']}")

# EXPERIMENT CONFIGURATION - Multi-Frame Sequence with Face-Aware Preprocessing
# This configuration supports 4 experiment scenarios:
# 1. PoolFormer-M36 + CrossEntropy Loss
# 2. PoolFormer-M36 + Focal Loss
# 3. PoolFormer-M48 + CrossEntropy Loss
# 4. PoolFormer-M48 + Focal Loss
#
# Toggle POOLFORMER_MODEL_VARIANT for model selection: 'm36' or 'm48'
# Toggle USE_FOCAL_LOSS for loss function: False (CrossEntropy) or True (Focal)

# FOCAL LOSS CONFIGURATION - Toggle for experimentation
USE_FOCAL_LOSS = True
FOCAL_LOSS_GAMMA = 2.0

# OPTIMIZED CLASS WEIGHTS CONFIGURATION
# CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
# v9 Train distribution: [1027, 650, 325, 273, 260, 65, 13]

# CrossEntropy Loss - Optimized inverse square root frequency weights
CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.26, 1.78, 1.94, 1.99, 3.98, 8.90]

# Focal Loss - Normalized per-class alpha values
FOCAL_LOSS_ALPHA_WEIGHTS = [0.048, 0.060, 0.085, 0.093, 0.095, 0.191, 0.427]

# POOLFORMER MODEL CONFIGURATION
POOLFORMER_MODEL_VARIANT = 'm48'

# Dynamic PoolFormer model selection based on variant
if POOLFORMER_MODEL_VARIANT == 'm36':
    POOLFORMER_MODEL_NAME = 'sail/poolformer_m36'
    EXPECTED_FEATURE_DIM = 768
    MODEL_PARAMS = '56M'
    print("Using PoolFormer-M36 for efficient token mixing analysis (56M parameters)")
elif POOLFORMER_MODEL_VARIANT == 'm48':
    POOLFORMER_MODEL_NAME = 'sail/poolformer_m48'
    EXPECTED_FEATURE_DIM = 768
    MODEL_PARAMS = '73M'
    print("Using PoolFormer-M48 for enhanced micro-expression recognition (73M parameters)")
else:
    raise ValueError(f"Unsupported POOLFORMER_MODEL_VARIANT: {POOLFORMER_MODEL_VARIANT}")

# Display experiment configuration
print("\n" + "=" * 50)
print("EXPERIMENT CONFIGURATION - MULTI-FRAME SEQUENCE FACE-AWARE")
print("=" * 50)
print(f"Dataset: v9 Multi-Frame Sequence with Face-Aware Preprocessing")
print(f"Frame strategy: Multiple frames per video (dense sampling)")
print(f"Training approach: Frame-level independent learning")
print(f"Loss Function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
if USE_FOCAL_LOSS:
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Alpha Weights (per-class): {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum Validation: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Class Weights (inverse sqrt freq): {CROSSENTROPY_CLASS_WEIGHTS}")
print(f"PoolFormer Model Variant: {POOLFORMER_MODEL_VARIANT.upper()}")
print(f"  Model: {POOLFORMER_MODEL_NAME}")
print(f"  Parameters: {MODEL_PARAMS}")
print(f"  Feature Dimension: {EXPECTED_FEATURE_DIM}")
print(f"Input Resolution: 224x224px (native from v9 preprocessing)")
print(f"Image Format: Grayscale converted to RGB (3-channel)")
print("=" * 50)

# Enhanced GPU configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

# Fixed batch size configuration for large dataset
BATCH_SIZE = 16
NUM_WORKERS = 4

if 'A100' in gpu_name or 'L4' in gpu_name:
    torch.backends.cudnn.benchmark = True
    print(f"GPU optimization enabled for {gpu_name}")

print(f"Large dataset configuration: Batch size {BATCH_SIZE} (optimal for 2613 samples at 224px)")
print(f"Iterations per epoch: {2613 // BATCH_SIZE} (~163 iterations per epoch)")

# RAM preloading workers
RAM_PRELOAD_WORKERS = 32
print(f"RAM preload workers: {RAM_PRELOAD_WORKERS} (parallel image loading)")

# CASME II class mapping and analysis
CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

# Extract class distribution from v9 preprocessing metadata
print("\nLoading v9 class distribution...")
train_dist = preprocessing_info['splits']['train']['emotion_distribution']
val_dist = preprocessing_info['splits']['val']['emotion_distribution']
test_dist = preprocessing_info['splits']['test']['emotion_distribution']

# Convert to ordered list matching CASME2_CLASSES
def emotion_dist_to_list(emotion_dict, class_names):
    """Convert emotion distribution dict to ordered list"""
    return [emotion_dict.get(cls, 0) for cls in class_names]

train_dist_list = emotion_dist_to_list(train_dist, CASME2_CLASSES)
val_dist_list = emotion_dist_to_list(val_dist, CASME2_CLASSES)
test_dist_list = emotion_dist_to_list(test_dist, CASME2_CLASSES)

print(f"\nv9 Train distribution: {train_dist_list}")
print(f"v9 Val distribution: {val_dist_list}")
print(f"v9 Test distribution: {test_dist_list}")

# Apply optimized class weights based on loss function selection
if USE_FOCAL_LOSS:
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

# CASME II PoolFormer Configuration
CASME2_POOLFORMER_CONFIG = {
    # Architecture configuration
    'poolformer_model': POOLFORMER_MODEL_NAME,
    'model_variant': POOLFORMER_MODEL_VARIANT,
    'model_params': MODEL_PARAMS,
    'input_size': 224,
    'num_classes': 7,
    'dropout_rate': 0.3,
    'expected_feature_dim': EXPECTED_FEATURE_DIM,

    # Training configuration
    'learning_rate': 2e-5,
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,

    # Scheduler configuration
    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_monitor': 'validation_f1',
    'scheduler_factor': 0.5,
    'scheduler_patience': 3,
    'scheduler_min_lr': 1e-7,

    # Dataset configuration
    'dataset_version': 'v9',
    'preprocessing_method': 'face_aware_bbox_expansion',
    'frame_strategy': 'multi_frame_sequence',
    'training_approach': 'frame_level_independent',
    'inference_strategy': 'late_fusion_aggregation',

    # Loss configuration
    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS
}

# Optimized Focal Loss implementation
class OptimizedFocalLoss(nn.Module):
    """Optimized Focal Loss with per-class alpha weights"""

    def __init__(self, alpha=None, gamma=2.0):
        super(OptimizedFocalLoss, self).__init__()
        self.gamma = gamma
        if alpha is not None:
            self.alpha = torch.tensor(alpha, dtype=torch.float32)
        else:
            self.alpha = None

    def forward(self, inputs, targets):
        if self.alpha is not None and self.alpha.device != inputs.device:
            self.alpha = self.alpha.to(inputs.device)

        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.alpha is not None:
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss

        return focal_loss.mean()

# PoolFormer CASME II Baseline Model - FIXED VERSION
class PoolFormerCASME2Baseline(nn.Module):
    """PoolFormer baseline model for CASME II micro-expression recognition with corrected forward pass"""

    def __init__(self, num_classes=7, dropout_rate=0.3):
        super(PoolFormerCASME2Baseline, self).__init__()

        from transformers import PoolFormerModel

        self.poolformer = PoolFormerModel.from_pretrained(
            CASME2_POOLFORMER_CONFIG['poolformer_model']
        )

        hidden_size = CASME2_POOLFORMER_CONFIG['expected_feature_dim']

        # Classifier with 2D adaptive pooling for PoolFormer 4D output
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, num_classes)
        )

        print(f"PoolFormer feature dimension: {hidden_size}")
        print(f"Classification head: {hidden_size} -> GAP2D -> 512 -> 128 -> {num_classes}")
        print(f"Dropout rate: {dropout_rate} (balanced for large dataset)")

    def forward(self, x):
        """
        Forward pass with proper handling of PoolFormer output format

        PoolFormer returns last_hidden_state which can be:
        - 3D: [batch, seq_len, hidden_dim]
        - 4D: [batch, hidden_dim, height, width]

        This implementation handles both formats
        """
        outputs = self.poolformer(pixel_values=x)
        pooled_output = outputs.last_hidden_state

        # Handle different output formats from PoolFormer
        if pooled_output.dim() == 4:
            # 4D output: [batch, channels, H, W]
            # Apply 2D adaptive pooling
            pooled_output = self.adaptive_pool(pooled_output)
        elif pooled_output.dim() == 3:
            # 3D output: [batch, seq_len, hidden_dim]
            # Reshape to 4D for consistent processing
            batch_size, seq_len, hidden_dim = pooled_output.shape
            # Assume square spatial dimensions
            spatial_size = int(np.sqrt(seq_len))
            if spatial_size * spatial_size == seq_len:
                pooled_output = pooled_output.permute(0, 2, 1).reshape(batch_size, hidden_dim, spatial_size, spatial_size)
                pooled_output = self.adaptive_pool(pooled_output)
            else:
                # Fallback: global average pooling over sequence dimension
                pooled_output = pooled_output.mean(dim=1, keepdim=True).unsqueeze(-1)
        else:
            raise ValueError(f"Unexpected PoolFormer output shape: {pooled_output.shape}")

        # Pass through classifier
        logits = self.classifier(pooled_output)

        return logits

# Enhanced optimizer and scheduler factory
def create_optimizer_scheduler_casme2(model, config):
    """Create optimizer and scheduler for CASME II PoolFormer training"""

    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

# PoolFormer Image Processor setup for 224px input
from transformers import PoolFormerImageProcessor

print("\nSetting up PoolFormer Image Processor for 224px input...")

poolformer_processor = PoolFormerImageProcessor.from_pretrained(
    CASME2_POOLFORMER_CONFIG['poolformer_model'],
    do_resize=True,
    size={'height': 224, 'width': 224},
    do_normalize=True,
    do_rescale=True,
    do_center_crop=False
)

# Transform functions for PoolFormer with proper error handling
def poolformer_transform_train(image):
    """Training transform with PoolFormer Image Processor"""
    try:
        if not isinstance(image, Image.Image):
            raise TypeError(f"Expected PIL Image, got {type(image)}")

        inputs = poolformer_processor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values']

        # Remove batch dimension: [1, C, H, W] -> [C, H, W]
        if pixel_values.dim() == 4 and pixel_values.size(0) == 1:
            pixel_values = pixel_values.squeeze(0)

        return pixel_values

    except Exception as e:
        print(f"Error in poolformer_transform_train: {e}")
        print(f"Image type: {type(image)}")
        if isinstance(image, Image.Image):
            print(f"Image size: {image.size}")
        raise

def poolformer_transform_val(image):
    """Validation transform with PoolFormer Image Processor"""
    try:
        if not isinstance(image, Image.Image):
            raise TypeError(f"Expected PIL Image, got {type(image)}")

        inputs = poolformer_processor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values']

        # Remove batch dimension: [1, C, H, W] -> [C, H, W]
        if pixel_values.dim() == 4 and pixel_values.size(0) == 1:
            pixel_values = pixel_values.squeeze(0)

        return pixel_values

    except Exception as e:
        print(f"Error in poolformer_transform_val: {e}")
        print(f"Image type: {type(image)}")
        if isinstance(image, Image.Image):
            print(f"Image size: {image.size}")
        raise

print("PoolFormer Image Processor configured for 224px with token mixing")
print("Transform functions with enhanced error handling and validation")

# Custom Dataset class for CASME II
class CASME2Dataset(Dataset):
    """Custom dataset class for CASME II with flexible file loading"""

    def __init__(self, dataset_root, split, transform=None):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []
        self.filenames = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading CASME II {split} dataset...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        if len(all_files) > 0:
            print(f"Sample filename: {all_files[0]}")

        loaded_count = 0
        skipped_count = 0

        for filename in sorted(all_files):
            emotion_found = None
            name_without_ext = filename.rsplit('.', 1)[0]

            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                loaded_count += 1
            else:
                skipped_count += 1
                if skipped_count <= 3:
                    print(f"  Skipped (no emotion found): {filename}")

        print(f"Loaded {len(self.images)} samples for {split} split")
        if skipped_count > 0:
            print(f"  Skipped {skipped_count} files (no recognizable emotion)")

        if len(self.images) == 0:
            print(f"ERROR: No samples loaded! Check filename format and emotion labels.")
            print(f"Expected emotions in filenames: {CASME2_CLASSES}")

        self._print_distribution()

    def _print_distribution(self):
        """Print class distribution"""
        if len(self.labels) == 0:
            print("  No samples to display distribution")
            return

        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx]

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

print(f"\nDataset paths:")
print(f"Train: {DATASET_ROOT}/train")
print(f"Validation: {DATASET_ROOT}/val")
print(f"Test: {DATASET_ROOT}/test")

# Architecture validation with enhanced error handling
print("\nPoolFormer CASME II architecture validation...")

try:
    test_model = PoolFormerCASME2Baseline(num_classes=7, dropout_rate=0.3).to(device)
    test_input = torch.randn(1, 3, 224, 224).to(device)

    print(f"Test input shape: {test_input.shape}")
    test_output = test_model(test_input)
    print(f"Test output shape: {test_output.shape}")

    print(f"\nValidation successful: Output shape {test_output.shape}")
    print(f"PoolFormer {POOLFORMER_MODEL_VARIANT.upper()} with {MODEL_PARAMS} parameters")
    print(f"Token mixing with pooling operations")
    print(f"PoolFormer {POOLFORMER_MODEL_VARIANT.upper()} architecture validated successfully")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")
    import traceback
    traceback.print_exc()

# Loss function factory
def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    """
    Factory function to create loss criterion based on configuration

    Args:
        weights: Class weights for CrossEntropy
        use_focal_loss: Whether to use Focal Loss or CrossEntropy
        alpha_weights: Per-class alpha weights for Focal Loss
        gamma: Focal loss gamma parameter

    Returns:
        Loss function
    """
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

# Global configuration for training pipeline
GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': poolformer_transform_train,
    'transform_val': poolformer_transform_val,
    'poolformer_config': CASME2_POOLFORMER_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'dataset_root': DATASET_ROOT,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

# Configuration validation and summary
print("\n" + "=" * 60)
print("CASME II MULTI-FRAME SEQUENCE POOLFORMER CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: PoolFormer")
print(f"  Variant: {POOLFORMER_MODEL_VARIANT.upper()}")
print(f"  Model: {POOLFORMER_MODEL_NAME}")
print(f"  Parameters: {MODEL_PARAMS}")
print(f"  Input Resolution: 224px (native from v9 preprocessing)")
print(f"  Feature Dimension: {EXPECTED_FEATURE_DIM}")
print(f"  Token Mixing: Pooling operations (attention-free)")
print(f"  Classification Head: {EXPECTED_FEATURE_DIM} -> GAP2D -> 512 -> 128 -> 7")

print(f"\nDataset Configuration:")
print(f"  Version: {CASME2_POOLFORMER_CONFIG['dataset_version']}")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Frame strategy: {CASME2_POOLFORMER_CONFIG['frame_strategy']}")
print(f"  Training approach: {CASME2_POOLFORMER_CONFIG['training_approach']}")
print(f"  Inference strategy: {CASME2_POOLFORMER_CONFIG['inference_strategy']}")
print(f"  Weight Optimization: {'Per-class Alpha' if USE_FOCAL_LOSS else 'Inverse Sqrt Frequency'}")

print(f"\nTraining Configuration:")
print(f"  Train samples: {preprocessing_info['splits']['train']['total_images']} frames")
print(f"  Validation samples: {preprocessing_info['splits']['val']['total_images']} frames")
print(f"  Test samples: {preprocessing_info['splits']['test']['total_images']} frames")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {CASME2_POOLFORMER_CONFIG['learning_rate']}")
print(f"  Dropout rate: {CASME2_POOLFORMER_CONFIG['dropout_rate']}")

print("\nNext: Cell 2 - Dataset Loading and Multi-Frame PoolFormer Training Pipeline")

CASME II MULTI-FRAME SEQUENCE POOLFORMER WITH FACE-AWARE PREPROCESSING

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...
CASME II Multi-Frame Sequence PoolFormer - Face-Aware Preprocessing Infrastructure
Loading CASME II v9 preprocessing metadata...
Dataset variant: MFS
Processing date: 2025-10-19T08:20:12.098301
Preprocessing method: face_bbox_expansion_all_directions
Total images processed: 2774
Face detection rate: 100.00%
Target size: 224x224px
BBox expansion: 20px (all directions)

Dataset split information:
  Train samples: 2613
  Validation samples: 78
  Test samples: 83
Using PoolFormer-M48 for enhanced micro-expression recognition (73M parameters)

EXPERIMENT CONFIGURATION - MULTI-FRAME SEQUENCE FACE-AWARE
Dataset: v9 Multi-Frame Sequence with Face-Aware Preprocessing
Frame strategy: Multiple frames per video (dense sampling)
Training approach: Frame-level independent learning
Loss Function: Focal Los

preprocessor_config.json:   0%|          | 0.00/283 [00:00<?, ?B/s]

PoolFormer Image Processor configured for 224px with token mixing
Transform functions with enhanced error handling and validation

Dataset paths:
Train: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v9/train
Validation: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v9/val
Test: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v9/test

PoolFormer CASME II architecture validation...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/294M [00:00<?, ?B/s]

PoolFormer feature dimension: 768
Classification head: 768 -> GAP2D -> 512 -> 128 -> 7
Dropout rate: 0.3 (balanced for large dataset)
Test input shape: torch.Size([1, 3, 224, 224])
Test output shape: torch.Size([1, 7])

Validation successful: Output shape torch.Size([1, 7])
PoolFormer M48 with 73M parameters
Token mixing with pooling operations
PoolFormer M48 architecture validated successfully

CASME II MULTI-FRAME SEQUENCE POOLFORMER CONFIGURATION COMPLETE
Loss Configuration:
  Function: Optimized Focal Loss
  Gamma: 2.0
  Per-class Alpha: [0.048, 0.06, 0.085, 0.093, 0.095, 0.191, 0.427]
  Alpha Sum: 0.999

Model Configuration:
  Architecture: PoolFormer
  Variant: M48
  Model: sail/poolformer_m48
  Parameters: 73M
  Input Resolution: 224px (native from v9 preprocessing)
  Feature Dimension: 768
  Token Mixing: Pooling operations (attention-free)
  Classification Head: 768 -> GAP2D -> 512 -> 128 -> 7

Dataset Configuration:
  Version: v9
  Classes: 7
  Frame strategy: multi_frame_seq

In [2]:
# @title Cell 2: CASME II Multi-Frame Sequence PoolFormer Training Pipeline

# File: 07_03_PoolFormer_CASME2_MFS_Cell2_FIXED.py
# Location: experiments/07_03_PoolFormer_CASME2-MFS-PREP.ipynb
# Purpose: Enhanced training pipeline for CASME II Multi-Frame Sequence PoolFormer with optimized RAM caching
# Fix: Added custom collate function for PoolFormer processor compatibility

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import shutil
import tempfile

print("CASME II Multi-Frame Sequence PoolFormer Training Pipeline")
print("=" * 70)
print(f"Loss Function: {'Optimized Focal Loss' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'CrossEntropy Loss'}")
if CASME2_POOLFORMER_CONFIG['use_focal_loss']:
    print(f"Focal Loss Parameters:")
    print(f"  Gamma: {CASME2_POOLFORMER_CONFIG['focal_loss_gamma']}")
    print(f"  Per-class Alpha: {CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights']}")
    print(f"  Alpha Sum: {sum(CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights']):.3f}")
else:
    print(f"CrossEntropy Parameters:")
    print(f"  Optimized Class Weights: {CASME2_POOLFORMER_CONFIG['crossentropy_class_weights']}")
print(f"Dataset version: {CASME2_POOLFORMER_CONFIG['dataset_version']}")
print(f"Frame strategy: {CASME2_POOLFORMER_CONFIG['frame_strategy']}")
print(f"Training approach: {CASME2_POOLFORMER_CONFIG['training_approach']}")
print(f"PoolFormer variant: {CASME2_POOLFORMER_CONFIG['model_variant'].upper()}")
print(f"Model parameters: {CASME2_POOLFORMER_CONFIG['model_params']}")
print(f"Training epochs: {CASME2_POOLFORMER_CONFIG['num_epochs']}")
print(f"Batch size: {CASME2_POOLFORMER_CONFIG['batch_size']}")
print(f"Scheduler patience: {CASME2_POOLFORMER_CONFIG['scheduler_patience']}")

# Enhanced CASME II Dataset with optimized RAM caching for large dataset
class CASME2DatasetTraining(Dataset):
    """Enhanced CASME II dataset for training with RAM caching optimization for large dataset"""

    def __init__(self, dataset_root, split, transform=None, use_ram_cache=True):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.filenames = []
        self.cached_images = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading CASME II {split} dataset for training...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        if len(all_files) > 0:
            print(f"Sample filename: {all_files[0]}")

        loaded_count = 0
        skipped_count = 0

        for filename in sorted(all_files):
            emotion_found = None
            name_without_ext = filename.rsplit('.', 1)[0]

            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                loaded_count += 1
            else:
                skipped_count += 1
                if skipped_count <= 3:
                    print(f"  Skipped (no emotion found): {filename}")

        print(f"Loaded {len(self.images)} CASME II {split} samples")
        if skipped_count > 0:
            print(f"  Skipped {skipped_count} files (no recognizable emotion)")

        if len(self.images) == 0:
            print(f"ERROR: No samples loaded! Check filename format and emotion labels.")
            print(f"Expected emotions in filenames: {CASME2_CLASSES}")

        self._print_distribution()

        if self.use_ram_cache and len(self.images) > 0:
            self._preload_to_ram()

    def _print_distribution(self):
        """Print class distribution"""
        if len(self.labels) == 0:
            print("  No samples to display distribution")
            return

        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        """RAM preloading with parallel loading for training efficiency"""
        if len(self.images) == 0:
            print(f"Skipping RAM preload: No images to load")
            return

        print(f"Preloading {len(self.images)} {self.split} images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            """Load single image with error handling"""
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (224, 224), (128, 128, 128)), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc=f"Loading {self.split} to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * 224 * 224 * 3 * 4 / 1e9
        print(f"{self.split.upper()} RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (224, 224), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx]

# Custom collate function for PoolFormer processor compatibility
def poolformer_collate_fn(batch):
    """
    Custom collate function to handle PoolFormer processor output

    PoolFormer processor may return tensors with unexpected shapes or additional metadata
    This function ensures proper batching of (image, label, filename) tuples

    Args:
        batch: List of tuples from dataset __getitem__

    Returns:
        tuple: (batched_images, batched_labels, filenames_list)
    """
    images = []
    labels = []
    filenames = []

    for item in batch:
        # Ensure we extract exactly 3 elements from each batch item
        if len(item) == 3:
            img, lbl, fname = item
            images.append(img)
            labels.append(lbl)
            filenames.append(fname)
        else:
            # Log unexpected batch structure but continue processing
            print(f"Warning: Unexpected batch item with {len(item)} elements, skipping...")
            continue

    # Validate we have data to batch
    if len(images) == 0:
        raise ValueError("No valid samples in batch after filtering")

    # Stack images into batch tensor
    # PoolFormer expects [batch_size, channels, height, width]
    images = torch.stack(images, dim=0)

    # Convert labels to tensor
    labels = torch.tensor(labels, dtype=torch.long)

    return images, labels, filenames

# Enhanced metrics calculation with comprehensive error handling
def calculate_metrics_safe_robust(predictions, labels, class_names, average='macro'):
    """
    Calculate metrics with enhanced error handling and validation

    Args:
        predictions: Predicted labels
        labels: True labels
        class_names: List of class names
        average: Averaging method for metrics

    Returns:
        dict: Computed metrics
    """
    try:
        predictions = np.array(predictions)
        labels = np.array(labels)

        if len(predictions) == 0 or len(labels) == 0:
            return {
                'accuracy': 0.0,
                'precision': 0.0,
                'recall': 0.0,
                'f1_score': 0.0
            }

        if len(predictions) != len(labels):
            print(f"Warning: Prediction and label length mismatch: {len(predictions)} vs {len(labels)}")
            min_len = min(len(predictions), len(labels))
            predictions = predictions[:min_len]
            labels = labels[:min_len]

        accuracy = accuracy_score(labels, predictions)

        unique_labels = np.unique(np.concatenate([labels, predictions]))
        labels_present = [i for i in range(len(class_names)) if i in unique_labels]

        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions,
            labels=labels_present,
            average=average,
            zero_division=0
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1)
        }

    except Exception as e:
        print(f"Error in metrics calculation: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }

# Enhanced training epoch with comprehensive validation
def train_epoch(model, dataloader, criterion, optimizer, device, epoch, total_epochs):
    """Enhanced training epoch with comprehensive validation and progress tracking"""
    model.train()
    running_loss = 0.0
    all_predictions = []
    all_labels = []
    all_filenames = []
    batch_count = 0

    pbar = tqdm(dataloader, desc=f"Train Epoch {epoch+1}/{total_epochs}")

    for batch_data in pbar:
        try:
            # Unpack batch data with validation
            if len(batch_data) != 3:
                print(f"Warning: Expected 3 values from dataloader, got {len(batch_data)}")
                print(f"Batch data types: {[type(x) for x in batch_data]}")
                continue

            images, labels, filenames = batch_data

            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            outputs = model(images)

            if outputs is None or torch.isnan(outputs).any() or torch.isinf(outputs).any():
                print(f"Warning: Invalid model outputs detected at batch {batch_count}")
                continue

            loss = criterion(outputs, labels)

            if torch.isnan(loss) or torch.isinf(loss):
                print(f"Warning: Invalid loss detected at batch {batch_count}")
                continue

            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            running_loss += loss.item()
            batch_count += 1

            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_filenames.extend(filenames)

            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        except Exception as e:
            print(f"Error in training batch {batch_count}: {e}")
            continue

    avg_loss = running_loss / max(batch_count, 1)

    metrics = calculate_metrics_safe_robust(
        all_predictions, all_labels, CASME2_CLASSES, average='macro'
    )

    return avg_loss, metrics, all_filenames

# Enhanced validation epoch
def validate_epoch(model, dataloader, criterion, device, epoch, total_epochs):
    """Enhanced validation epoch with comprehensive metrics"""
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_labels = []
    all_filenames = []
    batch_count = 0

    pbar = tqdm(dataloader, desc=f"Val Epoch {epoch+1}/{total_epochs}")

    with torch.no_grad():
        for images, labels, filenames in pbar:
            try:
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)

                outputs = model(images)

                if outputs is None or torch.isnan(outputs).any() or torch.isinf(outputs).any():
                    print(f"Warning: Invalid validation outputs at batch {batch_count}")
                    continue

                loss = criterion(outputs, labels)

                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Warning: Invalid validation loss at batch {batch_count}")
                    continue

                running_loss += loss.item()
                batch_count += 1

                _, predicted = torch.max(outputs, 1)

                all_predictions.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_filenames.extend(filenames)

                pbar.set_postfix({'loss': f'{loss.item():.4f}'})

            except Exception as e:
                print(f"Error in validation batch {batch_count}: {e}")
                continue

    avg_loss = running_loss / max(batch_count, 1)

    metrics = calculate_metrics_safe_robust(
        all_predictions, all_labels, CASME2_CLASSES, average='macro'
    )

    return avg_loss, metrics, all_filenames

# Enhanced atomic checkpoint saving
def save_checkpoint_robust(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                          checkpoint_root, best_metrics, config):
    """Enhanced atomic checkpoint saving with validation"""
    try:
        checkpoint_data = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'train_metrics': train_metrics,
            'val_metrics': val_metrics,
            'best_metrics': best_metrics,
            'config': config
        }

        checkpoint_path = os.path.join(checkpoint_root, 'casme2_poolformer_mfs_best_f1.pth')

        with tempfile.NamedTemporaryFile(mode='wb', delete=False, dir=checkpoint_root, suffix='.tmp') as tmp_file:
            torch.save(checkpoint_data, tmp_file.name)
            tmp_path = tmp_file.name

        if os.path.exists(checkpoint_path):
            backup_path = checkpoint_path + '.backup'
            shutil.copy2(checkpoint_path, backup_path)

        shutil.move(tmp_path, checkpoint_path)

        if os.path.exists(checkpoint_path + '.backup'):
            os.remove(checkpoint_path + '.backup')

        checkpoint_size = os.path.getsize(checkpoint_path) / (1024 * 1024)

        if checkpoint_size < 10:
            print(f"Warning: Checkpoint size unusually small ({checkpoint_size:.1f}MB)")
            return None

        return checkpoint_path

    except Exception as e:
        print(f"ERROR: Checkpoint save failed: {e}")
        import traceback
        traceback.print_exc()

        if 'tmp_path' in locals() and os.path.exists(tmp_path):
            os.remove(tmp_path)

        return None

# JSON serialization helper
def safe_json_serialize(obj):
    """Convert numpy types to Python native types for JSON serialization"""
    if isinstance(obj, dict):
        return {key: safe_json_serialize(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [safe_json_serialize(item) for item in obj]
    elif isinstance(obj, (np.integer, np.int64, np.int32)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64, np.float32)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, torch.Tensor):
        return obj.cpu().numpy().tolist()
    else:
        return obj

# Dataset loading with RAM caching
print("\n" + "=" * 70)
print("LOADING CASME II DATASETS WITH RAM CACHING")
print("=" * 70)

train_dataset = CASME2DatasetTraining(
    dataset_root=GLOBAL_CONFIG_CASME2['dataset_root'],
    split='train',
    transform=GLOBAL_CONFIG_CASME2['transform_train'],
    use_ram_cache=True
)

val_dataset = CASME2DatasetTraining(
    dataset_root=GLOBAL_CONFIG_CASME2['dataset_root'],
    split='val',
    transform=GLOBAL_CONFIG_CASME2['transform_val'],
    use_ram_cache=True
)

# DataLoader with custom collate function for PoolFormer compatibility
train_loader = DataLoader(
    train_dataset,
    batch_size=GLOBAL_CONFIG_CASME2['batch_size'],
    shuffle=True,
    num_workers=GLOBAL_CONFIG_CASME2['num_workers'],
    pin_memory=True,
    persistent_workers=True if GLOBAL_CONFIG_CASME2['num_workers'] > 0 else False,
    collate_fn=poolformer_collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=GLOBAL_CONFIG_CASME2['batch_size'],
    shuffle=False,
    num_workers=GLOBAL_CONFIG_CASME2['num_workers'],
    pin_memory=True,
    persistent_workers=True if GLOBAL_CONFIG_CASME2['num_workers'] > 0 else False,
    collate_fn=poolformer_collate_fn
)

print(f"\nDataset loading completed:")
print(f"Train samples: {len(train_dataset)}, batches: {len(train_loader)}")
print(f"Validation samples: {len(val_dataset)}, batches: {len(val_loader)}")
print(f"Custom collate function applied for PoolFormer processor compatibility")

# Model initialization
print("\n" + "=" * 70)
print("INITIALIZING POOLFORMER MODEL")
print("=" * 70)

model = PoolFormerCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_POOLFORMER_CONFIG['dropout_rate']
).to(GLOBAL_CONFIG_CASME2['device'])

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: PoolFormerCASME2Baseline")
print(f"Variant: {POOLFORMER_MODEL_VARIANT.upper()}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Feature dimension: {EXPECTED_FEATURE_DIM}")

# Optimizer and criterion setup
optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_POOLFORMER_CONFIG
)

criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
    weights=GLOBAL_CONFIG_CASME2['class_weights'],
    use_focal_loss=CASME2_POOLFORMER_CONFIG['use_focal_loss'],
    alpha_weights=CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights'],
    gamma=CASME2_POOLFORMER_CONFIG['focal_loss_gamma']
)

print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Learning rate: {CASME2_POOLFORMER_CONFIG['learning_rate']}")
print(f"Weight decay: {CASME2_POOLFORMER_CONFIG['weight_decay']}")
print(f"Scheduler: ReduceLROnPlateau (patience={CASME2_POOLFORMER_CONFIG['scheduler_patience']})")

# Training initialization
print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70)

training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

start_time = time.time()

for epoch in range(CASME2_POOLFORMER_CONFIG['num_epochs']):
    epoch_start_time = time.time()

    print(f"\n{'='*70}")
    print(f"Epoch {epoch+1}/{CASME2_POOLFORMER_CONFIG['num_epochs']}")
    print(f"{'='*70}")

    # Training phase
    train_loss, train_metrics, train_filenames = train_epoch(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_POOLFORMER_CONFIG['num_epochs']
    )

    # Validation phase
    val_loss, val_metrics, val_filenames = validate_epoch(
        model, val_loader, criterion,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_POOLFORMER_CONFIG['num_epochs']
    )

    # Update scheduler
    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    # Record training history
    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    # Print epoch summary
    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    # Enhanced multi-criteria checkpoint saving logic
    save_model = False
    improvement_reason = ""

    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        best_model_path = save_checkpoint_robust(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics, GLOBAL_CONFIG_CASME2['checkpoint_root'],
            best_metrics, CASME2_POOLFORMER_CONFIG
        )

        if best_model_path:
            print(f"New best model: {improvement_reason} - F1: {best_metrics['f1']:.4f}")
        else:
            print(f"Warning: Failed to save checkpoint despite improvement")

    # Progress tracking
    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * CASME2_POOLFORMER_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / CASME2_POOLFORMER_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min")

# Training completion
total_time = time.time() - start_time
actual_epochs = CASME2_POOLFORMER_CONFIG['num_epochs']

print("\n" + "=" * 70)
print("CASME II MULTI-FRAME SEQUENCE POOLFORMER TRAINING COMPLETED")
print("=" * 70)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {actual_epochs}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")

# Enhanced training documentation export
results_dir = GLOBAL_CONFIG_CASME2['results_root']
os.makedirs(f"{results_dir}/training_logs", exist_ok=True)

training_history_path = f"{results_dir}/training_logs/casme2_poolformer_mfs_training_history.json"

print("\nExporting enhanced training documentation...")

try:
    training_summary = {
        'experiment_type': 'CASME2_PoolFormer_MultiFrameSequence',
        'experiment_configuration': {
            'dataset_version': CASME2_POOLFORMER_CONFIG['dataset_version'],
            'frame_strategy': CASME2_POOLFORMER_CONFIG['frame_strategy'],
            'training_approach': CASME2_POOLFORMER_CONFIG['training_approach'],
            'inference_strategy': CASME2_POOLFORMER_CONFIG['inference_strategy'],
            'loss_function': 'Optimized Focal Loss' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'CrossEntropy',
            'weight_approach': 'Per-class Alpha (sum=1.0)' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'Inverse Sqrt Frequency',
            'focal_loss_gamma': CASME2_POOLFORMER_CONFIG['focal_loss_gamma'],
            'focal_loss_alpha_weights': CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights'],
            'crossentropy_class_weights': CASME2_POOLFORMER_CONFIG['crossentropy_class_weights'],
            'poolformer_model': CASME2_POOLFORMER_CONFIG['poolformer_model'],
            'model_variant': CASME2_POOLFORMER_CONFIG['model_variant'],
            'model_params': CASME2_POOLFORMER_CONFIG['model_params']
        },
        'training_history': safe_json_serialize(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(actual_epochs),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'config': safe_json_serialize(CASME2_POOLFORMER_CONFIG),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'model_checkpoint': 'casme2_poolformer_mfs_best_f1.pth',
        'dataset_info': {
            'name': 'CASME_II',
            'version': CASME2_POOLFORMER_CONFIG['dataset_version'],
            'frame_strategy': CASME2_POOLFORMER_CONFIG['frame_strategy'],
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': 7,
            'class_names': CASME2_CLASSES
        },
        'architecture_info': {
            'model_type': 'PoolFormerCASME2Baseline',
            'backbone': CASME2_POOLFORMER_CONFIG['poolformer_model'],
            'variant': CASME2_POOLFORMER_CONFIG['model_variant'],
            'input_size': f"{CASME2_POOLFORMER_CONFIG['input_size']}x{CASME2_POOLFORMER_CONFIG['input_size']}",
            'expected_feature_dim': CASME2_POOLFORMER_CONFIG['expected_feature_dim'],
            'classification_head': f"{CASME2_POOLFORMER_CONFIG['expected_feature_dim']}->GAP->512->128->7",
            'token_mixing': 'pooling_operations'
        },
        'enhanced_features': {
            'hardened_checkpoint_system': True,
            'atomic_checkpoint_save': True,
            'checkpoint_validation': True,
            'model_output_validation': True,
            'enhanced_error_handling': True,
            'multi_criteria_checkpoint_logic': True,
            'memory_optimized_training': True,
            'ram_caching': True,
            'attention_free_token_mixing': True,
            'custom_collate_function': True,
            'poolformer_processor_compatibility': True
        }
    }

    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Enhanced training documentation saved successfully: {training_history_path}")
    print(f"Experiment details: {training_summary['experiment_configuration']['loss_function']} loss")
    if CASME2_POOLFORMER_CONFIG['use_focal_loss']:
        print(f"  Gamma: {CASME2_POOLFORMER_CONFIG['focal_loss_gamma']}, Alpha Sum: {sum(CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights']):.3f}")
    print(f"Model variant: {CASME2_POOLFORMER_CONFIG['model_variant'].upper()}")
    print(f"Model parameters: {CASME2_POOLFORMER_CONFIG['model_params']}")
    print(f"Dataset version: {CASME2_POOLFORMER_CONFIG['dataset_version']}")

except Exception as e:
    print(f"Warning: Could not save training documentation: {e}")
    print("Training completed successfully, but documentation export failed")

# Enhanced memory cleanup
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\nNext: Cell 3 - CASME II Multi-Frame Sequence PoolFormer Evaluation")
print("Enhanced training pipeline completed successfully!")

CASME II Multi-Frame Sequence PoolFormer Training Pipeline
Loss Function: Optimized Focal Loss
Focal Loss Parameters:
  Gamma: 2.0
  Per-class Alpha: [0.048, 0.06, 0.085, 0.093, 0.095, 0.191, 0.427]
  Alpha Sum: 0.999
Dataset version: v9
Frame strategy: multi_frame_sequence
Training approach: frame_level_independent
PoolFormer variant: M48
Model parameters: 73M
Training epochs: 50
Batch size: 16
Scheduler patience: 3

LOADING CASME II DATASETS WITH RAM CACHING
Loading CASME II train dataset for training...
Found 2613 image files in directory
Sample filename: sub13_EP01_01_apex_p+1_others.jpg
Loaded 2613 CASME II train samples
  others: 1027 samples (39.3%)
  disgust: 650 samples (24.9%)
  happiness: 325 samples (12.4%)
  repression: 273 samples (10.4%)
  surprise: 260 samples (10.0%)
  sadness: 65 samples (2.5%)
  fear: 13 samples (0.5%)
Preloading 2613 train images to RAM with 32 workers...


Loading train to RAM: 100%|██████████| 2613/2613 [00:47<00:00, 54.49it/s]


TRAIN RAM caching completed: 2613/2613 images, ~1.57GB
Loading CASME II val dataset for training...
Found 78 image files in directory
Sample filename: sub01_EP03_02_onset_others.jpg
Loaded 78 CASME II val samples
  others: 30 samples (38.5%)
  disgust: 18 samples (23.1%)
  happiness: 9 samples (11.5%)
  repression: 9 samples (11.5%)
  surprise: 6 samples (7.7%)
  sadness: 3 samples (3.8%)
  fear: 3 samples (3.8%)
Preloading 78 val images to RAM with 32 workers...


Loading val to RAM: 100%|██████████| 78/78 [00:01<00:00, 49.52it/s]


VAL RAM caching completed: 78/78 images, ~0.05GB

Dataset loading completed:
Train samples: 2613, batches: 164
Validation samples: 78, batches: 5
Custom collate function applied for PoolFormer processor compatibility

INITIALIZING POOLFORMER MODEL
PoolFormer feature dimension: 768
Classification head: 768 -> GAP2D -> 512 -> 128 -> 7
Dropout rate: 0.3 (balanced for large dataset)
Model: PoolFormerCASME2Baseline
Variant: M48
Total parameters: 73,163,207
Trainable parameters: 73,163,207
Feature dimension: 768
Scheduler: ReduceLROnPlateau monitoring validation_f1
Using Optimized Focal Loss with gamma=2.0
Per-class alpha weights: [0.048, 0.06, 0.085, 0.093, 0.095, 0.191, 0.427]
Alpha sum: 0.999
Optimizer: AdamW
Learning rate: 1e-05
Weight decay: 1e-05
Scheduler: ReduceLROnPlateau (patience=3)

STARTING TRAINING

Epoch 1/50


Train Epoch 1/50: 100%|██████████| 164/164 [00:43<00:00,  3.81it/s, loss=0.1260]
Val Epoch 1/50: 100%|██████████| 5/5 [00:00<00:00,  6.53it/s, loss=0.1134]


Train - Loss: 0.2384, F1: 0.2362, Acc: 0.3360
Val   - Loss: 0.1445, F1: 0.1843, Acc: 0.2692
Time  - Epoch: 43.9s, LR: 1.00e-05
New best model: Higher F1 - F1: 0.1843
Progress: 2.0% | Best F1: 0.1843 | ETA: 37.7min

Epoch 2/50


Train Epoch 2/50: 100%|██████████| 164/164 [00:41<00:00,  3.97it/s, loss=0.0638]
Val Epoch 2/50: 100%|██████████| 5/5 [00:00<00:00, 11.04it/s, loss=0.1235]


Train - Loss: 0.0628, F1: 0.4612, Acc: 0.5561
Val   - Loss: 0.1425, F1: 0.1549, Acc: 0.2692
Time  - Epoch: 41.7s, LR: 1.00e-05
Progress: 4.0% | Best F1: 0.1843 | ETA: 35.1min

Epoch 3/50


Train Epoch 3/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0005]
Val Epoch 3/50: 100%|██████████| 5/5 [00:00<00:00, 10.67it/s, loss=0.1335]


Train - Loss: 0.0324, F1: 0.6635, Acc: 0.7290
Val   - Loss: 0.1744, F1: 0.1052, Acc: 0.2308
Time  - Epoch: 42.1s, LR: 1.00e-05
Progress: 6.0% | Best F1: 0.1843 | ETA: 33.9min

Epoch 4/50


Train Epoch 4/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0040]
Val Epoch 4/50: 100%|██████████| 5/5 [00:00<00:00, 10.58it/s, loss=0.1668]


Train - Loss: 0.0187, F1: 0.7581, Acc: 0.8025
Val   - Loss: 0.2052, F1: 0.1502, Acc: 0.3077
Time  - Epoch: 41.9s, LR: 1.00e-05
Progress: 8.0% | Best F1: 0.1843 | ETA: 32.9min

Epoch 5/50


Train Epoch 5/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0001]
Val Epoch 5/50: 100%|██████████| 5/5 [00:00<00:00, 11.01it/s, loss=0.1480]


Train - Loss: 0.0131, F1: 0.8395, Acc: 0.8595
Val   - Loss: 0.2050, F1: 0.1493, Acc: 0.2692
Time  - Epoch: 41.9s, LR: 5.00e-06
Progress: 10.0% | Best F1: 0.1843 | ETA: 32.1min

Epoch 6/50


Train Epoch 6/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0010]
Val Epoch 6/50: 100%|██████████| 5/5 [00:00<00:00, 10.85it/s, loss=0.1524]


Train - Loss: 0.0077, F1: 0.9046, Acc: 0.9215
Val   - Loss: 0.2185, F1: 0.1267, Acc: 0.2821
Time  - Epoch: 42.0s, LR: 5.00e-06
Progress: 12.0% | Best F1: 0.1843 | ETA: 31.3min

Epoch 7/50


Train Epoch 7/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0000]
Val Epoch 7/50: 100%|██████████| 5/5 [00:00<00:00, 10.99it/s, loss=0.1939]


Train - Loss: 0.0052, F1: 0.9294, Acc: 0.9430
Val   - Loss: 0.2546, F1: 0.1264, Acc: 0.2692
Time  - Epoch: 41.9s, LR: 5.00e-06
Progress: 14.0% | Best F1: 0.1843 | ETA: 30.5min

Epoch 8/50


Train Epoch 8/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0023]
Val Epoch 8/50: 100%|██████████| 5/5 [00:00<00:00, 10.93it/s, loss=0.1993]


Train - Loss: 0.0035, F1: 0.9411, Acc: 0.9545
Val   - Loss: 0.2669, F1: 0.1244, Acc: 0.3077
Time  - Epoch: 42.0s, LR: 5.00e-06
Progress: 16.0% | Best F1: 0.1843 | ETA: 29.7min

Epoch 9/50


Train Epoch 9/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0000]
Val Epoch 9/50: 100%|██████████| 5/5 [00:00<00:00, 10.67it/s, loss=0.2356]


Train - Loss: 0.0025, F1: 0.9441, Acc: 0.9694
Val   - Loss: 0.2776, F1: 0.1766, Acc: 0.3077
Time  - Epoch: 41.9s, LR: 2.50e-06
Progress: 18.0% | Best F1: 0.1843 | ETA: 29.0min

Epoch 10/50


Train Epoch 10/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0027]
Val Epoch 10/50: 100%|██████████| 5/5 [00:00<00:00, 10.58it/s, loss=0.2123]


Train - Loss: 0.0019, F1: 0.9533, Acc: 0.9759
Val   - Loss: 0.2745, F1: 0.1603, Acc: 0.3205
Time  - Epoch: 41.9s, LR: 2.50e-06
Progress: 20.0% | Best F1: 0.1843 | ETA: 28.2min

Epoch 11/50


Train Epoch 11/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0016]
Val Epoch 11/50: 100%|██████████| 5/5 [00:00<00:00, 10.87it/s, loss=0.2298]


Train - Loss: 0.0019, F1: 0.9673, Acc: 0.9759
Val   - Loss: 0.2825, F1: 0.1535, Acc: 0.3077
Time  - Epoch: 41.9s, LR: 2.50e-06
Progress: 22.0% | Best F1: 0.1843 | ETA: 27.5min

Epoch 12/50


Train Epoch 12/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0001]
Val Epoch 12/50: 100%|██████████| 5/5 [00:00<00:00, 10.97it/s, loss=0.2307]


Train - Loss: 0.0014, F1: 0.9692, Acc: 0.9801
Val   - Loss: 0.3074, F1: 0.1896, Acc: 0.3590
Time  - Epoch: 41.9s, LR: 2.50e-06
New best model: Higher F1 - F1: 0.1896
Progress: 24.0% | Best F1: 0.1896 | ETA: 27.0min

Epoch 13/50


Train Epoch 13/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0000]
Val Epoch 13/50: 100%|██████████| 5/5 [00:00<00:00, 10.65it/s, loss=0.2223]


Train - Loss: 0.0012, F1: 0.9877, Acc: 0.9843
Val   - Loss: 0.2918, F1: 0.1529, Acc: 0.3077
Time  - Epoch: 42.1s, LR: 2.50e-06
Progress: 26.0% | Best F1: 0.1896 | ETA: 26.3min

Epoch 14/50


Train Epoch 14/50: 100%|██████████| 164/164 [00:41<00:00,  3.93it/s, loss=0.0003]
Val Epoch 14/50: 100%|██████████| 5/5 [00:00<00:00, 10.97it/s, loss=0.2293]


Train - Loss: 0.0011, F1: 0.9868, Acc: 0.9866
Val   - Loss: 0.2998, F1: 0.1616, Acc: 0.3077
Time  - Epoch: 42.2s, LR: 2.50e-06
Progress: 28.0% | Best F1: 0.1896 | ETA: 25.6min

Epoch 15/50


Train Epoch 15/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 15/50: 100%|██████████| 5/5 [00:00<00:00, 10.76it/s, loss=0.2469]


Train - Loss: 0.0012, F1: 0.9744, Acc: 0.9843
Val   - Loss: 0.3095, F1: 0.1572, Acc: 0.3205
Time  - Epoch: 42.1s, LR: 2.50e-06
Progress: 30.0% | Best F1: 0.1896 | ETA: 24.8min

Epoch 16/50


Train Epoch 16/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0001]
Val Epoch 16/50: 100%|██████████| 5/5 [00:00<00:00, 11.21it/s, loss=0.2669]


Train - Loss: 0.0012, F1: 0.9852, Acc: 0.9839
Val   - Loss: 0.3221, F1: 0.1584, Acc: 0.3205
Time  - Epoch: 42.0s, LR: 1.25e-06
Progress: 32.0% | Best F1: 0.1896 | ETA: 24.1min

Epoch 17/50


Train Epoch 17/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0000]
Val Epoch 17/50: 100%|██████████| 5/5 [00:00<00:00, 11.11it/s, loss=0.2691]


Train - Loss: 0.0013, F1: 0.9805, Acc: 0.9851
Val   - Loss: 0.3217, F1: 0.1699, Acc: 0.3205
Time  - Epoch: 42.0s, LR: 1.25e-06
Progress: 34.0% | Best F1: 0.1896 | ETA: 23.4min

Epoch 18/50


Train Epoch 18/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 18/50: 100%|██████████| 5/5 [00:00<00:00, 11.14it/s, loss=0.2532]


Train - Loss: 0.0009, F1: 0.9883, Acc: 0.9878
Val   - Loss: 0.3239, F1: 0.1495, Acc: 0.3205
Time  - Epoch: 42.1s, LR: 1.25e-06
Progress: 36.0% | Best F1: 0.1896 | ETA: 22.7min

Epoch 19/50


Train Epoch 19/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0207]
Val Epoch 19/50: 100%|██████████| 5/5 [00:00<00:00, 11.12it/s, loss=0.2539]


Train - Loss: 0.0009, F1: 0.9928, Acc: 0.9916
Val   - Loss: 0.3291, F1: 0.1213, Acc: 0.2949
Time  - Epoch: 42.1s, LR: 1.25e-06
Progress: 38.0% | Best F1: 0.1896 | ETA: 21.9min

Epoch 20/50


Train Epoch 20/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0014]
Val Epoch 20/50: 100%|██████████| 5/5 [00:00<00:00, 11.25it/s, loss=0.2401]


Train - Loss: 0.0009, F1: 0.9912, Acc: 0.9897
Val   - Loss: 0.3127, F1: 0.1753, Acc: 0.3333
Time  - Epoch: 42.0s, LR: 1.00e-06
Progress: 40.0% | Best F1: 0.1896 | ETA: 21.2min

Epoch 21/50


Train Epoch 21/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 21/50: 100%|██████████| 5/5 [00:00<00:00, 11.06it/s, loss=0.2479]


Train - Loss: 0.0009, F1: 0.9826, Acc: 0.9897
Val   - Loss: 0.3162, F1: 0.1561, Acc: 0.3205
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 42.0% | Best F1: 0.1896 | ETA: 20.5min

Epoch 22/50


Train Epoch 22/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0001]
Val Epoch 22/50: 100%|██████████| 5/5 [00:00<00:00, 11.16it/s, loss=0.2345]


Train - Loss: 0.0009, F1: 0.9817, Acc: 0.9878
Val   - Loss: 0.3153, F1: 0.1825, Acc: 0.3333
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 44.0% | Best F1: 0.1896 | ETA: 19.8min

Epoch 23/50


Train Epoch 23/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 23/50: 100%|██████████| 5/5 [00:00<00:00, 11.09it/s, loss=0.2600]


Train - Loss: 0.0007, F1: 0.9903, Acc: 0.9897
Val   - Loss: 0.3191, F1: 0.1749, Acc: 0.3462
Time  - Epoch: 42.0s, LR: 1.00e-06
Progress: 46.0% | Best F1: 0.1896 | ETA: 19.1min

Epoch 24/50


Train Epoch 24/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0002]
Val Epoch 24/50: 100%|██████████| 5/5 [00:00<00:00, 11.19it/s, loss=0.2562]


Train - Loss: 0.0007, F1: 0.9897, Acc: 0.9916
Val   - Loss: 0.3163, F1: 0.1716, Acc: 0.3333
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 48.0% | Best F1: 0.1896 | ETA: 18.4min

Epoch 25/50


Train Epoch 25/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 25/50: 100%|██████████| 5/5 [00:00<00:00, 11.05it/s, loss=0.2634]


Train - Loss: 0.0006, F1: 0.9889, Acc: 0.9908
Val   - Loss: 0.3205, F1: 0.1652, Acc: 0.3205
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 50.0% | Best F1: 0.1896 | ETA: 17.6min

Epoch 26/50


Train Epoch 26/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 26/50: 100%|██████████| 5/5 [00:00<00:00, 11.00it/s, loss=0.2621]


Train - Loss: 0.0010, F1: 0.9793, Acc: 0.9908
Val   - Loss: 0.3180, F1: 0.1652, Acc: 0.3205
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 52.0% | Best F1: 0.1896 | ETA: 16.9min

Epoch 27/50


Train Epoch 27/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 27/50: 100%|██████████| 5/5 [00:00<00:00, 11.05it/s, loss=0.2621]


Train - Loss: 0.0006, F1: 0.9875, Acc: 0.9920
Val   - Loss: 0.3223, F1: 0.1442, Acc: 0.3077
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 54.0% | Best F1: 0.1896 | ETA: 16.2min

Epoch 28/50


Train Epoch 28/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0000]
Val Epoch 28/50: 100%|██████████| 5/5 [00:00<00:00, 10.62it/s, loss=0.2653]


Train - Loss: 0.0005, F1: 0.9920, Acc: 0.9923
Val   - Loss: 0.3209, F1: 0.1426, Acc: 0.2949
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 56.0% | Best F1: 0.1896 | ETA: 15.5min

Epoch 29/50


Train Epoch 29/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0002]
Val Epoch 29/50: 100%|██████████| 5/5 [00:00<00:00, 10.89it/s, loss=0.2685]


Train - Loss: 0.0006, F1: 0.9870, Acc: 0.9908
Val   - Loss: 0.3174, F1: 0.1548, Acc: 0.3077
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 58.0% | Best F1: 0.1896 | ETA: 14.8min

Epoch 30/50


Train Epoch 30/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0000]
Val Epoch 30/50: 100%|██████████| 5/5 [00:00<00:00, 10.80it/s, loss=0.2695]


Train - Loss: 0.0006, F1: 0.9889, Acc: 0.9931
Val   - Loss: 0.3301, F1: 0.1024, Acc: 0.2564
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 60.0% | Best F1: 0.1896 | ETA: 14.1min

Epoch 31/50


Train Epoch 31/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0000]
Val Epoch 31/50: 100%|██████████| 5/5 [00:00<00:00, 10.98it/s, loss=0.2712]


Train - Loss: 0.0006, F1: 0.9885, Acc: 0.9920
Val   - Loss: 0.3279, F1: 0.1612, Acc: 0.3077
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 62.0% | Best F1: 0.1896 | ETA: 13.4min

Epoch 32/50


Train Epoch 32/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0006]
Val Epoch 32/50: 100%|██████████| 5/5 [00:00<00:00, 10.18it/s, loss=0.2493]


Train - Loss: 0.0007, F1: 0.9918, Acc: 0.9927
Val   - Loss: 0.3217, F1: 0.1738, Acc: 0.3205
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 64.0% | Best F1: 0.1896 | ETA: 12.7min

Epoch 33/50


Train Epoch 33/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0004]
Val Epoch 33/50: 100%|██████████| 5/5 [00:00<00:00, 10.82it/s, loss=0.2494]


Train - Loss: 0.0006, F1: 0.9850, Acc: 0.9897
Val   - Loss: 0.3204, F1: 0.1565, Acc: 0.2949
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 66.0% | Best F1: 0.1896 | ETA: 12.0min

Epoch 34/50


Train Epoch 34/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0000]
Val Epoch 34/50: 100%|██████████| 5/5 [00:00<00:00, 10.83it/s, loss=0.2573]


Train - Loss: 0.0005, F1: 0.9881, Acc: 0.9916
Val   - Loss: 0.3285, F1: 0.1560, Acc: 0.3077
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 68.0% | Best F1: 0.1896 | ETA: 11.3min

Epoch 35/50


Train Epoch 35/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0005]
Val Epoch 35/50: 100%|██████████| 5/5 [00:00<00:00, 10.49it/s, loss=0.2799]


Train - Loss: 0.0003, F1: 0.9952, Acc: 0.9962
Val   - Loss: 0.3340, F1: 0.1581, Acc: 0.3205
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 70.0% | Best F1: 0.1896 | ETA: 10.6min

Epoch 36/50


Train Epoch 36/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0000]
Val Epoch 36/50: 100%|██████████| 5/5 [00:00<00:00, 10.74it/s, loss=0.2821]


Train - Loss: 0.0004, F1: 0.9940, Acc: 0.9939
Val   - Loss: 0.3358, F1: 0.1789, Acc: 0.3462
Time  - Epoch: 42.0s, LR: 1.00e-06
Progress: 72.0% | Best F1: 0.1896 | ETA: 9.9min

Epoch 37/50


Train Epoch 37/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 37/50: 100%|██████████| 5/5 [00:00<00:00, 10.84it/s, loss=0.2932]


Train - Loss: 0.0005, F1: 0.9913, Acc: 0.9943
Val   - Loss: 0.3347, F1: 0.1644, Acc: 0.3333
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 74.0% | Best F1: 0.1896 | ETA: 9.2min

Epoch 38/50


Train Epoch 38/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0002]
Val Epoch 38/50: 100%|██████████| 5/5 [00:00<00:00, 10.99it/s, loss=0.2803]


Train - Loss: 0.0003, F1: 0.9964, Acc: 0.9962
Val   - Loss: 0.3411, F1: 0.1625, Acc: 0.3333
Time  - Epoch: 42.0s, LR: 1.00e-06
Progress: 76.0% | Best F1: 0.1896 | ETA: 8.4min

Epoch 39/50


Train Epoch 39/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 39/50: 100%|██████████| 5/5 [00:00<00:00, 10.75it/s, loss=0.2743]


Train - Loss: 0.0006, F1: 0.9883, Acc: 0.9939
Val   - Loss: 0.3361, F1: 0.1658, Acc: 0.3333
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 78.0% | Best F1: 0.1896 | ETA: 7.7min

Epoch 40/50


Train Epoch 40/50: 100%|██████████| 164/164 [00:41<00:00,  3.92it/s, loss=0.0000]
Val Epoch 40/50: 100%|██████████| 5/5 [00:00<00:00, 10.82it/s, loss=0.2943]


Train - Loss: 0.0003, F1: 0.9964, Acc: 0.9962
Val   - Loss: 0.3483, F1: 0.1643, Acc: 0.3205
Time  - Epoch: 42.3s, LR: 1.00e-06
Progress: 80.0% | Best F1: 0.1896 | ETA: 7.0min

Epoch 41/50


Train Epoch 41/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0001]
Val Epoch 41/50: 100%|██████████| 5/5 [00:00<00:00, 10.58it/s, loss=0.2830]


Train - Loss: 0.0003, F1: 0.9972, Acc: 0.9969
Val   - Loss: 0.3396, F1: 0.1889, Acc: 0.3462
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 82.0% | Best F1: 0.1896 | ETA: 6.3min

Epoch 42/50


Train Epoch 42/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 42/50: 100%|██████████| 5/5 [00:00<00:00, 10.66it/s, loss=0.2893]


Train - Loss: 0.0004, F1: 0.9962, Acc: 0.9958
Val   - Loss: 0.3500, F1: 0.1656, Acc: 0.3333
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 84.0% | Best F1: 0.1896 | ETA: 5.6min

Epoch 43/50


Train Epoch 43/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 43/50: 100%|██████████| 5/5 [00:00<00:00, 10.67it/s, loss=0.2908]


Train - Loss: 0.0002, F1: 0.9955, Acc: 0.9954
Val   - Loss: 0.3473, F1: 0.1594, Acc: 0.3205
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 86.0% | Best F1: 0.1896 | ETA: 4.9min

Epoch 44/50


Train Epoch 44/50: 100%|██████████| 164/164 [00:41<00:00,  3.94it/s, loss=0.0000]
Val Epoch 44/50: 100%|██████████| 5/5 [00:00<00:00, 10.72it/s, loss=0.2938]


Train - Loss: 0.0002, F1: 0.9982, Acc: 0.9973
Val   - Loss: 0.3500, F1: 0.1728, Acc: 0.3462
Time  - Epoch: 42.1s, LR: 1.00e-06
Progress: 88.0% | Best F1: 0.1896 | ETA: 4.2min

Epoch 45/50


Train Epoch 45/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0000]
Val Epoch 45/50: 100%|██████████| 5/5 [00:00<00:00, 10.81it/s, loss=0.2702]


Train - Loss: 0.0002, F1: 0.9965, Acc: 0.9969
Val   - Loss: 0.3420, F1: 0.1658, Acc: 0.3333
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 90.0% | Best F1: 0.1896 | ETA: 3.5min

Epoch 46/50


Train Epoch 46/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0001]
Val Epoch 46/50: 100%|██████████| 5/5 [00:00<00:00, 10.57it/s, loss=0.2740]


Train - Loss: 0.0002, F1: 0.9964, Acc: 0.9969
Val   - Loss: 0.3549, F1: 0.1588, Acc: 0.3077
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 92.0% | Best F1: 0.1896 | ETA: 2.8min

Epoch 47/50


Train Epoch 47/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0000]
Val Epoch 47/50: 100%|██████████| 5/5 [00:00<00:00, 10.67it/s, loss=0.2948]


Train - Loss: 0.0003, F1: 0.9968, Acc: 0.9962
Val   - Loss: 0.3574, F1: 0.1445, Acc: 0.3077
Time  - Epoch: 42.0s, LR: 1.00e-06
Progress: 94.0% | Best F1: 0.1896 | ETA: 2.1min

Epoch 48/50


Train Epoch 48/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0000]
Val Epoch 48/50: 100%|██████████| 5/5 [00:00<00:00, 11.05it/s, loss=0.2966]


Train - Loss: 0.0002, F1: 0.9976, Acc: 0.9973
Val   - Loss: 0.3526, F1: 0.1652, Acc: 0.3333
Time  - Epoch: 42.0s, LR: 1.00e-06
Progress: 96.0% | Best F1: 0.1896 | ETA: 1.4min

Epoch 49/50


Train Epoch 49/50: 100%|██████████| 164/164 [00:41<00:00,  3.96it/s, loss=0.0000]
Val Epoch 49/50: 100%|██████████| 5/5 [00:00<00:00, 10.88it/s, loss=0.2780]


Train - Loss: 0.0003, F1: 0.9931, Acc: 0.9946
Val   - Loss: 0.3474, F1: 0.1658, Acc: 0.3333
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 98.0% | Best F1: 0.1896 | ETA: 0.7min

Epoch 50/50


Train Epoch 50/50: 100%|██████████| 164/164 [00:41<00:00,  3.95it/s, loss=0.0000]
Val Epoch 50/50: 100%|██████████| 5/5 [00:00<00:00, 10.89it/s, loss=0.2998]


Train - Loss: 0.0002, F1: 0.9970, Acc: 0.9969
Val   - Loss: 0.3670, F1: 0.1191, Acc: 0.2949
Time  - Epoch: 41.9s, LR: 1.00e-06
Progress: 100.0% | Best F1: 0.1896 | ETA: 0.0min

CASME II MULTI-FRAME SEQUENCE POOLFORMER TRAINING COMPLETED
Training time: 35.2 minutes
Epochs completed: 50
Best validation F1: 0.1896 (epoch 12)
Final train F1: 0.9970
Final validation F1: 0.1191

Exporting enhanced training documentation...
Enhanced training documentation saved successfully: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/07_03_poolformer_casme2_mfs_prep/training_logs/casme2_poolformer_mfs_training_history.json
Experiment details: Optimized Focal Loss loss
  Gamma: 2.0, Alpha Sum: 0.999
Model variant: M48
Model parameters: 73M
Dataset version: v9

Next: Cell 3 - CASME II Multi-Frame Sequence PoolFormer Evaluation
Enhanced training pipeline completed successfully!


In [3]:
# @title Cell 3: CASME II PoolFormer Evaluation with Dual Dataset Support

# File: 07_03_PoolFormer_CASME2_MFS_Cell3.py
# Location: experiments/07_03_PoolFormer_CASME2-MFS-PREP.ipynb
# Purpose: Comprehensive evaluation framework with support for AF (v7) and KFS (v8) test datasets

import os
import time
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict

from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from concurrent.futures import ThreadPoolExecutor
import warnings
warnings.filterwarnings('ignore')

# =====================================================
# DUAL DATASET EVALUATION CONFIGURATION
# =====================================================
# Configure which test datasets to evaluate:
# 'v7' = Apex Frame preprocessing (28 samples, frame-level evaluation)
# 'v8' = Key Frame Sequence preprocessing (84 frames -> 28 videos with late fusion)

EVALUATE_DATASETS = ['v7', 'v8']  # Can be ['v7'], ['v8'], or ['v7', 'v8']

print("CASME II PoolFormer Evaluation Framework with Dual Dataset Support")
print("=" * 60)
print(f"Datasets to evaluate: {EVALUATE_DATASETS}")
print("=" * 60)

# =====================================================
# DATASET CONFIGURATION FUNCTION
# =====================================================

def get_test_dataset_config(version, project_root):
    """
    Get test dataset configuration based on version

    Args:
        version: 'v7' (AF) or 'v8' (KFS)
        project_root: Project root path

    Returns:
        dict: Configuration for selected test dataset
    """
    if version == 'v7':
        config = {
            'version': 'v7',
            'variant': 'AF',
            'dataset_path': f"{project_root}/datasets/processed_casme2/preprocessed_v7",
            'preprocessing_summary': 'preprocessing_summary.json',
            'description': 'Apex Frame with Face-Aware Preprocessing',
            'expected_samples': 28,
            'frame_strategy': 'apex_frame',
            'evaluation_mode': 'frame_level',
            'aggregation': None
        }
    elif version == 'v8':
        config = {
            'version': 'v8',
            'variant': 'KFS',
            'dataset_path': f"{project_root}/datasets/processed_casme2/preprocessed_v8",
            'preprocessing_summary': 'preprocessing_summary.json',
            'description': 'Key Frame Sequence with Face-Aware Preprocessing',
            'expected_frames': 84,
            'expected_videos': 28,
            'frame_strategy': 'key_frame_sequence',
            'frame_types': ['onset', 'apex', 'offset'],
            'evaluation_mode': 'video_level',
            'aggregation': 'late_fusion'
        }
    else:
        raise ValueError(f"Invalid version: {version}. Must be 'v7' or 'v8'")

    return config

# =====================================================
# VIDEO ID EXTRACTION FOR KFS LATE FUSION
# =====================================================

def extract_video_id_from_filename(filename):
    """
    Extract video ID from KFS filename by removing frame type suffix

    Expected format: sub01_EP02_01f_happiness_onset.jpg
    Video ID: sub01_EP02_01f_happiness

    Args:
        filename: Image filename with frame type

    Returns:
        str: Video ID without frame type
    """
    # Remove file extension
    name_without_ext = filename.rsplit('.', 1)[0]

    # Remove frame type suffix (onset, apex, offset)
    for frame_type in ['onset', 'apex', 'offset']:
        if name_without_ext.endswith(f'_{frame_type}'):
            video_id = name_without_ext.rsplit(f'_{frame_type}', 1)[0]
            return video_id

    # If no frame type found, return as is
    return name_without_ext

# =====================================================
# ENHANCED TEST DATASET CLASS
# =====================================================

class CASME2DatasetEvaluation(Dataset):
    """Enhanced CASME II test dataset with comprehensive evaluation support"""

    def __init__(self, dataset_root, split, transform=None, use_ram_cache=True):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.filenames = []
        self.emotions = []
        self.video_ids = []
        self.cached_images = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading CASME II {split} dataset for evaluation...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        loaded_count = 0

        for filename in sorted(all_files):
            emotion_found = None
            name_without_ext = filename.rsplit('.', 1)[0]

            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                video_id = extract_video_id_from_filename(filename)

                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                self.emotions.append(emotion_found)
                self.video_ids.append(video_id)
                loaded_count += 1

        print(f"Loaded {len(self.images)} CASME II {split} samples for evaluation")
        self._print_evaluation_distribution()

        if self.use_ram_cache and len(self.images) > 0:
            self._preload_to_ram_evaluation()

    def _print_evaluation_distribution(self):
        """Print comprehensive distribution for evaluation analysis"""
        if len(self.labels) == 0:
            print("No test samples found!")
            return

        label_counts = {}
        unique_videos = set(self.video_ids)

        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        print("Test set class distribution:")
        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

        print(f"Unique video IDs: {len(unique_videos)}")

        missing_classes = []
        for i, class_name in enumerate(CASME2_CLASSES):
            if i not in label_counts:
                missing_classes.append(class_name)

        if missing_classes:
            print(f"Missing classes in test set: {missing_classes}")

    def _preload_to_ram_evaluation(self):
        """RAM preloading optimized for evaluation"""
        if len(self.images) == 0:
            return

        print(f"Preloading {len(self.images)} test images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (224, 224), (128, 128, 128)), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc="Loading test set to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * 224 * 224 * 3 * 4 / 1e9
        print(f"TEST RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (224, 224), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx]

# =====================================================
# MODEL LOADING
# =====================================================

def load_trained_model_casme2(checkpoint_path, device):
    """Load trained PoolFormer model from checkpoint with enhanced validation"""
    print(f"\nValidating checkpoint availability...")
    print(f"Expected checkpoint: {os.path.basename(checkpoint_path)}")
    print(f"Full path: {checkpoint_path}")

    # Check if checkpoint directory exists
    checkpoint_dir = os.path.dirname(checkpoint_path)
    if not os.path.exists(checkpoint_dir):
        print(f"\nERROR: Checkpoint directory not found!")
        print(f"Directory: {checkpoint_dir}")
        print("\nTroubleshooting:")
        print("1. Make sure Cell 2 (Training) has been executed successfully")
        print("2. Check if training completed without errors")
        print("3. Verify the checkpoint was saved during training")
        raise FileNotFoundError(f"Checkpoint directory does not exist: {checkpoint_dir}")

    # List available checkpoints in directory
    available_checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]

    if not os.path.exists(checkpoint_path):
        print(f"\nERROR: Checkpoint file not found!")
        print(f"Expected: {os.path.basename(checkpoint_path)}")

        if available_checkpoints:
            print(f"\nAvailable checkpoints in directory:")
            for ckpt in available_checkpoints:
                ckpt_path = os.path.join(checkpoint_dir, ckpt)
                size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
                print(f"  - {ckpt} ({size_mb:.1f} MB)")
            print("\nPossible solutions:")
            print("1. Check if training saved checkpoint with different name")
            print("2. Re-run Cell 2 (Training) to generate checkpoint")
        else:
            print(f"\nNo checkpoints found in directory!")
            print("\nRequired actions:")
            print("1. Execute Cell 2 (Training Pipeline) first")
            print("2. Wait for training to complete (~2.5-3 hours)")
            print("3. Verify checkpoint is saved successfully")
            print("4. Then run this Cell 3 (Evaluation)")

        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # Validate checkpoint file size
    checkpoint_size = os.path.getsize(checkpoint_path) / (1024 * 1024)
    print(f"Checkpoint found: {checkpoint_size:.1f} MB")

    if checkpoint_size < 10:
        print(f"WARNING: Checkpoint size is unusually small ({checkpoint_size:.1f} MB)")
        print("This might indicate a corrupted or incomplete checkpoint")

    # Load checkpoint
    print(f"Loading checkpoint from disk...")
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
    except Exception as e:
        print(f"ERROR: Failed to load checkpoint: {e}")
        print("\nPossible causes:")
        print("1. Corrupted checkpoint file")
        print("2. Incompatible PyTorch version")
        print("3. Disk I/O error during save")
        raise

    # Initialize model
    print(f"Initializing PoolFormer model...")
    model = PoolFormerCASME2Baseline(
        num_classes=7,
        dropout_rate=CASME2_POOLFORMER_CONFIG['dropout_rate']
    ).to(device)

    # Load model weights
    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except Exception as e:
        print(f"ERROR: Failed to load model weights: {e}")
        print("\nPossible causes:")
        print("1. Model architecture mismatch")
        print("2. Checkpoint from different model variant")
        raise

    model.eval()

    print(f"Model loaded successfully from epoch {checkpoint['epoch']}")
    print(f"Best validation F1: {checkpoint['best_metrics']['f1']:.4f}")

    training_info = {
        'best_epoch': checkpoint['epoch'],
        'best_val_f1': checkpoint['best_metrics']['f1'],
        'best_val_loss': checkpoint['best_metrics']['loss'],
        'best_val_accuracy': checkpoint['best_metrics']['accuracy']
    }

    return model, training_info

# =====================================================
# FRAME-LEVEL INFERENCE (for v7 AF)
# =====================================================

def run_frame_level_inference(model, dataloader, device):
    """Run frame-level inference for AF evaluation"""
    model.eval()

    all_predictions = []
    all_labels = []
    all_filenames = []
    all_probabilities = []

    print("Running frame-level inference...")
    start_time = time.time()

    with torch.no_grad():
        for images, labels, filenames in tqdm(dataloader, desc="Frame-level inference"):
            images = images.to(device, non_blocking=True)

            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)

            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_filenames.extend(filenames)
            all_probabilities.extend(probabilities.cpu().numpy())

    inference_time = time.time() - start_time

    print(f"Frame-level inference completed in {inference_time:.2f}s")
    print(f"Processed {len(all_predictions)} frames")

    return {
        'predictions': np.array(all_predictions),
        'labels': np.array(all_labels),
        'filenames': all_filenames,
        'probabilities': np.array(all_probabilities),
        'inference_time': inference_time,
        'evaluation_mode': 'frame_level'
    }

# =====================================================
# VIDEO-LEVEL INFERENCE WITH LATE FUSION (for v8 KFS)
# =====================================================

def run_video_level_inference_late_fusion(model, dataloader, device):
    """Run video-level inference with late fusion for KFS evaluation"""
    model.eval()

    frame_predictions = []
    frame_labels = []
    frame_filenames = []
    frame_probabilities = []
    frame_video_ids = []

    print("Running frame-level predictions for late fusion...")
    start_time = time.time()

    with torch.no_grad():
        for images, labels, filenames in tqdm(dataloader, desc="Frame predictions"):
            images = images.to(device, non_blocking=True)

            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)

            _, predicted = torch.max(outputs, 1)

            frame_predictions.extend(predicted.cpu().numpy())
            frame_labels.extend(labels.numpy())
            frame_filenames.extend(filenames)
            frame_probabilities.extend(probabilities.cpu().numpy())

            for filename in filenames:
                video_id = extract_video_id_from_filename(filename)
                frame_video_ids.append(video_id)

    print(f"\nAggregating frame predictions to video level...")

    video_data = defaultdict(lambda: {
        'frame_predictions': [],
        'frame_probabilities': [],
        'frame_filenames': [],
        'true_label': None
    })

    for i, video_id in enumerate(frame_video_ids):
        video_data[video_id]['frame_predictions'].append(frame_predictions[i])
        video_data[video_id]['frame_probabilities'].append(frame_probabilities[i])
        video_data[video_id]['frame_filenames'].append(frame_filenames[i])
        if video_data[video_id]['true_label'] is None:
            video_data[video_id]['true_label'] = frame_labels[i]

    video_predictions = []
    video_labels = []
    video_ids_list = []

    for video_id, data in video_data.items():
        avg_probabilities = np.mean(data['frame_probabilities'], axis=0)
        video_prediction = np.argmax(avg_probabilities)

        video_predictions.append(video_prediction)
        video_labels.append(data['true_label'])
        video_ids_list.append(video_id)

    inference_time = time.time() - start_time

    print(f"Late fusion completed in {inference_time:.2f}s")
    print(f"Aggregated {len(frame_predictions)} frames into {len(video_predictions)} videos")

    return {
        'predictions': np.array(video_predictions),
        'labels': np.array(video_labels),
        'filenames': video_ids_list,
        'probabilities': None,
        'inference_time': inference_time,
        'evaluation_mode': 'video_level',
        'kfs_late_fusion_info': {
            'total_frames': len(frame_predictions),
            'total_videos': len(video_predictions),
            'aggregation_method': 'average_probabilities'
        }
    }

# =====================================================
# COMPREHENSIVE METRICS CALCULATION
# =====================================================

def calculate_comprehensive_metrics(inference_results):
    """Calculate comprehensive evaluation metrics"""
    predictions = inference_results['predictions']
    labels = inference_results['labels']
    filenames = inference_results['filenames']

    unique_labels = np.unique(labels)
    missing_classes = [i for i in range(len(CASME2_CLASSES)) if i not in unique_labels]
    available_classes = [i for i in range(len(CASME2_CLASSES)) if i in unique_labels]

    accuracy = accuracy_score(labels, predictions)

    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions,
        labels=available_classes,
        average='macro',
        zero_division=0
    )

    cm = confusion_matrix(labels, predictions, labels=list(range(len(CASME2_CLASSES))))

    per_class_metrics = {}

    for i, class_name in enumerate(CASME2_CLASSES):
        if i in available_classes:
            class_mask = (labels == i)

            if np.sum(class_mask) > 0:
                # Create binary classification: current class vs all others
                binary_labels = (labels == i).astype(int)
                binary_predictions = (predictions == i).astype(int)

                class_precision, class_recall, class_f1, _ = precision_recall_fscore_support(
                    binary_labels, binary_predictions,
                    average='binary',
                    pos_label=1,
                    zero_division=0
                )

                labels_binary = label_binarize(labels, classes=list(range(len(CASME2_CLASSES))))
                if inference_results['probabilities'] is not None:
                    probs = inference_results['probabilities'][:, i]
                    try:
                        fpr, tpr, _ = roc_curve(labels_binary[:, i], probs)
                        class_auc = auc(fpr, tpr)
                    except:
                        class_auc = 0.0
                else:
                    class_auc = 0.0

                per_class_metrics[class_name] = {
                    'precision': float(class_precision),
                    'recall': float(class_recall),
                    'f1_score': float(class_f1),
                    'auc': float(class_auc),
                    'support': int(np.sum(class_mask)),
                    'in_test_set': True
                }
            else:
                per_class_metrics[class_name] = {
                    'precision': 0.0,
                    'recall': 0.0,
                    'f1_score': 0.0,
                    'auc': 0.0,
                    'support': 0,
                    'in_test_set': True
                }
        else:
            per_class_metrics[class_name] = {
                'precision': 0.0,
                'recall': 0.0,
                'f1_score': 0.0,
                'auc': 0.0,
                'support': 0,
                'in_test_set': False
            }

    macro_auc = np.mean([m['auc'] for m in per_class_metrics.values() if m['in_test_set']])

    results = {
        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_precision': float(precision),
            'macro_recall': float(recall),
            'macro_f1': float(f1),
            'macro_auc': float(macro_auc)
        },
        'per_class_performance': per_class_metrics,
        'confusion_matrix': cm.tolist(),
        'evaluation_metadata': {
            'dataset': 'CASME_II',
            'model_type': 'PoolFormerCASME2Baseline',
            'test_samples': len(predictions),
            'class_names': CASME2_CLASSES,
            'missing_classes': [CASME2_CLASSES[i] for i in missing_classes],
            'available_classes': [CASME2_CLASSES[i] for i in available_classes],
            'evaluation_timestamp': datetime.now().isoformat(),
            'evaluation_mode': inference_results['evaluation_mode']
        },
        'inference_performance': {
            'total_time_seconds': float(inference_results['inference_time']),
            'average_time_ms_per_sample': float(inference_results['inference_time'] * 1000 / len(predictions))
        }
    }

    if 'kfs_late_fusion_info' in inference_results:
        results['kfs_late_fusion_info'] = inference_results['kfs_late_fusion_info']

    return results

# =====================================================
# WRONG PREDICTIONS ANALYSIS
# =====================================================

def analyze_wrong_predictions(inference_results):
    """Analyze wrong predictions for error pattern identification"""
    predictions = inference_results['predictions']
    labels = inference_results['labels']
    filenames = inference_results['filenames']

    wrong_predictions = []
    wrong_by_class = defaultdict(int)
    confusion_patterns = defaultdict(int)

    for i in range(len(predictions)):
        if predictions[i] != labels[i]:
            true_class = CASME2_CLASSES[labels[i]]
            pred_class = CASME2_CLASSES[predictions[i]]

            wrong_predictions.append({
                'filename': filenames[i],
                'true_label': int(labels[i]),
                'true_class': true_class,
                'predicted_label': int(predictions[i]),
                'predicted_class': pred_class
            })

            wrong_by_class[true_class] += 1
            confusion_patterns[f"{true_class} -> {pred_class}"] += 1

    error_summary = {}
    for class_name in CASME2_CLASSES:
        total_samples = np.sum(labels == CLASS_TO_IDX[class_name])
        errors = wrong_by_class.get(class_name, 0)
        error_summary[class_name] = {
            'total_samples': int(total_samples),
            'wrong_predictions': int(errors),
            'error_rate': float(errors / total_samples * 100) if total_samples > 0 else 0.0
        }

    results = {
        'analysis_metadata': {
            'total_samples': len(predictions),
            'total_wrong_predictions': len(wrong_predictions),
            'overall_error_rate': (len(wrong_predictions) / len(predictions) * 100) if len(predictions) > 0 else 0.0
        },
        'wrong_predictions': wrong_predictions,
        'wrong_predictions_by_class': dict(wrong_by_class),
        'error_summary': error_summary,
        'confusion_patterns': dict(confusion_patterns)
    }

    return results

# =====================================================
# SAVE EVALUATION RESULTS
# =====================================================

def save_evaluation_results(evaluation_results, wrong_predictions_results, results_dir, test_version):
    """Save comprehensive evaluation results"""
    os.makedirs(results_dir, exist_ok=True)

    results_file = f"{results_dir}/casme2_poolformer_evaluation_results_{test_version}.json"
    with open(results_file, 'w') as f:
        json.dump(evaluation_results, f, indent=2, default=str)

    wrong_predictions_file = f"{results_dir}/casme2_poolformer_wrong_predictions_{test_version}.json"
    with open(wrong_predictions_file, 'w') as f:
        json.dump(wrong_predictions_results, f, indent=2, default=str)

    print(f"Evaluation results saved:")
    print(f"  Main results: {os.path.basename(results_file)}")
    print(f"  Wrong predictions: {os.path.basename(wrong_predictions_file)}")

    return results_file, wrong_predictions_file

# =====================================================
# MAIN EVALUATION EXECUTION
# =====================================================

all_evaluation_results = {}

for dataset_version in EVALUATE_DATASETS:
    print("\n" + "=" * 70)
    print(f"EVALUATING DATASET: {dataset_version.upper()}")
    print("=" * 70)

    try:
        # Get dataset configuration
        test_config = get_test_dataset_config(dataset_version, PROJECT_ROOT)

        print(f"\nTest Dataset Configuration:")
        print(f"  Version: {test_config['version']}")
        print(f"  Variant: {test_config['variant']}")
        print(f"  Description: {test_config['description']}")
        print(f"  Frame strategy: {test_config['frame_strategy']}")
        print(f"  Evaluation mode: {test_config['evaluation_mode']}")
        if 'aggregation' in test_config and test_config['aggregation']:
            print(f"  Aggregation: {test_config['aggregation']}")
        print(f"  Dataset path: {test_config['dataset_path']}")

        # Create test dataset
        print(f"\nCreating CASME II test dataset from {test_config['variant']}...")
        test_dataset = CASME2DatasetEvaluation(
            dataset_root=test_config['dataset_path'],
            split='test',
            transform=GLOBAL_CONFIG_CASME2['transform_val'],
            use_ram_cache=True
        )

        if len(test_dataset) == 0:
            raise ValueError(f"No test samples found for {dataset_version}!")

        test_loader = DataLoader(
            test_dataset,
            batch_size=CASME2_POOLFORMER_CONFIG['batch_size'],
            shuffle=False,
            num_workers=CASME2_POOLFORMER_CONFIG['num_workers'],
            pin_memory=True
        )

        # Load trained model
        checkpoint_path = f"{GLOBAL_CONFIG_CASME2['checkpoint_root']}/casme2_poolformer_mfs_best_f1.pth"
        model, training_info = load_trained_model_casme2(checkpoint_path, GLOBAL_CONFIG_CASME2['device'])

        # Run inference based on evaluation mode
        if test_config['evaluation_mode'] == 'frame_level':
            print(f"\nRunning frame-level evaluation for {test_config['variant']}...")
            inference_results = run_frame_level_inference(model, test_loader, GLOBAL_CONFIG_CASME2['device'])

        elif test_config['evaluation_mode'] == 'video_level':
            print(f"\nRunning video-level evaluation with late fusion for {test_config['variant']}...")
            inference_results = run_video_level_inference_late_fusion(model, test_loader, GLOBAL_CONFIG_CASME2['device'])

        else:
            raise ValueError(f"Unknown evaluation mode: {test_config['evaluation_mode']}")

        # Calculate comprehensive metrics
        evaluation_results = calculate_comprehensive_metrics(inference_results)

        # Analyze wrong predictions
        wrong_predictions_results = analyze_wrong_predictions(inference_results)

        # Add training information
        evaluation_results['training_information'] = training_info
        evaluation_results['test_configuration'] = test_config

        # Save results
        results_dir = f"{GLOBAL_CONFIG_CASME2['results_root']}/evaluation_results"
        save_evaluation_results(
            evaluation_results, wrong_predictions_results, results_dir, test_config['version']
        )

        # Store for comparison
        all_evaluation_results[dataset_version] = {
            'evaluation': evaluation_results,
            'wrong_predictions': wrong_predictions_results,
            'config': test_config
        }

        # Display results
        print("\n" + "=" * 60)
        print(f"EVALUATION RESULTS - {test_config['variant']} ({dataset_version})")
        print("=" * 60)

        overall = evaluation_results['overall_performance']
        print(f"\nOverall Performance:")
        print(f"  Accuracy:  {overall['accuracy']:.4f}")
        print(f"  Precision: {overall['macro_precision']:.4f}")
        print(f"  Recall:    {overall['macro_recall']:.4f}")
        print(f"  F1 Score:  {overall['macro_f1']:.4f}")
        print(f"  AUC:       {overall['macro_auc']:.4f}")

        if 'kfs_late_fusion_info' in evaluation_results:
            fusion_info = evaluation_results['kfs_late_fusion_info']
            print(f"\nLate Fusion Info:")
            print(f"  Total frames processed: {fusion_info['total_frames']}")
            print(f"  Video-level predictions: {fusion_info['total_videos']}")
            print(f"  Aggregation method: {fusion_info['aggregation_method']}")

        print(f"\nPer-Class Performance:")
        for class_name, metrics in evaluation_results['per_class_performance'].items():
            in_test = "Present" if metrics['in_test_set'] else "Missing"
            print(f"  {class_name} [{in_test}]: F1={metrics['f1_score']:.4f}, "
                  f"Support={metrics['support']}")

        wrong_meta = wrong_predictions_results['analysis_metadata']
        print(f"\nWrong Predictions Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']} / {wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        print(f"\nInference Performance:")
        print(f"  Total time: {inference_results['inference_time']:.2f}s")
        print(f"  Speed: {evaluation_results['inference_performance']['average_time_ms_per_sample']:.1f} ms/sample")

    except Exception as e:
        print(f"Evaluation failed for {dataset_version}: {e}")
        import traceback
        traceback.print_exc()
        continue

# =====================================================
# COMPARATIVE ANALYSIS (if both datasets evaluated)
# =====================================================

if len(all_evaluation_results) == 2 and 'v7' in all_evaluation_results and 'v8' in all_evaluation_results:
    print("\n" + "=" * 70)
    print("COMPARATIVE ANALYSIS: AF (v7) vs KFS (v8)")
    print("=" * 70)

    v7_results = all_evaluation_results['v7']['evaluation']
    v8_results = all_evaluation_results['v8']['evaluation']

    print("\nOverall Performance Comparison:")
    print(f"{'Metric':<20} {'AF (v7)':<15} {'KFS (v8)':<15} {'Difference':<15}")
    print("-" * 65)

    metrics_to_compare = ['accuracy', 'macro_precision', 'macro_recall', 'macro_f1', 'macro_auc']

    for metric in metrics_to_compare:
        v7_val = v7_results['overall_performance'][metric]
        v8_val = v8_results['overall_performance'][metric]
        diff = v8_val - v7_val

        print(f"{metric:<20} {v7_val:<15.4f} {v8_val:<15.4f} {diff:+.4f}")

    print(f"\nEvaluation Modes:")
    print(f"  AF (v7): {v7_results['evaluation_metadata']['evaluation_mode']}")
    print(f"  KFS (v8): {v8_results['evaluation_metadata']['evaluation_mode']}")

    if 'kfs_late_fusion_info' in v8_results:
        print(f"\nKFS Late Fusion Strategy:")
        print(f"  Frames used: {v8_results['kfs_late_fusion_info']['total_frames']}")
        print(f"  Video predictions: {v8_results['kfs_late_fusion_info']['total_videos']}")
        print(f"  Aggregation: {v8_results['kfs_late_fusion_info']['aggregation_method']}")

# Memory cleanup
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\n" + "=" * 70)
print("CASME II POOLFORMER EVALUATION COMPLETED")
print("=" * 70)
print(f"Evaluated datasets: {EVALUATE_DATASETS}")
print("Next: Cell 4 - Generate confusion matrices and visualization")

CASME II PoolFormer Evaluation Framework with Dual Dataset Support
Datasets to evaluate: ['v7', 'v8']

EVALUATING DATASET: V7

Test Dataset Configuration:
  Version: v7
  Variant: AF
  Description: Apex Frame with Face-Aware Preprocessing
  Frame strategy: apex_frame
  Evaluation mode: frame_level
  Dataset path: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v7

Creating CASME II test dataset from AF...
Loading CASME II test dataset for evaluation...
Found 28 image files in directory
Loaded 28 CASME II test samples for evaluation
Test set class distribution:
  others: 10 samples (35.7%)
  disgust: 7 samples (25.0%)
  happiness: 4 samples (14.3%)
  repression: 3 samples (10.7%)
  surprise: 3 samples (10.7%)
  sadness: 1 samples (3.6%)
Unique video IDs: 28
Missing classes in test set: ['fear']
Preloading 28 test images to RAM with 32 workers...


Loading test set to RAM: 100%|██████████| 28/28 [00:00<00:00, 28.54it/s]


TEST RAM caching completed: 28/28 images, ~0.02GB

Validating checkpoint availability...
Expected checkpoint: casme2_poolformer_mfs_best_f1.pth
Full path: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/07_03_poolformer_casme2_mfs_prep/casme2_poolformer_mfs_best_f1.pth
Checkpoint found: 837.9 MB
Loading checkpoint from disk...
Initializing PoolFormer model...
PoolFormer feature dimension: 768
Classification head: 768 -> GAP2D -> 512 -> 128 -> 7
Dropout rate: 0.3 (balanced for large dataset)
Model loaded successfully from epoch 12
Best validation F1: 0.1896

Running frame-level evaluation for AF...
Running frame-level inference...


Frame-level inference: 100%|██████████| 2/2 [00:00<00:00,  2.53it/s]


Frame-level inference completed in 0.79s
Processed 28 frames
Evaluation results saved:
  Main results: casme2_poolformer_evaluation_results_v7.json
  Wrong predictions: casme2_poolformer_wrong_predictions_v7.json

EVALUATION RESULTS - AF (v7)

Overall Performance:
  Accuracy:  0.4643
  Precision: 0.4653
  Recall:    0.3552
  F1 Score:  0.3785
  AUC:       0.7619

Per-Class Performance:
  others [Present]: F1=0.4545, Support=10
  disgust [Present]: F1=0.6667, Support=7
  happiness [Present]: F1=0.2500, Support=4
  repression [Present]: F1=0.4000, Support=3
  surprise [Present]: F1=0.5000, Support=3
  sadness [Present]: F1=0.0000, Support=1
  fear [Missing]: F1=0.0000, Support=0

Wrong Predictions Analysis:
  Total errors: 15 / 28
  Error rate: 53.57%

Inference Performance:
  Total time: 0.79s
  Speed: 28.3 ms/sample

EVALUATING DATASET: V8

Test Dataset Configuration:
  Version: v8
  Variant: KFS
  Description: Key Frame Sequence with Face-Aware Preprocessing
  Frame strategy: key_fram

Loading test set to RAM: 100%|██████████| 84/84 [00:01<00:00, 49.85it/s]


TEST RAM caching completed: 84/84 images, ~0.05GB

Validating checkpoint availability...
Expected checkpoint: casme2_poolformer_mfs_best_f1.pth
Full path: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/07_03_poolformer_casme2_mfs_prep/casme2_poolformer_mfs_best_f1.pth
Checkpoint found: 837.9 MB
Loading checkpoint from disk...
Initializing PoolFormer model...
PoolFormer feature dimension: 768
Classification head: 768 -> GAP2D -> 512 -> 128 -> 7
Dropout rate: 0.3 (balanced for large dataset)
Model loaded successfully from epoch 12
Best validation F1: 0.1896

Running video-level evaluation with late fusion for KFS...
Running frame-level predictions for late fusion...


Frame predictions: 100%|██████████| 6/6 [00:00<00:00,  6.28it/s]


Aggregating frame predictions to video level...
Late fusion completed in 0.96s
Aggregated 84 frames into 84 videos
Evaluation results saved:
  Main results: casme2_poolformer_evaluation_results_v8.json
  Wrong predictions: casme2_poolformer_wrong_predictions_v8.json

EVALUATION RESULTS - KFS (v8)

Overall Performance:
  Accuracy:  0.4405
  Precision: 0.3626
  Recall:    0.3181
  F1 Score:  0.3261
  AUC:       0.0000

Late Fusion Info:
  Total frames processed: 84
  Video-level predictions: 84
  Aggregation method: average_probabilities

Per-Class Performance:
  others [Present]: F1=0.4918, Support=30
  disgust [Present]: F1=0.6522, Support=21
  happiness [Present]: F1=0.2400, Support=12
  repression [Present]: F1=0.1111, Support=9
  surprise [Present]: F1=0.4615, Support=9
  sadness [Present]: F1=0.0000, Support=3
  fear [Missing]: F1=0.0000, Support=0

Wrong Predictions Analysis:
  Total errors: 47 / 84
  Error rate: 55.95%

Inference Performance:
  Total time: 0.96s
  Speed: 11.4 ms




In [4]:
# @title Cell 4: CASME II PoolFormer Confusion Matrix Generation

# File: 07_03_PoolFormer_CASME2_MFS_Cell4.py
# Location: experiments/07_03_PoolFormer_CASME2-MFS-PREP.ipynb
# Purpose: Generate professional confusion matrix visualization for AF and KFS evaluations

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II PoolFormer Confusion Matrix Generation")
print("=" * 60)

PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/07_03_poolformer_casme2_mfs_prep"

def find_evaluation_json_files(results_path):
    """Find evaluation JSON files with version detection"""
    json_files = {}
    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        for version in ['v7', 'v8']:
            eval_pattern = f"{eval_dir}/casme2_poolformer_evaluation_results_{version}.json"
            eval_files = glob.glob(eval_pattern)

            if eval_files:
                json_files[f'main_{version}'] = eval_files[0]
                print(f"Found {version.upper()} evaluation file: {os.path.basename(eval_files[0])}")

            wrong_pattern = f"{eval_dir}/casme2_poolformer_wrong_predictions_{version}.json"
            wrong_files = glob.glob(wrong_pattern)

            if wrong_files:
                json_files[f'wrong_{version}'] = wrong_files[0]
                print(f"Found {version.upper()} wrong predictions: {os.path.basename(wrong_files[0])}")

        if not json_files:
            print(f"WARNING: No evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (evaluation) has been executed first!")
    else:
        print(f"ERROR: Evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results(json_path):
    """Load and parse evaluation results JSON"""
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1(per_class_performance):
    """Calculate weighted F1 score"""
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy(confusion_matrix):
    """Calculate balanced accuracy handling classes with zero support"""
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []
    classes_with_samples = []

    for i in range(n_classes):
        if cm[i, :].sum() > 0:
            classes_with_samples.append(i)

    for i in classes_with_samples:
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - tp - fn - fp

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color(color_value, threshold=0.5):
    """Determine optimal text color based on background intensity"""
    return 'white' if color_value > threshold else 'black'

def create_confusion_matrix_plot(data, output_path, test_version):
    """Create professional confusion matrix visualization for CASME II micro-expression recognition"""

    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']

    test_config = data.get('test_configuration', {})
    variant = test_config.get('variant', test_version.upper())
    description = test_config.get('description', f'{test_version} preprocessing')
    eval_mode = meta.get('evaluation_mode', 'frame_level')

    print(f"Processing confusion matrix for {variant} ({test_version})")
    print(f"Dataset: {description}")
    print(f"Evaluation mode: {eval_mode}")
    print(f"Confusion matrix shape: {cm.shape}")

    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1(per_class)
    balanced_acc = calculate_balanced_accuracy(cm)

    print(f"Calculated metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Balanced Acc: {balanced_acc:.4f}, Accuracy: {accuracy:.4f}")

    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    fig, ax = plt.subplots(figsize=(12, 10))

    cmap = 'Blues'
    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"

            cell_value = cm_pct[i, j]
            text_color = determine_text_color(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    preprocessing_note = f"Preprocessing: {description}\n"
    preprocessing_note += f"Dataset: {test_version}\n"
    preprocessing_note += f"Evaluation: {eval_mode.replace('_', ' ').title()}"

    if 'kfs_late_fusion_info' in data:
        fusion_info = data['kfs_late_fusion_info']
        preprocessing_note += f"\nFrames: {fusion_info['total_frames']}, Videos: {fusion_info['total_videos']}"

    missing_classes = meta.get('missing_classes', [])
    if missing_classes:
        preprocessing_note += f"\nMissing: {', '.join(missing_classes)}"

    ax.text(0.02, 0.98, preprocessing_note, transform=ax.transAxes, fontsize=8,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    title = f"CASME II {variant} Micro-Expression Recognition - PoolFormer\n"
    title += f"Acc: {accuracy:.4f}  |  Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes
    }

def generate_performance_summary(evaluation_data, wrong_predictions_data=None):
    """Generate comprehensive performance summary"""

    print("\n" + "=" * 60)
    print("CASME II MICRO-EXPRESSION RECOGNITION PERFORMANCE SUMMARY")
    print("=" * 60)

    overall = evaluation_data['overall_performance']
    meta = evaluation_data['evaluation_metadata']
    test_config = evaluation_data.get('test_configuration', {})

    variant = test_config.get('variant', 'N/A')

    print(f"Dataset: {meta['dataset']}")
    print(f"Variant: {variant}")
    print(f"Dataset version: {test_config.get('version', 'N/A')}")
    print(f"Preprocessing: {test_config.get('description', 'N/A')}")
    print(f"Test samples: {meta['test_samples']}")
    print(f"Model: {meta['model_type']}")
    print(f"Evaluation date: {meta['evaluation_timestamp']}")

    if 'kfs_late_fusion_info' in evaluation_data:
        fusion_info = evaluation_data['kfs_late_fusion_info']
        print(f"\nLate Fusion Information:")
        print(f"  Total frames: {fusion_info['total_frames']}")
        print(f"  Video predictions: {fusion_info['total_videos']}")
        print(f"  Aggregation: {fusion_info['aggregation_method']}")

    print(f"\nOverall Performance:")
    print(f"  Accuracy:         {overall['accuracy']:.4f}")
    print(f"  Macro Precision:  {overall['macro_precision']:.4f}")
    print(f"  Macro Recall:     {overall['macro_recall']:.4f}")
    print(f"  Macro F1:         {overall['macro_f1']:.4f}")
    print(f"  Macro AUC:        {overall['macro_auc']:.4f}")

    print(f"\nPer-Class Performance:")
    per_class = evaluation_data['per_class_performance']

    print(f"{'Class':<12} {'F1':<8} {'Precision':<10} {'Recall':<8} {'AUC':<8} {'Support':<8} {'In Test'}")
    print("-" * 65)

    for class_name, metrics in per_class.items():
        in_test = "Yes" if metrics['in_test_set'] else "No"
        print(f"{class_name:<12} {metrics['f1_score']:<8.4f} {metrics['precision']:<10.4f} "
              f"{metrics['recall']:<8.4f} {metrics['auc']:<8.4f} {metrics['support']:<8} {in_test}")

    if 'training_information' in evaluation_data:
        training = evaluation_data['training_information']
        print(f"\nTraining vs Test Comparison:")
        print(f"  Training Val F1:  {training['best_val_f1']:.4f}")
        print(f"  Test F1:          {overall['macro_f1']:.4f}")
        print(f"  Performance Gap:  {training['best_val_f1'] - overall['macro_f1']:+.4f}")
        print(f"  Best Epoch:       {training['best_epoch']}")

    missing_classes = meta.get('missing_classes', [])
    available_classes = meta.get('available_classes', [])

    print(f"\nClass Availability Analysis:")
    print(f"  Available classes: {len(available_classes)}/7")
    print(f"  Missing classes: {missing_classes if missing_classes else 'None'}")

    if wrong_predictions_data:
        wrong_meta = wrong_predictions_data['analysis_metadata']
        print(f"\nError Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']}/{wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        patterns = wrong_predictions_data.get('confusion_patterns', {})
        if patterns:
            print(f"\nTop Confusion Patterns:")
            sorted_patterns = sorted(patterns.items(), key=lambda x: x[1], reverse=True)[:3]
            for pattern, count in sorted_patterns:
                print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    inference = evaluation_data['inference_performance']
    print(f"  Total time: {inference['total_time_seconds']:.2f}s")
    print(f"  Speed: {inference['average_time_ms_per_sample']:.1f} ms/sample")

json_files = find_evaluation_json_files(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (evaluation) has been executed first!")
else:
    print(f"\nFound {len([k for k in json_files.keys() if k.startswith('main_')])} evaluation result(s)")

output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

results_summary = {}
generated_files = []

for version in ['v7', 'v8']:
    main_key = f'main_{version}'
    wrong_key = f'wrong_{version}'

    if main_key in json_files:
        print(f"\n{'='*60}")
        print(f"Processing {version.upper()} Evaluation Results")
        print(f"{'='*60}")

        eval_data = load_evaluation_results(json_files[main_key])

        wrong_data = None
        if wrong_key in json_files:
            wrong_data = load_evaluation_results(json_files[wrong_key])

        if eval_data is not None:
            try:
                cm_output_path = os.path.join(output_dir, f"confusion_matrix_CASME2_PoolFormer_{version}.png")
                metrics = create_confusion_matrix_plot(eval_data, cm_output_path, version)
                generated_files.append(cm_output_path)

                results_summary[version] = metrics

                print(f"\nSUCCESS: {version.upper()} confusion matrix generated successfully")
                print(f"Output file: {os.path.basename(cm_output_path)}")

                print(f"\nPerformance Metrics Summary:")
                print(f"  Accuracy:        {metrics['accuracy']:.4f}")
                print(f"  Macro F1:        {metrics['macro_f1']:.4f}")
                print(f"  Weighted F1:     {metrics['weighted_f1']:.4f}")
                print(f"  Balanced Acc:    {metrics['balanced_accuracy']:.4f}")

                if metrics['missing_classes']:
                    print(f"  Missing classes: {metrics['missing_classes']}")

            except Exception as e:
                print(f"ERROR: Failed to generate {version.upper()} confusion matrix: {str(e)}")
                import traceback
                traceback.print_exc()

            generate_performance_summary(eval_data, wrong_data)
        else:
            print(f"ERROR: Could not load {version.upper()} evaluation data")

if generated_files:
    print(f"\n" + "=" * 60)
    print("CASME II CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 60)

    print(f"\nGenerated visualization files:")
    for file_path in generated_files:
        filename = os.path.basename(file_path)
        print(f"  {filename}")

    for version in ['v7', 'v8']:
        if version in results_summary:
            variant = 'AF' if version == 'v7' else 'KFS'
            print(f"\n{variant} ({version.upper()}) Performance Summary:")
            metrics = results_summary[version]
            print(f"  Accuracy:       {metrics['accuracy']:.4f}")
            print(f"  Macro F1:       {metrics['macro_f1']:.4f}")
            print(f"  Weighted F1:    {metrics['weighted_f1']:.4f}")
            print(f"  Balanced Acc:   {metrics['balanced_accuracy']:.4f}")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No visualizations were generated")
    print("Please check:")
    print("1. Cell 3 evaluation results exist")
    print("2. JSON file structure is correct")
    print("3. No file permission issues")

print("\nCell 4 completed - CASME II confusion matrix analysis generated")
print("\n" + "=" * 60)
print("ALL EXPERIMENTS COMPLETED - READY FOR CONFERENCE PAPER")
print("=" * 60)

CASME II PoolFormer Confusion Matrix Generation
Found V7 evaluation file: casme2_poolformer_evaluation_results_v7.json
Found V7 wrong predictions: casme2_poolformer_wrong_predictions_v7.json
Found V8 evaluation file: casme2_poolformer_evaluation_results_v8.json
Found V8 wrong predictions: casme2_poolformer_wrong_predictions_v8.json

Found 2 evaluation result(s)

Processing V7 Evaluation Results
Successfully loaded: casme2_poolformer_evaluation_results_v7.json
Successfully loaded: casme2_poolformer_wrong_predictions_v7.json
Processing confusion matrix for AF (v7)
Dataset: Apex Frame with Face-Aware Preprocessing
Evaluation mode: frame_level
Confusion matrix shape: (7, 7)
Calculated metrics - Macro F1: 0.3785, Weighted F1: 0.4611, Balanced Acc: 0.6164, Accuracy: 0.4643
Confusion matrix saved to: confusion_matrix_CASME2_PoolFormer_v7.png

SUCCESS: V7 confusion matrix generated successfully
Output file: confusion_matrix_CASME2_PoolFormer_v7.png

Performance Metrics Summary:
  Accuracy:    