In [1]:
# @title Cell 1: CASME II Apex Frame PoolFormer Infrastructure Configuration

# File: 05_03_PoolFormer_CASME2_AF_Cell1.py
# Location: experiments/05_03_PoolFormer_CASME2-AF-PREP.ipynb
# Purpose: PoolFormer for CASME II micro-expression recognition with apex frame and face-aware preprocessing

# Mount Google Drive
from google.colab import drive
print("=" * 60)
print("CASME II APEX FRAME POOLFORMER WITH FACE-AWARE PREPROCESSING")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Project paths configuration
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/preprocessed_v7"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/05_03_poolformer_casme2_af_prep"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/05_03_poolformer_casme2_af_prep"

# Load CASME II v7 preprocessing metadata
PREPROCESSING_SUMMARY = f"{DATASET_ROOT}/preprocessing_summary.json"

print("CASME II Apex Frame PoolFormer - Face-Aware Preprocessing Infrastructure")
print("=" * 60)

# Validate preprocessing metadata exists
if not os.path.exists(PREPROCESSING_SUMMARY):
    raise FileNotFoundError(f"v7 preprocessing summary not found: {PREPROCESSING_SUMMARY}")

# Load v7 preprocessing metadata
print("Loading CASME II v7 preprocessing metadata...")
with open(PREPROCESSING_SUMMARY, 'r') as f:
    preprocessing_info = json.load(f)

print(f"Dataset variant: {preprocessing_info['variant']}")
print(f"Processing date: {preprocessing_info['processing_date']}")
print(f"Preprocessing method: {preprocessing_info['preprocessing_method']}")
print(f"Total images processed: {preprocessing_info['total_processed']}")
print(f"Face detection rate: {preprocessing_info['face_detection_stats']['detection_rate']:.2%}")

# Extract preprocessing parameters
preproc_params = preprocessing_info['preprocessing_parameters']
print(f"Target size: {preproc_params['target_size']}x{preproc_params['target_size']}px")
print(f"BBox expansion: {preproc_params['bbox_expansion']}px (all directions)")

# =====================================================
# EXPERIMENT CONFIGURATION - Apex Frame with Face-Aware Preprocessing
# =====================================================

# FOCAL LOSS CONFIGURATION - Toggle for experimentation
USE_FOCAL_LOSS = True  # Default: CrossEntropy, Set True to enable Focal Loss
FOCAL_LOSS_GAMMA = 2.0  # Focal loss focusing parameter (if enabled)

# OPTIMIZED CLASS WEIGHTS CONFIGURATION
# CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
# v7 Train distribution: [79, 50, 25, 21, 20, 5, 1] - inverse sqrt frequency approach

# CrossEntropy Loss - Optimized inverse square root frequency weights
CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.26, 1.78, 1.94, 1.99, 3.98, 8.89]

# Focal Loss - Normalized per-class alpha values (sum = 1.0)
FOCAL_LOSS_ALPHA_WEIGHTS = [0.052, 0.066, 0.093, 0.101, 0.104, 0.208, 0.376]

# POOLFORMER MODEL CONFIGURATION - Support m36 and m48 variants
POOLFORMER_MODEL_VARIANT = 'm48'  # Default: m36, Options: 'm36' or 'm48'

# Dynamic PoolFormer model selection based on variant
if POOLFORMER_MODEL_VARIANT == 'm36':
    POOLFORMER_MODEL_NAME = 'sail/poolformer_m36'
    MODEL_PARAMS = '56M'
    print("Using PoolFormer-M36 for micro-expression analysis (56M parameters)")
elif POOLFORMER_MODEL_VARIANT == 'm48':
    POOLFORMER_MODEL_NAME = 'sail/poolformer_m48'
    MODEL_PARAMS = '73M'
    print("Using PoolFormer-M48 for enhanced micro-expression recognition (73M parameters)")
else:
    raise ValueError(f"Unsupported POOLFORMER_MODEL_VARIANT: {POOLFORMER_MODEL_VARIANT}")

