In [None]:
# @title Cell 1: CASME II Key-Frame PoolFormer Infrastructure Configuration

# File: 03_03_PoolFormer_CASME2_KFS_Cell1.py
# Location: experiments/03_03_PoolFormer_CASME2-KFS.ipynb
# Purpose: PoolFormer for CASME II micro-expression recognition with key-frame strategy (onset, apex, offset)

# Mount Google Drive
from google.colab import drive
print("=" * 60)
print("CASME II KEY-FRAME POOLFORMER INFRASTRUCTURE")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Project paths configuration - Phase 2 structure
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/data_split_v2"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/03_03_poolformer_casme2_kfs"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/03_03_poolformer_casme2_kfs"

# Load CASME II Phase 2 dataset metadata
METADATA_TRAIN = f"{DATASET_ROOT}/split_metadata_v2.json"
PROCESSING_SUMMARY = f"{DATASET_ROOT}/processing_summary_v2.json"

print("CASME II Key-Frame PoolFormer - Infrastructure Configuration")
print("=" * 60)

# Validate Phase 2 metadata files exist
if not os.path.exists(METADATA_TRAIN):
    raise FileNotFoundError(f"Phase 2 metadata not found: {METADATA_TRAIN}")
if not os.path.exists(PROCESSING_SUMMARY):
    raise FileNotFoundError(f"Phase 2 processing summary not found: {PROCESSING_SUMMARY}")

# Load Phase 2 dataset metadata
print("Loading CASME II Phase 2 dataset metadata...")
with open(METADATA_TRAIN, 'r') as f:
    casme2_metadata = json.load(f)

with open(PROCESSING_SUMMARY, 'r') as f:
    processing_info = json.load(f)

print(f"Dataset: {processing_info['dataset']}")
print(f"Phase: {processing_info['phase']}")
print(f"Total images: {processing_info['total_images_copied']}")
print(f"Frame types: {processing_info['frame_types']}")
print(f"Expansion strategy: {processing_info['expansion_strategy']}")

# =====================================================
# ADVANCED EXPERIMENT CONFIGURATION - Optimized Parameters
# =====================================================

# FOCAL LOSS CONFIGURATION - Toggle and Advanced Parameters
USE_FOCAL_LOSS = True  # Set True to enable Focal Loss, False for CrossEntropy
FOCAL_LOSS_GAMMA = 2.0  # Focal loss focusing parameter (typically 1.0 - 3.0)

# OPTIMIZED CLASS WEIGHTS CONFIGURATION - Inverse Square Root Frequency Approach
# CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
# Train distribution: [99, 63, 32, 27, 25, 7, 2] - inverse sqrt frequency approach

# CrossEntropy Loss - Optimized inverse square root frequency weights
CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.25, 1.76, 1.91, 1.99, 3.76, 7.04]

# Focal Loss - Normalized per-class alpha values (sum = 1.0)
FOCAL_LOSS_ALPHA_WEIGHTS = [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]

# POOLFORMER MODEL CONFIGURATION - Support m36 and m48 variants
POOLFORMER_MODEL_VARIANT = 'm48'  # Options: 'm36' or 'm48'

# Dynamic PoolFormer model selection based on variant
if POOLFORMER_MODEL_VARIANT == 'm36':
    POOLFORMER_MODEL_NAME = 'sail/poolformer_m36'
    MODEL_PARAMS = '56M'
    print("Using PoolFormer-M36 for micro-expression analysis (56M parameters)")
elif POOLFORMER_MODEL_VARIANT == 'm48':
    POOLFORMER_MODEL_NAME = 'sail/poolformer_m48'
    MODEL_PARAMS = '73M'
    print("Using PoolFormer-M48 for enhanced micro-expression recognition (73M parameters)")
else:
    raise ValueError(f"Unsupported POOLFORMER_MODEL_VARIANT: {POOLFORMER_MODEL_VARIANT}")

# Display experiment configuration
print("\n" + "=" * 50)
print("OPTIMIZED EXPERIMENT CONFIGURATION")
print("=" * 50)
print(f"Loss Function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
if USE_FOCAL_LOSS:
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Alpha Weights (per-class): {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum Validation: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Class Weights (inverse sqrt freq): {CROSSENTROPY_CLASS_WEIGHTS}")
print(f"PoolFormer Model: {POOLFORMER_MODEL_NAME}")
print(f"Model Parameters: {MODEL_PARAMS}")
print("=" * 50)

# Enhanced GPU configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

# Hardware-optimized batch size for 384px input
if 'A100' in gpu_name:
    BATCH_SIZE = 20
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("A100: Optimized batch size for PoolFormer 384px")
elif 'L4' in gpu_name:
    BATCH_SIZE = 16
    NUM_WORKERS = 6
    torch.backends.cudnn.benchmark = True
    print("L4: Balanced performance configuration")
else:
    BATCH_SIZE = 12
    NUM_WORKERS = 4
    print("Default GPU: Conservative settings")

# RAM preloading workers (separate from DataLoader workers)
# Optimized for I/O-bound parallel image loading from Google Drive
RAM_PRELOAD_WORKERS = 12
print(f"RAM preload workers: {RAM_PRELOAD_WORKERS} (parallel image loading)")

# CASME II class mapping and analysis
CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

# Smart class distribution loading - handles both v1 and v2 JSON formats
print("\nLoading class distribution...")
try:
    # Try v1 format first (from split_metadata)
    if 'train' in casme2_metadata and 'class_distribution' in casme2_metadata['train']:
        train_dist = casme2_metadata['train']['class_distribution']
        val_dist = casme2_metadata['val']['class_distribution']
        test_dist = casme2_metadata['test']['class_distribution']
        print("Using class distribution from split_metadata (v1 format)")
    else:
        # Fallback to v2 format (from processing_summary)
        train_dist = processing_info['class_preservation']['train']
        val_dist = processing_info['class_preservation']['val']
        test_dist = processing_info['class_preservation']['test']
        print("Using class distribution from processing_summary (v2 format)")
except KeyError as e:
    raise KeyError(f"Could not load class distribution from metadata. Missing key: {e}")

print(f"\nTrain distribution: {train_dist}")
print(f"Validation distribution: {val_dist}")
print(f"Test distribution: {test_dist}")

# Apply optimized class weights based on loss function selection
if USE_FOCAL_LOSS:
    # For Focal Loss - use normalized alpha weights (per-class importance)
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    # For CrossEntropy - use inverse sqrt frequency weights
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

# CASME II PoolFormer Configuration
CASME2_POOLFORMER_CONFIG = {
    # Architecture configuration - PoolFormer specific
    'poolformer_model': POOLFORMER_MODEL_NAME,
    'model_variant': POOLFORMER_MODEL_VARIANT,
    'model_params': MODEL_PARAMS,
    'input_size': 384,
    'num_classes': 7,
    'dropout_rate': 0.2,
    'expected_feature_dim': 768,

    # Training configuration (proven effective from medical imaging)
    'learning_rate': 1e-5,
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'device': device,

    # Scheduler configuration
    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_factor': 0.5,
    'scheduler_patience': 3,
    'scheduler_min_lr': 1e-6,
    'scheduler_monitor': 'val_f1_macro',

    # Optimized loss configuration
    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS,
    'class_weights': class_weights,

    # Evaluation configuration
    'use_macro_avg': True,
    'early_stopping': False,
    'save_best_f1': True,
    'save_strategy': 'best_only',

    # Phase 2 specific configuration
    'dataset_phase': 'v2',
    'frame_types': processing_info['frame_types']
}

print(f"\nPoolFormer Configuration Summary:")
print(f"  Model: {CASME2_POOLFORMER_CONFIG['poolformer_model']}")
print(f"  Variant: {CASME2_POOLFORMER_CONFIG['model_variant'].upper()}")
print(f"  Parameters: {CASME2_POOLFORMER_CONFIG['model_params']}")
print(f"  Input size: {CASME2_POOLFORMER_CONFIG['input_size']}px")
print(f"  Learning rate: {CASME2_POOLFORMER_CONFIG['learning_rate']}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Dataset phase: {CASME2_POOLFORMER_CONFIG['dataset_phase']}")
print(f"  Frame types: {CASME2_POOLFORMER_CONFIG['frame_types']}")

# =====================================================
# ADVANCED FOCAL LOSS IMPLEMENTATION - Per-Class Alpha Support
# =====================================================

class OptimizedFocalLoss(nn.Module):
    """
    Advanced Focal Loss implementation with per-class alpha support
    Paper: "Focal Loss for Dense Object Detection" (Lin et al., 2017)

    Enhanced Formula: FL(p_t) = -α_t(1-p_t)^γ log(p_t)

    Args:
        alpha (list/tensor): Per-class alpha weights (must sum to 1.0)
        gamma (float): Focusing parameter for hard examples (default: 2.0)
        reduction (str): Reduction method ('mean', 'sum', 'none')
    """

    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()

        if alpha is not None:
            if isinstance(alpha, list):
                self.alpha = torch.tensor(alpha, dtype=torch.float32)
            else:
                self.alpha = alpha

            # Validation: alpha should sum to 1.0 for proper normalization
            alpha_sum = self.alpha.sum().item()
            if abs(alpha_sum - 1.0) > 0.01:
                print(f"Warning: Alpha weights sum to {alpha_sum:.3f}, expected 1.0")

        else:
            self.alpha = None

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Calculate cross entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')

        # Calculate p_t (probability of true class)
        pt = torch.exp(-ce_loss)

        # Apply per-class alpha if provided
        if self.alpha is not None:
            if self.alpha.device != targets.device:
                self.alpha = self.alpha.to(targets.device)
            alpha_t = self.alpha.gather(0, targets)
        else:
            alpha_t = 1.0

        # Apply focal loss formula: α_t(1-p_t)^γ * CE_loss
        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# PoolFormer Architecture for CASME II
class PoolFormerCASME2Baseline(nn.Module):
    """PoolFormer baseline for CASME II micro-expression recognition with token mixing"""

    def __init__(self, num_classes, dropout_rate=0.2):
        super(PoolFormerCASME2Baseline, self).__init__()

        # Hugging Face PoolFormer model
        from transformers import PoolFormerModel

        self.poolformer = PoolFormerModel.from_pretrained(
            CASME2_POOLFORMER_CONFIG['poolformer_model']
        )

        # Enable fine-tuning for micro-expression domain
        for param in self.poolformer.parameters():
            param.requires_grad = True

        # Get PoolFormer feature dimensions
        self.poolformer_feature_dim = self.poolformer.config.hidden_sizes[-1]

        print(f"PoolFormer feature dimension: {self.poolformer_feature_dim}")

        # Classification head with LayerNorm for stability
        self.classifier_layers = nn.Sequential(
            nn.Linear(self.poolformer_feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
        )

        # Final classification layer for CASME II classes
        self.classifier = nn.Linear(128, num_classes)

        print(f"PoolFormer CASME II: {self.poolformer_feature_dim} -> 512 -> 128 -> {num_classes}")

    def forward(self, pixel_values):
        # PoolFormer forward pass
        poolformer_outputs = self.poolformer(pixel_values=pixel_values)

        # Extract pooled features from PoolFormer last hidden state
        # PoolFormer output: [batch_size, channels, height, width]
        poolformer_features = poolformer_outputs.last_hidden_state

        # Global average pooling across spatial dimensions for classification
        # [batch_size, channels, height, width] -> [batch_size, channels]
        poolformer_features = poolformer_features.mean(dim=[2, 3])

        # Classification pipeline
        processed_features = self.classifier_layers(poolformer_features)
        output = self.classifier(processed_features)

        return output

# Enhanced optimizer and scheduler factory
def create_optimizer_scheduler_casme2(model, config):
    """Create optimizer and scheduler for CASME II PoolFormer training"""

    # AdamW optimizer with proven configuration
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    # ReduceLROnPlateau scheduler monitoring validation F1
    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

# PoolFormer Image Processor setup for 384px input
from transformers import PoolFormerImageProcessor

print("\nSetting up PoolFormer Image Processor for 384px input...")

poolformer_processor = PoolFormerImageProcessor.from_pretrained(
    CASME2_POOLFORMER_CONFIG['poolformer_model'],
    do_resize=True,
    size={'height': 384, 'width': 384},
    do_normalize=True,
    do_rescale=True,
    do_center_crop=False
)

# Transform functions for PoolFormer
def poolformer_transform_train(image):
    """Training transform with PoolFormer Image Processor"""
    inputs = poolformer_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

def poolformer_transform_val(image):
    """Validation transform with PoolFormer Image Processor"""
    inputs = poolformer_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

print("PoolFormer Image Processor configured for 384px with token mixing")

# Custom Dataset class for CASME II
class CASME2Dataset(Dataset):
    """Custom dataset class for CASME II with JSON metadata support"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='train'):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split

        print(f"Loaded {len(self.metadata)} samples for {split} split")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        sample = self.metadata[idx]

        # Load image
        image_path = os.path.join(self.dataset_root, self.split, sample['image_filename'])
        image = Image.open(image_path).convert('RGB')

        # Apply transform
        if self.transform:
            image = self.transform(image)

        # Get class label
        emotion = sample['emotion']
        label = CLASS_TO_IDX[emotion]

        return image, label, sample['sample_id']

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

# Dataset paths - Phase 2 structure
TRAIN_PATH = f"{DATASET_ROOT}/train"
VAL_PATH = f"{DATASET_ROOT}/val"
TEST_PATH = f"{DATASET_ROOT}/test"

print(f"\nDataset paths:")
print(f"Train: {TRAIN_PATH}")
print(f"Validation: {VAL_PATH}")
print(f"Test: {TEST_PATH}")

# Enhanced architecture validation with PoolFormer feature calculation
print("\nPoolFormer CASME II architecture validation...")

try:
    test_model = PoolFormerCASME2Baseline(num_classes=7, dropout_rate=0.2).to(device)
    test_input = torch.randn(1, 3, 384, 384).to(device)
    test_output = test_model(test_input)

    print(f"Validation successful: Output shape {test_output.shape}")
    print(f"PoolFormer {POOLFORMER_MODEL_VARIANT.upper()} with {MODEL_PARAMS} parameters")
    print(f"Token mixing with pooling operations")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")

# Optimized loss function factory with advanced configuration
def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    """
    Optimized factory function to create loss criterion based on advanced configuration

    Args:
        weights (Tensor): Class weights for CrossEntropy (ignored if focal loss used)
        use_focal_loss (bool): Whether to use Focal Loss or CrossEntropy
        alpha_weights (list): Per-class alpha weights for Focal Loss (must sum to 1.0)
        gamma (float): Focal loss gamma parameter

    Returns:
        Loss function (nn.Module)
    """
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

# Global configuration for training pipeline - enhanced with optimized features
GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': poolformer_transform_train,
    'transform_val': poolformer_transform_val,
    'poolformer_config': CASME2_POOLFORMER_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'train_path': TRAIN_PATH,
    'val_path': VAL_PATH,
    'test_path': TEST_PATH,
    'metadata': casme2_metadata,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

# Configuration validation and summary
print("\n" + "=" * 60)
print("CASME II KEY-FRAME POOLFORMER CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: {POOLFORMER_MODEL_NAME}")
print(f"  Variant: {POOLFORMER_MODEL_VARIANT.upper()}")
print(f"  Parameters: {MODEL_PARAMS}")
print(f"  Input Resolution: 384px")
print(f"  Feature Dimension: {CASME2_POOLFORMER_CONFIG['expected_feature_dim']}")

print(f"\nDataset Configuration:")
print(f"  Phase: {CASME2_POOLFORMER_CONFIG['dataset_phase']}")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Frame types: {CASME2_POOLFORMER_CONFIG['frame_types']}")
print(f"  Weight Optimization: {'Per-class Alpha' if USE_FOCAL_LOSS else 'Inverse Sqrt Frequency'}")

print("\nNext: Cell 2 - Dataset Loading and PoolFormer Training Pipeline")

CASME II KEY-FRAME POOLFORMER INFRASTRUCTURE

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...
CASME II Key-Frame PoolFormer - Infrastructure Configuration
Loading CASME II Phase 2 dataset metadata...
Dataset: CASME2_KeyFrames
Phase: Phase 2
Total images: 765
Frame types: ['onset', 'apex', 'offset']
Expansion strategy: onset_apex_offset_extraction
Using PoolFormer-M48 for enhanced micro-expression recognition (73M parameters)

OPTIMIZED EXPERIMENT CONFIGURATION
Loss Function: Focal Loss
  Gamma: 2.0
  Alpha Weights (per-class): [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum Validation: 0.999
PoolFormer Model: sail/poolformer_m48
Model Parameters: 73M

Device: cuda
GPU: NVIDIA L4 (23.8 GB)
L4: Balanced performance configuration
RAM preload workers: 12 (parallel image loading)

Loading class distribution...
Using class distribution from processing_summary (v2 format)

Train distribution: {'others'

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/283 [00:00<?, ?B/s]

PoolFormer Image Processor configured for 384px with token mixing

Dataset paths:
Train: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split_v2/train
Validation: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split_v2/val
Test: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split_v2/test

PoolFormer CASME II architecture validation...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/294M [00:00<?, ?B/s]

PoolFormer feature dimension: 768
PoolFormer CASME II: 768 -> 512 -> 128 -> 7
Validation successful: Output shape torch.Size([1, 7])
PoolFormer M48 with 73M parameters
Token mixing with pooling operations

CASME II KEY-FRAME POOLFORMER CONFIGURATION COMPLETE
Loss Configuration:
  Function: Optimized Focal Loss
  Gamma: 2.0
  Per-class Alpha: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum: 0.999

Model Configuration:
  Architecture: sail/poolformer_m48
  Variant: M48
  Parameters: 73M
  Input Resolution: 384px
  Feature Dimension: 768

Dataset Configuration:
  Phase: v2
  Classes: 7
  Frame types: ['onset', 'apex', 'offset']
  Weight Optimization: Per-class Alpha

Next: Cell 2 - Dataset Loading and PoolFormer Training Pipeline


In [None]:
# @title Cell 2: CASME II Key-Frame PoolFormer Training Pipeline

# File: 03_03_PoolFormer_CASME2_KFS_Cell2.py
# Location: experiments/03_03_PoolFormer_CASME2-KFS.ipynb
# Purpose: Enhanced training pipeline for CASME II Key-Frame PoolFormer with hardened checkpoint system

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import shutil
import tempfile

print("CASME II Key-Frame PoolFormer Training Pipeline with Hardened Checkpoint System")
print("=" * 80)
print(f"Loss Function: {'Optimized Focal Loss' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'CrossEntropy Loss'}")
if CASME2_POOLFORMER_CONFIG['use_focal_loss']:
    print(f"Focal Loss Parameters:")
    print(f"  Gamma: {CASME2_POOLFORMER_CONFIG['focal_loss_gamma']}")
    print(f"  Per-class Alpha: {CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights']}")
    print(f"  Alpha Sum: {sum(CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights']):.3f}")
else:
    print(f"CrossEntropy Parameters:")
    print(f"  Optimized Class Weights: {CASME2_POOLFORMER_CONFIG['crossentropy_class_weights']}")
print(f"Dataset Phase: {CASME2_POOLFORMER_CONFIG['dataset_phase']}")
print(f"Frame Types: {CASME2_POOLFORMER_CONFIG['frame_types']}")
print(f"Training epochs: {CASME2_POOLFORMER_CONFIG['num_epochs']}")
print(f"Scheduler patience: {CASME2_POOLFORMER_CONFIG['scheduler_patience']}")

# Smart metadata normalizer for v1 and v2 compatibility
def normalize_metadata_structure(metadata):
    """
    Normalize metadata structure to handle both v1 and v2 formats

    v1 format: metadata['train']['samples']
    v2 format: metadata['splits']['train']['samples']

    Returns: Normalized metadata with consistent structure
    """
    # Check if this is v2 format (has 'splits' key)
    if 'splits' in metadata:
        print("Detected v2 metadata format (with 'splits' key)")
        return metadata['splits']
    # Check if this is v1 format (direct split keys)
    elif 'train' in metadata:
        print("Detected v1 metadata format (direct split keys)")
        return metadata
    else:
        raise ValueError("Unknown metadata format: missing both 'splits' and 'train' keys")

# Enhanced CASME II Dataset with optimized RAM caching
class CASME2DatasetTraining(Dataset):
    """Enhanced CASME II dataset for training with RAM caching optimization"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='train', use_ram_cache=True):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.cached_images = []

        print(f"Loading CASME II {split} dataset for training...")

        # Process metadata
        for sample in self.metadata:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(CLASS_TO_IDX[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])

        print(f"Loaded {len(self.images)} CASME II {split} samples")
        self._print_distribution()

        # RAM caching for training efficiency
        if self.use_ram_cache:
            self._preload_to_ram()

    def _print_distribution(self):
        """Print class distribution"""
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        """RAM preloading with parallel loading for training efficiency"""
        print(f"Preloading {len(self.images)} {self.split} images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            """Load single image with error handling"""
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (384, 384), (128, 128, 128)), False

        # Parallel loading with ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            # Submit all loading tasks
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            # Collect results with progress bar
            for future in tqdm(futures, desc=f"Loading {self.split} to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * 384 * 384 * 3 * 4 / 1e9
        print(f"{self.split.upper()} RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (384, 384), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.sample_ids[idx]

# Enhanced metrics calculation with comprehensive error handling
def calculate_metrics_safe_robust(outputs, labels, class_names, average='macro'):
    """Calculate metrics with enhanced error handling and validation"""
    try:
        # Validate input tensors
        if outputs.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: outputs {outputs.size(0)} vs labels {labels.size(0)}")

        if isinstance(outputs, torch.Tensor):
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
        else:
            predictions = np.array(outputs)

        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()
        else:
            labels = np.array(labels)

        # Validate predictions are in valid range
        unique_preds = np.unique(predictions)
        unique_labels = np.unique(labels)

        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions,
            average=average,
            zero_division=0,
            labels=list(range(len(class_names)))
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1)
        }
    except Exception as e:
        print(f"Warning: Enhanced metrics calculation error: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }

# Enhanced training epoch function with robust model output validation
def train_epoch(model, dataloader, criterion, optimizer, device, epoch, total_epochs):
    """Enhanced training epoch with robust error handling and output validation"""
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    progress_bar = tqdm(dataloader, desc=f"CASME II Training Epoch {epoch+1}/{total_epochs}")

    for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Enhanced model output validation - handle multiple output formats
        model_output = model(images)

        # Robust output structure validation
        if isinstance(model_output, (tuple, list)):
            outputs = model_output[0]
        elif isinstance(model_output, dict):
            outputs = model_output.get('logits', model_output.get('prediction', model_output))
        else:
            outputs = model_output

        # Validate output shape for 7 CASME II classes
        if outputs.dim() != 2 or outputs.size(1) != 7:
            raise ValueError(f"Invalid CASME II output shape: {outputs.shape}, expected [batch_size, 7]")

        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), CASME2_POOLFORMER_CONFIG['gradient_clip'])

        optimizer.step()
        running_loss += loss.item()

        # Memory optimized: Move to CPU before accumulating
        all_outputs.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())

        # Update progress
        if batch_idx % 5 == 0:
            avg_loss = running_loss / (batch_idx + 1)
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'Loss': f'{avg_loss:.4f}',
                'LR': f'{current_lr:.2e}'
            })

    # Enhanced metrics calculation with error recovery
    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Training metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics

# Enhanced validation epoch function with robust model output validation
def validate_epoch(model, dataloader, criterion, device, epoch, total_epochs):
    """Enhanced validation epoch with robust error handling and output validation"""
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []
    all_sample_ids = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f"CASME II Validation Epoch {epoch+1}/{total_epochs}")

        for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)

            # Enhanced model output validation - handle multiple output formats
            model_output = model(images)

            # Robust output structure validation
            if isinstance(model_output, (tuple, list)):
                outputs = model_output[0]
            elif isinstance(model_output, dict):
                outputs = model_output.get('logits', model_output.get('prediction', model_output))
            else:
                outputs = model_output

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Memory optimized: Move to CPU before accumulating
            all_outputs.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sample_ids.extend(sample_ids)

            if batch_idx % 3 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                progress_bar.set_postfix({'Val Loss': f'{avg_loss:.4f}'})

    # Enhanced metrics calculation with error recovery
    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Validation metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics, all_sample_ids

# Hardened checkpoint saving function with atomic write and validation
def save_checkpoint_robust(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                         checkpoint_dir, best_metrics, config, max_retries=3):
    """
    Hardened checkpoint saving with atomic write, validation, and retry logic

    Improvements:
    1. Force all tensors to CPU before serialization
    2. Atomic save using temporary file
    3. Validate checkpoint after save by loading it back
    4. Detailed error logging
    5. Retry mechanism with exponential backoff
    """

    # Convert all tensors to CPU and serializable format
    def make_serializable_cpu(obj):
        if isinstance(obj, torch.Tensor):
            # Force CPU conversion for all tensors
            cpu_obj = obj.detach().cpu()
            return cpu_obj.item() if cpu_obj.numel() == 1 else cpu_obj.tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, dict):
            return {k: make_serializable_cpu(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [make_serializable_cpu(item) for item in obj]
        else:
            return obj

    # Prepare checkpoint with CPU-converted tensors
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_metrics': make_serializable_cpu(train_metrics),
        'val_metrics': make_serializable_cpu(val_metrics),
        'casme2_config': make_serializable_cpu(config),
        'best_f1': float(best_metrics['f1']),
        'best_loss': float(best_metrics['loss']),
        'best_acc': float(best_metrics['accuracy']),
        'class_names': CASME2_CLASSES,
        'num_classes': 7
    }

    final_path = f"{checkpoint_dir}/casme2_poolformer_keyframe_best_f1.pth"

    # Atomic save with retry logic
    for attempt in range(max_retries):
        try:
            # Step 1: Save to temporary file
            temp_fd, temp_path = tempfile.mkstemp(dir=checkpoint_dir, suffix='.pth.tmp')
            os.close(temp_fd)

            print(f"Attempt {attempt + 1}: Saving checkpoint to temporary file...")
            torch.save(checkpoint, temp_path)

            # Step 2: Validate checkpoint by loading it back
            print("Validating checkpoint integrity...")
            validation_checkpoint = torch.load(temp_path, map_location='cpu')

            # Verify critical keys exist
            required_keys = ['model_state_dict', 'epoch', 'best_f1', 'num_classes']
            for key in required_keys:
                if key not in validation_checkpoint:
                    raise ValueError(f"Checkpoint validation failed: missing key '{key}'")

            # Verify epoch matches
            if validation_checkpoint['epoch'] != epoch:
                raise ValueError(f"Checkpoint epoch mismatch: saved {epoch}, loaded {validation_checkpoint['epoch']}")

            print("Checkpoint validation passed")

            # Step 3: Atomic rename (overwrite existing checkpoint)
            print(f"Moving validated checkpoint to final location...")
            shutil.move(temp_path, final_path)

            print(f"Checkpoint saved and validated successfully: {os.path.basename(final_path)}")
            print(f"  Epoch: {epoch + 1}")
            print(f"  Val F1: {best_metrics['f1']:.4f}")
            print(f"  Val Loss: {best_metrics['loss']:.4f}")
            print(f"  Val Acc: {best_metrics['accuracy']:.4f}")

            return final_path

        except Exception as e:
            print(f"Checkpoint save attempt {attempt + 1}/{max_retries} failed: {e}")

            # Clean up temporary file if it exists
            if os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except:
                    pass

            if attempt < max_retries - 1:
                wait_time = 2 ** attempt  # Exponential backoff: 1s, 2s, 4s
                print(f"Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print(f"All {max_retries} checkpoint save attempts failed")
                return None

    return None

# Safe JSON serialization function
def safe_json_serialize(obj):
    """Convert objects to JSON-serializable format"""
    if isinstance(obj, torch.Tensor):
        return obj.cpu().item() if obj.numel() == 1 else obj.cpu().numpy().tolist()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: safe_json_serialize(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [safe_json_serialize(item) for item in obj]
    elif hasattr(obj, '__dict__'):
        return safe_json_serialize(obj.__dict__)
    else:
        try:
            # Try to convert to basic Python types
            return float(obj) if isinstance(obj, (int, float)) else str(obj)
        except:
            return str(obj)

# Create enhanced datasets with normalized metadata
print("\nCreating CASME II Key-Frame training datasets...")

# Normalize metadata structure for v1/v2 compatibility
normalized_metadata = normalize_metadata_structure(GLOBAL_CONFIG_CASME2['metadata'])

train_dataset = CASME2DatasetTraining(
    split_metadata=normalized_metadata,
    dataset_root=GLOBAL_CONFIG_CASME2['train_path'].replace('/train', ''),
    transform=GLOBAL_CONFIG_CASME2['transform_train'],
    split='train',
    use_ram_cache=True
)

val_dataset = CASME2DatasetTraining(
    split_metadata=normalized_metadata,
    dataset_root=GLOBAL_CONFIG_CASME2['val_path'].replace('/val', ''),
    transform=GLOBAL_CONFIG_CASME2['transform_val'],
    split='val',
    use_ram_cache=True
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CASME2_POOLFORMER_CONFIG['batch_size'],
    shuffle=True,
    num_workers=CASME2_POOLFORMER_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CASME2_POOLFORMER_CONFIG['batch_size'],
    shuffle=False,
    num_workers=CASME2_POOLFORMER_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

print(f"Training batches: {len(train_loader)} (samples: {len(train_dataset)})")
print(f"Validation batches: {len(val_loader)} (samples: {len(val_dataset)})")

# Initialize model, criterion, optimizer, scheduler
print("\nInitializing CASME II Key-Frame PoolFormer model...")
model = PoolFormerCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_POOLFORMER_CONFIG['dropout_rate']
).to(GLOBAL_CONFIG_CASME2['device'])

# Enhanced criterion creation using configurable factory function
if CASME2_POOLFORMER_CONFIG['use_focal_loss']:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=True,
        alpha_weights=CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights'],
        gamma=CASME2_POOLFORMER_CONFIG['focal_loss_gamma']
    )
else:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=False,
        alpha_weights=None,
        gamma=2.0
    )

optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_POOLFORMER_CONFIG
)

print(f"Optimizer: AdamW (LR={CASME2_POOLFORMER_CONFIG['learning_rate']})")
print(f"Scheduler: ReduceLROnPlateau (patience={CASME2_POOLFORMER_CONFIG['scheduler_patience']})")
print(f"Criterion: {'Optimized Focal Loss' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'CrossEntropy'}")

# Training history tracking
training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

# Enhanced best metrics tracking for multi-criteria checkpoint saving
best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

print("\nStarting CASME II Key-Frame PoolFormer training...")
print(f"Training configuration: {CASME2_POOLFORMER_CONFIG['num_epochs']} epochs")
print("=" * 80)

# Main training loop with hardened checkpoint system
start_time = time.time()

for epoch in range(CASME2_POOLFORMER_CONFIG['num_epochs']):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{CASME2_POOLFORMER_CONFIG['num_epochs']}")

    # Training phase
    train_loss, train_metrics = train_epoch(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_POOLFORMER_CONFIG['num_epochs']
    )

    # Validation phase
    val_loss, val_metrics, val_sample_ids = validate_epoch(
        model, val_loader, criterion,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_POOLFORMER_CONFIG['num_epochs']
    )

    # Update scheduler
    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    # Record training history
    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    # Print epoch summary
    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    # Enhanced multi-criteria checkpoint saving logic
    save_model = False
    improvement_reason = ""

    # Multi-criteria evaluation hierarchy
    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        best_model_path = save_checkpoint_robust(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics, GLOBAL_CONFIG_CASME2['checkpoint_root'],
            best_metrics, CASME2_POOLFORMER_CONFIG
        )

        if best_model_path:
            print(f"New best model: {improvement_reason} - F1: {best_metrics['f1']:.4f}")
        else:
            print(f"Warning: Failed to save checkpoint despite improvement")

    # Progress tracking
    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * CASME2_POOLFORMER_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / CASME2_POOLFORMER_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min")

# Training completion
total_time = time.time() - start_time
actual_epochs = CASME2_POOLFORMER_CONFIG['num_epochs']

print("\n" + "=" * 80)
print("CASME II KEY-FRAME POOLFORMER TRAINING COMPLETED")
print("=" * 80)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {actual_epochs}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")

# Enhanced training documentation export
results_dir = GLOBAL_CONFIG_CASME2['results_root']
os.makedirs(f"{results_dir}/training_logs", exist_ok=True)

training_history_path = f"{results_dir}/training_logs/casme2_poolformer_keyframe_training_history.json"

print("\nExporting enhanced training documentation...")

try:
    # Create comprehensive training summary with enhanced experiment configuration
    training_summary = {
        'experiment_type': 'CASME2_PoolFormer_KeyFrame_Baseline',
        'experiment_configuration': {
            'dataset_phase': CASME2_POOLFORMER_CONFIG['dataset_phase'],
            'frame_types': CASME2_POOLFORMER_CONFIG['frame_types'],
            'loss_function': 'Optimized Focal Loss' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'CrossEntropy',
            'weight_approach': 'Per-class Alpha (sum=1.0)' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'Inverse Sqrt Frequency',
            'focal_loss_gamma': CASME2_POOLFORMER_CONFIG['focal_loss_gamma'],
            'focal_loss_alpha_weights': CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights'],
            'crossentropy_class_weights': CASME2_POOLFORMER_CONFIG['crossentropy_class_weights'],
            'poolformer_model': CASME2_POOLFORMER_CONFIG['poolformer_model'],
            'model_variant': CASME2_POOLFORMER_CONFIG['model_variant'],
            'model_params': CASME2_POOLFORMER_CONFIG['model_params']
        },
        'training_history': safe_json_serialize(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(actual_epochs),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'config': safe_json_serialize(CASME2_POOLFORMER_CONFIG),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'model_checkpoint': 'casme2_poolformer_keyframe_best_f1.pth',
        'dataset_info': {
            'name': 'CASME_II',
            'phase': CASME2_POOLFORMER_CONFIG['dataset_phase'],
            'frame_types': CASME2_POOLFORMER_CONFIG['frame_types'],
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': 7,
            'class_names': CASME2_CLASSES
        },
        'architecture_info': {
            'model_type': 'PoolFormerCASME2Baseline',
            'backbone': CASME2_POOLFORMER_CONFIG['poolformer_model'],
            'variant': CASME2_POOLFORMER_CONFIG['model_variant'],
            'parameters': CASME2_POOLFORMER_CONFIG['model_params'],
            'input_size': f"{CASME2_POOLFORMER_CONFIG['input_size']}x{CASME2_POOLFORMER_CONFIG['input_size']}",
            'feature_dimension': CASME2_POOLFORMER_CONFIG['expected_feature_dim'],
            'classification_head': f"{CASME2_POOLFORMER_CONFIG['expected_feature_dim']}->512->128->7"
        },
        'enhanced_features': {
            'hardened_checkpoint_system': True,
            'atomic_checkpoint_save': True,
            'checkpoint_validation': True,
            'model_output_validation': True,
            'enhanced_error_handling': True,
            'multi_criteria_checkpoint_logic': True,
            'memory_optimized_training': True,
            'retry_with_backoff': True,
            'token_mixing_via_pooling': True
        }
    }

    # Save with proper JSON serialization
    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Enhanced training documentation saved successfully: {training_history_path}")
    print(f"Experiment details: {training_summary['experiment_configuration']['loss_function']} loss")
    if CASME2_POOLFORMER_CONFIG['use_focal_loss']:
        print(f"  Gamma: {CASME2_POOLFORMER_CONFIG['focal_loss_gamma']}, Alpha Sum: {sum(CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights']):.3f}")
    print(f"Model variant: {CASME2_POOLFORMER_CONFIG['poolformer_model']}")
    print(f"Dataset phase: {CASME2_POOLFORMER_CONFIG['dataset_phase']}")

except Exception as e:
    print(f"Warning: Could not save training documentation: {e}")
    print("Training completed successfully, but documentation export failed")

# Enhanced memory cleanup
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\nNext: Cell 3 - CASME II Key-Frame PoolFormer Evaluation with Test Version Selection")
print("Enhanced training pipeline with hardened checkpoint system completed successfully!")

CASME II Key-Frame PoolFormer Training Pipeline with Hardened Checkpoint System
Loss Function: Optimized Focal Loss
Focal Loss Parameters:
  Gamma: 2.0
  Per-class Alpha: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum: 0.999
Dataset Phase: v2
Frame Types: ['onset', 'apex', 'offset']
Training epochs: 50
Scheduler patience: 3

Creating CASME II Key-Frame training datasets...
Detected v2 metadata format (with 'splits' key)
Loading CASME II train dataset for training...
Loaded 603 CASME II train samples
  others: 237 samples (39.3%)
  disgust: 150 samples (24.9%)
  happiness: 75 samples (12.4%)
  repression: 63 samples (10.4%)
  surprise: 60 samples (10.0%)
  sadness: 15 samples (2.5%)
  fear: 3 samples (0.5%)
Preloading 603 train images to RAM with 12 workers...


Loading train to RAM: 100%|██████████| 603/603 [00:23<00:00, 25.19it/s]


TRAIN RAM caching completed: 603/603 images, ~1.07GB
Loading CASME II val dataset for training...
Loaded 78 CASME II val samples
  others: 30 samples (38.5%)
  disgust: 18 samples (23.1%)
  happiness: 9 samples (11.5%)
  repression: 9 samples (11.5%)
  surprise: 6 samples (7.7%)
  sadness: 3 samples (3.8%)
  fear: 3 samples (3.8%)
Preloading 78 val images to RAM with 12 workers...


Loading val to RAM: 100%|██████████| 78/78 [00:03<00:00, 21.69it/s]


VAL RAM caching completed: 78/78 images, ~0.14GB
Training batches: 38 (samples: 603)
Validation batches: 5 (samples: 78)

Initializing CASME II Key-Frame PoolFormer model...
PoolFormer feature dimension: 768
PoolFormer CASME II: 768 -> 512 -> 128 -> 7
Using Optimized Focal Loss with gamma=2.0
Per-class alpha weights: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
Alpha sum: 0.999
Scheduler: ReduceLROnPlateau monitoring val_f1_macro
Optimizer: AdamW (LR=1e-05)
Scheduler: ReduceLROnPlateau (patience=3)
Criterion: Optimized Focal Loss

Starting CASME II Key-Frame PoolFormer training...
Training configuration: 50 epochs

Epoch 1/50


CASME II Training Epoch 1/50: 100%|██████████| 38/38 [00:34<00:00,  1.09it/s, Loss=0.0994, LR=1.00e-05]
CASME II Validation Epoch 1/50: 100%|██████████| 5/5 [00:02<00:00,  1.80it/s, Val Loss=0.0673]


Train - Loss: 0.1003, F1: 0.1524, Acc: 0.2687
Val   - Loss: 0.1360, F1: 0.0774, Acc: 0.3718
Time  - Epoch: 37.8s, LR: 1.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_poolformer_keyframe_best_f1.pth
  Epoch: 1
  Val F1: 0.0774
  Val Loss: 0.1360
  Val Acc: 0.3718
New best model: Higher F1 - F1: 0.0774
Progress: 2.0% | Best F1: 0.0774 | ETA: 33.3min

Epoch 2/50


CASME II Training Epoch 2/50: 100%|██████████| 38/38 [00:28<00:00,  1.33it/s, Loss=0.0835, LR=1.00e-05]
CASME II Validation Epoch 2/50: 100%|██████████| 5/5 [00:01<00:00,  2.74it/s, Val Loss=0.0591]


Train - Loss: 0.0834, F1: 0.2167, Acc: 0.4129
Val   - Loss: 0.1429, F1: 0.1571, Acc: 0.2692
Time  - Epoch: 30.5s, LR: 1.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_poolformer_keyframe_best_f1.pth
  Epoch: 2
  Val F1: 0.1571
  Val Loss: 0.1429
  Val Acc: 0.2692
New best model: Higher F1 - F1: 0.1571
Progress: 4.0% | Best F1: 0.1571 | ETA: 29.6min

Epoch 3/50


CASME II Training Epoch 3/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0673, LR=1.00e-05]
CASME II Validation Epoch 3/50: 100%|██████████| 5/5 [00:01<00:00,  2.65it/s, Val Loss=0.0593]


Train - Loss: 0.0674, F1: 0.4190, Acc: 0.5357
Val   - Loss: 0.1482, F1: 0.2344, Acc: 0.3205
Time  - Epoch: 30.9s, LR: 1.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_poolformer_keyframe_best_f1.pth
  Epoch: 3
  Val F1: 0.2344
  Val Loss: 0.1482
  Val Acc: 0.3205
New best model: Higher F1 - F1: 0.2344
Progress: 6.0% | Best F1: 0.2344 | ETA: 28.1min

Epoch 4/50


CASME II Training Epoch 4/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0556, LR=1.00e-05]
CASME II Validation Epoch 4/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0514]


Train - Loss: 0.0559, F1: 0.5526, Acc: 0.6003
Val   - Loss: 0.1493, F1: 0.2091, Acc: 0.3718
Time  - Epoch: 30.8s, LR: 1.00e-05
Progress: 8.0% | Best F1: 0.2344 | ETA: 26.6min

Epoch 5/50


CASME II Training Epoch 5/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0460, LR=1.00e-05]
CASME II Validation Epoch 5/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0538]


Train - Loss: 0.0457, F1: 0.6031, Acc: 0.6667
Val   - Loss: 0.1540, F1: 0.1959, Acc: 0.3333
Time  - Epoch: 30.8s, LR: 1.00e-05
Progress: 10.0% | Best F1: 0.2344 | ETA: 25.4min

Epoch 6/50


CASME II Training Epoch 6/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0372, LR=1.00e-05]
CASME II Validation Epoch 6/50: 100%|██████████| 5/5 [00:01<00:00,  2.67it/s, Val Loss=0.0532]


Train - Loss: 0.0372, F1: 0.6827, Acc: 0.7297
Val   - Loss: 0.1552, F1: 0.2316, Acc: 0.3590
Time  - Epoch: 30.8s, LR: 1.00e-05
Progress: 12.0% | Best F1: 0.2344 | ETA: 24.5min

Epoch 7/50


CASME II Training Epoch 7/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0313, LR=1.00e-05]
CASME II Validation Epoch 7/50: 100%|██████████| 5/5 [00:01<00:00,  2.68it/s, Val Loss=0.0551]


Train - Loss: 0.0314, F1: 0.6686, Acc: 0.8043
Val   - Loss: 0.1579, F1: 0.2089, Acc: 0.3590
Time  - Epoch: 30.8s, LR: 5.00e-06
Progress: 14.0% | Best F1: 0.2344 | ETA: 23.7min

Epoch 8/50


CASME II Training Epoch 8/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0256, LR=5.00e-06]
CASME II Validation Epoch 8/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0535]


Train - Loss: 0.0256, F1: 0.8471, Acc: 0.8574
Val   - Loss: 0.1597, F1: 0.2148, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 5.00e-06
Progress: 16.0% | Best F1: 0.2344 | ETA: 22.9min

Epoch 9/50


CASME II Training Epoch 9/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0227, LR=5.00e-06]
CASME II Validation Epoch 9/50: 100%|██████████| 5/5 [00:01<00:00,  2.72it/s, Val Loss=0.0591]


Train - Loss: 0.0227, F1: 0.8809, Acc: 0.8955
Val   - Loss: 0.1643, F1: 0.1854, Acc: 0.2949
Time  - Epoch: 30.8s, LR: 5.00e-06
Progress: 18.0% | Best F1: 0.2344 | ETA: 22.2min

Epoch 10/50


CASME II Training Epoch 10/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0197, LR=5.00e-06]
CASME II Validation Epoch 10/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0565]


Train - Loss: 0.0195, F1: 0.8760, Acc: 0.9022
Val   - Loss: 0.1636, F1: 0.2071, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 5.00e-06
Progress: 20.0% | Best F1: 0.2344 | ETA: 21.6min

Epoch 11/50


CASME II Training Epoch 11/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0183, LR=5.00e-06]
CASME II Validation Epoch 11/50: 100%|██████████| 5/5 [00:01<00:00,  2.73it/s, Val Loss=0.0574]


Train - Loss: 0.0180, F1: 0.9282, Acc: 0.9254
Val   - Loss: 0.1654, F1: 0.1654, Acc: 0.2821
Time  - Epoch: 30.8s, LR: 2.50e-06
Progress: 22.0% | Best F1: 0.2344 | ETA: 20.9min

Epoch 12/50


CASME II Training Epoch 12/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0154, LR=2.50e-06]
CASME II Validation Epoch 12/50: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s, Val Loss=0.0591]


Train - Loss: 0.0155, F1: 0.9543, Acc: 0.9453
Val   - Loss: 0.1676, F1: 0.1468, Acc: 0.2821
Time  - Epoch: 30.8s, LR: 2.50e-06
Progress: 24.0% | Best F1: 0.2344 | ETA: 20.3min

Epoch 13/50


CASME II Training Epoch 13/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0147, LR=2.50e-06]
CASME II Validation Epoch 13/50: 100%|██████████| 5/5 [00:01<00:00,  2.74it/s, Val Loss=0.0587]


Train - Loss: 0.0143, F1: 0.9703, Acc: 0.9635
Val   - Loss: 0.1664, F1: 0.1940, Acc: 0.3077
Time  - Epoch: 30.8s, LR: 2.50e-06
Progress: 26.0% | Best F1: 0.2344 | ETA: 19.7min

Epoch 14/50


CASME II Training Epoch 14/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0137, LR=2.50e-06]
CASME II Validation Epoch 14/50: 100%|██████████| 5/5 [00:01<00:00,  2.74it/s, Val Loss=0.0602]


Train - Loss: 0.0135, F1: 0.9716, Acc: 0.9585
Val   - Loss: 0.1685, F1: 0.1986, Acc: 0.3077
Time  - Epoch: 30.8s, LR: 2.50e-06
Progress: 28.0% | Best F1: 0.2344 | ETA: 19.1min

Epoch 15/50


CASME II Training Epoch 15/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0131, LR=2.50e-06]
CASME II Validation Epoch 15/50: 100%|██████████| 5/5 [00:01<00:00,  2.73it/s, Val Loss=0.0583]


Train - Loss: 0.0131, F1: 0.9728, Acc: 0.9652
Val   - Loss: 0.1672, F1: 0.1933, Acc: 0.3077
Time  - Epoch: 30.8s, LR: 1.25e-06
Progress: 30.0% | Best F1: 0.2344 | ETA: 18.6min

Epoch 16/50


CASME II Training Epoch 16/50: 100%|██████████| 38/38 [00:28<00:00,  1.32it/s, Loss=0.0122, LR=1.25e-06]
CASME II Validation Epoch 16/50: 100%|██████████| 5/5 [00:01<00:00,  2.65it/s, Val Loss=0.0588]


Train - Loss: 0.0121, F1: 0.9833, Acc: 0.9768
Val   - Loss: 0.1678, F1: 0.1849, Acc: 0.2949
Time  - Epoch: 30.8s, LR: 1.25e-06
Progress: 32.0% | Best F1: 0.2344 | ETA: 18.0min

Epoch 17/50


CASME II Training Epoch 17/50: 100%|██████████| 38/38 [00:28<00:00,  1.32it/s, Loss=0.0112, LR=1.25e-06]
CASME II Validation Epoch 17/50: 100%|██████████| 5/5 [00:01<00:00,  2.73it/s, Val Loss=0.0592]


Train - Loss: 0.0111, F1: 0.9851, Acc: 0.9751
Val   - Loss: 0.1685, F1: 0.1646, Acc: 0.2821
Time  - Epoch: 30.6s, LR: 1.25e-06
Progress: 34.0% | Best F1: 0.2344 | ETA: 17.4min

Epoch 18/50


CASME II Training Epoch 18/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0110, LR=1.25e-06]
CASME II Validation Epoch 18/50: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s, Val Loss=0.0595]


Train - Loss: 0.0109, F1: 0.9816, Acc: 0.9735
Val   - Loss: 0.1688, F1: 0.2220, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.25e-06
Progress: 36.0% | Best F1: 0.2344 | ETA: 16.9min

Epoch 19/50


CASME II Training Epoch 19/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0106, LR=1.25e-06]
CASME II Validation Epoch 19/50: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s, Val Loss=0.0593]


Train - Loss: 0.0106, F1: 0.9846, Acc: 0.9801
Val   - Loss: 0.1684, F1: 0.1684, Acc: 0.2821
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 38.0% | Best F1: 0.2344 | ETA: 16.3min

Epoch 20/50


CASME II Training Epoch 20/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0100, LR=1.00e-06]
CASME II Validation Epoch 20/50: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s, Val Loss=0.0602]


Train - Loss: 0.0099, F1: 0.9853, Acc: 0.9801
Val   - Loss: 0.1692, F1: 0.2220, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 40.0% | Best F1: 0.2344 | ETA: 15.8min

Epoch 21/50


CASME II Training Epoch 21/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0097, LR=1.00e-06]
CASME II Validation Epoch 21/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0614]


Train - Loss: 0.0100, F1: 0.9826, Acc: 0.9751
Val   - Loss: 0.1702, F1: 0.1977, Acc: 0.3077
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 42.0% | Best F1: 0.2344 | ETA: 15.2min

Epoch 22/50


CASME II Training Epoch 22/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0097, LR=1.00e-06]
CASME II Validation Epoch 22/50: 100%|██████████| 5/5 [00:01<00:00,  2.62it/s, Val Loss=0.0606]


Train - Loss: 0.0097, F1: 0.9905, Acc: 0.9851
Val   - Loss: 0.1699, F1: 0.1849, Acc: 0.2949
Time  - Epoch: 30.9s, LR: 1.00e-06
Progress: 44.0% | Best F1: 0.2344 | ETA: 14.7min

Epoch 23/50


CASME II Training Epoch 23/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0093, LR=1.00e-06]
CASME II Validation Epoch 23/50: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s, Val Loss=0.0598]


Train - Loss: 0.0094, F1: 0.9857, Acc: 0.9851
Val   - Loss: 0.1695, F1: 0.1900, Acc: 0.2949
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 46.0% | Best F1: 0.2344 | ETA: 14.2min

Epoch 24/50


CASME II Training Epoch 24/50: 100%|██████████| 38/38 [00:28<00:00,  1.32it/s, Loss=0.0092, LR=1.00e-06]
CASME II Validation Epoch 24/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0600]


Train - Loss: 0.0092, F1: 0.9880, Acc: 0.9834
Val   - Loss: 0.1697, F1: 0.2084, Acc: 0.3077
Time  - Epoch: 30.7s, LR: 1.00e-06
Progress: 48.0% | Best F1: 0.2344 | ETA: 13.6min

Epoch 25/50


CASME II Training Epoch 25/50: 100%|██████████| 38/38 [00:28<00:00,  1.32it/s, Loss=0.0090, LR=1.00e-06]
CASME II Validation Epoch 25/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0605]


Train - Loss: 0.0090, F1: 0.9789, Acc: 0.9768
Val   - Loss: 0.1701, F1: 0.2258, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 50.0% | Best F1: 0.2344 | ETA: 13.1min

Epoch 26/50


CASME II Training Epoch 26/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0088, LR=1.00e-06]
CASME II Validation Epoch 26/50: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s, Val Loss=0.0620]


Train - Loss: 0.0087, F1: 0.9921, Acc: 0.9884
Val   - Loss: 0.1712, F1: 0.2023, Acc: 0.3077
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 52.0% | Best F1: 0.2344 | ETA: 12.6min

Epoch 27/50


CASME II Training Epoch 27/50: 100%|██████████| 38/38 [00:28<00:00,  1.32it/s, Loss=0.0085, LR=1.00e-06]
CASME II Validation Epoch 27/50: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s, Val Loss=0.0619]


Train - Loss: 0.0084, F1: 0.9864, Acc: 0.9867
Val   - Loss: 0.1713, F1: 0.2092, Acc: 0.3077
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 54.0% | Best F1: 0.2344 | ETA: 12.0min

Epoch 28/50


CASME II Training Epoch 28/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0080, LR=1.00e-06]
CASME II Validation Epoch 28/50: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s, Val Loss=0.0625]


Train - Loss: 0.0079, F1: 0.9934, Acc: 0.9900
Val   - Loss: 0.1713, F1: 0.2221, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 56.0% | Best F1: 0.2344 | ETA: 11.5min

Epoch 29/50


CASME II Training Epoch 29/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0084, LR=1.00e-06]
CASME II Validation Epoch 29/50: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s, Val Loss=0.0621]


Train - Loss: 0.0084, F1: 0.9949, Acc: 0.9934
Val   - Loss: 0.1712, F1: 0.2221, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 58.0% | Best F1: 0.2344 | ETA: 11.0min

Epoch 30/50


CASME II Training Epoch 30/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0078, LR=1.00e-06]
CASME II Validation Epoch 30/50: 100%|██████████| 5/5 [00:01<00:00,  2.73it/s, Val Loss=0.0628]


Train - Loss: 0.0078, F1: 0.9922, Acc: 0.9900
Val   - Loss: 0.1723, F1: 0.2101, Acc: 0.3077
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 60.0% | Best F1: 0.2344 | ETA: 10.4min

Epoch 31/50


CASME II Training Epoch 31/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0078, LR=1.00e-06]
CASME II Validation Epoch 31/50: 100%|██████████| 5/5 [00:01<00:00,  2.68it/s, Val Loss=0.0628]


Train - Loss: 0.0079, F1: 0.9907, Acc: 0.9867
Val   - Loss: 0.1719, F1: 0.2259, Acc: 0.3462
Time  - Epoch: 30.9s, LR: 1.00e-06
Progress: 62.0% | Best F1: 0.2344 | ETA: 9.9min

Epoch 32/50


CASME II Training Epoch 32/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0077, LR=1.00e-06]
CASME II Validation Epoch 32/50: 100%|██████████| 5/5 [00:01<00:00,  2.72it/s, Val Loss=0.0634]


Train - Loss: 0.0077, F1: 0.9907, Acc: 0.9884
Val   - Loss: 0.1730, F1: 0.2101, Acc: 0.3077
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 64.0% | Best F1: 0.2344 | ETA: 9.4min

Epoch 33/50


CASME II Training Epoch 33/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0072, LR=1.00e-06]
CASME II Validation Epoch 33/50: 100%|██████████| 5/5 [00:01<00:00,  2.67it/s, Val Loss=0.0632]


Train - Loss: 0.0073, F1: 0.9930, Acc: 0.9900
Val   - Loss: 0.1730, F1: 0.2094, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 66.0% | Best F1: 0.2344 | ETA: 8.9min

Epoch 34/50


CASME II Training Epoch 34/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0072, LR=1.00e-06]
CASME II Validation Epoch 34/50: 100%|██████████| 5/5 [00:01<00:00,  2.72it/s, Val Loss=0.0623]


Train - Loss: 0.0072, F1: 0.9928, Acc: 0.9917
Val   - Loss: 0.1722, F1: 0.2288, Acc: 0.3333
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 68.0% | Best F1: 0.2344 | ETA: 8.3min

Epoch 35/50


CASME II Training Epoch 35/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0065, LR=1.00e-06]
CASME II Validation Epoch 35/50: 100%|██████████| 5/5 [00:01<00:00,  2.67it/s, Val Loss=0.0625]


Train - Loss: 0.0067, F1: 0.9949, Acc: 0.9934
Val   - Loss: 0.1723, F1: 0.2184, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 70.0% | Best F1: 0.2344 | ETA: 7.8min

Epoch 36/50


CASME II Training Epoch 36/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0068, LR=1.00e-06]
CASME II Validation Epoch 36/50: 100%|██████████| 5/5 [00:01<00:00,  2.73it/s, Val Loss=0.0628]


Train - Loss: 0.0067, F1: 0.9943, Acc: 0.9934
Val   - Loss: 0.1727, F1: 0.2205, Acc: 0.3333
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 72.0% | Best F1: 0.2344 | ETA: 7.3min

Epoch 37/50


CASME II Training Epoch 37/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0065, LR=1.00e-06]
CASME II Validation Epoch 37/50: 100%|██████████| 5/5 [00:01<00:00,  2.65it/s, Val Loss=0.0629]


Train - Loss: 0.0065, F1: 0.9902, Acc: 0.9884
Val   - Loss: 0.1733, F1: 0.2324, Acc: 0.3333
Time  - Epoch: 30.9s, LR: 1.00e-06
Progress: 74.0% | Best F1: 0.2344 | ETA: 6.8min

Epoch 38/50


CASME II Training Epoch 38/50: 100%|██████████| 38/38 [00:28<00:00,  1.32it/s, Loss=0.0063, LR=1.00e-06]
CASME II Validation Epoch 38/50: 100%|██████████| 5/5 [00:01<00:00,  2.72it/s, Val Loss=0.0634]


Train - Loss: 0.0063, F1: 0.9957, Acc: 0.9950
Val   - Loss: 0.1739, F1: 0.2167, Acc: 0.3205
Time  - Epoch: 30.7s, LR: 1.00e-06
Progress: 76.0% | Best F1: 0.2344 | ETA: 6.2min

Epoch 39/50


CASME II Training Epoch 39/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0063, LR=1.00e-06]
CASME II Validation Epoch 39/50: 100%|██████████| 5/5 [00:01<00:00,  2.64it/s, Val Loss=0.0645]


Train - Loss: 0.0064, F1: 0.9943, Acc: 0.9934
Val   - Loss: 0.1746, F1: 0.2289, Acc: 0.3333
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 78.0% | Best F1: 0.2344 | ETA: 5.7min

Epoch 40/50


CASME II Training Epoch 40/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0063, LR=1.00e-06]
CASME II Validation Epoch 40/50: 100%|██████████| 5/5 [00:01<00:00,  2.72it/s, Val Loss=0.0639]


Train - Loss: 0.0063, F1: 0.9972, Acc: 0.9967
Val   - Loss: 0.1745, F1: 0.2177, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 80.0% | Best F1: 0.2344 | ETA: 5.2min

Epoch 41/50


CASME II Training Epoch 41/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0062, LR=1.00e-06]
CASME II Validation Epoch 41/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0639]


Train - Loss: 0.0062, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1740, F1: 0.2430, Acc: 0.3462
Time  - Epoch: 30.8s, LR: 1.00e-06
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_poolformer_keyframe_best_f1.pth
  Epoch: 41
  Val F1: 0.2430
  Val Loss: 0.1740
  Val Acc: 0.3462
New best model: Higher F1 - F1: 0.2430
Progress: 82.0% | Best F1: 0.2430 | ETA: 4.7min

Epoch 42/50


CASME II Training Epoch 42/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0058, LR=1.00e-06]
CASME II Validation Epoch 42/50: 100%|██████████| 5/5 [00:01<00:00,  2.68it/s, Val Loss=0.0638]


Train - Loss: 0.0058, F1: 0.9958, Acc: 0.9950
Val   - Loss: 0.1743, F1: 0.2430, Acc: 0.3462
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 84.0% | Best F1: 0.2430 | ETA: 4.2min

Epoch 43/50


CASME II Training Epoch 43/50: 100%|██████████| 38/38 [00:29<00:00,  1.31it/s, Loss=0.0058, LR=1.00e-06]
CASME II Validation Epoch 43/50: 100%|██████████| 5/5 [00:01<00:00,  2.68it/s, Val Loss=0.0646]


Train - Loss: 0.0058, F1: 0.9937, Acc: 0.9967
Val   - Loss: 0.1754, F1: 0.2299, Acc: 0.3333
Time  - Epoch: 30.9s, LR: 1.00e-06
Progress: 86.0% | Best F1: 0.2430 | ETA: 3.6min

Epoch 44/50


CASME II Training Epoch 44/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0052, LR=1.00e-06]
CASME II Validation Epoch 44/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0657]


Train - Loss: 0.0052, F1: 0.9949, Acc: 0.9934
Val   - Loss: 0.1762, F1: 0.2433, Acc: 0.3462
Time  - Epoch: 30.8s, LR: 1.00e-06
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_poolformer_keyframe_best_f1.pth
  Epoch: 44
  Val F1: 0.2433
  Val Loss: 0.1762
  Val Acc: 0.3462
New best model: Higher F1 - F1: 0.2433
Progress: 88.0% | Best F1: 0.2433 | ETA: 3.1min

Epoch 45/50


CASME II Training Epoch 45/50: 100%|██████████| 38/38 [00:29<00:00,  1.31it/s, Loss=0.0050, LR=1.00e-06]
CASME II Validation Epoch 45/50: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s, Val Loss=0.0653]


Train - Loss: 0.0051, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1762, F1: 0.2288, Acc: 0.3333
Time  - Epoch: 30.9s, LR: 1.00e-06
Progress: 90.0% | Best F1: 0.2433 | ETA: 2.6min

Epoch 46/50


CASME II Training Epoch 46/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0051, LR=1.00e-06]
CASME II Validation Epoch 46/50: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s, Val Loss=0.0648]


Train - Loss: 0.0052, F1: 0.9986, Acc: 0.9983
Val   - Loss: 0.1758, F1: 0.2389, Acc: 0.3462
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 92.0% | Best F1: 0.2433 | ETA: 2.1min

Epoch 47/50


CASME II Training Epoch 47/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0051, LR=1.00e-06]
CASME II Validation Epoch 47/50: 100%|██████████| 5/5 [00:01<00:00,  2.73it/s, Val Loss=0.0658]


Train - Loss: 0.0050, F1: 0.9972, Acc: 0.9967
Val   - Loss: 0.1764, F1: 0.2139, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 94.0% | Best F1: 0.2433 | ETA: 1.6min

Epoch 48/50


CASME II Training Epoch 48/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0050, LR=1.00e-06]
CASME II Validation Epoch 48/50: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s, Val Loss=0.0656]


Train - Loss: 0.0050, F1: 0.9986, Acc: 0.9983
Val   - Loss: 0.1767, F1: 0.2138, Acc: 0.3205
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 96.0% | Best F1: 0.2433 | ETA: 1.0min

Epoch 49/50


CASME II Training Epoch 49/50: 100%|██████████| 38/38 [00:28<00:00,  1.31it/s, Loss=0.0050, LR=1.00e-06]
CASME II Validation Epoch 49/50: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s, Val Loss=0.0664]


Train - Loss: 0.0049, F1: 0.9986, Acc: 0.9983
Val   - Loss: 0.1772, F1: 0.2389, Acc: 0.3462
Time  - Epoch: 30.8s, LR: 1.00e-06
Progress: 98.0% | Best F1: 0.2433 | ETA: 0.5min

Epoch 50/50


CASME II Training Epoch 50/50: 100%|██████████| 38/38 [00:28<00:00,  1.32it/s, Loss=0.0050, LR=1.00e-06]
CASME II Validation Epoch 50/50: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s, Val Loss=0.0648]


Train - Loss: 0.0049, F1: 0.9972, Acc: 0.9967
Val   - Loss: 0.1764, F1: 0.2378, Acc: 0.3462
Time  - Epoch: 30.7s, LR: 1.00e-06
Progress: 100.0% | Best F1: 0.2433 | ETA: 0.0min

CASME II KEY-FRAME POOLFORMER TRAINING COMPLETED
Training time: 26.0 minutes
Epochs completed: 50
Best validation F1: 0.2433 (epoch 44)
Final train F1: 0.9972
Final validation F1: 0.2378

Exporting enhanced training documentation...
Enhanced training documentation saved successfully: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/03_03_poolformer_casme2_kfs/training_logs/casme2_poolformer_keyframe_training_history.json
Experiment details: Optimized Focal Loss loss
  Gamma: 2.0, Alpha Sum: 0.999
Model variant: sail/poolformer_m48
Dataset phase: v2

Next: Cell 3 - CASME II Key-Frame PoolFormer Evaluation with Test Version Selection
Enhanced training pipeline with hardened checkpoint system completed successfully!


In [None]:
# @title Cell 3: CASME II Key-Frame PoolFormer Evaluation with Test Version Selection

# File: 03_03_PoolFormer_CASME2_KFS_Cell3.py
# Location: experiments/03_03_PoolFormer_CASME2-KFS.ipynb
# Purpose: Comprehensive evaluation framework with support for v1 (apex-only) or v2 (key-frames) test sets

import os
import time
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime

# Evaluation specific imports
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import pickle
import warnings
warnings.filterwarnings('ignore')

# =====================================================
# TEST DATASET VERSION SELECTOR
# =====================================================
# Select which test dataset to use for evaluation:
# 'v1' = Phase 1 apex-only frames (28 samples from data_split/test/)
# 'v2' = Phase 2 key-frames (84 samples from data_split_v2/test/)

TEST_DATASET_VERSION = 'v2'  # Default: Phase 2 key-frames

print("CASME II Key-Frame PoolFormer Evaluation Framework")
print("=" * 60)
print(f"Test Dataset Version: {TEST_DATASET_VERSION}")
print("=" * 60)

# =====================================================
# DYNAMIC TEST DATASET CONFIGURATION
# =====================================================

def get_test_dataset_config(version, project_root):
    """
    Get test dataset configuration based on version selection

    Args:
        version: 'v1' or 'v2'
        project_root: Project root path

    Returns:
        dict: Configuration for selected test dataset
    """
    if version == 'v1':
        config = {
            'version': 'v1',
            'phase': 'Phase 1',
            'dataset_path': f"{project_root}/datasets/processed_casme2/data_split",
            'metadata_file': 'split_metadata.json',
            'processing_summary': 'processing_summary.json',
            'description': 'Apex-only frames',
            'expected_samples': 28,
            'frame_types': ['apex']
        }
    elif version == 'v2':
        config = {
            'version': 'v2',
            'phase': 'Phase 2',
            'dataset_path': f"{project_root}/datasets/processed_casme2/data_split_v2",
            'metadata_file': 'split_metadata_v2.json',
            'processing_summary': 'processing_summary_v2.json',
            'description': 'Key-frames (onset, apex, offset)',
            'expected_samples': 84,
            'frame_types': ['onset', 'apex', 'offset']
        }
    else:
        raise ValueError(f"Invalid TEST_DATASET_VERSION: {version}. Must be 'v1' or 'v2'")

    return config

# Get test dataset configuration
test_config = get_test_dataset_config(TEST_DATASET_VERSION, PROJECT_ROOT)

print(f"Test Dataset Configuration:")
print(f"  Version: {test_config['version']}")
print(f"  Phase: {test_config['phase']}")
print(f"  Description: {test_config['description']}")
print(f"  Expected samples: {test_config['expected_samples']}")
print(f"  Frame types: {test_config['frame_types']}")
print(f"  Dataset path: {test_config['dataset_path']}")

# Independent test dataset metadata loading
print(f"\nLoading test metadata independently from selected version...")
test_metadata_path = f"{test_config['dataset_path']}/{test_config['metadata_file']}"
test_processing_path = f"{test_config['dataset_path']}/{test_config['processing_summary']}"

if not os.path.exists(test_metadata_path):
    raise FileNotFoundError(f"Test metadata not found: {test_metadata_path}")

print(f"Test metadata path: {test_metadata_path}")

with open(test_metadata_path, 'r') as f:
    test_metadata = json.load(f)

# Normalize metadata structure for v1/v2 compatibility
def normalize_metadata_structure(metadata):
    """Normalize metadata structure to handle both v1 and v2 formats"""
    if 'splits' in metadata:
        print("  Metadata format: v2 (with 'splits' key)")
        return metadata['splits']
    elif 'train' in metadata or 'test' in metadata:
        print("  Metadata format: v1 (direct split keys)")
        return metadata
    else:
        raise ValueError("Unknown metadata format")

normalized_test_metadata = normalize_metadata_structure(test_metadata)

# Verify test samples
if 'test' not in normalized_test_metadata:
    raise ValueError("Test split not found in metadata")

actual_test_samples = len(normalized_test_metadata['test']['samples'])
print(f"Loaded {actual_test_samples} test samples (expected: {test_config['expected_samples']})")

if actual_test_samples != test_config['expected_samples']:
    print(f"WARNING: Sample count mismatch! Expected {test_config['expected_samples']}, got {actual_test_samples}")

# Enhanced test dataset for CASME II evaluation
class CASME2DatasetEvaluation(Dataset):
    """Enhanced CASME II test dataset with comprehensive evaluation support"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='test', use_ram_cache=True):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.emotions = []
        self.subjects = []
        self.cached_images = []

        print(f"Loading CASME II {split} dataset for evaluation...")

        # Process metadata for evaluation
        for sample in self.metadata:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(CLASS_TO_IDX[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])
            self.emotions.append(sample['emotion'])
            self.subjects.append(sample['subject'])

        print(f"Loaded {len(self.images)} CASME II {split} samples for evaluation")
        self._print_evaluation_distribution()

        # RAM caching for fast evaluation
        if self.use_ram_cache:
            self._preload_to_ram_evaluation()

    def _print_evaluation_distribution(self):
        """Print comprehensive distribution for evaluation analysis"""
        if len(self.labels) == 0:
            print("No test samples found!")
            return

        label_counts = {}
        subject_counts = {}

        for label, subject in zip(self.labels, self.subjects):
            label_counts[label] = label_counts.get(label, 0) + 1
            subject_counts[subject] = subject_counts.get(subject, 0) + 1

        print("Test set class distribution:")
        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

        print(f"Test set covers {len(subject_counts)} subjects")

        # Check for missing classes
        missing_classes = []
        for i, class_name in enumerate(CASME2_CLASSES):
            if i not in label_counts:
                missing_classes.append(class_name)

        if missing_classes:
            print(f"Missing classes in test set: {missing_classes}")

    def _preload_to_ram_evaluation(self):
        """RAM preloading with parallel loading optimized for evaluation"""
        print(f"Preloading {len(self.images)} test images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            """Load single image with error handling"""
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (384, 384), (128, 128, 128)), False

        # Parallel loading with ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            # Submit all loading tasks
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            # Collect results with progress bar
            for future in tqdm(futures, desc="Loading test images to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * 384 * 384 * 3 * 4 / 1e9
        print(f"Test RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (384, 384), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return (image, self.labels[idx], self.sample_ids[idx],
                self.emotions[idx], self.subjects[idx], os.path.basename(self.images[idx]))

# CASME II evaluation configuration with test version tracking
EVALUATION_CONFIG_CASME2 = {
    'model_type': 'PoolFormer_CASME2_KeyFrame_Baseline',
    'task_type': 'micro_expression_recognition',
    'num_classes': 7,
    'class_names': CASME2_CLASSES,
    'checkpoint_file': 'casme2_poolformer_keyframe_best_f1.pth',
    'dataset_name': 'CASME_II',
    'input_size': '384x384',
    'evaluation_protocol': 'stratified_split',
    'test_dataset_version': test_config['version'],
    'test_dataset_phase': test_config['phase'],
    'test_dataset_description': test_config['description'],
    'test_frame_types': test_config['frame_types']
}

print(f"\nCASME II PoolFormer Evaluation Configuration:")
print(f"  Model: {EVALUATION_CONFIG_CASME2['model_type']}")
print(f"  Task: {EVALUATION_CONFIG_CASME2['task_type']}")
print(f"  Test Version: {EVALUATION_CONFIG_CASME2['test_dataset_version']}")
print(f"  Test Description: {EVALUATION_CONFIG_CASME2['test_dataset_description']}")
print(f"  Classes: {EVALUATION_CONFIG_CASME2['class_names']}")
print(f"  Input size: {EVALUATION_CONFIG_CASME2['input_size']}")

def extract_logits_safe_casme2(outputs_all):
    """Robust logits extraction for CASME II PoolFormer model"""
    if isinstance(outputs_all, torch.Tensor):
        return outputs_all
    if isinstance(outputs_all, (tuple, list)):
        for item in outputs_all:
            if isinstance(item, torch.Tensor):
                return item
    if isinstance(outputs_all, dict):
        for key in ('logits', 'logit', 'predictions', 'outputs', 'scores'):
            value = outputs_all.get(key)
            if isinstance(value, torch.Tensor):
                return value
        # Fallback to first tensor value
        for value in outputs_all.values():
            if isinstance(value, torch.Tensor):
                return value
    raise RuntimeError("Unable to extract tensor logits from CASME II PoolFormer model output")

def load_trained_model_casme2(checkpoint_path, device):
    """Load trained CASME II PoolFormer model with comprehensive compatibility"""
    print(f"Loading trained CASME II PoolFormer model from: {checkpoint_path}")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # Multiple loading approaches for maximum compatibility
    checkpoint = None
    loading_method = "unknown"

    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        loading_method = "standard"
    except Exception as e1:
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            loading_method = "weights_only_false"
        except Exception as e2:
            try:
                import pickle
                with open(checkpoint_path, 'rb') as f:
                    checkpoint = pickle.load(f)
                loading_method = "pickle"
            except Exception as e3:
                raise RuntimeError(f"All loading methods failed: {e1}, {e2}, {e3}")

    print(f"Checkpoint loaded using: {loading_method}")

    # Initialize CASME II PoolFormer model
    model = PoolFormerCASME2Baseline(
        num_classes=EVALUATION_CONFIG_CASME2['num_classes'],
        dropout_rate=CASME2_POOLFORMER_CONFIG['dropout_rate']
    ).to(device)

    # Load state dict with fallback approaches
    state_dict = checkpoint.get('model_state_dict', checkpoint)

    try:
        model.load_state_dict(state_dict, strict=True)
        print("Model state loaded with strict=True")
    except Exception as e:
        print(f"Strict loading failed, trying non-strict: {str(e)[:100]}...")
        try:
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            if missing_keys or unexpected_keys:
                print(f"Non-strict loading: Missing {len(missing_keys)}, Unexpected {len(unexpected_keys)}")
            else:
                print("Model state loaded with strict=False (no key mismatches)")
        except Exception as e2:
            raise RuntimeError(f"Both loading approaches failed: {e2}")

    model.eval()

    # Extract training information
    training_info = {
        'best_val_f1': float(checkpoint.get('best_f1', 0.0)),
        'best_val_loss': float(checkpoint.get('best_loss', float('inf'))),
        'best_val_accuracy': float(checkpoint.get('best_acc', 0.0)),
        'best_epoch': int(checkpoint.get('epoch', 0)) + 1,
        'model_checkpoint': EVALUATION_CONFIG_CASME2['checkpoint_file'],
        'num_classes': EVALUATION_CONFIG_CASME2['num_classes'],
        'config': checkpoint.get('casme2_config', {})
    }

    print(f"Model loaded successfully:")
    print(f"  Best validation F1: {training_info['best_val_f1']:.4f}")
    print(f"  Best validation accuracy: {training_info['best_val_accuracy']:.4f}")
    print(f"  Best epoch: {training_info['best_epoch']}")
    print(f"  Model classes: {EVALUATION_CONFIG_CASME2['num_classes']}")

    return model, training_info

def run_model_inference_casme2(model, test_loader, device):
    """Run CASME II PoolFormer model inference with comprehensive tracking"""
    print("Running CASME II PoolFormer model inference on test set...")

    model.eval()
    all_predictions = []
    all_probabilities = []
    all_labels = []
    all_sample_ids = []
    all_emotions = []
    all_subjects = []
    all_filenames = []

    inference_start = time.time()

    with torch.no_grad():
        for batch_idx, (images, labels, sample_ids, emotions, subjects, filenames) in enumerate(
            tqdm(test_loader, desc="CASME II Inference")):

            images = images.to(device)

            # Forward pass with robust output extraction
            try:
                outputs_raw = model(images)
                outputs = extract_logits_safe_casme2(outputs_raw)
            except Exception as e:
                print(f"Error in model forward pass: {e}")
                outputs = model(images)
                if not isinstance(outputs, torch.Tensor):
                    outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs

            # Validate output shape for 7 CASME II classes
            if outputs.shape[1] != 7:
                print(f"Warning: Expected 7 classes output, got {outputs.shape[1]}")

            # Get probabilities and predictions
            probabilities = torch.softmax(outputs, dim=1)
            predictions = torch.argmax(probabilities, dim=1)

            # Store results (CPU for memory efficiency)
            all_predictions.extend(predictions.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_sample_ids.extend(sample_ids)
            all_emotions.extend(emotions)
            all_subjects.extend(subjects)
            all_filenames.extend(filenames)

    inference_time = time.time() - inference_start

    print(f"CASME II inference completed: {len(all_predictions)} samples in {inference_time:.2f}s")

    # Analyze prediction distribution
    predictions_array = np.array(all_predictions)
    labels_array = np.array(all_labels)

    unique_predictions, pred_counts = np.unique(predictions_array, return_counts=True)
    print(f"Predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    unique_labels, label_counts = np.unique(labels_array, return_counts=True)
    print(f"True classes in test: {[CASME2_CLASSES[i] for i in unique_labels]}")

    return {
        'predictions': predictions_array,
        'probabilities': np.array(all_probabilities),
        'labels': labels_array,
        'sample_ids': all_sample_ids,
        'emotions': all_emotions,
        'subjects': all_subjects,
        'filenames': all_filenames,
        'inference_time': inference_time,
        'samples_count': len(predictions_array)
    }

def analyze_wrong_predictions_casme2(inference_results):
    """Comprehensive wrong predictions analysis for CASME II"""
    print("Analyzing wrong predictions for CASME II micro-expression recognition...")

    predictions = inference_results['predictions']
    labels = inference_results['labels']
    sample_ids = inference_results['sample_ids']
    emotions = inference_results['emotions']
    subjects = inference_results['subjects']
    filenames = inference_results['filenames']

    # Find wrong predictions
    wrong_mask = predictions != labels
    wrong_indices = np.where(wrong_mask)[0]

    # Organize by true emotion class
    wrong_predictions_by_class = {}
    subject_error_analysis = {}

    for class_name in CASME2_CLASSES:
        wrong_predictions_by_class[class_name] = []

    # Analyze wrong predictions
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        sample_id = sample_ids[idx]
        emotion = emotions[idx]
        subject = subjects[idx]
        filename = filenames[idx]

        true_class = CASME2_CLASSES[true_label]
        pred_class = CASME2_CLASSES[pred_label]

        wrong_info = {
            'sample_id': sample_id,
            'filename': filename,
            'subject': subject,
            'true_label': int(true_label),
            'true_class': true_class,
            'predicted_label': int(pred_label),
            'predicted_class': pred_class,
            'emotion': emotion
        }

        wrong_predictions_by_class[true_class].append(wrong_info)

        # Subject error tracking
        if subject not in subject_error_analysis:
            subject_error_analysis[subject] = {'total': 0, 'wrong': 0, 'errors': []}
        subject_error_analysis[subject]['wrong'] += 1
        subject_error_analysis[subject]['errors'].append(wrong_info)

    # Count total samples per subject
    for subject in subjects:
        if subject in subject_error_analysis:
            subject_error_analysis[subject]['total'] += 1
        else:
            subject_error_analysis[subject] = {'total': 1, 'wrong': 0, 'errors': []}

    # Calculate error rates per subject
    for subject in subject_error_analysis:
        total = subject_error_analysis[subject]['total']
        wrong = subject_error_analysis[subject]['wrong']
        subject_error_analysis[subject]['error_rate'] = wrong / total if total > 0 else 0.0

    # Summary statistics
    total_wrong = len(wrong_indices)
    total_samples = len(predictions)
    error_rate = (total_wrong / total_samples) * 100

    # Confusion patterns analysis
    confusion_patterns = {}
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        pattern = f"{CASME2_CLASSES[true_label]}_to_{CASME2_CLASSES[pred_label]}"
        confusion_patterns[pattern] = confusion_patterns.get(pattern, 0) + 1

    analysis_results = {
        'analysis_metadata': {
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'model_type': EVALUATION_CONFIG_CASME2['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2['dataset_name'],
            'test_version': EVALUATION_CONFIG_CASME2['test_dataset_version'],
            'test_description': EVALUATION_CONFIG_CASME2['test_dataset_description'],
            'total_samples': int(total_samples),
            'total_wrong_predictions': int(total_wrong),
            'overall_error_rate': float(error_rate)
        },
        'wrong_predictions_by_class': wrong_predictions_by_class,
        'subject_error_analysis': subject_error_analysis,
        'confusion_patterns': confusion_patterns,
        'error_summary': {
            class_name: len(wrong_predictions_by_class[class_name])
            for class_name in CASME2_CLASSES
        }
    }

    return analysis_results

def calculate_comprehensive_metrics_casme2(inference_results):
    """Calculate comprehensive evaluation metrics for CASME II micro-expression recognition"""
    print("Calculating comprehensive metrics for CASME II micro-expression recognition...")

    predictions = inference_results['predictions']
    probabilities = inference_results['probabilities']
    labels = inference_results['labels']

    if len(predictions) == 0:
        raise ValueError("No predictions to evaluate!")

    # Identify available classes in test set
    unique_test_labels = sorted(np.unique(labels))
    unique_predictions = sorted(np.unique(predictions))

    print(f"Test set contains labels: {[CASME2_CLASSES[i] for i in unique_test_labels]}")
    print(f"Model predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    # Basic metrics
    accuracy = accuracy_score(labels, predictions)

    # Macro metrics (only for available classes)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions, labels=unique_test_labels, average='macro', zero_division=0
    )

    print(f"Macro F1 (available classes): {f1:.4f}")

    # Per-class metrics (all 7 classes)
    precision_per_class, recall_per_class, f1_per_class, support_per_class = precision_recall_fscore_support(
        labels, predictions, labels=range(7), average=None, zero_division=0
    )

    # Confusion matrix
    cm = confusion_matrix(labels, predictions, labels=range(7))

    # Multi-class AUC (only for classes with test samples)
    auc_scores = {}
    fpr_dict = {}
    tpr_dict = {}

    try:
        labels_binarized = label_binarize(labels, classes=range(7))

        for i, class_name in enumerate(CASME2_CLASSES):
            if i in unique_test_labels and len(np.unique(labels_binarized[:, i])) > 1:
                fpr, tpr, _ = roc_curve(labels_binarized[:, i], probabilities[:, i])
                auc_score = auc(fpr, tpr)
                auc_scores[class_name] = float(auc_score)
                fpr_dict[class_name] = fpr.tolist()
                tpr_dict[class_name] = tpr.tolist()
            else:
                auc_scores[class_name] = 0.0
                fpr_dict[class_name] = [0.0, 1.0]
                tpr_dict[class_name] = [0.0, 0.0]

        # Macro AUC for available classes
        available_auc_scores = [auc_scores[CASME2_CLASSES[i]] for i in unique_test_labels]
        macro_auc = float(np.mean(available_auc_scores)) if available_auc_scores else 0.0

    except Exception as e:
        print(f"Warning: AUC calculation failed: {e}")
        auc_scores = {class_name: 0.0 for class_name in CASME2_CLASSES}
        macro_auc = 0.0

    # Subject-level analysis
    subjects = inference_results['subjects']
    subject_performance = {}

    for subject in set(subjects):
        subject_mask = [s == subject for s in subjects]
        subject_predictions = predictions[subject_mask]
        subject_labels = labels[subject_mask]

        if len(subject_predictions) > 0:
            subject_acc = accuracy_score(subject_labels, subject_predictions)
            subject_performance[subject] = {
                'accuracy': float(subject_acc),
                'samples': int(len(subject_predictions)),
                'correct': int(np.sum(subject_predictions == subject_labels))
            }

    # Comprehensive results with test version tracking
    comprehensive_results = {
        'evaluation_metadata': {
            'model_type': EVALUATION_CONFIG_CASME2['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2['dataset_name'],
            'test_version': EVALUATION_CONFIG_CASME2['test_dataset_version'],
            'test_phase': EVALUATION_CONFIG_CASME2['test_dataset_phase'],
            'test_description': EVALUATION_CONFIG_CASME2['test_dataset_description'],
            'test_frame_types': EVALUATION_CONFIG_CASME2['test_frame_types'],
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'num_classes': EVALUATION_CONFIG_CASME2['num_classes'],
            'class_names': EVALUATION_CONFIG_CASME2['class_names'],
            'test_samples': int(len(labels)),
            'available_classes': [CASME2_CLASSES[i] for i in unique_test_labels],
            'missing_classes': [CASME2_CLASSES[i] for i in range(7) if i not in unique_test_labels]
        },

        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_precision': float(precision),
            'macro_recall': float(recall),
            'macro_f1': float(f1),
            'macro_auc': macro_auc
        },

        'per_class_performance': {},

        'confusion_matrix': cm.tolist(),

        'subject_level_performance': subject_performance,

        'roc_analysis': {
            'auc_scores': auc_scores,
            'fpr_curves': fpr_dict,
            'tpr_curves': tpr_dict
        },

        'inference_performance': {
            'total_time_seconds': float(inference_results['inference_time']),
            'average_time_ms_per_sample': float(inference_results['inference_time'] * 1000 / len(labels))
        }
    }

    # Per-class performance details
    for i, class_name in enumerate(CASME2_CLASSES):
        comprehensive_results['per_class_performance'][class_name] = {
            'precision': float(precision_per_class[i]),
            'recall': float(recall_per_class[i]),
            'f1_score': float(f1_per_class[i]),
            'support': int(support_per_class[i]),
            'auc': auc_scores[class_name],
            'in_test_set': i in unique_test_labels
        }

    return comprehensive_results

def save_evaluation_results_casme2(evaluation_results, wrong_predictions_results, results_dir, test_version):
    """Save comprehensive evaluation results for CASME II with test version in filename"""
    os.makedirs(results_dir, exist_ok=True)

    # Save main evaluation results with version suffix
    results_file = f"{results_dir}/casme2_poolformer_keyframe_evaluation_results_{test_version}.json"
    with open(results_file, 'w') as f:
        json.dump(evaluation_results, f, indent=2, default=str)

    # Save wrong predictions analysis with version suffix
    wrong_predictions_file = f"{results_dir}/casme2_poolformer_keyframe_wrong_predictions_{test_version}.json"
    with open(wrong_predictions_file, 'w') as f:
        json.dump(wrong_predictions_results, f, indent=2, default=str)

    print(f"Evaluation results saved:")
    print(f"  Main results: {os.path.basename(results_file)}")
    print(f"  Wrong predictions: {os.path.basename(wrong_predictions_file)}")

    return results_file, wrong_predictions_file

# Main evaluation execution
try:
    print("\nStarting CASME II Key-Frame PoolFormer comprehensive evaluation...")
    print(f"Using test dataset: {test_config['description']} ({test_config['version']})")

    # Create test dataset with selected version
    print(f"\nCreating CASME II test dataset from {test_config['phase']}...")
    casme2_test_dataset = CASME2DatasetEvaluation(
        split_metadata=normalized_test_metadata,
        dataset_root=test_config['dataset_path'],
        transform=GLOBAL_CONFIG_CASME2['transform_val'],
        split='test',
        use_ram_cache=True
    )

    if len(casme2_test_dataset) == 0:
        raise ValueError("No test samples found! Check test data path.")

    casme2_test_loader = DataLoader(
        casme2_test_dataset,
        batch_size=CASME2_POOLFORMER_CONFIG['batch_size'],
        shuffle=False,
        num_workers=CASME2_POOLFORMER_CONFIG['num_workers'],
        pin_memory=True
    )

    # Load trained model
    checkpoint_path = f"{GLOBAL_CONFIG_CASME2['checkpoint_root']}/{EVALUATION_CONFIG_CASME2['checkpoint_file']}"
    casme2_model, training_info = load_trained_model_casme2(checkpoint_path, GLOBAL_CONFIG_CASME2['device'])

    # Run inference
    inference_results = run_model_inference_casme2(casme2_model, casme2_test_loader, GLOBAL_CONFIG_CASME2['device'])

    # Calculate comprehensive metrics
    evaluation_results = calculate_comprehensive_metrics_casme2(inference_results)

    # Analyze wrong predictions
    wrong_predictions_results = analyze_wrong_predictions_casme2(inference_results)

    # Add training information
    evaluation_results['training_information'] = training_info

    # Save results with test version in filename
    results_dir = f"{GLOBAL_CONFIG_CASME2['results_root']}/evaluation_results"
    results_file, wrong_file = save_evaluation_results_casme2(
        evaluation_results, wrong_predictions_results, results_dir, test_config['version']
    )

    # Display comprehensive results
    print("\n" + "=" * 60)
    print("CASME II KEY-FRAME POOLFORMER EVALUATION RESULTS")
    print("=" * 60)
    print(f"Test Dataset: {test_config['description']} ({test_config['version']})")

    # Overall performance
    overall = evaluation_results['overall_performance']
    print(f"\nOverall Performance (Macro - Available Classes):")
    print(f"  Accuracy:  {overall['accuracy']:.4f}")
    print(f"  Precision: {overall['macro_precision']:.4f}")
    print(f"  Recall:    {overall['macro_recall']:.4f}")
    print(f"  F1 Score:  {overall['macro_f1']:.4f}")
    print(f"  AUC:       {overall['macro_auc']:.4f}")

    # Per-class performance
    print(f"\nPer-Class Performance:")
    for class_name, metrics in evaluation_results['per_class_performance'].items():
        in_test = "Present" if metrics['in_test_set'] else "Missing"
        print(f"  {class_name} [{in_test}]: F1={metrics['f1_score']:.4f}, "
              f"AUC={metrics['auc']:.4f}, Support={metrics['support']}")

    # Training vs test comparison
    print(f"\nTraining vs Test Performance:")
    training_f1 = training_info['best_val_f1']
    training_acc = training_info['best_val_accuracy']
    test_f1 = overall['macro_f1']
    test_acc = overall['accuracy']

    print(f"  Training Val F1:  {training_f1:.4f}")
    print(f"  Test F1:          {test_f1:.4f}")
    print(f"  F1 Difference:    {training_f1 - test_f1:+.4f}")
    print(f"  Training Val Acc: {training_acc:.4f}")
    print(f"  Test Accuracy:    {test_acc:.4f}")
    print(f"  Acc Difference:   {training_acc - test_acc:+.4f}")
    print(f"  Best Epoch:       {training_info['best_epoch']}")

    # Wrong predictions summary
    print(f"\n" + "=" * 40)
    print("WRONG PREDICTIONS ANALYSIS")
    print("=" * 40)

    wrong_meta = wrong_predictions_results['analysis_metadata']
    print(f"Total wrong predictions: {wrong_meta['total_wrong_predictions']} / {wrong_meta['total_samples']}")
    print(f"Overall error rate: {wrong_meta['overall_error_rate']:.2f}%")

    print(f"\nErrors by True Class:")
    for class_name, error_count in wrong_predictions_results['error_summary'].items():
        if error_count > 0:
            wrong_samples = wrong_predictions_results['wrong_predictions_by_class'][class_name]
            print(f"  {class_name}: {error_count} errors")
            for sample in wrong_samples[:3]:  # Show first 3
                print(f"    - {sample['filename']} -> predicted as {sample['predicted_class']}")
            if len(wrong_samples) > 3:
                print(f"    ... and {len(wrong_samples) - 3} more")

    # Subject-level analysis
    print(f"\nSubject-Level Performance:")
    subject_perfs = list(evaluation_results['subject_level_performance'].items())
    subject_perfs.sort(key=lambda x: x[1]['accuracy'], reverse=True)
    for subject, perf in subject_perfs[:5]:  # Show top 5
        print(f"  {subject}: {perf['accuracy']:.3f} ({perf['correct']}/{perf['samples']})")

    # Most common confusion patterns
    print(f"\nMost Common Confusion Patterns:")
    patterns = sorted(wrong_predictions_results['confusion_patterns'].items(),
                     key=lambda x: x[1], reverse=True)
    for pattern, count in patterns[:3]:
        print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    print(f"  Total time: {inference_results['inference_time']:.2f}s")
    print(f"  Speed: {evaluation_results['inference_performance']['average_time_ms_per_sample']:.1f} ms/sample")

    print(f"\nTest Dataset Info:")
    print(f"  Version: {test_config['version']}")
    print(f"  Phase: {test_config['phase']}")
    print(f"  Frame types: {test_config['frame_types']}")
    print(f"  Missing classes: {evaluation_results['evaluation_metadata']['missing_classes']}")

    print("\n" + "=" * 60)
    print("CASME II KEY-FRAME POOLFORMER EVALUATION COMPLETED")
    print("=" * 60)

except Exception as e:
    print(f"Evaluation failed: {e}")
    import traceback
    traceback.print_exc()

finally:
    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

print(f"\nEvaluation completed with test dataset {TEST_DATASET_VERSION}")
print("Next: Generate confusion matrix and comparative analysis")

CASME II Key-Frame PoolFormer Evaluation Framework
Test Dataset Version: v2
Test Dataset Configuration:
  Version: v2
  Phase: Phase 2
  Description: Key-frames (onset, apex, offset)
  Expected samples: 84
  Frame types: ['onset', 'apex', 'offset']
  Dataset path: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split_v2

Loading test metadata independently from selected version...
Test metadata path: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split_v2/split_metadata_v2.json
  Metadata format: v2 (with 'splits' key)
Loaded 84 test samples (expected: 84)

CASME II PoolFormer Evaluation Configuration:
  Model: PoolFormer_CASME2_KeyFrame_Baseline
  Task: micro_expression_recognition
  Test Version: v2
  Test Description: Key-frames (onset, apex, offset)
  Classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
  Input size: 384x384

S

Loading test images to RAM: 100%|██████████| 84/84 [00:02<00:00, 29.85it/s]


Test RAM caching completed: 84/84 images, ~0.15GB
Loading trained CASME II PoolFormer model from: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/03_03_poolformer_casme2_kfs/casme2_poolformer_keyframe_best_f1.pth
Checkpoint loaded using: standard
PoolFormer feature dimension: 768
PoolFormer CASME II: 768 -> 512 -> 128 -> 7
Model state loaded with strict=True
Model loaded successfully:
  Best validation F1: 0.2433
  Best validation accuracy: 0.3462
  Best epoch: 44
  Model classes: 7
Running CASME II PoolFormer model inference on test set...


CASME II Inference: 100%|██████████| 6/6 [00:02<00:00,  2.07it/s]

CASME II inference completed: 84 samples in 2.90s
Predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise']
True classes in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Calculating comprehensive metrics for CASME II micro-expression recognition...
Test set contains labels: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Model predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise']
Macro F1 (available classes): 0.3795
Analyzing wrong predictions for CASME II micro-expression recognition...
Evaluation results saved:
  Main results: casme2_poolformer_keyframe_evaluation_results_v2.json
  Wrong predictions: casme2_poolformer_keyframe_wrong_predictions_v2.json

CASME II KEY-FRAME POOLFORMER EVALUATION RESULTS
Test Dataset: Key-frames (onset, apex, offset) (v2)

Overall Performance (Macro - Available Classes):
  Accuracy:  0.4881
  Precision: 0.4461
  Recall:    0.3722
  F1 Score:  0.3795
 




In [None]:
# @title Cell 4: CASME II Key-Frame PoolFormer Confusion Matrix Generation

# File: 03_03_PoolFormer_CASME2_KFS_Cell4.py
# Location: experiments/03_03_PoolFormer_CASME2-KFS.ipynb
# Purpose: Generate professional confusion matrix and comprehensive analysis with test version support

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II Key-Frame PoolFormer Confusion Matrix Generation")
print("=" * 60)

# Project paths configuration
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/03_03_poolformer_casme2_kfs"

def find_evaluation_json_files_casme2(results_path):
    """Find CASME II evaluation JSON files with version detection"""
    json_files = {}

    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        # Look for evaluation results with version suffix
        for version in ['v1', 'v2']:
            eval_pattern = f"{eval_dir}/casme2_poolformer_keyframe_evaluation_results_{version}.json"
            eval_files = glob.glob(eval_pattern)

            if eval_files:
                json_files[f'main_{version}'] = eval_files[0]
                print(f"Found {version.upper()} evaluation file: {os.path.basename(eval_files[0])}")

            # Look for wrong predictions with version suffix
            wrong_pattern = f"{eval_dir}/casme2_poolformer_keyframe_wrong_predictions_{version}.json"
            wrong_files = glob.glob(wrong_pattern)

            if wrong_files:
                json_files[f'wrong_{version}'] = wrong_files[0]
                print(f"Found {version.upper()} wrong predictions: {os.path.basename(wrong_files[0])}")

        if not json_files:
            print(f"WARNING: No evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (evaluation) has been executed first!")
    else:
        print(f"ERROR: Evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results_casme2(json_path):
    """Load and parse CASME II evaluation results JSON"""
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1_casme2(per_class_performance):
    """Calculate weighted F1 score for CASME II micro-expression classes"""
    # Only count classes that have test samples (support > 0)
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy_casme2(confusion_matrix):
    """
    Calculate balanced accuracy for CASME II 7-class micro-expression recognition
    Handles classes with zero support (missing in test set)
    """
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []

    # Find classes with actual test samples
    classes_with_samples = []
    for i in range(n_classes):
        if cm[i, :].sum() > 0:
            classes_with_samples.append(i)

    for i in classes_with_samples:
        # True positives, false negatives, false positives, true negatives for class i
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - tp - fn - fp

        # Calculate sensitivity and specificity for class i
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        # Per-class balanced accuracy
        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    # Overall balanced accuracy (mean of available classes)
    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color_casme2(color_value, threshold=0.5):
    """Determine optimal text color based on background intensity"""
    return 'white' if color_value > threshold else 'black'

def analyze_missing_classes_casme2(data):
    """Analyze missing classes in CASME II test set"""
    meta = data['evaluation_metadata']
    available_classes = meta.get('available_classes', [])
    missing_classes = meta.get('missing_classes', [])

    print(f"Class Analysis:")
    print(f"  Available in test: {available_classes}")
    print(f"  Missing from test: {missing_classes}")

    return {
        'available': available_classes,
        'missing': missing_classes,
        'total_classes': len(meta['class_names'])
    }

def create_confusion_matrix_plot_casme2(data, output_path, test_version):
    """Create professional confusion matrix visualization for CASME II micro-expression recognition"""

    # Extract data
    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']
    test_desc = meta.get('test_description', test_version)

    print(f"Processing confusion matrix for CASME II classes: {class_names}")
    print(f"Test version: {test_version} - {test_desc}")
    print(f"Confusion matrix shape: {cm.shape}")

    # Calculate comprehensive metrics
    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1_casme2(per_class)
    balanced_acc = calculate_balanced_accuracy_casme2(cm)

    print(f"Calculated metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Balanced Acc: {balanced_acc:.4f}, Accuracy: {accuracy:.4f}")

    # Row-wise normalization for percentage display
    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    # Create visualization with appropriate size for 7 classes
    fig, ax = plt.subplots(figsize=(12, 10))

    # Color scheme optimized for micro-expression research
    cmap = 'Blues'

    # Create heatmap with improved color scaling
    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    # Annotate cells with count and percentage
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            # Handle percentage calculation for classes with 0 samples
            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"

            # Determine text color based on cell intensity
            cell_value = cm_pct[i, j]
            text_color = determine_text_color_casme2(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    # Configure axes with micro-expression class names
    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    # Add test version and missing classes note
    missing_classes = meta.get('missing_classes', [])
    note_text = f"Test: {test_desc} ({test_version.upper()})"
    if missing_classes:
        note_text += f"\nMissing: {', '.join(missing_classes)}"

    ax.text(0.02, 0.98, note_text, transform=ax.transAxes, fontsize=9,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    # Create comprehensive title for micro-expression research with PoolFormer
    title = f"CASME II Micro-Expression Recognition - PoolFormer - {test_version.upper()}\n"
    title += f"Acc: {accuracy:.4f}  |  Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes,
        'test_version': test_version
    }

def create_per_class_performance_chart_casme2(data, output_path, test_version):
    """Create per-class performance visualization for CASME II"""
    per_class = data['per_class_performance']
    class_names = data['evaluation_metadata']['class_names']
    test_desc = data['evaluation_metadata'].get('test_description', test_version)

    # Extract metrics for each class
    classes = []
    f1_scores = []
    precisions = []
    recalls = []
    supports = []
    in_test_flags = []

    for class_name in class_names:
        class_data = per_class[class_name]
        classes.append(class_name)
        f1_scores.append(class_data['f1_score'])
        precisions.append(class_data['precision'])
        recalls.append(class_data['recall'])
        supports.append(class_data['support'])
        in_test_flags.append(class_data['in_test_set'])

    # Create grouped bar chart
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

    x = np.arange(len(classes))
    width = 0.25

    # Top plot: F1, Precision, Recall
    bars1 = ax1.bar(x - width, f1_scores, width, label='F1 Score', alpha=0.8, color='steelblue')
    bars2 = ax1.bar(x, precisions, width, label='Precision', alpha=0.8, color='orange')
    bars3 = ax1.bar(x + width, recalls, width, label='Recall', alpha=0.8, color='green')

    ax1.set_xlabel('Emotion Classes', fontweight='bold')
    ax1.set_ylabel('Score', fontweight='bold')
    ax1.set_title(f'CASME II Per-Class Performance - PoolFormer - {test_version.upper()} ({test_desc})',
                  fontweight='bold', pad=20)
    ax1.set_xticks(x)
    ax1.set_xticklabels(classes, rotation=45, ha='right')
    ax1.legend()
    ax1.grid(axis='y', alpha=0.3)
    ax1.set_ylim(0, 1.0)

    # Add value labels on bars
    for bars in [bars1, bars2, bars3]:
        for bar, in_test in zip(bars, in_test_flags):
            height = bar.get_height()
            if in_test:
                ax1.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)
            else:
                ax1.annotate('N/A', xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3), textcoords="offset points", ha='center', va='bottom',
                           fontsize=8, color='red', fontweight='bold')

    # Bottom plot: Support (sample count)
    bars4 = ax2.bar(x, supports, color='purple', alpha=0.7)
    ax2.set_xlabel('Emotion Classes', fontweight='bold')
    ax2.set_ylabel('Number of Test Samples', fontweight='bold')
    ax2.set_title(f'CASME II Test Set Class Distribution - {test_version.upper()}', fontweight='bold')
    ax2.set_xticks(x)
    ax2.set_xticklabels(classes, rotation=45, ha='right')
    ax2.grid(axis='y', alpha=0.3)

    # Add value labels on support bars
    for bar, support in zip(bars4, supports):
        height = bar.get_height()
        ax2.annotate(f'{int(support)}', xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=10)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Per-class performance chart saved to: {os.path.basename(output_path)}")

def generate_performance_summary_casme2(evaluation_data, wrong_predictions_data=None):
    """Generate comprehensive performance summary for CASME II"""

    print("\n" + "=" * 60)
    print("CASME II MICRO-EXPRESSION RECOGNITION PERFORMANCE SUMMARY")
    print("=" * 60)

    # Overall performance
    overall = evaluation_data['overall_performance']
    meta = evaluation_data['evaluation_metadata']

    print(f"Dataset: {meta['dataset']}")
    print(f"Test version: {meta.get('test_version', 'N/A')}")
    print(f"Test description: {meta.get('test_description', 'N/A')}")
    print(f"Test samples: {meta['test_samples']}")
    print(f"Model: {meta['model_type']}")
    print(f"Evaluation date: {meta['evaluation_timestamp']}")

    print(f"\nOverall Performance:")
    print(f"  Accuracy:         {overall['accuracy']:.4f}")
    print(f"  Macro Precision:  {overall['macro_precision']:.4f}")
    print(f"  Macro Recall:     {overall['macro_recall']:.4f}")
    print(f"  Macro F1:         {overall['macro_f1']:.4f}")
    print(f"  Macro AUC:        {overall['macro_auc']:.4f}")

    # Per-class performance
    print(f"\nPer-Class Performance:")
    per_class = evaluation_data['per_class_performance']

    print(f"{'Class':<12} {'F1':<8} {'Precision':<10} {'Recall':<8} {'AUC':<8} {'Support':<8} {'In Test'}")
    print("-" * 65)

    for class_name, metrics in per_class.items():
        in_test = "Yes" if metrics['in_test_set'] else "No"
        print(f"{class_name:<12} {metrics['f1_score']:<8.4f} {metrics['precision']:<10.4f} "
              f"{metrics['recall']:<8.4f} {metrics['auc']:<8.4f} {metrics['support']:<8} {in_test}")

    # Training vs test performance
    if 'training_information' in evaluation_data:
        training = evaluation_data['training_information']
        print(f"\nTraining vs Test Comparison:")
        print(f"  Training Val F1:  {training['best_val_f1']:.4f}")
        print(f"  Test F1:          {overall['macro_f1']:.4f}")
        print(f"  Performance Gap:  {training['best_val_f1'] - overall['macro_f1']:+.4f}")
        print(f"  Best Epoch:       {training['best_epoch']}")

    # Class imbalance analysis
    missing_classes = meta.get('missing_classes', [])
    available_classes = meta.get('available_classes', [])

    print(f"\nClass Availability Analysis:")
    print(f"  Available classes: {len(available_classes)}/7")
    print(f"  Missing classes: {missing_classes if missing_classes else 'None'}")

    # Wrong predictions summary if available
    if wrong_predictions_data:
        wrong_meta = wrong_predictions_data['analysis_metadata']
        print(f"\nError Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']}/{wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        # Top confusion patterns
        patterns = wrong_predictions_data.get('confusion_patterns', {})
        if patterns:
            print(f"\nTop Confusion Patterns:")
            sorted_patterns = sorted(patterns.items(), key=lambda x: x[1], reverse=True)[:3]
            for pattern, count in sorted_patterns:
                print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    inference = evaluation_data['inference_performance']
    print(f"  Total time: {inference['total_time_seconds']:.2f}s")
    print(f"  Speed: {inference['average_time_ms_per_sample']:.1f} ms/sample")

# Find evaluation JSON files
json_files = find_evaluation_json_files_casme2(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (evaluation) has been executed first!")
else:
    print(f"\nFound {len([k for k in json_files.keys() if k.startswith('main_')])} evaluation result(s)")

# Create output directory
output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Process evaluation results for each version
results_summary = {}
generated_files = []

for version in ['v1', 'v2']:
    main_key = f'main_{version}'
    wrong_key = f'wrong_{version}'

    if main_key in json_files:
        print(f"\n{'='*60}")
        print(f"Processing {version.upper()} Evaluation Results")
        print(f"{'='*60}")

        # Load evaluation data
        eval_data = load_evaluation_results_casme2(json_files[main_key])

        # Load wrong predictions data if available
        wrong_data = None
        if wrong_key in json_files:
            wrong_data = load_evaluation_results_casme2(json_files[wrong_key])

        if eval_data is not None:
            try:
                # Analyze missing classes
                class_analysis = analyze_missing_classes_casme2(eval_data)

                # Generate confusion matrix
                cm_output_path = os.path.join(output_dir, f"confusion_matrix_CASME2_PoolFormer_KeyFrame_{version}.png")
                metrics = create_confusion_matrix_plot_casme2(eval_data, cm_output_path, version)
                generated_files.append(cm_output_path)

                # Generate per-class performance chart
                perf_output_path = os.path.join(output_dir, f"per_class_performance_CASME2_PoolFormer_KeyFrame_{version}.png")
                create_per_class_performance_chart_casme2(eval_data, perf_output_path, version)
                generated_files.append(perf_output_path)

                results_summary[version] = metrics
                results_summary[version]['class_analysis'] = class_analysis

                print(f"SUCCESS: {version.upper()} visualization files generated successfully")

            except Exception as e:
                print(f"ERROR: Failed to generate {version.upper()} visualizations: {str(e)}")
                import traceback
                traceback.print_exc()

            # Generate comprehensive summary
            generate_performance_summary_casme2(eval_data, wrong_data)
        else:
            print(f"ERROR: Could not load {version.upper()} evaluation data")

# Final summary
if generated_files:
    print(f"\n" + "=" * 60)
    print("CASME II CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 60)

    print(f"\nGenerated visualization files:")
    for file_path in generated_files:
        filename = os.path.basename(file_path)
        print(f"  {filename}")

    # Summary for each version
    for version in ['v1', 'v2']:
        if version in results_summary:
            print(f"\n{version.upper()} Performance Summary:")
            metrics = results_summary[version]
            print(f"  Accuracy:       {metrics['accuracy']:.4f}")
            print(f"  Macro F1:       {metrics['macro_f1']:.4f}")
            print(f"  Weighted F1:    {metrics['weighted_f1']:.4f}")
            print(f"  Balanced Acc:   {metrics['balanced_accuracy']:.4f}")

            if 'class_analysis' in metrics:
                analysis = metrics['class_analysis']
                print(f"  Available classes: {len(analysis['available'])}/{analysis['total_classes']}")
                print(f"  Missing classes: {len(analysis['missing'])}")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No visualizations were generated")
    print("Please check:")
    print("1. Cell 3 evaluation results exist")
    print("2. JSON file structure is correct")
    print("3. No file permission issues")

print("\nCell 4 completed - CASME II confusion matrix analysis generated")

CASME II Key-Frame PoolFormer Confusion Matrix Generation
Found V1 evaluation file: casme2_poolformer_keyframe_evaluation_results_v1.json
Found V1 wrong predictions: casme2_poolformer_keyframe_wrong_predictions_v1.json
Found V2 evaluation file: casme2_poolformer_keyframe_evaluation_results_v2.json
Found V2 wrong predictions: casme2_poolformer_keyframe_wrong_predictions_v2.json

Found 2 evaluation result(s)

Processing V1 Evaluation Results
Successfully loaded: casme2_poolformer_keyframe_evaluation_results_v1.json
Successfully loaded: casme2_poolformer_keyframe_wrong_predictions_v1.json
Class Analysis:
  Available in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
  Missing from test: ['fear']
Processing confusion matrix for CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
Test version: v1 - Apex-only frames
Confusion matrix shape: (7, 7)
Calculated metrics - Macro F1: 0.3974, Weighted F1: 0.4906, Balanced Ac