# Display experiment configuration
print("\n" + "=" * 50)
print("EXPERIMENT CONFIGURATION - APEX FRAME FACE-AWARE")
print("=" * 50)
print(f"Dataset: v7 Apex Frame with Face-Aware Preprocessing")
print(f"Loss Function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
if USE_FOCAL_LOSS:
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Alpha Weights (per-class): {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum Validation: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Class Weights (inverse sqrt freq): {CROSSENTROPY_CLASS_WEIGHTS}")
print(f"PoolFormer Model: {POOLFORMER_MODEL_NAME}")
print(f"Model Parameters: {MODEL_PARAMS}")
print(f"Input Resolution: 224x224px (face-centered with bbox expansion)")
print(f"Image Format: Grayscale converted to RGB (3-channel)")
print("=" * 50)

# Enhanced GPU configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

# Optimized batch size for small dataset (201 train samples)
# Using batch size 4 for better gradient estimates with extreme class imbalance
BATCH_SIZE = 4
NUM_WORKERS = 4

if 'A100' in gpu_name or 'L4' in gpu_name:
    torch.backends.cudnn.benchmark = True
    print(f"GPU optimization enabled for {gpu_name}")

print(f"Small dataset configuration: Batch size {BATCH_SIZE} (optimal for 201 samples)")
print(f"Iterations per epoch: {201 // BATCH_SIZE} (more frequent weight updates)")

# RAM preloading workers (separate from DataLoader workers)
RAM_PRELOAD_WORKERS = 32
print(f"RAM preload workers: {RAM_PRELOAD_WORKERS} (parallel image loading)")

# CASME II class mapping and analysis
CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

# Extract class distribution from v7 preprocessing metadata
print("\nLoading v7 class distribution...")
train_dist = preprocessing_info['splits']['train']['emotion_distribution']
val_dist = preprocessing_info['splits']['val']['emotion_distribution']
test_dist = preprocessing_info['splits']['test']['emotion_distribution']

# Convert to ordered list matching CASME2_CLASSES
def emotion_dist_to_list(emotion_dict, class_names):
    """Convert emotion distribution dict to ordered list"""
    return [emotion_dict.get(cls, 0) for cls in class_names]

train_dist_list = emotion_dist_to_list(train_dist, CASME2_CLASSES)
val_dist_list = emotion_dist_to_list(val_dist, CASME2_CLASSES)
test_dist_list = emotion_dist_to_list(test_dist, CASME2_CLASSES)

print(f"\nv7 Train distribution: {train_dist_list}")
print(f"v7 Val distribution: {val_dist_list}")
print(f"v7 Test distribution: {test_dist_list}")

# Apply optimized class weights based on loss function selection
if USE_FOCAL_LOSS:
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

# CASME II PoolFormer Configuration for Apex Frame with Face-Aware Preprocessing
# Optimized for small dataset (201 train samples) with extreme class imbalance
CASME2_POOLFORMER_CONFIG = {
    # Architecture configuration - simplified for small dataset
    'poolformer_model': POOLFORMER_MODEL_NAME,
    'model_variant': POOLFORMER_MODEL_VARIANT,
    'model_params': MODEL_PARAMS,
    'input_size': 224,
    'num_classes': 7,
    'dropout_rate': 0.3,  # Increased from 0.2 for stronger regularization
    'expected_feature_dim': 512,

    # Training configuration - optimized for small dataset
    'learning_rate': 2e-5,  # Slightly higher for small batch size (4)
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'device': device,

    # Scheduler configuration
    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_factor': 0.5,
    'scheduler_patience': 5,
    'scheduler_min_lr': 1e-6,
    'scheduler_monitor': 'val_f1_macro',

    # Loss configuration
    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS,
    'class_weights': class_weights,

    # Evaluation configuration
    'use_macro_avg': True,
    'early_stopping': False,
    'save_best_f1': True,
    'save_strategy': 'best_only',

    # v7 specific configuration
    'dataset_version': 'v7',
    'preprocessing_method': 'face_aware_bbox_expansion',
    'frame_strategy': 'apex_frame',
    'image_format': 'grayscale_to_rgb',
    'bbox_expansion': preproc_params['bbox_expansion'],
    'face_detection_rate': preprocessing_info['face_detection_stats']['detection_rate']
}

print(f"\nPoolFormer Configuration Summary:")
print(f"  Model: {CASME2_POOLFORMER_CONFIG['poolformer_model']}")
print(f"  Input size: {CASME2_POOLFORMER_CONFIG['input_size']}px")
print(f"  Learning rate: {CASME2_POOLFORMER_CONFIG['learning_rate']} (optimized for small batches)")
print(f"  Batch size: {BATCH_SIZE} (optimal for small dataset)")
print(f"  Dropout rate: {CASME2_POOLFORMER_CONFIG['dropout_rate']} (increased regularization)")
print(f"  Dataset version: {CASME2_POOLFORMER_CONFIG['dataset_version']}")
print(f"  Frame strategy: {CASME2_POOLFORMER_CONFIG['frame_strategy']}")
print(f"  Preprocessing: {CASME2_POOLFORMER_CONFIG['preprocessing_method']}")
print(f"  Train images: {preprocessing_info['splits']['train']['total_images']}")
print(f"  Architecture: Simplified 512->256->7 for small dataset")

# =====================================================
# FOCAL LOSS IMPLEMENTATION
# =====================================================

class OptimizedFocalLoss(nn.Module):
    """
    Advanced Focal Loss implementation with per-class alpha support
    Paper: Focal Loss for Dense Object Detection (Lin et al., 2017)
    """

    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()

        if alpha is not None:
            if isinstance(alpha, list):
                self.alpha = torch.tensor(alpha, dtype=torch.float32)
            else:
                self.alpha = alpha

            alpha_sum = self.alpha.sum().item()
            if abs(alpha_sum - 1.0) > 0.01:
                print(f"Warning: Alpha weights sum to {alpha_sum:.3f}, expected 1.0")
        else:
            self.alpha = None

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)

        if self.alpha is not None:
            if self.alpha.device != targets.device:
                self.alpha = self.alpha.to(targets.device)
            alpha_t = self.alpha.gather(0, targets)
        else:
            alpha_t = 1.0

        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# PoolFormer Architecture for CASME II
class PoolFormerCASME2Baseline(nn.Module):
    """PoolFormer baseline for CASME II micro-expression recognition - Simplified for small dataset"""

    def __init__(self, num_classes, dropout_rate=0.3):
        super(PoolFormerCASME2Baseline, self).__init__()

        from transformers import PoolFormerModel

        self.poolformer = PoolFormerModel.from_pretrained(
            CASME2_POOLFORMER_CONFIG['poolformer_model']
        )

        # Enable fine-tuning
        for param in self.poolformer.parameters():
            param.requires_grad = True

        self.poolformer_feature_dim = self.poolformer.config.hidden_sizes[-1]

        print(f"PoolFormer feature dimension: {self.poolformer_feature_dim}")

        # Simplified classification head for small dataset (201 samples)
        # Reduced complexity: 512 -> 256 -> 7 (instead of 512 -> 256 -> 128 -> 7)
        self.classifier_layers = nn.Sequential(
            nn.Linear(self.poolformer_feature_dim, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(dropout_rate),
        )

        self.classifier = nn.Linear(256, num_classes)

        print(f"PoolFormer CASME II Simplified: {self.poolformer_feature_dim} -> 256 -> {num_classes}")
        print(f"Dropout rate: {dropout_rate} (increased for small dataset regularization)")

    def forward(self, pixel_values):
        poolformer_outputs = self.poolformer(pixel_values=pixel_values)

        # PoolFormer output: [batch_size, channels, height, width]
        poolformer_features = poolformer_outputs.last_hidden_state

        # Global average pooling across spatial dimensions
        # [batch_size, channels, height, width] -> [batch_size, channels]
        poolformer_features = poolformer_features.mean(dim=[2, 3])

        # Classification pipeline
        processed_features = self.classifier_layers(poolformer_features)
        output = self.classifier(processed_features)

        return output

# Optimizer and scheduler factory
def create_optimizer_scheduler_casme2(model, config):
    """Create optimizer and scheduler for CASME II PoolFormer training"""

    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

# PoolFormer Image Processor setup for 224px input
from transformers import PoolFormerImageProcessor

print("\nSetting up PoolFormer Image Processor for 224px input...")

poolformer_processor = PoolFormerImageProcessor.from_pretrained(
    CASME2_POOLFORMER_CONFIG['poolformer_model'],
    do_resize=True,
    size={'height': 224, 'width': 224},
    do_normalize=True,
    do_rescale=True,
    do_center_crop=False
)

# Transform functions for PoolFormer with grayscale to RGB conversion
def poolformer_transform_train(image):
    """
    Training transform with PoolFormer Image Processor
    Handles grayscale to RGB conversion via channel repetition
    """
    # Ensure RGB format (grayscale images will be converted via channel repetition)
    if image.mode != 'RGB':
        image = image.convert('RGB')

    inputs = poolformer_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

def poolformer_transform_val(image):
    """
    Validation transform with PoolFormer Image Processor
    Handles grayscale to RGB conversion via channel repetition
    """
    # Ensure RGB format (grayscale images will be converted via channel repetition)
    if image.mode != 'RGB':
        image = image.convert('RGB')

    inputs = poolformer_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

print("PoolFormer Image Processor configured for 224px")
print("Grayscale images will be converted to RGB via channel repetition")

# Custom Dataset class for CASME II v7
class CASME2Dataset(Dataset):
    """Custom dataset class for CASME II v7 with flat file structure"""

    def __init__(self, dataset_root, split, transform=None):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []
        self.emotions = []

        split_dir = os.path.join(dataset_root, split)

        if not os.path.exists(split_dir):
            raise ValueError(f"Split directory not found: {split_dir}")

        # Load all images from flat directory structure
        # Filename pattern: sub01_EP02_01f_happiness.jpg
        for img_file in os.listdir(split_dir):
            if img_file.endswith(('.jpg', '.jpeg', '.png')):
                # Extract emotion from filename (last part before extension)
                emotion = img_file.rsplit('_', 1)[-1].split('.')[0]

                if emotion in CASME2_CLASSES:
                    self.images.append(os.path.join(split_dir, img_file))
                    self.labels.append(CLASS_TO_IDX[emotion])
                    self.emotions.append(emotion)
                else:
                    print(f"Warning: Unknown emotion '{emotion}' in file {img_file}")

        print(f"Loaded {len(self.images)} samples for {split} split")
        self._print_distribution()

    def _print_distribution(self):
        """Print class distribution"""
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]

        try:
            # Load image (will be grayscale from v7 preprocessing)
            image = Image.open(image_path)

            # Verify size is 224x224
            if image.size != (224, 224):
                print(f"Warning: Image {os.path.basename(image_path)} has unexpected size {image.size}")
                image = image.resize((224, 224), Image.Resampling.LANCZOS)
        except Exception as e:
            print(f"Error loading {image_path}: {e}")
            # Fallback to gray dummy image
            image = Image.new('L', (224, 224), 128)

        # Apply transform (will handle grayscale to RGB conversion)
        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        sample_id = os.path.basename(self.images[idx])

        return image, label, sample_id

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

# Dataset paths - v7 structure
TRAIN_PATH = f"{DATASET_ROOT}/train"
VAL_PATH = f"{DATASET_ROOT}/val"
TEST_PATH = f"{DATASET_ROOT}/test"

print(f"\nDataset paths:")
print(f"Train: {TRAIN_PATH}")
print(f"Validation: {VAL_PATH}")
print(f"Test: {TEST_PATH}")

# Architecture validation
print("\nPoolFormer CASME II architecture validation...")

try:
    test_model = PoolFormerCASME2Baseline(num_classes=7, dropout_rate=0.3).to(device)
    test_input = torch.randn(1, 3, 224, 224).to(device)
    test_output = test_model(test_input)

    print(f"Validation successful: Output shape {test_output.shape}")
    print(f"PoolFormer {CASME2_POOLFORMER_CONFIG['model_variant'].upper()} with {MODEL_PARAMS} parameters")
    print(f"Simplified architecture: 512 -> 256 -> 7 (reduced parameters for small dataset)")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")

# Loss function factory
def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    """Factory function to create loss criterion"""
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

# Global configuration
GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': poolformer_transform_train,
    'transform_val': poolformer_transform_val,
    'poolformer_config': CASME2_POOLFORMER_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'train_path': TRAIN_PATH,
    'val_path': VAL_PATH,
    'test_path': TEST_PATH,
    'preprocessing_info': preprocessing_info,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

# Configuration validation and summary
print("\n" + "=" * 60)
print("CASME II APEX FRAME POOLFORMER CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: {POOLFORMER_MODEL_NAME}")
print(f"  Variant: {POOLFORMER_MODEL_VARIANT.upper()}")
print(f"  Parameters: {MODEL_PARAMS}")
print(f"  Input Resolution: 224px")
print(f"  Feature Dimension: {CASME2_POOLFORMER_CONFIG['expected_feature_dim']}")
print(f"  Classification Head: Simplified 512->256->7")
print(f"  Dropout: {CASME2_POOLFORMER_CONFIG['dropout_rate']} (increased regularization)")

print(f"\nDataset Configuration:")
print(f"  Version: {CASME2_POOLFORMER_CONFIG['dataset_version']}")
print(f"  Frame strategy: {CASME2_POOLFORMER_CONFIG['frame_strategy']}")
print(f"  Preprocessing: {CASME2_POOLFORMER_CONFIG['preprocessing_method']}")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Train samples: {preprocessing_info['splits']['train']['total_images']}")
print(f"  Val samples: {preprocessing_info['splits']['val']['total_images']}")
print(f"  Test samples: {preprocessing_info['splits']['test']['total_images']}")
print(f"  Face detection rate: {CASME2_POOLFORMER_CONFIG['face_detection_rate']:.2%}")

print(f"\nTraining Configuration (Optimized for Small Dataset):")
print(f"  Batch size: {BATCH_SIZE} (optimal for 201 samples)")
print(f"  Learning rate: {CASME2_POOLFORMER_CONFIG['learning_rate']} (adjusted for small batches)")
print(f"  Iterations/epoch: ~{201 // BATCH_SIZE} (frequent updates)")
print(f"  Strategy: Simplified architecture + stronger regularization")

print("\nNext: Cell 2 - Dataset Loading and Training Pipeline")

CASME II APEX FRAME POOLFORMER WITH FACE-AWARE PREPROCESSING

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...
CASME II Apex Frame PoolFormer - Face-Aware Preprocessing Infrastructure
Loading CASME II v7 preprocessing metadata...
Dataset variant: AF
Processing date: 2025-10-19T08:16:29.397012
Preprocessing method: face_bbox_expansion_all_directions
Total images processed: 255
Face detection rate: 100.00%
Target size: 224x224px
BBox expansion: 20px (all directions)
Using PoolFormer-M48 for enhanced micro-expression recognition (73M parameters)

EXPERIMENT CONFIGURATION - APEX FRAME FACE-AWARE
Dataset: v7 Apex Frame with Face-Aware Preprocessing
Loss Function: Focal Loss
  Gamma: 2.0
  Alpha Weights (per-class): [0.052, 0.066, 0.093, 0.101, 0.104, 0.208, 0.376]
  Alpha Sum Validation: 1.000
PoolFormer Model: sail/poolformer_m48
Model Parameters: 73M
Input Resolution: 224x224px (face-centered with bbox expansion)

preprocessor_config.json:   0%|          | 0.00/283 [00:00<?, ?B/s]

PoolFormer Image Processor configured for 224px
Grayscale images will be converted to RGB via channel repetition

Dataset paths:
Train: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v7/train
Validation: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v7/val
Test: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v7/test

PoolFormer CASME II architecture validation...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/294M [00:00<?, ?B/s]

PoolFormer feature dimension: 768
PoolFormer CASME II Simplified: 768 -> 256 -> 7
Dropout rate: 0.3 (increased for small dataset regularization)
Validation successful: Output shape torch.Size([1, 7])
PoolFormer M48 with 73M parameters
Simplified architecture: 512 -> 256 -> 7 (reduced parameters for small dataset)

CASME II APEX FRAME POOLFORMER CONFIGURATION COMPLETE
Loss Configuration:
  Function: Optimized Focal Loss
  Gamma: 2.0
  Per-class Alpha: [0.052, 0.066, 0.093, 0.101, 0.104, 0.208, 0.376]
  Alpha Sum: 1.000

Model Configuration:
  Architecture: sail/poolformer_m48
  Variant: M48
  Parameters: 73M
  Input Resolution: 224px
  Feature Dimension: 512
  Classification Head: Simplified 512->256->7
  Dropout: 0.3 (increased regularization)

Dataset Configuration:
  Version: v7
  Frame strategy: apex_frame
  Preprocessing: face_aware_bbox_expansion
  Classes: 7
  Train samples: 201
  Val samples: 26
  Test samples: 28
  Face detection rate: 100.00%

Training Configuration (Optimized

In [2]:
# @title Cell 2: CASME II Apex Frame PoolFormer Training Pipeline

# File: 05_03_PoolFormer_CASME2_AF_Cell2.py
# Location: experiments/05_03_PoolFormer_CASME2-AF-PREP.ipynb
# Purpose: Training pipeline for CASME II Apex Frame with face-aware preprocessing

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import shutil
import tempfile

print("CASME II Apex Frame PoolFormer Training Pipeline")
print("=" * 70)
print(f"Loss Function: {'Optimized Focal Loss' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'CrossEntropy Loss'}")
if CASME2_POOLFORMER_CONFIG['use_focal_loss']:
    print(f"Focal Loss Parameters:")
    print(f"  Gamma: {CASME2_POOLFORMER_CONFIG['focal_loss_gamma']}")
    print(f"  Per-class Alpha: {CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights']}")
    print(f"  Alpha Sum: {sum(CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights']):.3f}")
else:
    print(f"CrossEntropy Parameters:")
    print(f"  Optimized Class Weights: {CASME2_POOLFORMER_CONFIG['crossentropy_class_weights']}")
print(f"Dataset Version: {CASME2_POOLFORMER_CONFIG['dataset_version']}")
print(f"Frame Strategy: {CASME2_POOLFORMER_CONFIG['frame_strategy']}")
print(f"Preprocessing: {CASME2_POOLFORMER_CONFIG['preprocessing_method']}")
print(f"Training epochs: {CASME2_POOLFORMER_CONFIG['num_epochs']}")
print(f"Scheduler patience: {CASME2_POOLFORMER_CONFIG['scheduler_patience']}")

# Enhanced CASME II Dataset with RAM caching
class CASME2DatasetTraining(Dataset):
    """Enhanced CASME II dataset for training with RAM caching optimization"""

    def __init__(self, dataset_root, split, transform=None, use_ram_cache=True):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.emotions = []
        self.cached_images = []

        split_dir = os.path.join(dataset_root, split)

        if not os.path.exists(split_dir):
            raise ValueError(f"Split directory not found: {split_dir}")

        print(f"Loading CASME II {split} dataset for training...")

        # Load all images from flat directory structure
        # Filename pattern: sub01_EP02_01f_happiness.jpg
        for img_file in os.listdir(split_dir):
            if img_file.endswith(('.jpg', '.jpeg', '.png')):
                # Extract emotion from filename (last part before extension)
                emotion = img_file.rsplit('_', 1)[-1].split('.')[0]

                if emotion in CASME2_CLASSES:
                    self.images.append(os.path.join(split_dir, img_file))
                    self.labels.append(CLASS_TO_IDX[emotion])
                    self.emotions.append(emotion)
                else:
                    print(f"Warning: Unknown emotion '{emotion}' in file {img_file}")

        print(f"Loaded {len(self.images)} CASME II {split} samples")
        self._print_distribution()

        # RAM caching for training efficiency
        if self.use_ram_cache:
            self._preload_to_ram()

    def _print_distribution(self):
        """Print class distribution"""
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        """RAM preloading with parallel loading for training efficiency"""
        print(f"Preloading {len(self.images)} {self.split} images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            """Load single image with error handling"""
            try:
                # Load image (grayscale from v7 preprocessing)
                image = Image.open(img_path)

                # Verify size is 224x224
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)

                return idx, image, True
            except Exception as e:
                return idx, Image.new('L', (224, 224), 128), False

        # Parallel loading with ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc=f"Loading {self.split} to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_mb = len(self.cached_images) * 224 * 224 * 1 / 1e6
        print(f"{self.split.upper()} RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_mb:.1f}MB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx])
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
            except:
                image = Image.new('L', (224, 224), 128)

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], os.path.basename(self.images[idx])

# Metrics calculation with comprehensive error handling
def calculate_metrics_safe_robust(outputs, labels, class_names, average='macro'):
    """Calculate metrics with enhanced error handling and validation"""
    try:
        if outputs.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: outputs {outputs.size(0)} vs labels {labels.size(0)}")

        if isinstance(outputs, torch.Tensor):
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
        else:
            predictions = np.array(outputs)

        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()
        else:
            labels = np.array(labels)

        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions,
            average=average,
            zero_division=0,
            labels=list(range(len(class_names)))
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1)
        }
    except Exception as e:
        print(f"Warning: Metrics calculation error: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }

# Training epoch function
def train_epoch(model, dataloader, criterion, optimizer, device, epoch, total_epochs):
    """Training epoch with robust error handling"""
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch+1}/{total_epochs}")

    for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Model forward pass
        model_output = model(images)

        # Robust output extraction
        if isinstance(model_output, (tuple, list)):
            outputs = model_output[0]
        elif isinstance(model_output, dict):
            outputs = model_output.get('logits', model_output.get('prediction', model_output))
        else:
            outputs = model_output

        # Validate output shape for 7 CASME II classes
        if outputs.dim() != 2 or outputs.size(1) != 7:
            raise ValueError(f"Invalid output shape: {outputs.shape}, expected [batch_size, 7]")

        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), CASME2_POOLFORMER_CONFIG['gradient_clip'])

        optimizer.step()
        running_loss += loss.item()

        # Memory optimized: Move to CPU before accumulating
        all_outputs.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())

        # Update progress
        if batch_idx % 5 == 0:
            avg_loss = running_loss / (batch_idx + 1)
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'Loss': f'{avg_loss:.4f}',
                'LR': f'{current_lr:.2e}'
            })

    # Metrics calculation
    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Training metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics

# Validation epoch function
def validate_epoch(model, dataloader, criterion, device, epoch, total_epochs):
    """Validation epoch with robust error handling"""
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []
    all_sample_ids = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f"Validation Epoch {epoch+1}/{total_epochs}")

        for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)

            # Model forward pass
            model_output = model(images)

            # Robust output extraction
            if isinstance(model_output, (tuple, list)):
                outputs = model_output[0]
            elif isinstance(model_output, dict):
                outputs = model_output.get('logits', model_output.get('prediction', model_output))
            else:
                outputs = model_output

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Memory optimized: Move to CPU before accumulating
            all_outputs.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sample_ids.extend(sample_ids)

            if batch_idx % 3 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                progress_bar.set_postfix({'Val Loss': f'{avg_loss:.4f}'})

    # Metrics calculation
    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Validation metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics, all_sample_ids

# Hardened checkpoint saving function
def save_checkpoint_robust(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                         checkpoint_dir, best_metrics, config, max_retries=3):
    """Hardened checkpoint saving with atomic write and validation"""

    def make_serializable_cpu(obj):
        if isinstance(obj, torch.Tensor):
            cpu_obj = obj.detach().cpu()
            return cpu_obj.item() if cpu_obj.numel() == 1 else cpu_obj.tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, dict):
            return {k: make_serializable_cpu(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [make_serializable_cpu(item) for item in obj]
        else:
            return obj

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_metrics': make_serializable_cpu(train_metrics),
        'val_metrics': make_serializable_cpu(val_metrics),
        'casme2_config': make_serializable_cpu(config),
        'best_f1': float(best_metrics['f1']),
        'best_loss': float(best_metrics['loss']),
        'best_acc': float(best_metrics['accuracy']),
        'class_names': CASME2_CLASSES,
        'num_classes': 7
    }

    final_path = f"{checkpoint_dir}/casme2_poolformer_apex_frame_best_f1.pth"

    for attempt in range(max_retries):
        try:
            # Save to temporary file
            temp_fd, temp_path = tempfile.mkstemp(dir=checkpoint_dir, suffix='.pth.tmp')
            os.close(temp_fd)

            print(f"Attempt {attempt + 1}: Saving checkpoint to temporary file...")
            torch.save(checkpoint, temp_path)

            # Validate checkpoint
            print("Validating checkpoint integrity...")
            validation_checkpoint = torch.load(temp_path, map_location='cpu')

            required_keys = ['model_state_dict', 'epoch', 'best_f1', 'num_classes']
            for key in required_keys:
                if key not in validation_checkpoint:
                    raise ValueError(f"Checkpoint validation failed: missing key '{key}'")

            if validation_checkpoint['epoch'] != epoch:
                raise ValueError(f"Checkpoint epoch mismatch: saved {epoch}, loaded {validation_checkpoint['epoch']}")

            print("Checkpoint validation passed")

            # Atomic rename
            print(f"Moving validated checkpoint to final location...")
            shutil.move(temp_path, final_path)

            print(f"Checkpoint saved and validated successfully: {os.path.basename(final_path)}")
            print(f"  Epoch: {epoch + 1}")
            print(f"  Val F1: {best_metrics['f1']:.4f}")
            print(f"  Val Loss: {best_metrics['loss']:.4f}")
            print(f"  Val Acc: {best_metrics['accuracy']:.4f}")

            return final_path

        except Exception as e:
            print(f"Checkpoint save attempt {attempt + 1}/{max_retries} failed: {e}")

            if os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except:
                    pass

            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                print(f"Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print(f"All {max_retries} checkpoint save attempts failed")
                return None

    return None

# Safe JSON serialization
def safe_json_serialize(obj):
    """Convert objects to JSON-serializable format"""
    if isinstance(obj, torch.Tensor):
        return obj.cpu().item() if obj.numel() == 1 else obj.cpu().numpy().tolist()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: safe_json_serialize(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [safe_json_serialize(item) for item in obj]
    elif hasattr(obj, '__dict__'):
        return safe_json_serialize(obj.__dict__)
    else:
        try:
            return float(obj) if isinstance(obj, (int, float)) else str(obj)
        except:
            return str(obj)

# Create training datasets
print("\nCreating CASME II Apex Frame training datasets...")

train_dataset = CASME2DatasetTraining(
    dataset_root=DATASET_ROOT,
    split='train',
    transform=GLOBAL_CONFIG_CASME2['transform_train'],
    use_ram_cache=True
)

val_dataset = CASME2DatasetTraining(
    dataset_root=DATASET_ROOT,
    split='val',
    transform=GLOBAL_CONFIG_CASME2['transform_val'],
    use_ram_cache=True
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CASME2_POOLFORMER_CONFIG['batch_size'],
    shuffle=True,
    num_workers=CASME2_POOLFORMER_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CASME2_POOLFORMER_CONFIG['batch_size'],
    shuffle=False,
    num_workers=CASME2_POOLFORMER_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

print(f"Training batches: {len(train_loader)} (samples: {len(train_dataset)})")
print(f"Validation batches: {len(val_loader)} (samples: {len(val_dataset)})")

# Initialize model, criterion, optimizer, scheduler
print("\nInitializing CASME II Apex Frame PoolFormer model...")
model = PoolFormerCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_POOLFORMER_CONFIG['dropout_rate']
).to(GLOBAL_CONFIG_CASME2['device'])

# Create criterion
if CASME2_POOLFORMER_CONFIG['use_focal_loss']:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=True,
        alpha_weights=CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights'],
        gamma=CASME2_POOLFORMER_CONFIG['focal_loss_gamma']
    )
else:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=False,
        alpha_weights=None,
        gamma=2.0
    )

optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_POOLFORMER_CONFIG
)

print(f"Optimizer: AdamW (LR={CASME2_POOLFORMER_CONFIG['learning_rate']})")
print(f"Scheduler: ReduceLROnPlateau (patience={CASME2_POOLFORMER_CONFIG['scheduler_patience']})")
print(f"Criterion: {'Optimized Focal Loss' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'CrossEntropy'}")

# Training history tracking
training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

# Best metrics tracking
best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

print("\nStarting CASME II Apex Frame PoolFormer training...")
print(f"Training configuration: {CASME2_POOLFORMER_CONFIG['num_epochs']} epochs")
print(f"Small dataset size: {len(train_dataset)} train samples (overfitting risk)")
print("=" * 70)

# Main training loop
start_time = time.time()

for epoch in range(CASME2_POOLFORMER_CONFIG['num_epochs']):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{CASME2_POOLFORMER_CONFIG['num_epochs']}")

    # Training phase
    train_loss, train_metrics = train_epoch(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_POOLFORMER_CONFIG['num_epochs']
    )

    # Validation phase
    val_loss, val_metrics, val_sample_ids = validate_epoch(
        model, val_loader, criterion,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_POOLFORMER_CONFIG['num_epochs']
    )

    # Update scheduler
    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    # Record training history
    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    # Print epoch summary
    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    # Multi-criteria checkpoint saving
    save_model = False
    improvement_reason = ""

    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        best_model_path = save_checkpoint_robust(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics, GLOBAL_CONFIG_CASME2['checkpoint_root'],
            best_metrics, CASME2_POOLFORMER_CONFIG
        )

        if best_model_path:
            print(f"New best model: {improvement_reason} - F1: {best_metrics['f1']:.4f}")
        else:
            print(f"Warning: Failed to save checkpoint despite improvement")

    # Progress tracking
    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * CASME2_POOLFORMER_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / CASME2_POOLFORMER_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min")

# Training completion
total_time = time.time() - start_time
actual_epochs = CASME2_POOLFORMER_CONFIG['num_epochs']

print("\n" + "=" * 70)
print("CASME II APEX FRAME POOLFORMER TRAINING COMPLETED")
print("=" * 70)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {actual_epochs}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")

# Export training documentation
results_dir = GLOBAL_CONFIG_CASME2['results_root']
os.makedirs(f"{results_dir}/training_logs", exist_ok=True)

training_history_path = f"{results_dir}/training_logs/casme2_poolformer_apex_frame_training_history.json"

print("\nExporting training documentation...")

try:
    training_summary = {
        'experiment_type': 'CASME2_PoolFormer_Apex_Frame_Face_Aware',
        'experiment_configuration': {
            'dataset_version': CASME2_POOLFORMER_CONFIG['dataset_version'],
            'frame_strategy': CASME2_POOLFORMER_CONFIG['frame_strategy'],
            'preprocessing_method': CASME2_POOLFORMER_CONFIG['preprocessing_method'],
            'image_format': CASME2_POOLFORMER_CONFIG['image_format'],
            'bbox_expansion': CASME2_POOLFORMER_CONFIG['bbox_expansion'],
            'loss_function': 'Optimized Focal Loss' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'CrossEntropy',
            'weight_approach': 'Per-class Alpha (sum=1.0)' if CASME2_POOLFORMER_CONFIG['use_focal_loss'] else 'Inverse Sqrt Frequency',
            'focal_loss_gamma': CASME2_POOLFORMER_CONFIG['focal_loss_gamma'],
            'focal_loss_alpha_weights': CASME2_POOLFORMER_CONFIG['focal_loss_alpha_weights'],
            'crossentropy_class_weights': CASME2_POOLFORMER_CONFIG['crossentropy_class_weights'],
            'poolformer_model': CASME2_POOLFORMER_CONFIG['poolformer_model'],
            'model_variant': CASME2_POOLFORMER_CONFIG['model_variant']
        },
        'training_history': safe_json_serialize(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(actual_epochs),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'config': safe_json_serialize(CASME2_POOLFORMER_CONFIG),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'model_checkpoint': 'casme2_poolformer_apex_frame_best_f1.pth',
        'dataset_info': {
            'name': 'CASME_II',
            'version': CASME2_POOLFORMER_CONFIG['dataset_version'],
            'frame_strategy': CASME2_POOLFORMER_CONFIG['frame_strategy'],
            'preprocessing': CASME2_POOLFORMER_CONFIG['preprocessing_method'],
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': 7,
            'class_names': CASME2_CLASSES,
            'face_detection_rate': CASME2_POOLFORMER_CONFIG['face_detection_rate']
        },
        'architecture_info': {
            'model_type': 'PoolFormerCASME2Baseline',
            'backbone': CASME2_POOLFORMER_CONFIG['poolformer_model'],
            'model_variant': CASME2_POOLFORMER_CONFIG['model_variant'],
            'model_params': CASME2_POOLFORMER_CONFIG['model_params'],
            'input_size': f"{CASME2_POOLFORMER_CONFIG['input_size']}x{CASME2_POOLFORMER_CONFIG['input_size']}",
            'classification_head': '512->256->7'
        }
    }

    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Training documentation saved successfully: {training_history_path}")
    print(f"Loss function: {training_summary['experiment_configuration']['loss_function']}")
    print(f"Model variant: {CASME2_POOLFORMER_CONFIG['poolformer_model']}")
    print(f"Dataset version: {CASME2_POOLFORMER_CONFIG['dataset_version']}")

except Exception as e:
    print(f"Warning: Could not save training documentation: {e}")
    print("Training completed successfully, but documentation export failed")

# Memory cleanup
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\nNext: Cell 3 - CASME II Apex Frame PoolFormer Evaluation")
print("Training pipeline with face-aware preprocessing completed successfully!")

CASME II Apex Frame PoolFormer Training Pipeline
Loss Function: Optimized Focal Loss
Focal Loss Parameters:
  Gamma: 2.0
  Per-class Alpha: [0.052, 0.066, 0.093, 0.101, 0.104, 0.208, 0.376]
  Alpha Sum: 1.000
Dataset Version: v7
Frame Strategy: apex_frame
Preprocessing: face_aware_bbox_expansion
Training epochs: 50
Scheduler patience: 5

Creating CASME II Apex Frame training datasets...
Loading CASME II train dataset for training...
Loaded 201 CASME II train samples
  others: 79 samples (39.3%)
  disgust: 50 samples (24.9%)
  happiness: 25 samples (12.4%)
  repression: 21 samples (10.4%)
  surprise: 20 samples (10.0%)
  sadness: 5 samples (2.5%)
  fear: 1 samples (0.5%)
Preloading 201 train images to RAM with 32 workers...


Loading train to RAM: 100%|██████████| 201/201 [00:10<00:00, 18.75it/s]


TRAIN RAM caching completed: 201/201 images, ~10.1MB
Loading CASME II val dataset for training...
Loaded 26 CASME II val samples
  others: 10 samples (38.5%)
  disgust: 6 samples (23.1%)
  happiness: 3 samples (11.5%)
  repression: 3 samples (11.5%)
  surprise: 2 samples (7.7%)
  sadness: 1 samples (3.8%)
  fear: 1 samples (3.8%)
Preloading 26 val images to RAM with 32 workers...


Loading val to RAM: 100%|██████████| 26/26 [00:01<00:00, 14.06it/s]


VAL RAM caching completed: 26/26 images, ~1.3MB
Training batches: 51 (samples: 201)
Validation batches: 7 (samples: 26)

Initializing CASME II Apex Frame PoolFormer model...
PoolFormer feature dimension: 768
PoolFormer CASME II Simplified: 768 -> 256 -> 7
Dropout rate: 0.3 (increased for small dataset regularization)
Using Optimized Focal Loss with gamma=2.0
Per-class alpha weights: [0.052, 0.066, 0.093, 0.101, 0.104, 0.208, 0.376]
Alpha sum: 1.000
Scheduler: ReduceLROnPlateau monitoring val_f1_macro
Optimizer: AdamW (LR=2e-05)
Scheduler: ReduceLROnPlateau (patience=5)
Criterion: Optimized Focal Loss

Starting CASME II Apex Frame PoolFormer training...
Training configuration: 50 epochs
Small dataset size: 201 train samples (overfitting risk)

Epoch 1/50


Training Epoch 1/50: 100%|██████████| 51/51 [00:07<00:00,  6.85it/s, Loss=0.1026, LR=2.00e-05]
Validation Epoch 1/50: 100%|██████████| 7/7 [00:00<00:00, 14.97it/s, Val Loss=0.1303]


Train - Loss: 0.1026, F1: 0.1578, Acc: 0.3333
Val   - Loss: 0.1303, F1: 0.0905, Acc: 0.3077
Time  - Epoch: 7.9s, LR: 2.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_poolformer_apex_frame_best_f1.pth
  Epoch: 1
  Val F1: 0.0905
  Val Loss: 0.1303
  Val Acc: 0.3077
New best model: Higher F1 - F1: 0.0905
Progress: 2.0% | Best F1: 0.0905 | ETA: 9.2min

Epoch 2/50


Training Epoch 2/50: 100%|██████████| 51/51 [00:05<00:00,  9.44it/s, Loss=0.0827, LR=2.00e-05]
Validation Epoch 2/50: 100%|██████████| 7/7 [00:00<00:00, 18.05it/s, Val Loss=0.1314]


Train - Loss: 0.0827, F1: 0.1728, Acc: 0.4179
Val   - Loss: 0.1314, F1: 0.1694, Acc: 0.3462
Time  - Epoch: 5.8s, LR: 2.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_poolformer_apex_frame_best_f1.pth
  Epoch: 2
  Val F1: 0.1694
  Val Loss: 0.1314
  Val Acc: 0.3462
New best model: Higher F1 - F1: 0.1694
Progress: 4.0% | Best F1: 0.1694 | ETA: 7.9min

Epoch 3/50


Training Epoch 3/50: 100%|██████████| 51/51 [00:05<00:00,  9.24it/s, Loss=0.0675, LR=2.00e-05]
Validation Epoch 3/50: 100%|██████████| 7/7 [00:00<00:00, 17.72it/s, Val Loss=0.1340]


Train - Loss: 0.0675, F1: 0.3177, Acc: 0.5274
Val   - Loss: 0.1340, F1: 0.1444, Acc: 0.3077
Time  - Epoch: 5.9s, LR: 2.00e-05
Progress: 6.0% | Best F1: 0.1694 | ETA: 6.7min

Epoch 4/50


Training Epoch 4/50: 100%|██████████| 51/51 [00:05<00:00,  9.31it/s, Loss=0.0532, LR=2.00e-05]
Validation Epoch 4/50: 100%|██████████| 7/7 [00:00<00:00, 17.89it/s, Val Loss=0.1372]


Train - Loss: 0.0532, F1: 0.4743, Acc: 0.6617
Val   - Loss: 0.1372, F1: 0.1591, Acc: 0.3462
Time  - Epoch: 5.9s, LR: 2.00e-05
Progress: 8.0% | Best F1: 0.1694 | ETA: 6.1min

Epoch 5/50


Training Epoch 5/50: 100%|██████████| 51/51 [00:05<00:00,  9.35it/s, Loss=0.0367, LR=2.00e-05]
Validation Epoch 5/50: 100%|██████████| 7/7 [00:00<00:00, 16.32it/s, Val Loss=0.1327]


Train - Loss: 0.0367, F1: 0.6010, Acc: 0.7761
Val   - Loss: 0.1327, F1: 0.1400, Acc: 0.3077
Time  - Epoch: 5.9s, LR: 2.00e-05
Progress: 10.0% | Best F1: 0.1694 | ETA: 5.6min

Epoch 6/50


Training Epoch 6/50: 100%|██████████| 51/51 [00:05<00:00,  8.69it/s, Loss=0.0261, LR=2.00e-05]
Validation Epoch 6/50: 100%|██████████| 7/7 [00:00<00:00, 17.57it/s, Val Loss=0.1333]


Train - Loss: 0.0261, F1: 0.7038, Acc: 0.8408
Val   - Loss: 0.1333, F1: 0.2330, Acc: 0.3846
Time  - Epoch: 6.3s, LR: 2.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_poolformer_apex_frame_best_f1.pth
  Epoch: 6
  Val F1: 0.2330
  Val Loss: 0.1333
  Val Acc: 0.3846
New best model: Higher F1 - F1: 0.2330
Progress: 12.0% | Best F1: 0.2330 | ETA: 7.9min

Epoch 7/50


Training Epoch 7/50: 100%|██████████| 51/51 [00:05<00:00,  9.18it/s, Loss=0.0163, LR=2.00e-05]
Validation Epoch 7/50: 100%|██████████| 7/7 [00:00<00:00, 16.87it/s, Val Loss=0.1403]


Train - Loss: 0.0163, F1: 0.7813, Acc: 0.9254
Val   - Loss: 0.1403, F1: 0.1983, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 2.00e-05
Progress: 14.0% | Best F1: 0.2330 | ETA: 7.2min

Epoch 8/50


Training Epoch 8/50: 100%|██████████| 51/51 [00:05<00:00,  8.72it/s, Loss=0.0124, LR=2.00e-05]
Validation Epoch 8/50: 100%|██████████| 7/7 [00:00<00:00, 17.11it/s, Val Loss=0.1505]


Train - Loss: 0.0124, F1: 0.8259, Acc: 0.9652
Val   - Loss: 0.1505, F1: 0.1444, Acc: 0.3077
Time  - Epoch: 6.3s, LR: 2.00e-05
Progress: 16.0% | Best F1: 0.2330 | ETA: 6.7min

Epoch 9/50


Training Epoch 9/50: 100%|██████████| 51/51 [00:05<00:00,  9.36it/s, Loss=0.0073, LR=2.00e-05]
Validation Epoch 9/50: 100%|██████████| 7/7 [00:00<00:00, 16.84it/s, Val Loss=0.1473]


Train - Loss: 0.0073, F1: 0.8477, Acc: 0.9900
Val   - Loss: 0.1473, F1: 0.1922, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 2.00e-05
Progress: 18.0% | Best F1: 0.2330 | ETA: 6.3min

Epoch 10/50


Training Epoch 10/50: 100%|██████████| 51/51 [00:05<00:00,  9.12it/s, Loss=0.0055, LR=2.00e-05]
Validation Epoch 10/50: 100%|██████████| 7/7 [00:00<00:00, 17.00it/s, Val Loss=0.1626]


Train - Loss: 0.0055, F1: 0.8495, Acc: 0.9900
Val   - Loss: 0.1626, F1: 0.2031, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 2.00e-05
Progress: 20.0% | Best F1: 0.2330 | ETA: 5.9min

Epoch 11/50


Training Epoch 11/50: 100%|██████████| 51/51 [00:05<00:00,  9.35it/s, Loss=0.0041, LR=2.00e-05]
Validation Epoch 11/50: 100%|██████████| 7/7 [00:00<00:00, 17.02it/s, Val Loss=0.1593]


Train - Loss: 0.0041, F1: 0.8538, Acc: 0.9950
Val   - Loss: 0.1593, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 2.00e-05
Progress: 22.0% | Best F1: 0.2330 | ETA: 5.6min

Epoch 12/50


Training Epoch 12/50: 100%|██████████| 51/51 [00:05<00:00,  9.24it/s, Loss=0.0030, LR=2.00e-05]
Validation Epoch 12/50: 100%|██████████| 7/7 [00:00<00:00, 17.11it/s, Val Loss=0.1666]


Train - Loss: 0.0030, F1: 0.8538, Acc: 0.9950
Val   - Loss: 0.1666, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 5.9s, LR: 1.00e-05
Progress: 24.0% | Best F1: 0.2330 | ETA: 5.3min

Epoch 13/50


Training Epoch 13/50: 100%|██████████| 51/51 [00:05<00:00,  9.21it/s, Loss=0.0018, LR=1.00e-05]
Validation Epoch 13/50: 100%|██████████| 7/7 [00:00<00:00, 17.06it/s, Val Loss=0.1666]


Train - Loss: 0.0018, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1666, F1: 0.2102, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 1.00e-05
Progress: 26.0% | Best F1: 0.2330 | ETA: 5.0min

Epoch 14/50


Training Epoch 14/50: 100%|██████████| 51/51 [00:05<00:00,  9.15it/s, Loss=0.0018, LR=1.00e-05]
Validation Epoch 14/50: 100%|██████████| 7/7 [00:00<00:00, 16.16it/s, Val Loss=0.1704]


Train - Loss: 0.0018, F1: 0.8538, Acc: 0.9950
Val   - Loss: 0.1704, F1: 0.1857, Acc: 0.3846
Time  - Epoch: 6.0s, LR: 1.00e-05
Progress: 28.0% | Best F1: 0.2330 | ETA: 4.8min

Epoch 15/50


Training Epoch 15/50: 100%|██████████| 51/51 [00:05<00:00,  9.26it/s, Loss=0.0010, LR=1.00e-05]
Validation Epoch 15/50: 100%|██████████| 7/7 [00:00<00:00, 16.24it/s, Val Loss=0.1710]


Train - Loss: 0.0010, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1710, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 1.00e-05
Progress: 30.0% | Best F1: 0.2330 | ETA: 4.6min

Epoch 16/50


Training Epoch 16/50: 100%|██████████| 51/51 [00:05<00:00,  8.95it/s, Loss=0.0008, LR=1.00e-05]
Validation Epoch 16/50: 100%|██████████| 7/7 [00:00<00:00, 15.91it/s, Val Loss=0.1717]


Train - Loss: 0.0008, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1717, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.2s, LR: 1.00e-05
Progress: 32.0% | Best F1: 0.2330 | ETA: 4.4min

Epoch 17/50


Training Epoch 17/50: 100%|██████████| 51/51 [00:05<00:00,  9.24it/s, Loss=0.0005, LR=1.00e-05]
Validation Epoch 17/50: 100%|██████████| 7/7 [00:00<00:00, 16.92it/s, Val Loss=0.1752]


Train - Loss: 0.0005, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1752, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 1.00e-05
Progress: 34.0% | Best F1: 0.2330 | ETA: 4.2min

Epoch 18/50


Training Epoch 18/50: 100%|██████████| 51/51 [00:05<00:00,  9.06it/s, Loss=0.0006, LR=1.00e-05]
Validation Epoch 18/50: 100%|██████████| 7/7 [00:00<00:00, 16.48it/s, Val Loss=0.1750]


Train - Loss: 0.0006, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1750, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.1s, LR: 5.00e-06
Progress: 36.0% | Best F1: 0.2330 | ETA: 4.0min

Epoch 19/50


Training Epoch 19/50: 100%|██████████| 51/51 [00:05<00:00,  9.07it/s, Loss=0.0004, LR=5.00e-06]
Validation Epoch 19/50: 100%|██████████| 7/7 [00:00<00:00, 17.17it/s, Val Loss=0.1762]


Train - Loss: 0.0004, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1762, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 5.00e-06
Progress: 38.0% | Best F1: 0.2330 | ETA: 3.9min

Epoch 20/50


Training Epoch 20/50: 100%|██████████| 51/51 [00:05<00:00,  8.97it/s, Loss=0.0004, LR=5.00e-06]
Validation Epoch 20/50: 100%|██████████| 7/7 [00:00<00:00, 15.54it/s, Val Loss=0.1766]


Train - Loss: 0.0004, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1766, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.2s, LR: 5.00e-06
Progress: 40.0% | Best F1: 0.2330 | ETA: 3.7min

Epoch 21/50


Training Epoch 21/50: 100%|██████████| 51/51 [00:05<00:00,  9.15it/s, Loss=0.0003, LR=5.00e-06]
Validation Epoch 21/50: 100%|██████████| 7/7 [00:00<00:00, 16.20it/s, Val Loss=0.1776]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1776, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 5.00e-06
Progress: 42.0% | Best F1: 0.2330 | ETA: 3.6min

Epoch 22/50


Training Epoch 22/50: 100%|██████████| 51/51 [00:05<00:00,  9.27it/s, Loss=0.0004, LR=5.00e-06]
Validation Epoch 22/50: 100%|██████████| 7/7 [00:00<00:00, 16.28it/s, Val Loss=0.1777]


Train - Loss: 0.0004, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1777, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 5.9s, LR: 5.00e-06
Progress: 44.0% | Best F1: 0.2330 | ETA: 3.4min

Epoch 23/50


Training Epoch 23/50: 100%|██████████| 51/51 [00:05<00:00,  9.07it/s, Loss=0.0003, LR=5.00e-06]
Validation Epoch 23/50: 100%|██████████| 7/7 [00:00<00:00, 15.83it/s, Val Loss=0.1799]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1799, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.1s, LR: 5.00e-06
Progress: 46.0% | Best F1: 0.2330 | ETA: 3.3min

Epoch 24/50


Training Epoch 24/50: 100%|██████████| 51/51 [00:05<00:00,  9.14it/s, Loss=0.0003, LR=5.00e-06]
Validation Epoch 24/50: 100%|██████████| 7/7 [00:00<00:00, 15.92it/s, Val Loss=0.1798]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1798, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 2.50e-06
Progress: 48.0% | Best F1: 0.2330 | ETA: 3.1min

Epoch 25/50


Training Epoch 25/50: 100%|██████████| 51/51 [00:05<00:00,  8.97it/s, Loss=0.0003, LR=2.50e-06]
Validation Epoch 25/50: 100%|██████████| 7/7 [00:00<00:00, 16.43it/s, Val Loss=0.1804]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1804, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.1s, LR: 2.50e-06
Progress: 50.0% | Best F1: 0.2330 | ETA: 3.0min

Epoch 26/50


Training Epoch 26/50: 100%|██████████| 51/51 [00:05<00:00,  9.12it/s, Loss=0.0003, LR=2.50e-06]
Validation Epoch 26/50: 100%|██████████| 7/7 [00:00<00:00, 16.25it/s, Val Loss=0.1807]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1807, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.0s, LR: 2.50e-06
Progress: 52.0% | Best F1: 0.2330 | ETA: 2.8min

Epoch 27/50


Training Epoch 27/50: 100%|██████████| 51/51 [00:05<00:00,  9.13it/s, Loss=0.0002, LR=2.50e-06]
Validation Epoch 27/50: 100%|██████████| 7/7 [00:00<00:00, 16.60it/s, Val Loss=0.1811]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1811, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.0s, LR: 2.50e-06
Progress: 54.0% | Best F1: 0.2330 | ETA: 2.7min

Epoch 28/50


Training Epoch 28/50: 100%|██████████| 51/51 [00:05<00:00,  9.17it/s, Loss=0.0003, LR=2.50e-06]
Validation Epoch 28/50: 100%|██████████| 7/7 [00:00<00:00, 16.19it/s, Val Loss=0.1810]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1810, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 2.50e-06
Progress: 56.0% | Best F1: 0.2330 | ETA: 2.6min

Epoch 29/50


Training Epoch 29/50: 100%|██████████| 51/51 [00:05<00:00,  9.07it/s, Loss=0.0003, LR=2.50e-06]
Validation Epoch 29/50: 100%|██████████| 7/7 [00:00<00:00, 16.77it/s, Val Loss=0.1814]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1814, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.1s, LR: 2.50e-06
Progress: 58.0% | Best F1: 0.2330 | ETA: 2.5min

Epoch 30/50


Training Epoch 30/50: 100%|██████████| 51/51 [00:05<00:00,  9.03it/s, Loss=0.0003, LR=2.50e-06]
Validation Epoch 30/50: 100%|██████████| 7/7 [00:00<00:00, 16.92it/s, Val Loss=0.1820]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1820, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.1s, LR: 1.25e-06
Progress: 60.0% | Best F1: 0.2330 | ETA: 2.3min

Epoch 31/50


Training Epoch 31/50: 100%|██████████| 51/51 [00:05<00:00,  9.19it/s, Loss=0.0003, LR=1.25e-06]
Validation Epoch 31/50: 100%|██████████| 7/7 [00:00<00:00, 16.33it/s, Val Loss=0.1823]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1823, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.0s, LR: 1.25e-06
Progress: 62.0% | Best F1: 0.2330 | ETA: 2.2min

Epoch 32/50


Training Epoch 32/50: 100%|██████████| 51/51 [00:05<00:00,  9.30it/s, Loss=0.0003, LR=1.25e-06]
Validation Epoch 32/50: 100%|██████████| 7/7 [00:00<00:00, 17.11it/s, Val Loss=0.1824]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1824, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.25e-06
Progress: 64.0% | Best F1: 0.2330 | ETA: 2.1min

Epoch 33/50


Training Epoch 33/50: 100%|██████████| 51/51 [00:05<00:00,  9.33it/s, Loss=0.0002, LR=1.25e-06]
Validation Epoch 33/50: 100%|██████████| 7/7 [00:00<00:00, 16.62it/s, Val Loss=0.1827]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1827, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.25e-06
Progress: 66.0% | Best F1: 0.2330 | ETA: 1.9min

Epoch 34/50


Training Epoch 34/50: 100%|██████████| 51/51 [00:05<00:00,  8.95it/s, Loss=0.0002, LR=1.25e-06]
Validation Epoch 34/50: 100%|██████████| 7/7 [00:00<00:00, 16.87it/s, Val Loss=0.1829]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1829, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.1s, LR: 1.25e-06
Progress: 68.0% | Best F1: 0.2330 | ETA: 1.8min

Epoch 35/50


Training Epoch 35/50: 100%|██████████| 51/51 [00:05<00:00,  9.30it/s, Loss=0.0002, LR=1.25e-06]
Validation Epoch 35/50: 100%|██████████| 7/7 [00:00<00:00, 16.62it/s, Val Loss=0.1827]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1827, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.25e-06
Progress: 70.0% | Best F1: 0.2330 | ETA: 1.7min

Epoch 36/50


Training Epoch 36/50: 100%|██████████| 51/51 [00:05<00:00,  9.27it/s, Loss=0.0003, LR=1.25e-06]
Validation Epoch 36/50: 100%|██████████| 7/7 [00:00<00:00, 16.37it/s, Val Loss=0.1828]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1828, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 72.0% | Best F1: 0.2330 | ETA: 1.6min

Epoch 37/50


Training Epoch 37/50: 100%|██████████| 51/51 [00:05<00:00,  9.31it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 37/50: 100%|██████████| 7/7 [00:00<00:00, 17.35it/s, Val Loss=0.1828]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1828, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 74.0% | Best F1: 0.2330 | ETA: 1.5min

Epoch 38/50


Training Epoch 38/50: 100%|██████████| 51/51 [00:05<00:00,  9.37it/s, Loss=0.0003, LR=1.00e-06]
Validation Epoch 38/50: 100%|██████████| 7/7 [00:00<00:00, 16.93it/s, Val Loss=0.1829]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1829, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 76.0% | Best F1: 0.2330 | ETA: 1.4min

Epoch 39/50


Training Epoch 39/50: 100%|██████████| 51/51 [00:05<00:00,  9.31it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 39/50: 100%|██████████| 7/7 [00:00<00:00, 17.08it/s, Val Loss=0.1830]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1830, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 78.0% | Best F1: 0.2330 | ETA: 1.2min

Epoch 40/50


Training Epoch 40/50: 100%|██████████| 51/51 [00:05<00:00,  9.37it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 40/50: 100%|██████████| 7/7 [00:00<00:00, 16.87it/s, Val Loss=0.1832]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1832, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 80.0% | Best F1: 0.2330 | ETA: 1.1min

Epoch 41/50


Training Epoch 41/50: 100%|██████████| 51/51 [00:05<00:00,  9.26it/s, Loss=0.0003, LR=1.00e-06]
Validation Epoch 41/50: 100%|██████████| 7/7 [00:00<00:00, 17.27it/s, Val Loss=0.1832]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1832, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 82.0% | Best F1: 0.2330 | ETA: 1.0min

Epoch 42/50


Training Epoch 42/50: 100%|██████████| 51/51 [00:05<00:00,  8.68it/s, Loss=0.0003, LR=1.00e-06]
Validation Epoch 42/50: 100%|██████████| 7/7 [00:00<00:00, 17.30it/s, Val Loss=0.1833]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1833, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.3s, LR: 1.00e-06
Progress: 84.0% | Best F1: 0.2330 | ETA: 0.9min

Epoch 43/50


Training Epoch 43/50: 100%|██████████| 51/51 [00:05<00:00,  9.14it/s, Loss=0.0003, LR=1.00e-06]
Validation Epoch 43/50: 100%|██████████| 7/7 [00:00<00:00, 17.22it/s, Val Loss=0.1839]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1839, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.0s, LR: 1.00e-06
Progress: 86.0% | Best F1: 0.2330 | ETA: 0.8min

Epoch 44/50


Training Epoch 44/50: 100%|██████████| 51/51 [00:05<00:00,  9.30it/s, Loss=0.0003, LR=1.00e-06]
Validation Epoch 44/50: 100%|██████████| 7/7 [00:00<00:00, 16.53it/s, Val Loss=0.1838]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1838, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 88.0% | Best F1: 0.2330 | ETA: 0.7min

Epoch 45/50


Training Epoch 45/50: 100%|██████████| 51/51 [00:05<00:00,  9.21it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 45/50: 100%|██████████| 7/7 [00:00<00:00, 16.86it/s, Val Loss=0.1838]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1838, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 6.0s, LR: 1.00e-06
Progress: 90.0% | Best F1: 0.2330 | ETA: 0.6min

Epoch 46/50


Training Epoch 46/50: 100%|██████████| 51/51 [00:05<00:00,  9.33it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 46/50: 100%|██████████| 7/7 [00:00<00:00, 17.09it/s, Val Loss=0.1840]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1840, F1: 0.1888, Acc: 0.3846
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 92.0% | Best F1: 0.2330 | ETA: 0.4min

Epoch 47/50


Training Epoch 47/50: 100%|██████████| 51/51 [00:05<00:00,  9.33it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 47/50: 100%|██████████| 7/7 [00:00<00:00, 17.00it/s, Val Loss=0.1843]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1843, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 94.0% | Best F1: 0.2330 | ETA: 0.3min

Epoch 48/50


Training Epoch 48/50: 100%|██████████| 51/51 [00:05<00:00,  9.21it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 48/50: 100%|██████████| 7/7 [00:00<00:00, 17.04it/s, Val Loss=0.1843]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1843, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 6.0s, LR: 1.00e-06
Progress: 96.0% | Best F1: 0.2330 | ETA: 0.2min

Epoch 49/50


Training Epoch 49/50: 100%|██████████| 51/51 [00:05<00:00,  9.38it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 49/50: 100%|██████████| 7/7 [00:00<00:00, 16.86it/s, Val Loss=0.1843]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1843, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 98.0% | Best F1: 0.2330 | ETA: 0.1min

Epoch 50/50


Training Epoch 50/50: 100%|██████████| 51/51 [00:05<00:00,  9.37it/s, Loss=0.0002, LR=1.00e-06]
Validation Epoch 50/50: 100%|██████████| 7/7 [00:00<00:00, 16.98it/s, Val Loss=0.1844]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1844, F1: 0.2065, Acc: 0.4231
Time  - Epoch: 5.9s, LR: 1.00e-06
Progress: 100.0% | Best F1: 0.2330 | ETA: 0.0min

CASME II APEX FRAME POOLFORMER TRAINING COMPLETED
Training time: 5.5 minutes
Epochs completed: 50
Best validation F1: 0.2330 (epoch 6)
Final train F1: 1.0000
Final validation F1: 0.2065

Exporting training documentation...
Training documentation saved successfully: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/05_03_poolformer_casme2_af_prep/training_logs/casme2_poolformer_apex_frame_training_history.json
Loss function: Optimized Focal Loss
Model variant: sail/poolformer_m48
Dataset version: v7

Next: Cell 3 - CASME II Apex Frame PoolFormer Evaluation
Training pipeline with face-aware preprocessing completed successfully!


In [3]:
# @title Cell 3: CASME II Apex Frame PoolFormer Evaluation

# File: 05_03_PoolFormer_CASME2_AF_Cell3.py
# Location: experiments/05_03_PoolFormer_CASME2-AF-PREP.ipynb
# Purpose: Comprehensive evaluation framework for apex frame with face-aware preprocessing

import os
import time
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime

# Evaluation specific imports
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from concurrent.futures import ThreadPoolExecutor
import pickle
import warnings
warnings.filterwarnings('ignore')

print("CASME II Apex Frame PoolFormer Evaluation Framework")
print("=" * 60)
print(f"Dataset Version: v7 - Face-Aware Preprocessing")
print(f"Frame Strategy: Apex Frame")
print("=" * 60)

# CASME II evaluation configuration
EVALUATION_CONFIG_CASME2 = {
    'model_type': 'PoolFormer_CASME2_Apex_Frame_Face_Aware',
    'task_type': 'micro_expression_recognition',
    'num_classes': 7,
    'class_names': CASME2_CLASSES,
    'checkpoint_file': 'casme2_poolformer_apex_frame_best_f1.pth',
    'dataset_name': 'CASME_II',
    'dataset_version': 'v7',
    'preprocessing_method': 'face_aware_bbox_expansion',
    'input_size': '224x224',
    'evaluation_protocol': 'stratified_split'
}

print(f"\nCASME II PoolFormer Evaluation Configuration:")
print(f"  Model: {EVALUATION_CONFIG_CASME2['model_type']}")
print(f"  Task: {EVALUATION_CONFIG_CASME2['task_type']}")
print(f"  Dataset Version: {EVALUATION_CONFIG_CASME2['dataset_version']}")
print(f"  Preprocessing: {EVALUATION_CONFIG_CASME2['preprocessing_method']}")
print(f"  Classes: {EVALUATION_CONFIG_CASME2['class_names']}")
print(f"  Input size: {EVALUATION_CONFIG_CASME2['input_size']}")

# Enhanced test dataset with RAM caching
class CASME2DatasetEvaluation(Dataset):
    """Enhanced CASME II test dataset with comprehensive evaluation support"""

    def __init__(self, dataset_root, split, transform=None, use_ram_cache=True):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.emotions = []
        self.subjects = []
        self.filenames = []
        self.cached_images = []

        split_dir = os.path.join(dataset_root, split)

        if not os.path.exists(split_dir):
            raise ValueError(f"Split directory not found: {split_dir}")

        print(f"Loading CASME II {split} dataset for evaluation...")

        # Load all images from flat directory structure
        # Filename pattern: sub01_EP02_01f_happiness.jpg
        for img_file in os.listdir(split_dir):
            if img_file.endswith(('.jpg', '.jpeg', '.png')):
                # Extract emotion from filename
                emotion = img_file.rsplit('_', 1)[-1].split('.')[0]

                # Extract subject from filename (sub01, sub02, etc)
                subject = img_file.split('_')[0]

                if emotion in CASME2_CLASSES:
                    self.images.append(os.path.join(split_dir, img_file))
                    self.labels.append(CLASS_TO_IDX[emotion])
                    self.emotions.append(emotion)
                    self.subjects.append(subject)
                    self.filenames.append(img_file)

        print(f"Loaded {len(self.images)} CASME II {split} samples for evaluation")
        self._print_evaluation_distribution()

        # RAM caching for fast evaluation
        if self.use_ram_cache:
            self._preload_to_ram_evaluation()

    def _print_evaluation_distribution(self):
        """Print comprehensive distribution for evaluation analysis"""
        if len(self.labels) == 0:
            print("No test samples found!")
            return

        label_counts = {}
        subject_counts = {}

        for label, subject in zip(self.labels, self.subjects):
            label_counts[label] = label_counts.get(label, 0) + 1
            subject_counts[subject] = subject_counts.get(subject, 0) + 1

        print("Test set class distribution:")
        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

        print(f"Test set covers {len(subject_counts)} subjects")

        # Check for missing classes
        missing_classes = []
        for i, class_name in enumerate(CASME2_CLASSES):
            if i not in label_counts:
                missing_classes.append(class_name)

        if missing_classes:
            print(f"Missing classes in test set: {missing_classes}")

    def _preload_to_ram_evaluation(self):
        """RAM preloading with parallel loading optimized for evaluation"""
        print(f"Preloading {len(self.images)} test images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            """Load single image with error handling"""
            try:
                image = Image.open(img_path)
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('L', (224, 224), 128), False

        # Parallel loading with ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc="Loading test images to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_mb = len(self.cached_images) * 224 * 224 * 1 / 1e6
        print(f"Test RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_mb:.1f}MB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx])
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
            except:
                image = Image.new('L', (224, 224), 128)

        if self.transform:
            image = self.transform(image)

        return (image, self.labels[idx], self.filenames[idx],
                self.emotions[idx], self.subjects[idx])

def extract_logits_safe_casme2(outputs_all):
    """Robust logits extraction for CASME II PoolFormer model"""
    if isinstance(outputs_all, torch.Tensor):
        return outputs_all
    if isinstance(outputs_all, (tuple, list)):
        for item in outputs_all:
            if isinstance(item, torch.Tensor):
                return item
    if isinstance(outputs_all, dict):
        for key in ('logits', 'logit', 'predictions', 'outputs', 'scores'):
            value = outputs_all.get(key)
            if isinstance(value, torch.Tensor):
                return value
        for value in outputs_all.values():
            if isinstance(value, torch.Tensor):
                return value
    raise RuntimeError("Unable to extract tensor logits from CASME II PoolFormer model output")

def load_trained_model_casme2(checkpoint_path, device):
    """Load trained CASME II PoolFormer model with comprehensive compatibility"""
    print(f"Loading trained CASME II Apex Frame PoolFormer model from: {checkpoint_path}")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # Multiple loading approaches for maximum compatibility
    checkpoint = None
    loading_method = "unknown"

    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        loading_method = "standard"
    except Exception as e1:
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            loading_method = "weights_only_false"
        except Exception as e2:
            try:
                import pickle
                with open(checkpoint_path, 'rb') as f:
                    checkpoint = pickle.load(f)
                loading_method = "pickle"
            except Exception as e3:
                raise RuntimeError(f"All loading methods failed: {e1}, {e2}, {e3}")

    print(f"Checkpoint loaded using: {loading_method}")

    # Initialize CASME II PoolFormer model
    model = PoolFormerCASME2Baseline(
        num_classes=EVALUATION_CONFIG_CASME2['num_classes'],
        dropout_rate=CASME2_POOLFORMER_CONFIG['dropout_rate']
    ).to(device)

    # Load state dict with fallback approaches
    state_dict = checkpoint.get('model_state_dict', checkpoint)

    try:
        model.load_state_dict(state_dict, strict=True)
        print("Model state loaded with strict=True")
    except Exception as e:
        print(f"Strict loading failed, trying non-strict: {str(e)[:100]}...")
        try:
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            if missing_keys or unexpected_keys:
                print(f"Non-strict loading: Missing {len(missing_keys)}, Unexpected {len(unexpected_keys)}")
            else:
                print("Model state loaded with strict=False (no key mismatches)")
        except Exception as e2:
            raise RuntimeError(f"Both loading approaches failed: {e2}")

    model.eval()

    # Extract training information
    training_info = {
        'best_val_f1': float(checkpoint.get('best_f1', 0.0)),
        'best_val_loss': float(checkpoint.get('best_loss', float('inf'))),
        'best_val_accuracy': float(checkpoint.get('best_acc', 0.0)),
        'best_epoch': int(checkpoint.get('epoch', 0)) + 1,
        'model_checkpoint': EVALUATION_CONFIG_CASME2['checkpoint_file'],
        'num_classes': EVALUATION_CONFIG_CASME2['num_classes'],
        'config': checkpoint.get('casme2_config', {})
    }

    print(f"Model loaded successfully:")
    print(f"  Best validation F1: {training_info['best_val_f1']:.4f}")
    print(f"  Best validation accuracy: {training_info['best_val_accuracy']:.4f}")
    print(f"  Best epoch: {training_info['best_epoch']}")
    print(f"  Model classes: {EVALUATION_CONFIG_CASME2['num_classes']}")

    return model, training_info

def run_model_inference_casme2(model, test_loader, device):
    """Run CASME II PoolFormer model inference with comprehensive tracking"""
    print("Running CASME II Apex Frame PoolFormer model inference on test set...")

    model.eval()
    all_predictions = []
    all_probabilities = []
    all_labels = []
    all_filenames = []
    all_emotions = []
    all_subjects = []

    inference_start = time.time()

    with torch.no_grad():
        for batch_idx, (images, labels, filenames, emotions, subjects) in enumerate(
            tqdm(test_loader, desc="CASME II Inference")):

            images = images.to(device)

            # Forward pass with robust output extraction
            try:
                outputs_raw = model(images)
                outputs = extract_logits_safe_casme2(outputs_raw)
            except Exception as e:
                print(f"Error in model forward pass: {e}")
                outputs = model(images)
                if not isinstance(outputs, torch.Tensor):
                    outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs

            # Validate output shape for 7 CASME II classes
            if outputs.shape[1] != 7:
                print(f"Warning: Expected 7 classes output, got {outputs.shape[1]}")

            # Get probabilities and predictions
            probabilities = torch.softmax(outputs, dim=1)
            predictions = torch.argmax(probabilities, dim=1)

            # Store results (CPU for memory efficiency)
            all_predictions.extend(predictions.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_filenames.extend(filenames)
            all_emotions.extend(emotions)
            all_subjects.extend(subjects)

    inference_time = time.time() - inference_start

    print(f"CASME II inference completed: {len(all_predictions)} samples in {inference_time:.2f}s")

    # Analyze prediction distribution
    predictions_array = np.array(all_predictions)
    labels_array = np.array(all_labels)

    unique_predictions, pred_counts = np.unique(predictions_array, return_counts=True)
    print(f"Predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    unique_labels, label_counts = np.unique(labels_array, return_counts=True)
    print(f"True classes in test: {[CASME2_CLASSES[i] for i in unique_labels]}")

    return {
        'predictions': predictions_array,
        'probabilities': np.array(all_probabilities),
        'labels': labels_array,
        'filenames': all_filenames,
        'emotions': all_emotions,
        'subjects': all_subjects,
        'inference_time': inference_time,
        'samples_count': len(predictions_array)
    }

def analyze_wrong_predictions_casme2(inference_results):
    """Comprehensive wrong predictions analysis for CASME II"""
    print("Analyzing wrong predictions for CASME II micro-expression recognition...")

    predictions = inference_results['predictions']
    labels = inference_results['labels']
    filenames = inference_results['filenames']
    emotions = inference_results['emotions']
    subjects = inference_results['subjects']

    # Find wrong predictions
    wrong_mask = predictions != labels
    wrong_indices = np.where(wrong_mask)[0]

    # Organize by true emotion class
    wrong_predictions_by_class = {}
    subject_error_analysis = {}

    for class_name in CASME2_CLASSES:
        wrong_predictions_by_class[class_name] = []

    # Analyze wrong predictions
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        filename = filenames[idx]
        emotion = emotions[idx]
        subject = subjects[idx]

        true_class = CASME2_CLASSES[true_label]
        pred_class = CASME2_CLASSES[pred_label]

        wrong_info = {
            'filename': filename,
            'subject': subject,
            'true_label': int(true_label),
            'true_class': true_class,
            'predicted_label': int(pred_label),
            'predicted_class': pred_class,
            'emotion': emotion
        }

        wrong_predictions_by_class[true_class].append(wrong_info)

        # Subject error tracking
        if subject not in subject_error_analysis:
            subject_error_analysis[subject] = {'total': 0, 'wrong': 0, 'errors': []}
        subject_error_analysis[subject]['wrong'] += 1
        subject_error_analysis[subject]['errors'].append(wrong_info)

    # Count total samples per subject
    for subject in subjects:
        if subject in subject_error_analysis:
            subject_error_analysis[subject]['total'] += 1
        else:
            subject_error_analysis[subject] = {'total': 1, 'wrong': 0, 'errors': []}

    # Calculate error rates per subject
    for subject in subject_error_analysis:
        total = subject_error_analysis[subject]['total']
        wrong = subject_error_analysis[subject]['wrong']
        subject_error_analysis[subject]['error_rate'] = wrong / total if total > 0 else 0.0

    # Summary statistics
    total_wrong = len(wrong_indices)
    total_samples = len(predictions)
    error_rate = (total_wrong / total_samples) * 100

    # Confusion patterns analysis
    confusion_patterns = {}
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        pattern = f"{CASME2_CLASSES[true_label]}_to_{CASME2_CLASSES[pred_label]}"
        confusion_patterns[pattern] = confusion_patterns.get(pattern, 0) + 1

    analysis_results = {
        'analysis_metadata': {
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'model_type': EVALUATION_CONFIG_CASME2['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2['dataset_name'],
            'dataset_version': EVALUATION_CONFIG_CASME2['dataset_version'],
            'preprocessing': EVALUATION_CONFIG_CASME2['preprocessing_method'],
            'total_samples': int(total_samples),
            'total_wrong_predictions': int(total_wrong),
            'overall_error_rate': float(error_rate)
        },
        'wrong_predictions_by_class': wrong_predictions_by_class,
        'subject_error_analysis': subject_error_analysis,
        'confusion_patterns': confusion_patterns,
        'error_summary': {
            class_name: len(wrong_predictions_by_class[class_name])
            for class_name in CASME2_CLASSES
        }
    }

    return analysis_results

def calculate_comprehensive_metrics_casme2(inference_results):
    """Calculate comprehensive evaluation metrics for CASME II micro-expression recognition"""
    print("Calculating comprehensive metrics for CASME II micro-expression recognition...")

    predictions = inference_results['predictions']
    probabilities = inference_results['probabilities']
    labels = inference_results['labels']

    if len(predictions) == 0:
        raise ValueError("No predictions to evaluate!")

    # Identify available classes in test set
    unique_test_labels = sorted(np.unique(labels))
    unique_predictions = sorted(np.unique(predictions))

    print(f"Test set contains labels: {[CASME2_CLASSES[i] for i in unique_test_labels]}")
    print(f"Model predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    # Basic metrics
    accuracy = accuracy_score(labels, predictions)

    # Macro metrics (only for available classes)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions, labels=unique_test_labels, average='macro', zero_division=0
    )

    print(f"Macro F1 (available classes): {f1:.4f}")

    # Per-class metrics (all 7 classes)
    precision_per_class, recall_per_class, f1_per_class, support_per_class = precision_recall_fscore_support(
        labels, predictions, labels=range(7), average=None, zero_division=0
    )

    # Confusion matrix
    cm = confusion_matrix(labels, predictions, labels=range(7))

    # Multi-class AUC (only for classes with test samples)
    auc_scores = {}
    fpr_dict = {}
    tpr_dict = {}

    try:
        labels_binarized = label_binarize(labels, classes=range(7))

        for i, class_name in enumerate(CASME2_CLASSES):
            if i in unique_test_labels and len(np.unique(labels_binarized[:, i])) > 1:
                fpr, tpr, _ = roc_curve(labels_binarized[:, i], probabilities[:, i])
                auc_score = auc(fpr, tpr)
                auc_scores[class_name] = float(auc_score)
                fpr_dict[class_name] = fpr.tolist()
                tpr_dict[class_name] = tpr.tolist()
            else:
                auc_scores[class_name] = 0.0
                fpr_dict[class_name] = [0.0, 1.0]
                tpr_dict[class_name] = [0.0, 0.0]

        # Macro AUC for available classes
        available_auc_scores = [auc_scores[CASME2_CLASSES[i]] for i in unique_test_labels]
        macro_auc = float(np.mean(available_auc_scores)) if available_auc_scores else 0.0

    except Exception as e:
        print(f"Warning: AUC calculation failed: {e}")
        auc_scores = {class_name: 0.0 for class_name in CASME2_CLASSES}
        macro_auc = 0.0

    # Subject-level analysis
    subjects = inference_results['subjects']
    subject_performance = {}

    for subject in set(subjects):
        subject_mask = [s == subject for s in subjects]
        subject_predictions = predictions[subject_mask]
        subject_labels = labels[subject_mask]

        if len(subject_predictions) > 0:
            subject_acc = accuracy_score(subject_labels, subject_predictions)
            subject_performance[subject] = {
                'accuracy': float(subject_acc),
                'samples': int(len(subject_predictions)),
                'correct': int(np.sum(subject_predictions == subject_labels))
            }

    # Comprehensive results
    comprehensive_results = {
        'evaluation_metadata': {
            'model_type': EVALUATION_CONFIG_CASME2['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2['dataset_name'],
            'dataset_version': EVALUATION_CONFIG_CASME2['dataset_version'],
            'preprocessing_method': EVALUATION_CONFIG_CASME2['preprocessing_method'],
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'num_classes': EVALUATION_CONFIG_CASME2['num_classes'],
            'class_names': EVALUATION_CONFIG_CASME2['class_names'],
            'test_samples': int(len(labels)),
            'available_classes': [CASME2_CLASSES[i] for i in unique_test_labels],
            'missing_classes': [CASME2_CLASSES[i] for i in range(7) if i not in unique_test_labels]
        },

        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_precision': float(precision),
            'macro_recall': float(recall),
            'macro_f1': float(f1),
            'macro_auc': macro_auc
        },

        'per_class_performance': {},

        'confusion_matrix': cm.tolist(),

        'subject_level_performance': subject_performance,

        'roc_analysis': {
            'auc_scores': auc_scores,
            'fpr_curves': fpr_dict,
            'tpr_curves': tpr_dict
        },

        'inference_performance': {
            'total_time_seconds': float(inference_results['inference_time']),
            'average_time_ms_per_sample': float(inference_results['inference_time'] * 1000 / len(labels))
        }
    }

    # Per-class performance details
    for i, class_name in enumerate(CASME2_CLASSES):
        comprehensive_results['per_class_performance'][class_name] = {
            'precision': float(precision_per_class[i]),
            'recall': float(recall_per_class[i]),
            'f1_score': float(f1_per_class[i]),
            'support': int(support_per_class[i]),
            'auc': auc_scores[class_name],
            'in_test_set': i in unique_test_labels
        }

    return comprehensive_results

def save_evaluation_results_casme2(evaluation_results, wrong_predictions_results, results_dir):
    """Save comprehensive evaluation results for CASME II"""
    os.makedirs(results_dir, exist_ok=True)

    # Save main evaluation results
    results_file = f"{results_dir}/casme2_poolformer_apex_frame_evaluation_results.json"
    with open(results_file, 'w') as f:
        json.dump(evaluation_results, f, indent=2, default=str)

    # Save wrong predictions analysis
    wrong_predictions_file = f"{results_dir}/casme2_poolformer_apex_frame_wrong_predictions.json"
    with open(wrong_predictions_file, 'w') as f:
        json.dump(wrong_predictions_results, f, indent=2, default=str)

    print(f"Evaluation results saved:")
    print(f"  Main results: {os.path.basename(results_file)}")
    print(f"  Wrong predictions: {os.path.basename(wrong_predictions_file)}")

    return results_file, wrong_predictions_file

# Main evaluation execution
try:
    print("\nStarting CASME II Apex Frame PoolFormer comprehensive evaluation...")
    print(f"Using test dataset: v7 Face-Aware Preprocessing")

    # Create test dataset
    print(f"\nCreating CASME II test dataset from v7...")
    casme2_test_dataset = CASME2DatasetEvaluation(
        dataset_root=DATASET_ROOT,
        split='test',
        transform=GLOBAL_CONFIG_CASME2['transform_val'],
        use_ram_cache=True
    )

    if len(casme2_test_dataset) == 0:
        raise ValueError("No test samples found! Check test data path.")

    casme2_test_loader = DataLoader(
        casme2_test_dataset,
        batch_size=CASME2_POOLFORMER_CONFIG['batch_size'],
        shuffle=False,
        num_workers=CASME2_POOLFORMER_CONFIG['num_workers'],
        pin_memory=True
    )

    # Load trained model
    checkpoint_path = f"{GLOBAL_CONFIG_CASME2['checkpoint_root']}/{EVALUATION_CONFIG_CASME2['checkpoint_file']}"
    casme2_model, training_info = load_trained_model_casme2(checkpoint_path, GLOBAL_CONFIG_CASME2['device'])

    # Run inference
    inference_results = run_model_inference_casme2(casme2_model, casme2_test_loader, GLOBAL_CONFIG_CASME2['device'])

    # Calculate comprehensive metrics
    evaluation_results = calculate_comprehensive_metrics_casme2(inference_results)

    # Analyze wrong predictions
    wrong_predictions_results = analyze_wrong_predictions_casme2(inference_results)

    # Add training information
    evaluation_results['training_information'] = training_info

    # Save results
    results_dir = f"{GLOBAL_CONFIG_CASME2['results_root']}/evaluation_results"
    results_file, wrong_file = save_evaluation_results_casme2(
        evaluation_results, wrong_predictions_results, results_dir
    )

    # Display comprehensive results
    print("\n" + "=" * 60)
    print("CASME II APEX FRAME POOLFORMER EVALUATION RESULTS")
    print("=" * 60)
    print(f"Dataset: v7 Face-Aware Preprocessing (Apex Frame)")

    # Overall performance
    overall = evaluation_results['overall_performance']
    print(f"\nOverall Performance (Macro - Available Classes):")
    print(f"  Accuracy:  {overall['accuracy']:.4f}")
    print(f"  Precision: {overall['macro_precision']:.4f}")
    print(f"  Recall:    {overall['macro_recall']:.4f}")
    print(f"  F1 Score:  {overall['macro_f1']:.4f}")
    print(f"  AUC:       {overall['macro_auc']:.4f}")

    # Per-class performance
    print(f"\nPer-Class Performance:")
    for class_name, metrics in evaluation_results['per_class_performance'].items():
        in_test = "Present" if metrics['in_test_set'] else "Missing"
        print(f"  {class_name} [{in_test}]: F1={metrics['f1_score']:.4f}, "
              f"AUC={metrics['auc']:.4f}, Support={metrics['support']}")

    # Training vs test comparison
    print(f"\nTraining vs Test Performance:")
    training_f1 = training_info['best_val_f1']
    training_acc = training_info['best_val_accuracy']
    test_f1 = overall['macro_f1']
    test_acc = overall['accuracy']

    print(f"  Training Val F1:  {training_f1:.4f}")
    print(f"  Test F1:          {test_f1:.4f}")
    print(f"  F1 Difference:    {training_f1 - test_f1:+.4f}")
    print(f"  Training Val Acc: {training_acc:.4f}")
    print(f"  Test Accuracy:    {test_acc:.4f}")
    print(f"  Acc Difference:   {training_acc - test_acc:+.4f}")
    print(f"  Best Epoch:       {training_info['best_epoch']}")

    # Wrong predictions summary
    print(f"\n" + "=" * 40)
    print("WRONG PREDICTIONS ANALYSIS")
    print("=" * 40)

    wrong_meta = wrong_predictions_results['analysis_metadata']
    print(f"Total wrong predictions: {wrong_meta['total_wrong_predictions']} / {wrong_meta['total_samples']}")
    print(f"Overall error rate: {wrong_meta['overall_error_rate']:.2f}%")

    print(f"\nErrors by True Class:")
    for class_name, error_count in wrong_predictions_results['error_summary'].items():
        if error_count > 0:
            wrong_samples = wrong_predictions_results['wrong_predictions_by_class'][class_name]
            print(f"  {class_name}: {error_count} errors")
            for sample in wrong_samples[:3]:
                print(f"    - {sample['filename']} -> predicted as {sample['predicted_class']}")
            if len(wrong_samples) > 3:
                print(f"    ... and {len(wrong_samples) - 3} more")

    # Subject-level analysis
    print(f"\nSubject-Level Performance:")
    subject_perfs = list(evaluation_results['subject_level_performance'].items())
    subject_perfs.sort(key=lambda x: x[1]['accuracy'], reverse=True)
    for subject, perf in subject_perfs[:5]:
        print(f"  {subject}: {perf['accuracy']:.3f} ({perf['correct']}/{perf['samples']})")

    # Most common confusion patterns
    print(f"\nMost Common Confusion Patterns:")
    patterns = sorted(wrong_predictions_results['confusion_patterns'].items(),
                     key=lambda x: x[1], reverse=True)
    for pattern, count in patterns[:3]:
        print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    print(f"  Total time: {inference_results['inference_time']:.2f}s")
    print(f"  Speed: {evaluation_results['inference_performance']['average_time_ms_per_sample']:.1f} ms/sample")

    print(f"\nTest Dataset Info:")
    print(f"  Version: {EVALUATION_CONFIG_CASME2['dataset_version']}")
    print(f"  Preprocessing: {EVALUATION_CONFIG_CASME2['preprocessing_method']}")
    print(f"  Missing classes: {evaluation_results['evaluation_metadata']['missing_classes']}")

    print("\n" + "=" * 60)
    print("CASME II APEX FRAME POOLFORMER EVALUATION COMPLETED")
    print("=" * 60)

except Exception as e:
    print(f"Evaluation failed: {e}")
    import traceback
    traceback.print_exc()

finally:
    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

print(f"\nEvaluation completed successfully")
print("Next: Cell 4 - Generate confusion matrix visualization")

CASME II Apex Frame PoolFormer Evaluation Framework
Dataset Version: v7 - Face-Aware Preprocessing
Frame Strategy: Apex Frame

CASME II PoolFormer Evaluation Configuration:
  Model: PoolFormer_CASME2_Apex_Frame_Face_Aware
  Task: micro_expression_recognition
  Dataset Version: v7
  Preprocessing: face_aware_bbox_expansion
  Classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
  Input size: 224x224

Starting CASME II Apex Frame PoolFormer comprehensive evaluation...
Using test dataset: v7 Face-Aware Preprocessing

Creating CASME II test dataset from v7...
Loading CASME II test dataset for evaluation...
Loaded 28 CASME II test samples for evaluation
Test set class distribution:
  others: 10 samples (35.7%)
  disgust: 7 samples (25.0%)
  happiness: 4 samples (14.3%)
  repression: 3 samples (10.7%)
  surprise: 3 samples (10.7%)
  sadness: 1 samples (3.6%)
Test set covers 16 subjects
Missing classes in test set: ['fear']
Preloading 28 test images to RAM w

Loading test images to RAM: 100%|██████████| 28/28 [00:01<00:00, 19.20it/s]


Test RAM caching completed: 28/28 images, ~1.4MB
Loading trained CASME II Apex Frame PoolFormer model from: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/05_03_poolformer_casme2_af_prep/casme2_poolformer_apex_frame_best_f1.pth
Checkpoint loaded using: standard
PoolFormer feature dimension: 768
PoolFormer CASME II Simplified: 768 -> 256 -> 7
Dropout rate: 0.3 (increased for small dataset regularization)
Model state loaded with strict=True
Model loaded successfully:
  Best validation F1: 0.2330
  Best validation accuracy: 0.3846
  Best epoch: 6
  Model classes: 7
Running CASME II Apex Frame PoolFormer model inference on test set...


CASME II Inference: 100%|██████████| 7/7 [00:00<00:00, 15.39it/s]


CASME II inference completed: 28 samples in 0.46s
Predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
True classes in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Calculating comprehensive metrics for CASME II micro-expression recognition...
Test set contains labels: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Model predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Macro F1 (available classes): 0.2315
Analyzing wrong predictions for CASME II micro-expression recognition...
Evaluation results saved:
  Main results: casme2_poolformer_apex_frame_evaluation_results.json
  Wrong predictions: casme2_poolformer_apex_frame_wrong_predictions.json

CASME II APEX FRAME POOLFORMER EVALUATION RESULTS
Dataset: v7 Face-Aware Preprocessing (Apex Frame)

Overall Performance (Macro - Available Classes):
  Accuracy:  0.3929
  Precision: 0.2324
  Recall:    0.2369
  

In [4]:
# @title Cell 4: CASME II Apex Frame PoolFormer Confusion Matrix Generation

# File: 05_03_PoolFormer_CASME2_AF_Cell4.py
# Location: experiments/05_03_PoolFormer_CASME2-AF-PREP.ipynb
# Purpose: Generate professional confusion matrix visualization with comprehensive metrics

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II Apex Frame PoolFormer Confusion Matrix Generation")
print("=" * 60)

# Project paths configuration
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/05_03_poolformer_casme2_af_prep"

def find_evaluation_json_files_casme2(results_path):
    """Find CASME II evaluation JSON files"""
    json_files = {}
    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        # Look for evaluation results
        eval_pattern = f"{eval_dir}/casme2_poolformer_apex_frame_evaluation_results.json"
        eval_files = glob.glob(eval_pattern)

        if eval_files:
            json_files['main'] = eval_files[0]
            print(f"Found evaluation file: {os.path.basename(eval_files[0])}")

        # Look for wrong predictions
        wrong_pattern = f"{eval_dir}/casme2_poolformer_apex_frame_wrong_predictions.json"
        wrong_files = glob.glob(wrong_pattern)

        if wrong_files:
            json_files['wrong'] = wrong_files[0]
            print(f"Found wrong predictions: {os.path.basename(wrong_files[0])}")

        if not json_files:
            print(f"WARNING: No evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (evaluation) has been executed first!")
    else:
        print(f"ERROR: Evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results_casme2(json_path):
    """Load and parse CASME II evaluation results JSON"""
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1_casme2(per_class_performance):
    """Calculate weighted F1 score for CASME II micro-expression classes"""
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy_casme2(confusion_matrix):
    """Calculate balanced accuracy for CASME II 7-class micro-expression recognition"""
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []

    # Find classes with actual test samples
    classes_with_samples = []
    for i in range(n_classes):
        if cm[i, :].sum() > 0:
            classes_with_samples.append(i)

    for i in classes_with_samples:
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - tp - fn - fp

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color_casme2(color_value, threshold=0.5):
    """Determine optimal text color based on background intensity"""
    return 'white' if color_value > threshold else 'black'

def create_confusion_matrix_plot_casme2(data, output_path):
    """Create professional confusion matrix visualization for CASME II micro-expression recognition"""

    # Extract data
    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']

    print(f"Processing confusion matrix for CASME II classes: {class_names}")
    print(f"Dataset version: {meta['dataset_version']}")
    print(f"Preprocessing: {meta['preprocessing_method']}")
    print(f"Confusion matrix shape: {cm.shape}")

    # Calculate comprehensive metrics
    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1_casme2(per_class)
    balanced_acc = calculate_balanced_accuracy_casme2(cm)

    print(f"Calculated metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Balanced Acc: {balanced_acc:.4f}, Accuracy: {accuracy:.4f}")

    # Row-wise normalization for percentage display
    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    # Create visualization
    fig, ax = plt.subplots(figsize=(12, 10))

    # Color scheme optimized for micro-expression research
    cmap = 'Blues'

    # Create heatmap with improved color scaling
    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    # Annotate cells with count and percentage
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            # Handle percentage calculation for classes with 0 samples
            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"

            # Determine text color based on cell intensity
            cell_value = cm_pct[i, j]
            text_color = determine_text_color_casme2(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    # Configure axes
    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    # Add preprocessing info note
    preprocessing_note = f"Preprocessing: {meta['preprocessing_method']}\nDataset: {meta['dataset_version']}"
    missing_classes = meta.get('missing_classes', [])
    if missing_classes:
        preprocessing_note += f"\nMissing: {', '.join(missing_classes)}"

    ax.text(0.02, 0.98, preprocessing_note, transform=ax.transAxes, fontsize=8,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    # Create comprehensive title with all metrics
    title = f"CASME II Apex Frame Micro-Expression Recognition - PoolFormer Face-Aware\n"
    title += f"Acc: {accuracy:.4f}  |  Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes
    }

def generate_performance_summary_casme2(evaluation_data, wrong_predictions_data=None):
    """Generate comprehensive performance summary for CASME II"""

    print("\n" + "=" * 60)
    print("CASME II APEX FRAME POOLFORMER PERFORMANCE SUMMARY")
    print("=" * 60)

    # Overall performance
    overall = evaluation_data['overall_performance']
    meta = evaluation_data['evaluation_metadata']

    print(f"Dataset: {meta['dataset']}")
    print(f"Dataset version: {meta.get('dataset_version', 'N/A')}")
    print(f"Preprocessing: {meta.get('preprocessing_method', 'N/A')}")
    print(f"Test samples: {meta['test_samples']}")
    print(f"Model: {meta['model_type']}")
    print(f"Evaluation date: {meta['evaluation_timestamp']}")

    print(f"\nOverall Performance:")
    print(f"  Accuracy:         {overall['accuracy']:.4f}")
    print(f"  Macro Precision:  {overall['macro_precision']:.4f}")
    print(f"  Macro Recall:     {overall['macro_recall']:.4f}")
    print(f"  Macro F1:         {overall['macro_f1']:.4f}")
    print(f"  Macro AUC:        {overall['macro_auc']:.4f}")

    # Per-class performance
    print(f"\nPer-Class Performance:")
    per_class = evaluation_data['per_class_performance']

    print(f"{'Class':<12} {'F1':<8} {'Precision':<10} {'Recall':<8} {'AUC':<8} {'Support':<8} {'In Test'}")
    print("-" * 65)

    for class_name, metrics in per_class.items():
        in_test = "Yes" if metrics['in_test_set'] else "No"
        print(f"{class_name:<12} {metrics['f1_score']:<8.4f} {metrics['precision']:<10.4f} "
              f"{metrics['recall']:<8.4f} {metrics['auc']:<8.4f} {metrics['support']:<8} {in_test}")

    # Training vs test performance
    if 'training_information' in evaluation_data:
        training = evaluation_data['training_information']
        print(f"\nTraining vs Test Comparison:")
        print(f"  Training Val F1:  {training['best_val_f1']:.4f}")
        print(f"  Test F1:          {overall['macro_f1']:.4f}")
        print(f"  Performance Gap:  {training['best_val_f1'] - overall['macro_f1']:+.4f}")
        print(f"  Best Epoch:       {training['best_epoch']}")

    # Class availability analysis
    missing_classes = meta.get('missing_classes', [])
    available_classes = meta.get('available_classes', [])

    print(f"\nClass Availability Analysis:")
    print(f"  Available classes: {len(available_classes)}/7")
    print(f"  Missing classes: {missing_classes if missing_classes else 'None'}")

    # Wrong predictions summary if available
    if wrong_predictions_data:
        wrong_meta = wrong_predictions_data['analysis_metadata']
        print(f"\nError Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']}/{wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        # Top confusion patterns
        patterns = wrong_predictions_data.get('confusion_patterns', {})
        if patterns:
            print(f"\nTop Confusion Patterns:")
            sorted_patterns = sorted(patterns.items(), key=lambda x: x[1], reverse=True)[:3]
            for pattern, count in sorted_patterns:
                print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    inference = evaluation_data['inference_performance']
    print(f"  Total time: {inference['total_time_seconds']:.2f}s")
    print(f"  Speed: {inference['average_time_ms_per_sample']:.1f} ms/sample")

# Find evaluation JSON files
json_files = find_evaluation_json_files_casme2(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (evaluation) has been executed first!")
else:
    print(f"\nFound evaluation results")

# Create output directory
output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Process evaluation results
if 'main' in json_files:
    print(f"\n{'='*60}")
    print(f"Processing Evaluation Results")
    print(f"{'='*60}")

    # Load evaluation data
    eval_data = load_evaluation_results_casme2(json_files['main'])

    # Load wrong predictions data if available
    wrong_data = None
    if 'wrong' in json_files:
        wrong_data = load_evaluation_results_casme2(json_files['wrong'])

    if eval_data is not None:
        try:
            # Generate confusion matrix
            cm_output_path = os.path.join(output_dir, "confusion_matrix_CASME2_PoolFormer_Apex_Frame_v7.png")
            metrics = create_confusion_matrix_plot_casme2(eval_data, cm_output_path)

            print(f"\nSUCCESS: Confusion matrix generated successfully")
            print(f"Output file: {os.path.basename(cm_output_path)}")

            # Display metrics summary
            print(f"\nPerformance Metrics Summary:")
            print(f"  Accuracy:        {metrics['accuracy']:.4f}")
            print(f"  Macro F1:        {metrics['macro_f1']:.4f}")
            print(f"  Weighted F1:     {metrics['weighted_f1']:.4f}")
            print(f"  Balanced Acc:    {metrics['balanced_accuracy']:.4f}")

            if metrics['missing_classes']:
                print(f"  Missing classes: {metrics['missing_classes']}")

        except Exception as e:
            print(f"ERROR: Failed to generate confusion matrix: {str(e)}")
            import traceback
            traceback.print_exc()

        # Generate comprehensive summary
        generate_performance_summary_casme2(eval_data, wrong_data)
    else:
        print(f"ERROR: Could not load evaluation data")

    # Final summary
    print(f"\n" + "=" * 60)
    print("CASME II APEX FRAME POOLFORMER CONFUSION MATRIX COMPLETED")
    print("=" * 60)

    print(f"\nGenerated file:")
    print(f"  confusion_matrix_CASME2_PoolFormer_Apex_Frame_v7.png")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No evaluation results found")
    print("Please run Cell 3 (evaluation) first to generate evaluation JSON files")

print("\nCell 4 completed - CASME II Apex Frame PoolFormer confusion matrix generated")

CASME II Apex Frame PoolFormer Confusion Matrix Generation
Found evaluation file: casme2_poolformer_apex_frame_evaluation_results.json
Found wrong predictions: casme2_poolformer_apex_frame_wrong_predictions.json

Found evaluation results

Processing Evaluation Results
Successfully loaded: casme2_poolformer_apex_frame_evaluation_results.json
Successfully loaded: casme2_poolformer_apex_frame_wrong_predictions.json
Processing confusion matrix for CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
Dataset version: v7
Preprocessing: face_aware_bbox_expansion
Confusion matrix shape: (7, 7)
Calculated metrics - Macro F1: 0.2315, Weighted F1: 0.3948, Balanced Acc: 0.5557, Accuracy: 0.3929
Confusion matrix saved to: confusion_matrix_CASME2_PoolFormer_Apex_Frame_v7.png

SUCCESS: Confusion matrix generated successfully
Output file: confusion_matrix_CASME2_PoolFormer_Apex_Frame_v7.png

Performance Metrics Summary:
  Accuracy:        0.3929
  Macro F1: