In [1]:
# @title Cell 1: CASME II Multi-Frame Sequence ViT Infrastructure Configuration

# File: 07_01_ViT_CASME2_MFS_Cell1.py
# Location: experiments/07_01_ViT_CASME2-MFS-PREP.ipynb
# Purpose: ViT-Base for CASME II micro-expression recognition with multi-frame sequence strategy and face-aware preprocessing

from google.colab import drive
print("=" * 60)
print("CASME II MULTI-FRAME SEQUENCE ViT WITH FACE-AWARE PREPROCESSING")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Project paths configuration
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/preprocessed_v9"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/07_01_vit_casme2_mfs_prep"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/07_01_vit_casme2_mfs_prep"

# Load CASME II v9 preprocessing metadata
PREPROCESSING_SUMMARY = f"{DATASET_ROOT}/preprocessing_summary.json"

print("CASME II Multi-Frame Sequence ViT - Face-Aware Preprocessing Infrastructure")
print("=" * 60)

# Validate preprocessing metadata exists
if not os.path.exists(PREPROCESSING_SUMMARY):
    raise FileNotFoundError(f"v9 preprocessing summary not found: {PREPROCESSING_SUMMARY}")

# Load v9 preprocessing metadata
print("Loading CASME II v9 preprocessing metadata...")
with open(PREPROCESSING_SUMMARY, 'r') as f:
    preprocessing_info = json.load(f)

print(f"Dataset variant: {preprocessing_info['variant']}")
print(f"Processing date: {preprocessing_info['processing_date']}")
print(f"Preprocessing method: {preprocessing_info['preprocessing_method']}")
print(f"Total images processed: {preprocessing_info['total_processed']}")
print(f"Face detection rate: {preprocessing_info['face_detection_stats']['detection_rate']:.2%}")

# Extract preprocessing parameters
preproc_params = preprocessing_info['preprocessing_parameters']
print(f"Target size: {preproc_params['target_size']}x{preproc_params['target_size']}px")
print(f"BBox expansion: {preproc_params['bbox_expansion']}px (all directions)")

# Display split information
print(f"\nDataset split information:")
print(f"  Train samples: {preprocessing_info['splits']['train']['total_images']}")
print(f"  Validation samples: {preprocessing_info['splits']['val']['total_images']}")
print(f"  Test samples: {preprocessing_info['splits']['test']['total_images']}")

# =====================================================
# EXPERIMENT CONFIGURATION - Multi-Frame Sequence with Face-Aware Preprocessing
# =====================================================
# This configuration supports 4 experiment scenarios:
# 1. ViT-Base Patch16 + CrossEntropy Loss
# 2. ViT-Base Patch16 + Focal Loss
# 3. ViT-Base Patch32 + CrossEntropy Loss
# 4. ViT-Base Patch32 + Focal Loss
#
# Toggle VIT_MODEL_VARIANT for model selection: 'patch16' or 'patch32'
# Toggle USE_FOCAL_LOSS for loss function: False (CrossEntropy) or True (Focal)
# =====================================================

# FOCAL LOSS CONFIGURATION - Toggle for experimentation
USE_FOCAL_LOSS = True  # Default: CrossEntropy, Set True to enable Focal Loss
FOCAL_LOSS_GAMMA = 2.0  # Focal loss focusing parameter (if enabled)

# OPTIMIZED CLASS WEIGHTS CONFIGURATION
# CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
# v9 Train distribution: [1027, 650, 325, 273, 260, 65, 13] - inverse sqrt frequency approach

# CrossEntropy Loss - Optimized inverse square root frequency weights
CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.26, 1.78, 1.94, 1.99, 3.98, 8.90]

# Focal Loss - Normalized per-class alpha values (sum = 1.0)
FOCAL_LOSS_ALPHA_WEIGHTS = [0.048, 0.060, 0.085, 0.093, 0.095, 0.191, 0.427]

# VIT MODEL CONFIGURATION - Support Patch16 and Patch32 variants
# ViT-Base Patch16: 86M parameters, fine-grained attention with 14x14 patches at 224px
# ViT-Base Patch32: 88M parameters, efficient attention with 7x7 patches at 224px
VIT_MODEL_VARIANT = 'patch16'  # Options: 'patch16' or 'patch32'

# Dynamic ViT model selection based on patch size
if VIT_MODEL_VARIANT == 'patch16':
    VIT_MODEL_NAME = 'google/vit-base-patch16-224-in21k'
    PATCH_SIZE = 16
    print("Using ViT-Base Patch16 for fine-grained micro-expression analysis (86M parameters)")
elif VIT_MODEL_VARIANT == 'patch32':
    VIT_MODEL_NAME = 'google/vit-base-patch32-224-in21k'
    PATCH_SIZE = 32
    print("Using ViT-Base Patch32 for efficient micro-expression recognition (88M parameters)")
else:
    raise ValueError(f"Unsupported VIT_MODEL_VARIANT: {VIT_MODEL_VARIANT}")

# Display experiment configuration
print("\n" + "=" * 50)
print("EXPERIMENT CONFIGURATION - MULTI-FRAME SEQUENCE FACE-AWARE")
print("=" * 50)
print(f"Dataset: v9 Multi-Frame Sequence with Face-Aware Preprocessing")
print(f"Frame strategy: Multiple frames per video (dense sampling)")
print(f"Training approach: Frame-level independent learning")
print(f"Loss Function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
if USE_FOCAL_LOSS:
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Alpha Weights (per-class): {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum Validation: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Class Weights (inverse sqrt freq): {CROSSENTROPY_CLASS_WEIGHTS}")
print(f"ViT Model Variant: {VIT_MODEL_VARIANT.upper()}")
print(f"  Model: {VIT_MODEL_NAME}")
print(f"  Patch Size: {PATCH_SIZE}px")
print(f"Input Resolution: 384x384px (upscaled from 224px with position interpolation)")
print(f"Image Format: Grayscale converted to RGB (3-channel)")
print("=" * 50)

# Enhanced GPU configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

# Fixed batch size configuration for large dataset (2613 train samples)
BATCH_SIZE = 16  # Optimized for large dataset at 384px resolution
NUM_WORKERS = 4

if 'A100' in gpu_name or 'L4' in gpu_name:
    torch.backends.cudnn.benchmark = True
    print(f"GPU optimization enabled for {gpu_name}")

print(f"Large dataset configuration: Batch size {BATCH_SIZE} (optimal for 2613 samples at 384px)")
print(f"Iterations per epoch: {2613 // BATCH_SIZE} (~82 iterations per epoch)")

# RAM preloading workers (separate from DataLoader workers)
RAM_PRELOAD_WORKERS = 32
print(f"RAM preload workers: {RAM_PRELOAD_WORKERS} (parallel image loading)")

# CASME II class mapping and analysis
CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

# Extract class distribution from v9 preprocessing metadata
print("\nLoading v9 class distribution...")
train_dist = preprocessing_info['splits']['train']['emotion_distribution']
val_dist = preprocessing_info['splits']['val']['emotion_distribution']
test_dist = preprocessing_info['splits']['test']['emotion_distribution']

# Convert to ordered list matching CASME2_CLASSES
def emotion_dist_to_list(emotion_dict, class_names):
    """Convert emotion distribution dict to ordered list"""
    return [emotion_dict.get(cls, 0) for cls in class_names]

train_dist_list = emotion_dist_to_list(train_dist, CASME2_CLASSES)
val_dist_list = emotion_dist_to_list(val_dist, CASME2_CLASSES)
test_dist_list = emotion_dist_to_list(test_dist, CASME2_CLASSES)

print(f"\nv9 Train distribution: {train_dist_list}")
print(f"v9 Val distribution: {val_dist_list}")
print(f"v9 Test distribution: {test_dist_list}")

# Apply optimized class weights based on loss function selection
if USE_FOCAL_LOSS:
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

# CASME II ViT Configuration for Multi-Frame Sequence with Face-Aware Preprocessing
# Optimized for large dataset (2613 train samples) with multi-frame strategy
CASME2_VIT_CONFIG = {
    # Architecture configuration - ViT specific
    'vit_model': VIT_MODEL_NAME,
    'model_variant': VIT_MODEL_VARIANT,
    'patch_size': PATCH_SIZE,
    'input_size': 224,  # Upscaled from 224px with position interpolation
    'num_classes': 7,
    'dropout_rate': 0.3,  # Balanced regularization for large dataset
    'expected_feature_dim': 768,  # ViT-Base hidden dimension
    'interpolate_pos_encoding': True,  # Enable for 384px input

    # Training configuration - proven optimal from KFS-PREP
    'learning_rate': 2e-5,  # Proven optimal for transformer fine-tuning
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,

    # Scheduler configuration - stable for large dataset
    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_factor': 0.5,
    'scheduler_patience': 3,  # Stable patience for large dataset
    'scheduler_min_lr': 1e-7,
    'scheduler_monitor': 'validation F1',

    # Loss function configuration
    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS,

    # Dataset configuration - v9 MFS specific
    'dataset_version': 'v9',
    'frame_strategy': 'multi_frame_sequence',
    'frame_types': ['multiple_frames_per_video'],
    'training_approach': 'frame_level_independent',
    'inference_strategy': 'frame_level_evaluation',

    # Regularization configuration
    'label_smoothing': 0.0,
    'mixup_alpha': 0.0,
    'cutmix_alpha': 0.0
}

# Optimized Focal Loss implementation
class OptimizedFocalLoss(nn.Module):
    """
    Optimized Focal Loss with per-class alpha weights
    Handles class imbalance with normalized alpha weights (sum = 1.0)
    """
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()
        if alpha is not None:
            self.alpha = torch.tensor(alpha, dtype=torch.float32)
        else:
            self.alpha = None
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        if self.alpha is not None and self.alpha.device != inputs.device:
            self.alpha = self.alpha.to(inputs.device)

        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.alpha is not None:
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# ViT CASME II Baseline Model - Enhanced Architecture for MFS
class ViTCASME2Baseline(nn.Module):
    """
    ViT baseline for CASME II micro-expression recognition
    Enhanced architecture: 768 -> 512 -> 128 -> 7 (deeper for large dataset)
    """

    def __init__(self, num_classes=7, dropout_rate=0.2):
        super(ViTCASME2Baseline, self).__init__()

        # Load pretrained ViT backbone from HuggingFace
        from transformers import ViTModel

        self.vit = ViTModel.from_pretrained(
            CASME2_VIT_CONFIG['vit_model'],
            add_pooling_layer=False
        )

        # Enable fine-tuning for micro-expression domain adaptation
        for param in self.vit.parameters():
            param.requires_grad = True

        # Get ViT feature dimension from configuration
        self.vit_feature_dim = self.vit.config.hidden_size

        print(f"ViT feature dimension: {self.vit_feature_dim}")

        # Enhanced classification head for large dataset
        # Architecture: 768 -> 512 -> 128 -> 7 (proven effective for MFS)
        self.classifier_layers = nn.Sequential(
            nn.Linear(self.vit_feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
        )

        self.classifier = nn.Linear(128, num_classes)

        print(f"Classification head: {self.vit_feature_dim} -> 512 -> 128 -> {num_classes}")
        print(f"Dropout rate: {dropout_rate} (balanced for large dataset)")

    def forward(self, pixel_values):
        # ViT forward pass with position embedding interpolation for 384px
        vit_outputs = self.vit(
            pixel_values=pixel_values,
            interpolate_pos_encoding=CASME2_VIT_CONFIG['interpolate_pos_encoding']
        )

        # Extract CLS token features
        # CLS token is at index 0 of last_hidden_state
        vit_features = vit_outputs.last_hidden_state[:, 0]

        # Classification pipeline
        processed_features = self.classifier_layers(vit_features)
        output = self.classifier(processed_features)

        return output

# Enhanced optimizer and scheduler factory
def create_optimizer_scheduler_casme2(model, config):
    """Create optimizer and scheduler for CASME II ViT training"""

    # AdamW optimizer with proven configuration
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    # ReduceLROnPlateau scheduler monitoring validation F1
    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

# ViT Image Processor setup for 384px input
from transformers import ViTImageProcessor

print("\nSetting up ViT Image Processor for 384px input...")

vit_processor = ViTImageProcessor.from_pretrained(
    CASME2_VIT_CONFIG['vit_model'],
    do_resize=True,
    size={'height': 384, 'width': 384},
    do_normalize=True,
    do_rescale=True,
    do_center_crop=False
)

print(f"ViT preprocessing configured:")
print(f"  Input size: 384x384px")
print(f"  Resize from: 224x224px (v9 native)")
print(f"  Position encoding interpolation: Enabled")

# Transform functions for ViT
def vit_transform_train(image):
    """
    Training transform with ViT Image Processor
    Handles grayscale to RGB conversion
    """
    # Ensure RGB format (grayscale images will be converted via channel repetition)
    if image.mode != 'RGB':
        image = image.convert('RGB')

    inputs = vit_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

def vit_transform_val(image):
    """
    Validation transform with ViT Image Processor
    Handles grayscale to RGB conversion
    """
    # Ensure RGB format (grayscale images will be converted via channel repetition)
    if image.mode != 'RGB':
        image = image.convert('RGB')

    inputs = vit_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

print("ViT Image Processor configured for 384px with position interpolation")

# CASME II Dataset for v9 MFS with flat directory structure
class CASME2DatasetMFS(Dataset):
    """CASME II v9 Multi-Frame Sequence dataset with flexible filename parsing"""

    def __init__(self, dataset_root, split, transform=None):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []
        self.filenames = []

        split_path = os.path.join(dataset_root, split)

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        # Load all images from flat directory structure
        print(f"Loading {split} dataset from {split_path}...")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        if len(all_files) > 0:
            print(f"Sample filename: {all_files[0]}")

        loaded_count = 0
        skipped_count = 0

        for filename in sorted(all_files):
            # Extract emotion from filename - flexible approach
            # Try multiple patterns to handle different filename formats
            emotion_found = None

            # Remove file extension first
            name_without_ext = filename.rsplit('.', 1)[0]

            # Pattern 1: Check if any emotion class appears in filename
            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                loaded_count += 1
            else:
                skipped_count += 1
                if skipped_count <= 3:
                    print(f"  Skipped (no emotion found): {filename}")

        print(f"Loaded {len(self.images)} samples for {split} split")
        if skipped_count > 0:
            print(f"  Skipped {skipped_count} files (no recognizable emotion)")

        if len(self.images) == 0:
            print(f"ERROR: No samples loaded! Check filename format and emotion labels.")
            print(f"Expected emotions in filenames: {CASME2_CLASSES}")

        self._print_distribution()

    def _print_distribution(self):
        """Print class distribution"""
        if len(self.labels) == 0:
            print("  No samples to display distribution")
            return

        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx]

# Create directories
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

print(f"\nDataset paths:")
print(f"Train: {DATASET_ROOT}/train")
print(f"Validation: {DATASET_ROOT}/val")
print(f"Test: {DATASET_ROOT}/test")

# Architecture validation
print("\nViT CASME II architecture validation...")

try:
    test_model = ViTCASME2Baseline(num_classes=7, dropout_rate=0.2).to(device)
    test_input = torch.randn(1, 3, 384, 384).to(device)
    test_output = test_model(test_input)

    # Dynamic token calculation based on configured patch size
    expected_tokens = (384 // CASME2_VIT_CONFIG['patch_size']) ** 2

    print(f"Validation successful: Output shape {test_output.shape}")
    print(f"Expected tokens for 384px with patch{CASME2_VIT_CONFIG['patch_size']}: {expected_tokens} tokens")
    print(f"ViT {VIT_MODEL_VARIANT.upper()} architecture validated")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")
    import traceback
    traceback.print_exc()

# Loss function factory
def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    """
    Factory function to create loss criterion based on configuration

    Args:
        weights: Class weights for CrossEntropy
        use_focal_loss: Whether to use Focal Loss or CrossEntropy
        alpha_weights: Per-class alpha weights for Focal Loss
        gamma: Focal loss gamma parameter

    Returns:
        Loss function
    """
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

# Global configuration for training pipeline
GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': vit_transform_train,
    'transform_val': vit_transform_val,
    'vit_config': CASME2_VIT_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'dataset_root': DATASET_ROOT,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

# Configuration validation and summary
print("\n" + "=" * 60)
print("CASME II MULTI-FRAME SEQUENCE ViT CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: ViT-Base")
print(f"  Variant: {VIT_MODEL_VARIANT.upper()}")
print(f"  Model: {VIT_MODEL_NAME}")
print(f"  Patch Size: {PATCH_SIZE}px")
print(f"  Input Resolution: 384px (upscaled from 224px v9 native)")
print(f"  Feature Dimension: {CASME2_VIT_CONFIG['expected_feature_dim']}")
print(f"  Position Interpolation: Enabled")
print(f"  Classification Head: {CASME2_VIT_CONFIG['expected_feature_dim']} -> 512 -> 128 -> 7 (enhanced)")

print(f"\nDataset Configuration:")
print(f"  Version: {CASME2_VIT_CONFIG['dataset_version']}")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Frame strategy: {CASME2_VIT_CONFIG['frame_strategy']}")
print(f"  Training approach: {CASME2_VIT_CONFIG['training_approach']}")
print(f"  Inference strategy: {CASME2_VIT_CONFIG['inference_strategy']}")
print(f"  Weight Optimization: {'Per-class Alpha' if USE_FOCAL_LOSS else 'Inverse Sqrt Frequency'}")

print(f"\nTraining Configuration:")
print(f"  Train samples: {preprocessing_info['splits']['train']['total_images']} frames")
print(f"  Validation samples: {preprocessing_info['splits']['val']['total_images']} frames")
print(f"  Test samples: {preprocessing_info['splits']['test']['total_images']} frames")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {CASME2_VIT_CONFIG['learning_rate']}")
print(f"  Dropout rate: {CASME2_VIT_CONFIG['dropout_rate']}")
print(f"  Expected tokens (patch{PATCH_SIZE}): {(384 // PATCH_SIZE) ** 2}")

print("\nNext: Cell 2 - Dataset Loading and Multi-Frame Training Pipeline")

CASME II MULTI-FRAME SEQUENCE ViT WITH FACE-AWARE PREPROCESSING

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...
CASME II Multi-Frame Sequence ViT - Face-Aware Preprocessing Infrastructure
Loading CASME II v9 preprocessing metadata...
Dataset variant: MFS
Processing date: 2025-10-19T08:20:12.098301
Preprocessing method: face_bbox_expansion_all_directions
Total images processed: 2774
Face detection rate: 100.00%
Target size: 224x224px
BBox expansion: 20px (all directions)

Dataset split information:
  Train samples: 2613
  Validation samples: 78
  Test samples: 83
Using ViT-Base Patch16 for fine-grained micro-expression analysis (86M parameters)

EXPERIMENT CONFIGURATION - MULTI-FRAME SEQUENCE FACE-AWARE
Dataset: v9 Multi-Frame Sequence with Face-Aware Preprocessing
Frame strategy: Multiple frames per video (dense sampling)
Training approach: Frame-level independent learning
Loss Function: Focal Loss
  Gamma: 

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

ViT preprocessing configured:
  Input size: 384x384px
  Resize from: 224x224px (v9 native)
  Position encoding interpolation: Enabled
ViT Image Processor configured for 384px with position interpolation

Dataset paths:
Train: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v9/train
Validation: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v9/val
Test: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v9/test

ViT CASME II architecture validation...


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ViT feature dimension: 768
Classification head: 768 -> 512 -> 128 -> 7
Dropout rate: 0.2 (balanced for large dataset)
Validation successful: Output shape torch.Size([1, 7])
Expected tokens for 384px with patch16: 576 tokens
ViT PATCH16 architecture validated

CASME II MULTI-FRAME SEQUENCE ViT CONFIGURATION COMPLETE
Loss Configuration:
  Function: Optimized Focal Loss
  Gamma: 2.0
  Per-class Alpha: [0.048, 0.06, 0.085, 0.093, 0.095, 0.191, 0.427]
  Alpha Sum: 0.999

Model Configuration:
  Architecture: ViT-Base
  Variant: PATCH16
  Model: google/vit-base-patch16-224-in21k
  Patch Size: 16px
  Input Resolution: 384px (upscaled from 224px v9 native)
  Feature Dimension: 768
  Position Interpolation: Enabled
  Classification Head: 768 -> 512 -> 128 -> 7 (enhanced)

Dataset Configuration:
  Version: v9
  Classes: 7
  Frame strategy: multi_frame_sequence
  Training approach: frame_level_independent
  Inference strategy: frame_level_evaluation
  Weight Optimization: Per-class Alpha

Training

In [2]:
# @title Cell 2: CASME II Multi-Frame Sequence ViT Training Pipeline

# File: 07_01_ViT_CASME2_MFS_Cell2.py
# Location: experiments/07_01_ViT_CASME2-MFS-PREP.ipynb
# Purpose: Enhanced training pipeline for CASME II Multi-Frame Sequence ViT with optimized RAM caching

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import shutil
import tempfile

print("CASME II Multi-Frame Sequence ViT Training Pipeline")
print("=" * 70)
print(f"Loss Function: {'Optimized Focal Loss' if CASME2_VIT_CONFIG['use_focal_loss'] else 'CrossEntropy Loss'}")
if CASME2_VIT_CONFIG['use_focal_loss']:
    print(f"Focal Loss Parameters:")
    print(f"  Gamma: {CASME2_VIT_CONFIG['focal_loss_gamma']}")
    print(f"  Per-class Alpha: {CASME2_VIT_CONFIG['focal_loss_alpha_weights']}")
    print(f"  Alpha Sum: {sum(CASME2_VIT_CONFIG['focal_loss_alpha_weights']):.3f}")
else:
    print(f"CrossEntropy Parameters:")
    print(f"  Optimized Class Weights: {CASME2_VIT_CONFIG['crossentropy_class_weights']}")
print(f"Dataset version: {CASME2_VIT_CONFIG['dataset_version']}")
print(f"Frame strategy: {CASME2_VIT_CONFIG['frame_strategy']}")
print(f"Training approach: {CASME2_VIT_CONFIG['training_approach']}")
print(f"ViT variant: {CASME2_VIT_CONFIG['model_variant'].upper()}")
print(f"Patch size: {CASME2_VIT_CONFIG['patch_size']}px")
print(f"Training epochs: {CASME2_VIT_CONFIG['num_epochs']}")
print(f"Batch size: {CASME2_VIT_CONFIG['batch_size']}")
print(f"Scheduler patience: {CASME2_VIT_CONFIG['scheduler_patience']}")

# Enhanced CASME II Dataset with optimized RAM caching for large dataset
class CASME2DatasetTraining(Dataset):
    """Enhanced CASME II dataset for training with RAM caching optimization for large dataset"""

    def __init__(self, dataset_root, split, transform=None, use_ram_cache=True):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.filenames = []
        self.cached_images = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading CASME II {split} dataset for training...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        if len(all_files) > 0:
            print(f"Sample filename: {all_files[0]}")

        loaded_count = 0
        skipped_count = 0

        for filename in sorted(all_files):
            # Extract emotion from filename - flexible approach
            emotion_found = None

            # Remove file extension first
            name_without_ext = filename.rsplit('.', 1)[0]

            # Check if any emotion class appears in filename
            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                loaded_count += 1
            else:
                skipped_count += 1
                if skipped_count <= 3:
                    print(f"  Skipped (no emotion found): {filename}")

        print(f"Loaded {len(self.images)} CASME II {split} samples")
        if skipped_count > 0:
            print(f"  Skipped {skipped_count} files (no recognizable emotion)")

        if len(self.images) == 0:
            print(f"ERROR: No samples loaded! Check filename format and emotion labels.")
            print(f"Expected emotions in filenames: {CASME2_CLASSES}")

        self._print_distribution()

        if self.use_ram_cache and len(self.images) > 0:
            self._preload_to_ram()

    def _print_distribution(self):
        """Print class distribution"""
        if len(self.labels) == 0:
            print("  No samples to display distribution")
            return

        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        """RAM preloading with parallel loading for training efficiency"""
        if len(self.images) == 0:
            print(f"Skipping RAM preload: No images to load")
            return

        print(f"Preloading {len(self.images)} {self.split} images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            """Load single image with error handling"""
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (224, 224), (128, 128, 128)), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc=f"Loading {self.split} to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * 224 * 224 * 3 * 4 / 1e9
        print(f"{self.split.upper()} RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (224, 224), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx]

# Enhanced metrics calculation with comprehensive error handling
def calculate_metrics_safe_robust(predictions, labels, class_names, average='macro'):
    """
    Calculate metrics with enhanced error handling and validation

    Args:
        predictions: Predicted labels
        labels: True labels
        class_names: List of class names
        average: Averaging method for metrics

    Returns:
        dict: Computed metrics
    """
    try:
        # Ensure arrays are numpy arrays
        predictions = np.array(predictions)
        labels = np.array(labels)

        # Validate input
        if len(predictions) == 0 or len(labels) == 0:
            return {
                'accuracy': 0.0,
                'precision': 0.0,
                'recall': 0.0,
                'f1_score': 0.0
            }

        # Calculate accuracy
        accuracy = accuracy_score(labels, predictions)

        # Calculate precision, recall, F1 with zero_division handling
        precision, recall, f1, support = precision_recall_fscore_support(
            labels, predictions,
            average=average,
            zero_division=0
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1)
        }

    except Exception as e:
        print(f"Warning: Metrics calculation failed: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }

# Enhanced checkpoint saving with atomic operations
def save_checkpoint_robust(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                          checkpoint_root, best_metrics, config):
    """
    Save checkpoint with enhanced robustness and atomic operations

    Args:
        model: Model to save
        optimizer: Optimizer state
        scheduler: Scheduler state
        epoch: Current epoch
        train_metrics: Training metrics
        val_metrics: Validation metrics
        checkpoint_root: Root directory for checkpoints
        best_metrics: Best metrics tracker
        config: Configuration dict

    Returns:
        str: Path to saved checkpoint or None if failed
    """
    try:
        os.makedirs(checkpoint_root, exist_ok=True)

        checkpoint_data = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'train_metrics': train_metrics,
            'val_metrics': val_metrics,
            'best_metrics': best_metrics,
            'config': config
        }

        # Atomic save using temporary file
        final_path = os.path.join(checkpoint_root, 'casme2_vit_mfs_best_f1.pth')

        with tempfile.NamedTemporaryFile(mode='wb', delete=False, dir=checkpoint_root) as tmp_file:
            tmp_path = tmp_file.name
            torch.save(checkpoint_data, tmp_file)

        # Atomic rename
        shutil.move(tmp_path, final_path)

        return final_path

    except Exception as e:
        print(f"ERROR: Checkpoint save failed: {e}")
        return None

# JSON serialization helper
def safe_json_serialize(obj):
    """Safely convert objects to JSON-serializable format"""
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, dict):
        return {key: safe_json_serialize(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [safe_json_serialize(item) for item in obj]
    elif isinstance(obj, torch.Tensor):
        return obj.detach().cpu().numpy().tolist()
    else:
        return obj

# Enhanced training epoch function
def train_epoch(model, train_loader, criterion, optimizer, device, epoch, total_epochs):
    """Training epoch with enhanced progress tracking"""
    model.train()

    running_loss = 0.0
    all_predictions = []
    all_labels = []

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{total_epochs} [Train]")

    for batch_idx, (images, labels, filenames) in enumerate(pbar):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()

        if CASME2_VIT_CONFIG['gradient_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), CASME2_VIT_CONFIG['gradient_clip'])

        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)

        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'avg_loss': f"{running_loss/(batch_idx+1):.4f}"
        })

    epoch_loss = running_loss / len(train_loader)
    metrics = calculate_metrics_safe_robust(all_predictions, all_labels, CASME2_CLASSES)

    return epoch_loss, metrics, None

# Enhanced validation epoch function
def validate_epoch(model, val_loader, criterion, device, epoch, total_epochs):
    """Validation epoch with enhanced metrics"""
    model.eval()

    running_loss = 0.0
    all_predictions = []
    all_labels = []
    all_filenames = []

    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{total_epochs} [Val]")

        for batch_idx, (images, labels, filenames) in enumerate(pbar):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_filenames.extend(filenames)

            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'avg_loss': f"{running_loss/(batch_idx+1):.4f}"
            })

    epoch_loss = running_loss / len(val_loader)
    metrics = calculate_metrics_safe_robust(all_predictions, all_labels, CASME2_CLASSES)

    return epoch_loss, metrics, all_filenames

# Create datasets with RAM caching
print("\nCreating CASME II datasets with RAM caching optimization...")

train_dataset = CASME2DatasetTraining(
    dataset_root=GLOBAL_CONFIG_CASME2['dataset_root'],
    split='train',
    transform=GLOBAL_CONFIG_CASME2['transform_train'],
    use_ram_cache=True
)

val_dataset = CASME2DatasetTraining(
    dataset_root=GLOBAL_CONFIG_CASME2['dataset_root'],
    split='val',
    transform=GLOBAL_CONFIG_CASME2['transform_val'],
    use_ram_cache=True
)

print(f"\nDataset loading complete:")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=GLOBAL_CONFIG_CASME2['batch_size'],
    shuffle=True,
    num_workers=GLOBAL_CONFIG_CASME2['num_workers'],
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=GLOBAL_CONFIG_CASME2['batch_size'],
    shuffle=False,
    num_workers=GLOBAL_CONFIG_CASME2['num_workers'],
    pin_memory=True
)

print(f"Data loaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

# Initialize model
print("\nInitializing ViT CASME II model...")
model = ViTCASME2Baseline(
    num_classes=CASME2_VIT_CONFIG['num_classes'],
    dropout_rate=CASME2_VIT_CONFIG['dropout_rate']
).to(GLOBAL_CONFIG_CASME2['device'])

print(f"Model initialized on {GLOBAL_CONFIG_CASME2['device']}")

# Create optimizer and scheduler
optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_VIT_CONFIG
)

# Create loss criterion
criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
    weights=GLOBAL_CONFIG_CASME2['class_weights'],
    use_focal_loss=CASME2_VIT_CONFIG['use_focal_loss'],
    alpha_weights=CASME2_VIT_CONFIG['focal_loss_alpha_weights'],
    gamma=CASME2_VIT_CONFIG['focal_loss_gamma']
)

print("\nTraining components initialized successfully")

# Training loop with enhanced tracking
print("\n" + "=" * 70)
print("Starting CASME II Multi-Frame Sequence ViT training...")
print("=" * 70)

training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

start_time = time.time()

for epoch in range(CASME2_VIT_CONFIG['num_epochs']):
    epoch_start_time = time.time()

    print(f"\n{'='*70}")
    print(f"Epoch {epoch+1}/{CASME2_VIT_CONFIG['num_epochs']}")
    print(f"{'='*70}")

    # Training phase
    train_loss, train_metrics, train_filenames = train_epoch(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_VIT_CONFIG['num_epochs']
    )

    # Validation phase
    val_loss, val_metrics, val_filenames = validate_epoch(
        model, val_loader, criterion,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_VIT_CONFIG['num_epochs']
    )

    # Update scheduler
    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    # Record training history
    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    # Print epoch summary
    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    # Enhanced multi-criteria checkpoint saving logic
    save_model = False
    improvement_reason = ""

    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        best_model_path = save_checkpoint_robust(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics, GLOBAL_CONFIG_CASME2['checkpoint_root'],
            best_metrics, CASME2_VIT_CONFIG
        )

        if best_model_path:
            print(f"New best model: {improvement_reason} - F1: {best_metrics['f1']:.4f}")
        else:
            print(f"Warning: Failed to save checkpoint despite improvement")

    # Progress tracking
    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * CASME2_VIT_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / CASME2_VIT_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min")

# Training completion
total_time = time.time() - start_time
actual_epochs = CASME2_VIT_CONFIG['num_epochs']

print("\n" + "=" * 70)
print("CASME II MULTI-FRAME SEQUENCE ViT TRAINING COMPLETED")
print("=" * 70)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {actual_epochs}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")

# Enhanced training documentation export
results_dir = GLOBAL_CONFIG_CASME2['results_root']
os.makedirs(f"{results_dir}/training_logs", exist_ok=True)

training_history_path = f"{results_dir}/training_logs/casme2_vit_mfs_training_history.json"

print("\nExporting enhanced training documentation...")

try:
    training_summary = {
        'experiment_type': 'CASME2_ViT_MultiFrameSequence',
        'experiment_configuration': {
            'dataset_version': CASME2_VIT_CONFIG['dataset_version'],
            'frame_strategy': CASME2_VIT_CONFIG['frame_strategy'],
            'training_approach': CASME2_VIT_CONFIG['training_approach'],
            'inference_strategy': CASME2_VIT_CONFIG['inference_strategy'],
            'loss_function': 'Optimized Focal Loss' if CASME2_VIT_CONFIG['use_focal_loss'] else 'CrossEntropy',
            'weight_approach': 'Per-class Alpha (sum=1.0)' if CASME2_VIT_CONFIG['use_focal_loss'] else 'Inverse Sqrt Frequency',
            'focal_loss_gamma': CASME2_VIT_CONFIG['focal_loss_gamma'],
            'focal_loss_alpha_weights': CASME2_VIT_CONFIG['focal_loss_alpha_weights'],
            'crossentropy_class_weights': CASME2_VIT_CONFIG['crossentropy_class_weights'],
            'vit_model': CASME2_VIT_CONFIG['vit_model'],
            'model_variant': CASME2_VIT_CONFIG['model_variant'],
            'patch_size': CASME2_VIT_CONFIG['patch_size']
        },
        'training_history': safe_json_serialize(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(actual_epochs),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'config': safe_json_serialize(CASME2_VIT_CONFIG),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'model_checkpoint': 'casme2_vit_mfs_best_f1.pth',
        'dataset_info': {
            'name': 'CASME_II',
            'version': CASME2_VIT_CONFIG['dataset_version'],
            'frame_strategy': CASME2_VIT_CONFIG['frame_strategy'],
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': 7,
            'class_names': CASME2_CLASSES
        },
        'architecture_info': {
            'model_type': 'ViTCASME2Baseline',
            'backbone': CASME2_VIT_CONFIG['vit_model'],
            'variant': CASME2_VIT_CONFIG['model_variant'],
            'patch_size': CASME2_VIT_CONFIG['patch_size'],
            'input_size': f"{CASME2_VIT_CONFIG['input_size']}x{CASME2_VIT_CONFIG['input_size']}",
            'expected_feature_dim': CASME2_VIT_CONFIG['expected_feature_dim'],
            'classification_head': f"{CASME2_VIT_CONFIG['expected_feature_dim']}->512->128->7",
            'position_interpolation': CASME2_VIT_CONFIG['interpolate_pos_encoding']
        },
        'enhanced_features': {
            'hardened_checkpoint_system': True,
            'atomic_checkpoint_save': True,
            'checkpoint_validation': True,
            'model_output_validation': True,
            'enhanced_error_handling': True,
            'multi_criteria_checkpoint_logic': True,
            'memory_optimized_training': True,
            'ram_caching': True,
            'position_encoding_interpolation': True
        }
    }

    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Enhanced training documentation saved successfully: {training_history_path}")
    print(f"Experiment details: {training_summary['experiment_configuration']['loss_function']} loss")
    if CASME2_VIT_CONFIG['use_focal_loss']:
        print(f"  Gamma: {CASME2_VIT_CONFIG['focal_loss_gamma']}, Alpha Sum: {sum(CASME2_VIT_CONFIG['focal_loss_alpha_weights']):.3f}")
    print(f"Model variant: {CASME2_VIT_CONFIG['model_variant'].upper()}")
    print(f"Patch size: {CASME2_VIT_CONFIG['patch_size']}px")
    print(f"Dataset version: {CASME2_VIT_CONFIG['dataset_version']}")

except Exception as e:
    print(f"Warning: Could not save training documentation: {e}")
    print("Training completed successfully, but documentation export failed")

# Enhanced memory cleanup
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\nNext: Cell 3 - CASME II Multi-Frame Sequence ViT Evaluation")
print("Enhanced training pipeline completed successfully!")

CASME II Multi-Frame Sequence ViT Training Pipeline
Loss Function: Optimized Focal Loss
Focal Loss Parameters:
  Gamma: 2.0
  Per-class Alpha: [0.048, 0.06, 0.085, 0.093, 0.095, 0.191, 0.427]
  Alpha Sum: 0.999
Dataset version: v9
Frame strategy: multi_frame_sequence
Training approach: frame_level_independent
ViT variant: PATCH16
Patch size: 16px
Training epochs: 50
Batch size: 16
Scheduler patience: 3

Creating CASME II datasets with RAM caching optimization...
Loading CASME II train dataset for training...
Found 2613 image files in directory
Sample filename: sub13_EP01_01_apex_p+1_others.jpg
Loaded 2613 CASME II train samples
  others: 1027 samples (39.3%)
  disgust: 650 samples (24.9%)
  happiness: 325 samples (12.4%)
  repression: 273 samples (10.4%)
  surprise: 260 samples (10.0%)
  sadness: 65 samples (2.5%)
  fear: 13 samples (0.5%)
Preloading 2613 train images to RAM with 32 workers...


Loading train to RAM: 100%|██████████| 2613/2613 [00:54<00:00, 48.20it/s] 


TRAIN RAM caching completed: 2613/2613 images, ~1.57GB
Loading CASME II val dataset for training...
Found 78 image files in directory
Sample filename: sub01_EP03_02_onset_others.jpg
Loaded 78 CASME II val samples
  others: 30 samples (38.5%)
  disgust: 18 samples (23.1%)
  happiness: 9 samples (11.5%)
  repression: 9 samples (11.5%)
  surprise: 6 samples (7.7%)
  sadness: 3 samples (3.8%)
  fear: 3 samples (3.8%)
Preloading 78 val images to RAM with 32 workers...


Loading val to RAM: 100%|██████████| 78/78 [00:03<00:00, 25.06it/s]


VAL RAM caching completed: 78/78 images, ~0.05GB

Dataset loading complete:
  Train samples: 2613
  Validation samples: 78
Data loaders created:
  Train batches: 164
  Validation batches: 5

Initializing ViT CASME II model...


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ViT feature dimension: 768
Classification head: 768 -> 512 -> 128 -> 7
Dropout rate: 0.3 (balanced for large dataset)
Model initialized on cuda
Scheduler: ReduceLROnPlateau monitoring validation F1
Using Optimized Focal Loss with gamma=2.0
Per-class alpha weights: [0.048, 0.06, 0.085, 0.093, 0.095, 0.191, 0.427]
Alpha sum: 0.999

Training components initialized successfully

Starting CASME II Multi-Frame Sequence ViT training...

Epoch 1/50


Epoch 1/50 [Train]: 100%|██████████| 164/164 [01:59<00:00,  1.37it/s, loss=0.0611, avg_loss=0.0646]
Epoch 1/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  2.96it/s, loss=0.0712, avg_loss=0.1407]


Train - Loss: 0.0646, F1: 0.3965, Acc: 0.4868
Val   - Loss: 0.1407, F1: 0.1669, Acc: 0.3077
Time  - Epoch: 121.1s, LR: 2.00e-05
New best model: Higher F1 - F1: 0.1669
Progress: 2.0% | Best F1: 0.1669 | ETA: 103.0min

Epoch 2/50


Epoch 2/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0119, avg_loss=0.0279]
Epoch 2/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.07it/s, loss=0.0918, avg_loss=0.1680]


Train - Loss: 0.0279, F1: 0.7178, Acc: 0.7631
Val   - Loss: 0.1680, F1: 0.1313, Acc: 0.2949
Time  - Epoch: 123.7s, LR: 2.00e-05
Progress: 4.0% | Best F1: 0.1669 | ETA: 99.9min

Epoch 3/50


Epoch 3/50 [Train]: 100%|██████████| 164/164 [02:01<00:00,  1.34it/s, loss=0.0075, avg_loss=0.0116]
Epoch 3/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1247, avg_loss=0.1896]


Train - Loss: 0.0116, F1: 0.9197, Acc: 0.9116
Val   - Loss: 0.1896, F1: 0.1722, Acc: 0.3333
Time  - Epoch: 123.6s, LR: 2.00e-05
New best model: Higher F1 - F1: 0.1722
Progress: 6.0% | Best F1: 0.1722 | ETA: 98.1min

Epoch 4/50


Epoch 4/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0020, avg_loss=0.0049]
Epoch 4/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.12it/s, loss=0.1197, avg_loss=0.1956]


Train - Loss: 0.0049, F1: 0.9788, Acc: 0.9728
Val   - Loss: 0.1956, F1: 0.1564, Acc: 0.3205
Time  - Epoch: 123.7s, LR: 2.00e-05
Progress: 8.0% | Best F1: 0.1722 | ETA: 95.7min

Epoch 5/50


Epoch 5/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0008, avg_loss=0.0022]
Epoch 5/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.10it/s, loss=0.1318, avg_loss=0.2062]


Train - Loss: 0.0022, F1: 0.9925, Acc: 0.9908
Val   - Loss: 0.2062, F1: 0.1803, Acc: 0.3462
Time  - Epoch: 123.8s, LR: 2.00e-05
New best model: Higher F1 - F1: 0.1803
Progress: 10.0% | Best F1: 0.1803 | ETA: 93.9min

Epoch 6/50


Epoch 6/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0008, avg_loss=0.0022]
Epoch 6/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.14it/s, loss=0.0990, avg_loss=0.1991]


Train - Loss: 0.0022, F1: 0.9748, Acc: 0.9881
Val   - Loss: 0.1991, F1: 0.1728, Acc: 0.3462
Time  - Epoch: 123.8s, LR: 2.00e-05
Progress: 12.0% | Best F1: 0.1803 | ETA: 91.6min

Epoch 7/50


Epoch 7/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0012, avg_loss=0.0018]
Epoch 7/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.18it/s, loss=0.1282, avg_loss=0.2187]


Train - Loss: 0.0018, F1: 0.9922, Acc: 0.9908
Val   - Loss: 0.2187, F1: 0.1218, Acc: 0.3205
Time  - Epoch: 123.8s, LR: 2.00e-05
Progress: 14.0% | Best F1: 0.1803 | ETA: 89.4min

Epoch 8/50


Epoch 8/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0194, avg_loss=0.0016]
Epoch 8/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.13it/s, loss=0.1357, avg_loss=0.2187]


Train - Loss: 0.0016, F1: 0.9931, Acc: 0.9916
Val   - Loss: 0.2187, F1: 0.1453, Acc: 0.3846
Time  - Epoch: 123.8s, LR: 2.00e-05
Progress: 16.0% | Best F1: 0.1803 | ETA: 87.2min

Epoch 9/50


Epoch 9/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0018, avg_loss=0.0036]
Epoch 9/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.16it/s, loss=0.1371, avg_loss=0.2124]


Train - Loss: 0.0036, F1: 0.9755, Acc: 0.9721
Val   - Loss: 0.2124, F1: 0.1745, Acc: 0.3462
Time  - Epoch: 123.8s, LR: 1.00e-05
Progress: 18.0% | Best F1: 0.1803 | ETA: 85.1min

Epoch 10/50


Epoch 10/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0003, avg_loss=0.0009]
Epoch 10/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.09it/s, loss=0.1161, avg_loss=0.2190]


Train - Loss: 0.0009, F1: 0.9981, Acc: 0.9969
Val   - Loss: 0.2190, F1: 0.1809, Acc: 0.3333
Time  - Epoch: 123.8s, LR: 1.00e-05
New best model: Higher F1 - F1: 0.1809
Progress: 20.0% | Best F1: 0.1809 | ETA: 83.1min

Epoch 11/50


Epoch 11/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0006]
Epoch 11/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.14it/s, loss=0.1200, avg_loss=0.2138]


Train - Loss: 0.0006, F1: 0.9995, Acc: 0.9992
Val   - Loss: 0.2138, F1: 0.2273, Acc: 0.3974
Time  - Epoch: 123.7s, LR: 1.00e-05
New best model: Higher F1 - F1: 0.2273
Progress: 22.0% | Best F1: 0.2273 | ETA: 81.3min

Epoch 12/50


Epoch 12/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0004, avg_loss=0.0004]
Epoch 12/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.14it/s, loss=0.1242, avg_loss=0.2182]


Train - Loss: 0.0004, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2182, F1: 0.2210, Acc: 0.3846
Time  - Epoch: 123.9s, LR: 1.00e-05
Progress: 24.0% | Best F1: 0.2273 | ETA: 79.1min

Epoch 13/50


Epoch 13/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0003, avg_loss=0.0004]
Epoch 13/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1224, avg_loss=0.2197]


Train - Loss: 0.0004, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2197, F1: 0.1913, Acc: 0.3718
Time  - Epoch: 123.8s, LR: 1.00e-05
Progress: 26.0% | Best F1: 0.2273 | ETA: 77.0min

Epoch 14/50


Epoch 14/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0005, avg_loss=0.0004]
Epoch 14/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.14it/s, loss=0.1263, avg_loss=0.2219]


Train - Loss: 0.0004, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2219, F1: 0.1851, Acc: 0.3590
Time  - Epoch: 123.6s, LR: 1.00e-05
Progress: 28.0% | Best F1: 0.2273 | ETA: 74.9min

Epoch 15/50


Epoch 15/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0003]
Epoch 15/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.09it/s, loss=0.1369, avg_loss=0.2270]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2270, F1: 0.1660, Acc: 0.3462
Time  - Epoch: 123.7s, LR: 5.00e-06
Progress: 30.0% | Best F1: 0.2273 | ETA: 72.8min

Epoch 16/50


Epoch 16/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0003]
Epoch 16/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1357, avg_loss=0.2269]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2269, F1: 0.1800, Acc: 0.3462
Time  - Epoch: 123.8s, LR: 5.00e-06
Progress: 32.0% | Best F1: 0.2273 | ETA: 70.6min

Epoch 17/50


Epoch 17/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0003]
Epoch 17/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.10it/s, loss=0.1362, avg_loss=0.2276]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2276, F1: 0.1800, Acc: 0.3462
Time  - Epoch: 123.8s, LR: 5.00e-06
Progress: 34.0% | Best F1: 0.2273 | ETA: 68.5min

Epoch 18/50


Epoch 18/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0003]
Epoch 18/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1352, avg_loss=0.2280]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2280, F1: 0.1800, Acc: 0.3462
Time  - Epoch: 123.7s, LR: 5.00e-06
Progress: 36.0% | Best F1: 0.2273 | ETA: 66.4min

Epoch 19/50


Epoch 19/50 [Train]: 100%|██████████| 164/164 [02:01<00:00,  1.34it/s, loss=0.0000, avg_loss=0.0003]
Epoch 19/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1358, avg_loss=0.2289]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2289, F1: 0.1862, Acc: 0.3590
Time  - Epoch: 123.6s, LR: 2.50e-06
Progress: 38.0% | Best F1: 0.2273 | ETA: 64.3min

Epoch 20/50


Epoch 20/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0003]
Epoch 20/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.15it/s, loss=0.1386, avg_loss=0.2300]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2300, F1: 0.1871, Acc: 0.3590
Time  - Epoch: 123.8s, LR: 2.50e-06
Progress: 40.0% | Best F1: 0.2273 | ETA: 62.2min

Epoch 21/50


Epoch 21/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0003]
Epoch 21/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.12it/s, loss=0.1397, avg_loss=0.2307]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2307, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 2.50e-06
Progress: 42.0% | Best F1: 0.2273 | ETA: 60.1min

Epoch 22/50


Epoch 22/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0003]
Epoch 22/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.09it/s, loss=0.1357, avg_loss=0.2296]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2296, F1: 0.1983, Acc: 0.3846
Time  - Epoch: 123.8s, LR: 2.50e-06
Progress: 44.0% | Best F1: 0.2273 | ETA: 58.1min

Epoch 23/50


Epoch 23/50 [Train]: 100%|██████████| 164/164 [02:01<00:00,  1.35it/s, loss=0.0007, avg_loss=0.0003]
Epoch 23/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1415, avg_loss=0.2313]


Train - Loss: 0.0003, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2313, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.5s, LR: 1.25e-06
Progress: 46.0% | Best F1: 0.2273 | ETA: 56.0min

Epoch 24/50


Epoch 24/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 24/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1428, avg_loss=0.2318]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2318, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.8s, LR: 1.25e-06
Progress: 48.0% | Best F1: 0.2273 | ETA: 53.9min

Epoch 25/50


Epoch 25/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0002]
Epoch 25/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1443, avg_loss=0.2327]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2327, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 1.25e-06
Progress: 50.0% | Best F1: 0.2273 | ETA: 51.8min

Epoch 26/50


Epoch 26/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 26/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.14it/s, loss=0.1434, avg_loss=0.2331]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2331, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.8s, LR: 1.25e-06
Progress: 52.0% | Best F1: 0.2273 | ETA: 49.7min

Epoch 27/50


Epoch 27/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0002]
Epoch 27/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.18it/s, loss=0.1432, avg_loss=0.2335]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2335, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.8s, LR: 6.25e-07
Progress: 54.0% | Best F1: 0.2273 | ETA: 47.6min

Epoch 28/50


Epoch 28/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0002]
Epoch 28/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.16it/s, loss=0.1416, avg_loss=0.2331]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2331, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 6.25e-07
Progress: 56.0% | Best F1: 0.2273 | ETA: 45.6min

Epoch 29/50


Epoch 29/50 [Train]: 100%|██████████| 164/164 [02:01<00:00,  1.34it/s, loss=0.0000, avg_loss=0.0002]
Epoch 29/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.10it/s, loss=0.1391, avg_loss=0.2322]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2322, F1: 0.1990, Acc: 0.3846
Time  - Epoch: 123.6s, LR: 6.25e-07
Progress: 58.0% | Best F1: 0.2273 | ETA: 43.5min

Epoch 30/50


Epoch 30/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0002]
Epoch 30/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.13it/s, loss=0.1398, avg_loss=0.2325]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2325, F1: 0.1939, Acc: 0.3718
Time  - Epoch: 123.8s, LR: 6.25e-07
Progress: 60.0% | Best F1: 0.2273 | ETA: 41.4min

Epoch 31/50


Epoch 31/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0000, avg_loss=0.0002]
Epoch 31/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.15it/s, loss=0.1424, avg_loss=0.2334]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2334, F1: 0.1939, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 3.13e-07
Progress: 62.0% | Best F1: 0.2273 | ETA: 39.3min

Epoch 32/50


Epoch 32/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 32/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.15it/s, loss=0.1434, avg_loss=0.2337]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2337, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 3.13e-07
Progress: 64.0% | Best F1: 0.2273 | ETA: 37.3min

Epoch 33/50


Epoch 33/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0002]
Epoch 33/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.14it/s, loss=0.1446, avg_loss=0.2342]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2342, F1: 0.1933, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 3.13e-07
Progress: 66.0% | Best F1: 0.2273 | ETA: 35.2min

Epoch 34/50


Epoch 34/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 34/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1455, avg_loss=0.2346]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2346, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.7s, LR: 3.13e-07
Progress: 68.0% | Best F1: 0.2273 | ETA: 33.1min

Epoch 35/50


Epoch 35/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 35/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.06it/s, loss=0.1457, avg_loss=0.2347]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2347, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.6s, LR: 1.56e-07
Progress: 70.0% | Best F1: 0.2273 | ETA: 31.0min

Epoch 36/50


Epoch 36/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 36/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.08it/s, loss=0.1459, avg_loss=0.2349]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2349, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.8s, LR: 1.56e-07
Progress: 72.0% | Best F1: 0.2273 | ETA: 29.0min

Epoch 37/50


Epoch 37/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0000, avg_loss=0.0002]
Epoch 37/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.18it/s, loss=0.1462, avg_loss=0.2350]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2350, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.7s, LR: 1.56e-07
Progress: 74.0% | Best F1: 0.2273 | ETA: 26.9min

Epoch 38/50


Epoch 38/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0002]
Epoch 38/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.08it/s, loss=0.1471, avg_loss=0.2354]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2354, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.9s, LR: 1.56e-07
Progress: 76.0% | Best F1: 0.2273 | ETA: 24.8min

Epoch 39/50


Epoch 39/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 39/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.13it/s, loss=0.1455, avg_loss=0.2349]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2349, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.8s, LR: 1.00e-07
Progress: 78.0% | Best F1: 0.2273 | ETA: 22.8min

Epoch 40/50


Epoch 40/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 40/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.14it/s, loss=0.1458, avg_loss=0.2352]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2352, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.7s, LR: 1.00e-07
Progress: 80.0% | Best F1: 0.2273 | ETA: 20.7min

Epoch 41/50


Epoch 41/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0005, avg_loss=0.0002]
Epoch 41/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.10it/s, loss=0.1456, avg_loss=0.2351]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2351, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.7s, LR: 1.00e-07
Progress: 82.0% | Best F1: 0.2273 | ETA: 18.6min

Epoch 42/50


Epoch 42/50 [Train]: 100%|██████████| 164/164 [02:01<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 42/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.12it/s, loss=0.1457, avg_loss=0.2352]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2352, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.6s, LR: 1.00e-07
Progress: 84.0% | Best F1: 0.2273 | ETA: 16.5min

Epoch 43/50


Epoch 43/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 43/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.11it/s, loss=0.1453, avg_loss=0.2352]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2352, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.7s, LR: 1.00e-07
Progress: 86.0% | Best F1: 0.2273 | ETA: 14.5min

Epoch 44/50


Epoch 44/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 44/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.18it/s, loss=0.1442, avg_loss=0.2348]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2348, F1: 0.1939, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 1.00e-07
Progress: 88.0% | Best F1: 0.2273 | ETA: 12.4min

Epoch 45/50


Epoch 45/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0006, avg_loss=0.0002]
Epoch 45/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.13it/s, loss=0.1441, avg_loss=0.2348]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2348, F1: 0.1939, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 1.00e-07
Progress: 90.0% | Best F1: 0.2273 | ETA: 10.3min

Epoch 46/50


Epoch 46/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 46/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.15it/s, loss=0.1439, avg_loss=0.2348]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2348, F1: 0.1939, Acc: 0.3718
Time  - Epoch: 123.6s, LR: 1.00e-07
Progress: 92.0% | Best F1: 0.2273 | ETA: 8.3min

Epoch 47/50


Epoch 47/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 47/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.08it/s, loss=0.1438, avg_loss=0.2349]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2349, F1: 0.1939, Acc: 0.3718
Time  - Epoch: 123.7s, LR: 1.00e-07
Progress: 94.0% | Best F1: 0.2273 | ETA: 6.2min

Epoch 48/50


Epoch 48/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0001, avg_loss=0.0002]
Epoch 48/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.14it/s, loss=0.1432, avg_loss=0.2348]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2348, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.7s, LR: 1.00e-07
Progress: 96.0% | Best F1: 0.2273 | ETA: 4.1min

Epoch 49/50


Epoch 49/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0002]
Epoch 49/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.10it/s, loss=0.1431, avg_loss=0.2349]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2349, F1: 0.1939, Acc: 0.3718
Time  - Epoch: 123.8s, LR: 1.00e-07
Progress: 98.0% | Best F1: 0.2273 | ETA: 2.1min

Epoch 50/50


Epoch 50/50 [Train]: 100%|██████████| 164/164 [02:02<00:00,  1.34it/s, loss=0.0002, avg_loss=0.0002]
Epoch 50/50 [Val]: 100%|██████████| 5/5 [00:01<00:00,  3.13it/s, loss=0.1451, avg_loss=0.2354]

Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2354, F1: 0.1722, Acc: 0.3590
Time  - Epoch: 123.7s, LR: 1.00e-07
Progress: 100.0% | Best F1: 0.2273 | ETA: 0.0min

CASME II MULTI-FRAME SEQUENCE ViT TRAINING COMPLETED
Training time: 103.3 minutes
Epochs completed: 50
Best validation F1: 0.2273 (epoch 11)
Final train F1: 1.0000
Final validation F1: 0.1722

Exporting enhanced training documentation...
Enhanced training documentation saved successfully: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/07_01_vit_casme2_mfs_prep/training_logs/casme2_vit_mfs_training_history.json
Experiment details: Optimized Focal Loss loss
  Gamma: 2.0, Alpha Sum: 0.999
Model variant: PATCH16
Patch size: 16px
Dataset version: v9

Next: Cell 3 - CASME II Multi-Frame Sequence ViT Evaluation
Enhanced training pipeline completed successfully!





In [3]:
# @title Cell 3: CASME II ViT Evaluation with Dual Dataset Support

# File: 07_01_ViT_CASME2_MFS_Cell3.py
# Location: experiments/07_01_ViT_CASME2-MFS-PREP.ipynb
# Purpose: Comprehensive evaluation framework with support for AF (v7) and KFS (v8) test datasets

import os
import time
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict

from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from concurrent.futures import ThreadPoolExecutor
import warnings
warnings.filterwarnings('ignore')

# =====================================================
# DUAL DATASET EVALUATION CONFIGURATION
# =====================================================
# Configure which test datasets to evaluate:
# 'v7' = Apex Frame preprocessing (28 samples, frame-level evaluation)
# 'v8' = Key Frame Sequence preprocessing (84 frames -> 28 videos with late fusion)

EVALUATE_DATASETS = ['v7', 'v8']  # Can be ['v7'], ['v8'], or ['v7', 'v8']

print("CASME II ViT Evaluation Framework with Dual Dataset Support")
print("=" * 60)
print(f"Datasets to evaluate: {EVALUATE_DATASETS}")
print("=" * 60)

# =====================================================
# DATASET CONFIGURATION FUNCTION
# =====================================================

def get_test_dataset_config(version, project_root):
    """
    Get test dataset configuration based on version

    Args:
        version: 'v7' (AF) or 'v8' (KFS)
        project_root: Project root path

    Returns:
        dict: Configuration for selected test dataset
    """
    if version == 'v7':
        config = {
            'version': 'v7',
            'variant': 'AF',
            'dataset_path': f"{project_root}/datasets/processed_casme2/preprocessed_v7",
            'preprocessing_summary': 'preprocessing_summary.json',
            'description': 'Apex Frame with Face-Aware Preprocessing',
            'expected_samples': 28,
            'frame_strategy': 'apex_frame',
            'evaluation_mode': 'frame_level',
            'aggregation': None
        }
    elif version == 'v8':
        config = {
            'version': 'v8',
            'variant': 'KFS',
            'dataset_path': f"{project_root}/datasets/processed_casme2/preprocessed_v8",
            'preprocessing_summary': 'preprocessing_summary.json',
            'description': 'Key Frame Sequence with Face-Aware Preprocessing',
            'expected_frames': 84,
            'expected_videos': 28,
            'frame_strategy': 'key_frame_sequence',
            'frame_types': ['onset', 'apex', 'offset'],
            'evaluation_mode': 'video_level',
            'aggregation': 'late_fusion'
        }
    else:
        raise ValueError(f"Invalid version: {version}. Must be 'v7' or 'v8'")

    return config

# =====================================================
# VIDEO ID EXTRACTION FOR KFS LATE FUSION
# =====================================================

def extract_video_id_from_filename(filename):
    """
    Extract video ID from KFS filename by removing frame type suffix

    Expected format: sub01_EP02_01f_happiness_onset.jpg
    Video ID: sub01_EP02_01f_happiness

    Args:
        filename: Image filename with frame type

    Returns:
        str: Video ID without frame type
    """
    # Remove file extension
    name_without_ext = filename.rsplit('.', 1)[0]

    # Remove frame type suffix (onset, apex, offset)
    for frame_type in ['onset', 'apex', 'offset']:
        if name_without_ext.endswith(f'_{frame_type}'):
            video_id = name_without_ext.rsplit(f'_{frame_type}', 1)[0]
            return video_id

    # If no frame type found, return as is
    return name_without_ext

# =====================================================
# ENHANCED TEST DATASET CLASS
# =====================================================

class CASME2DatasetEvaluation(Dataset):
    """Enhanced CASME II test dataset with comprehensive evaluation support"""

    def __init__(self, dataset_root, split, transform=None, use_ram_cache=True):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.filenames = []
        self.emotions = []
        self.video_ids = []
        self.cached_images = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading CASME II {split} dataset for evaluation...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        loaded_count = 0

        for filename in sorted(all_files):
            emotion_found = None
            name_without_ext = filename.rsplit('.', 1)[0]

            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                video_id = extract_video_id_from_filename(filename)

                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                self.emotions.append(emotion_found)
                self.video_ids.append(video_id)
                loaded_count += 1

        print(f"Loaded {len(self.images)} CASME II {split} samples for evaluation")
        self._print_evaluation_distribution()

        if self.use_ram_cache and len(self.images) > 0:
            self._preload_to_ram_evaluation()

    def _print_evaluation_distribution(self):
        """Print comprehensive distribution for evaluation analysis"""
        if len(self.labels) == 0:
            print("No test samples found!")
            return

        label_counts = {}
        unique_videos = set(self.video_ids)

        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        print("Test set class distribution:")
        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

        print(f"Unique video IDs: {len(unique_videos)}")

        missing_classes = []
        for i, class_name in enumerate(CASME2_CLASSES):
            if i not in label_counts:
                missing_classes.append(class_name)

        if missing_classes:
            print(f"Missing classes in test set: {missing_classes}")

    def _preload_to_ram_evaluation(self):
        """RAM preloading optimized for evaluation"""
        if len(self.images) == 0:
            return

        print(f"Preloading {len(self.images)} test images to RAM...")

        self.cached_images = [None] * len(self.images)

        def load_single_image(idx, img_path):
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
                return idx, image, True
            except:
                return idx, Image.new('RGB', (224, 224), (128, 128, 128)), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc="Loading test to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image

        print(f"RAM caching completed: {len(self.cached_images)} test images")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (224, 224), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx], self.video_ids[idx]

# =====================================================
# MODEL LOADING FUNCTION
# =====================================================

def load_trained_model_casme2(checkpoint_path, device):
    """Load trained ViT model from checkpoint"""
    print(f"Loading trained model from: {checkpoint_path}")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device)

    model = ViTCASME2Baseline(
        num_classes=7,
        dropout_rate=checkpoint['config']['dropout_rate']
    ).to(device)

    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    print(f"Model loaded successfully from epoch {checkpoint['epoch']}")

    training_info = {
        'best_epoch': checkpoint['epoch'],
        'best_val_f1': checkpoint['val_metrics']['f1_score'],
        'best_val_accuracy': checkpoint['val_metrics']['accuracy'],
        'config': checkpoint['config']
    }

    return model, training_info

# =====================================================
# FRAME-LEVEL INFERENCE (AF v7)
# =====================================================

def run_frame_level_inference(model, test_loader, device):
    """Run frame-level inference for AF evaluation"""
    print("Running frame-level inference...")

    model.eval()
    all_predictions = []
    all_labels = []
    all_filenames = []
    all_probabilities = []

    start_time = time.time()

    with torch.no_grad():
        for images, labels, filenames, video_ids in tqdm(test_loader, desc="Frame-level inference"):
            images = images.to(device)

            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)

            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_filenames.extend(filenames)
            all_probabilities.extend(probabilities.cpu().numpy())

    inference_time = time.time() - start_time

    return {
        'predictions': all_predictions,
        'labels': all_labels,
        'filenames': all_filenames,
        'probabilities': all_probabilities,
        'inference_time': inference_time,
        'evaluation_mode': 'frame_level'
    }

# =====================================================
# VIDEO-LEVEL INFERENCE WITH LATE FUSION (KFS v8)
# =====================================================

def run_video_level_inference_late_fusion(model, test_loader, device):
    """Run video-level inference with late fusion for KFS evaluation"""
    print("Running video-level inference with late fusion...")

    model.eval()

    # Collect frame-level predictions
    frame_predictions = []
    frame_labels = []
    frame_filenames = []
    frame_video_ids = []
    frame_probabilities = []

    start_time = time.time()

    with torch.no_grad():
        for images, labels, filenames, video_ids in tqdm(test_loader, desc="Frame inference"):
            images = images.to(device)

            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)

            _, predicted = torch.max(outputs, 1)

            frame_predictions.extend(predicted.cpu().numpy())
            frame_labels.extend(labels.numpy())
            frame_filenames.extend(filenames)
            frame_video_ids.extend(video_ids)
            frame_probabilities.extend(probabilities.cpu().numpy())

    # Aggregate to video-level using late fusion
    print("Aggregating frame predictions to video level...")

    video_data = {}
    for pred, label, filename, video_id, prob in zip(
        frame_predictions, frame_labels, frame_filenames, frame_video_ids, frame_probabilities
    ):
        if video_id not in video_data:
            video_data[video_id] = {
                'predictions': [],
                'probabilities': [],
                'label': label,
                'filenames': []
            }

        video_data[video_id]['predictions'].append(pred)
        video_data[video_id]['probabilities'].append(prob)
        video_data[video_id]['filenames'].append(filename)

    # Late fusion: average probabilities across frames
    video_predictions = []
    video_labels = []
    video_ids_list = []

    for video_id, data in video_data.items():
        avg_prob = np.mean(data['probabilities'], axis=0)
        final_pred = np.argmax(avg_prob)

        video_predictions.append(final_pred)
        video_labels.append(data['label'])
        video_ids_list.append(video_id)

    inference_time = time.time() - start_time

    return {
        'predictions': video_predictions,
        'labels': video_labels,
        'video_ids': video_ids_list,
        'inference_time': inference_time,
        'evaluation_mode': 'video_level',
        'kfs_late_fusion_info': {
            'total_frames': len(frame_predictions),
            'total_videos': len(video_predictions),
            'aggregation_method': 'average_probability'
        }
    }

# =====================================================
# COMPREHENSIVE METRICS CALCULATION
# =====================================================

def calculate_comprehensive_metrics(inference_results):
    """Calculate comprehensive evaluation metrics"""

    predictions = np.array(inference_results['predictions'])
    labels = np.array(inference_results['labels'])

    # Overall metrics
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions, average='macro', zero_division=0
    )

    # Per-class metrics
    per_class_precision, per_class_recall, per_class_f1, per_class_support = \
        precision_recall_fscore_support(labels, predictions, average=None, zero_division=0)

    # Confusion matrix
    cm = confusion_matrix(labels, predictions)

    # AUC calculation (if probabilities available)
    try:
        if 'probabilities' in inference_results:
            labels_bin = label_binarize(labels, classes=range(len(CASME2_CLASSES)))
            auc_scores = []

            for i in range(len(CASME2_CLASSES)):
                if labels_bin[:, i].sum() > 0:
                    fpr, tpr, _ = roc_curve(labels_bin[:, i],
                                           np.array(inference_results['probabilities'])[:, i])
                    auc_scores.append(auc(fpr, tpr))
                else:
                    auc_scores.append(0.0)

            macro_auc = np.mean([score for score in auc_scores if score > 0])
        else:
            auc_scores = [0.0] * len(CASME2_CLASSES)
            macro_auc = 0.0
    except:
        auc_scores = [0.0] * len(CASME2_CLASSES)
        macro_auc = 0.0

    # Check which classes are in test set
    unique_labels = set(labels)
    available_classes = [CASME2_CLASSES[i] for i in unique_labels]
    missing_classes = [cls for i, cls in enumerate(CASME2_CLASSES) if i not in unique_labels]

    # Build per-class performance dict
    per_class_performance = {}
    for i, class_name in enumerate(CASME2_CLASSES):
        per_class_performance[class_name] = {
            'precision': float(per_class_precision[i]) if i < len(per_class_precision) else 0.0,
            'recall': float(per_class_recall[i]) if i < len(per_class_recall) else 0.0,
            'f1_score': float(per_class_f1[i]) if i < len(per_class_f1) else 0.0,
            'support': int(per_class_support[i]) if i < len(per_class_support) else 0,
            'auc': float(auc_scores[i]) if i < len(auc_scores) else 0.0,
            'in_test_set': i in unique_labels
        }

    # Inference performance
    inference_performance = {
        'total_time_seconds': inference_results['inference_time'],
        'average_time_ms_per_sample': (inference_results['inference_time'] / len(predictions)) * 1000
    }

    results = {
        'evaluation_metadata': {
            'dataset': 'CASME_II',
            'model_type': 'ViTCASME2Baseline',
            'evaluation_timestamp': datetime.now().isoformat(),
            'evaluation_mode': inference_results['evaluation_mode'],
            'test_samples': len(predictions),
            'class_names': CASME2_CLASSES,
            'available_classes': available_classes,
            'missing_classes': missing_classes
        },
        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_precision': float(precision),
            'macro_recall': float(recall),
            'macro_f1': float(f1),
            'macro_auc': float(macro_auc)
        },
        'per_class_performance': per_class_performance,
        'confusion_matrix': cm.tolist(),
        'inference_performance': inference_performance
    }

    # Add KFS late fusion info if available
    if 'kfs_late_fusion_info' in inference_results:
        results['kfs_late_fusion_info'] = inference_results['kfs_late_fusion_info']

    return results

# =====================================================
# WRONG PREDICTIONS ANALYSIS
# =====================================================

def analyze_wrong_predictions(inference_results):
    """Analyze wrong predictions for error patterns"""

    predictions = np.array(inference_results['predictions'])
    labels = np.array(inference_results['labels'])

    if 'filenames' in inference_results:
        filenames = inference_results['filenames']
    elif 'video_ids' in inference_results:
        filenames = inference_results['video_ids']
    else:
        filenames = [f"sample_{i}" for i in range(len(predictions))]

    wrong_predictions = []
    wrong_by_class = defaultdict(int)
    confusion_patterns = defaultdict(int)

    for i, (pred, true_label) in enumerate(zip(predictions, labels)):
        if pred != true_label:
            true_class = CASME2_CLASSES[true_label]
            pred_class = CASME2_CLASSES[pred]

            wrong_predictions.append({
                'filename': filenames[i],
                'true_label': int(true_label),
                'true_class': true_class,
                'predicted_label': int(pred),
                'predicted_class': pred_class
            })

            wrong_by_class[true_class] += 1
            confusion_patterns[f"{true_class}->{pred_class}"] += 1

    error_summary = {}
    for class_name in CASME2_CLASSES:
        class_idx = CLASS_TO_IDX[class_name]
        class_mask = labels == class_idx
        class_total = class_mask.sum()

        if class_total > 0:
            class_wrong = wrong_by_class.get(class_name, 0)
            error_rate = (class_wrong / class_total) * 100
        else:
            class_wrong = 0
            error_rate = 0.0

        error_summary[class_name] = {
            'total_samples': int(class_total),
            'wrong_predictions': int(class_wrong),
            'error_rate_percent': float(error_rate)
        }

    results = {
        'analysis_metadata': {
            'total_samples': len(predictions),
            'total_wrong_predictions': len(wrong_predictions),
            'overall_error_rate': (len(wrong_predictions) / len(predictions) * 100) if len(predictions) > 0 else 0.0
        },
        'wrong_predictions': wrong_predictions,
        'wrong_predictions_by_class': dict(wrong_by_class),
        'error_summary': error_summary,
        'confusion_patterns': dict(confusion_patterns)
    }

    return results

# =====================================================
# SAVE EVALUATION RESULTS
# =====================================================

def save_evaluation_results(evaluation_results, wrong_predictions_results, results_dir, test_version):
    """Save comprehensive evaluation results"""
    os.makedirs(results_dir, exist_ok=True)

    results_file = f"{results_dir}/casme2_vit_evaluation_results_{test_version}.json"
    with open(results_file, 'w') as f:
        json.dump(evaluation_results, f, indent=2, default=str)

    wrong_predictions_file = f"{results_dir}/casme2_vit_wrong_predictions_{test_version}.json"
    with open(wrong_predictions_file, 'w') as f:
        json.dump(wrong_predictions_results, f, indent=2, default=str)

    print(f"Evaluation results saved:")
    print(f"  Main results: {os.path.basename(results_file)}")
    print(f"  Wrong predictions: {os.path.basename(wrong_predictions_file)}")

    return results_file, wrong_predictions_file

# =====================================================
# MAIN EVALUATION EXECUTION
# =====================================================

all_evaluation_results = {}

for dataset_version in EVALUATE_DATASETS:
    print("\n" + "=" * 70)
    print(f"EVALUATING DATASET: {dataset_version.upper()}")
    print("=" * 70)

    try:
        # Get dataset configuration
        test_config = get_test_dataset_config(dataset_version, PROJECT_ROOT)

        print(f"\nTest Dataset Configuration:")
        print(f"  Version: {test_config['version']}")
        print(f"  Variant: {test_config['variant']}")
        print(f"  Description: {test_config['description']}")
        print(f"  Frame strategy: {test_config['frame_strategy']}")
        print(f"  Evaluation mode: {test_config['evaluation_mode']}")
        if 'aggregation' in test_config and test_config['aggregation']:
            print(f"  Aggregation: {test_config['aggregation']}")
        print(f"  Dataset path: {test_config['dataset_path']}")

        # Create test dataset
        print(f"\nCreating CASME II test dataset from {test_config['variant']}...")
        test_dataset = CASME2DatasetEvaluation(
            dataset_root=test_config['dataset_path'],
            split='test',
            transform=GLOBAL_CONFIG_CASME2['transform_val'],
            use_ram_cache=True
        )

        if len(test_dataset) == 0:
            raise ValueError(f"No test samples found for {dataset_version}!")

        test_loader = DataLoader(
            test_dataset,
            batch_size=CASME2_VIT_CONFIG['batch_size'],
            shuffle=False,
            num_workers=CASME2_VIT_CONFIG['num_workers'],
            pin_memory=True
        )

        # Load trained model
        checkpoint_path = f"{GLOBAL_CONFIG_CASME2['checkpoint_root']}/casme2_vit_mfs_best_f1.pth"
        model, training_info = load_trained_model_casme2(checkpoint_path, GLOBAL_CONFIG_CASME2['device'])

        # Run inference based on evaluation mode
        if test_config['evaluation_mode'] == 'frame_level':
            print(f"\nRunning frame-level evaluation for {test_config['variant']}...")
            inference_results = run_frame_level_inference(model, test_loader, GLOBAL_CONFIG_CASME2['device'])

        elif test_config['evaluation_mode'] == 'video_level':
            print(f"\nRunning video-level evaluation with late fusion for {test_config['variant']}...")
            inference_results = run_video_level_inference_late_fusion(model, test_loader, GLOBAL_CONFIG_CASME2['device'])

        else:
            raise ValueError(f"Unknown evaluation mode: {test_config['evaluation_mode']}")

        # Calculate comprehensive metrics
        evaluation_results = calculate_comprehensive_metrics(inference_results)

        # Analyze wrong predictions
        wrong_predictions_results = analyze_wrong_predictions(inference_results)

        # Add training information
        evaluation_results['training_information'] = training_info
        evaluation_results['test_configuration'] = test_config

        # Save results
        results_dir = f"{GLOBAL_CONFIG_CASME2['results_root']}/evaluation_results"
        save_evaluation_results(
            evaluation_results, wrong_predictions_results, results_dir, test_config['version']
        )

        # Store for comparison
        all_evaluation_results[dataset_version] = {
            'evaluation': evaluation_results,
            'wrong_predictions': wrong_predictions_results,
            'config': test_config
        }

        # Display results
        print("\n" + "=" * 60)
        print(f"EVALUATION RESULTS - {test_config['variant']} ({dataset_version})")
        print("=" * 60)

        overall = evaluation_results['overall_performance']
        print(f"\nOverall Performance:")
        print(f"  Accuracy:  {overall['accuracy']:.4f}")
        print(f"  Precision: {overall['macro_precision']:.4f}")
        print(f"  Recall:    {overall['macro_recall']:.4f}")
        print(f"  F1 Score:  {overall['macro_f1']:.4f}")
        print(f"  AUC:       {overall['macro_auc']:.4f}")

        if 'kfs_late_fusion_info' in evaluation_results:
            fusion_info = evaluation_results['kfs_late_fusion_info']
            print(f"\nLate Fusion Info:")
            print(f"  Total frames processed: {fusion_info['total_frames']}")
            print(f"  Video-level predictions: {fusion_info['total_videos']}")
            print(f"  Aggregation method: {fusion_info['aggregation_method']}")

        print(f"\nPer-Class Performance:")
        for class_name, metrics in evaluation_results['per_class_performance'].items():
            in_test = "Present" if metrics['in_test_set'] else "Missing"
            print(f"  {class_name} [{in_test}]: F1={metrics['f1_score']:.4f}, "
                  f"Support={metrics['support']}")

        wrong_meta = wrong_predictions_results['analysis_metadata']
        print(f"\nWrong Predictions Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']} / {wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        print(f"\nInference Performance:")
        print(f"  Total time: {inference_results['inference_time']:.2f}s")
        print(f"  Speed: {evaluation_results['inference_performance']['average_time_ms_per_sample']:.1f} ms/sample")

    except Exception as e:
        print(f"Evaluation failed for {dataset_version}: {e}")
        import traceback
        traceback.print_exc()
        continue

# =====================================================
# COMPARATIVE ANALYSIS (if both datasets evaluated)
# =====================================================

if len(all_evaluation_results) == 2 and 'v7' in all_evaluation_results and 'v8' in all_evaluation_results:
    print("\n" + "=" * 70)
    print("COMPARATIVE ANALYSIS: AF (v7) vs KFS (v8)")
    print("=" * 70)

    v7_results = all_evaluation_results['v7']['evaluation']
    v8_results = all_evaluation_results['v8']['evaluation']

    print("\nOverall Performance Comparison:")
    print(f"{'Metric':<20} {'AF (v7)':<15} {'KFS (v8)':<15} {'Difference':<15}")
    print("-" * 65)

    metrics_to_compare = ['accuracy', 'macro_precision', 'macro_recall', 'macro_f1', 'macro_auc']

    for metric in metrics_to_compare:
        v7_val = v7_results['overall_performance'][metric]
        v8_val = v8_results['overall_performance'][metric]
        diff = v8_val - v7_val

        print(f"{metric:<20} {v7_val:<15.4f} {v8_val:<15.4f} {diff:+.4f}")

    print(f"\nEvaluation Modes:")
    print(f"  AF (v7): {v7_results['evaluation_metadata']['evaluation_mode']}")
    print(f"  KFS (v8): {v8_results['evaluation_metadata']['evaluation_mode']}")

    if 'kfs_late_fusion_info' in v8_results:
        print(f"\nKFS Late Fusion Strategy:")
        print(f"  Frames used: {v8_results['kfs_late_fusion_info']['total_frames']}")
        print(f"  Video predictions: {v8_results['kfs_late_fusion_info']['total_videos']}")
        print(f"  Aggregation: {v8_results['kfs_late_fusion_info']['aggregation_method']}")

# Memory cleanup
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\n" + "=" * 70)
print("CASME II ViT EVALUATION COMPLETED")
print("=" * 70)
print(f"Evaluated datasets: {EVALUATE_DATASETS}")
print("Next: Cell 4 - Generate confusion matrices and visualization")

CASME II ViT Evaluation Framework with Dual Dataset Support
Datasets to evaluate: ['v7', 'v8']

EVALUATING DATASET: V7

Test Dataset Configuration:
  Version: v7
  Variant: AF
  Description: Apex Frame with Face-Aware Preprocessing
  Frame strategy: apex_frame
  Evaluation mode: frame_level
  Dataset path: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v7

Creating CASME II test dataset from AF...
Loading CASME II test dataset for evaluation...
Found 28 image files in directory
Loaded 28 CASME II test samples for evaluation
Test set class distribution:
  others: 10 samples (35.7%)
  disgust: 7 samples (25.0%)
  happiness: 4 samples (14.3%)
  repression: 3 samples (10.7%)
  surprise: 3 samples (10.7%)
  sadness: 1 samples (3.6%)
Unique video IDs: 28
Missing classes in test set: ['fear']
Preloading 28 test images to RAM...


Loading test to RAM: 100%|██████████| 28/28 [00:01<00:00, 15.16it/s]


RAM caching completed: 28 test images
Loading trained model from: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/07_01_vit_casme2_mfs_prep/casme2_vit_mfs_best_f1.pth


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ViT feature dimension: 768
Classification head: 768 -> 512 -> 128 -> 7
Dropout rate: 0.3 (balanced for large dataset)
Model loaded successfully from epoch 11

Running frame-level evaluation for AF...
Running frame-level inference...


Frame-level inference: 100%|██████████| 2/2 [00:00<00:00,  2.33it/s]


Evaluation results saved:
  Main results: casme2_vit_evaluation_results_v7.json
  Wrong predictions: casme2_vit_wrong_predictions_v7.json

EVALUATION RESULTS - AF (v7)

Overall Performance:
  Accuracy:  0.5000
  Precision: 0.4097
  Recall:    0.3329
  F1 Score:  0.3393
  AUC:       0.7100

Per-Class Performance:
  others [Present]: F1=0.5833, Support=10
  disgust [Present]: F1=0.6667, Support=7
  happiness [Present]: F1=0.2857, Support=4
  repression [Present]: F1=0.0000, Support=3
  surprise [Present]: F1=0.5000, Support=3
  sadness [Present]: F1=0.0000, Support=1
  fear [Missing]: F1=0.0000, Support=0

Wrong Predictions Analysis:
  Total errors: 14 / 28
  Error rate: 50.00%

Inference Performance:
  Total time: 0.86s
  Speed: 30.8 ms/sample

EVALUATING DATASET: V8

Test Dataset Configuration:
  Version: v8
  Variant: KFS
  Description: Key Frame Sequence with Face-Aware Preprocessing
  Frame strategy: key_frame_sequence
  Evaluation mode: video_level
  Aggregation: late_fusion
  Data

Loading test to RAM: 100%|██████████| 84/84 [00:02<00:00, 28.29it/s]


RAM caching completed: 84 test images
Loading trained model from: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/07_01_vit_casme2_mfs_prep/casme2_vit_mfs_best_f1.pth


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ViT feature dimension: 768
Classification head: 768 -> 512 -> 128 -> 7
Dropout rate: 0.3 (balanced for large dataset)
Model loaded successfully from epoch 11

Running video-level evaluation with late fusion for KFS...
Running video-level inference with late fusion...


Frame inference: 100%|██████████| 6/6 [00:01<00:00,  3.43it/s]

Aggregating frame predictions to video level...
Evaluation results saved:
  Main results: casme2_vit_evaluation_results_v8.json
  Wrong predictions: casme2_vit_wrong_predictions_v8.json

EVALUATION RESULTS - KFS (v8)

Overall Performance:
  Accuracy:  0.5000
  Precision: 0.3736
  Recall:    0.3459
  F1 Score:  0.3476
  AUC:       0.0000

Late Fusion Info:
  Total frames processed: 84
  Video-level predictions: 84
  Aggregation method: average_probability

Per-Class Performance:
  others [Present]: F1=0.5714, Support=30
  disgust [Present]: F1=0.6522, Support=21
  happiness [Present]: F1=0.3000, Support=12
  repression [Present]: F1=0.1333, Support=9
  surprise [Present]: F1=0.4286, Support=9
  sadness [Present]: F1=0.0000, Support=3
  fear [Missing]: F1=0.0000, Support=0

Wrong Predictions Analysis:
  Total errors: 42 / 84
  Error rate: 50.00%

Inference Performance:
  Total time: 1.76s
  Speed: 20.9 ms/sample

COMPARATIVE ANALYSIS: AF (v7) vs KFS (v8)

Overall Performance Comparison:





In [4]:
# @title Cell 4: CASME II ViT Confusion Matrix Generation

# File: 07_01_ViT_CASME2_MFS_Cell4.py
# Location: experiments/07_01_ViT_CASME2-MFS-PREP.ipynb
# Purpose: Generate professional confusion matrix visualization for AF and KFS evaluations

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II ViT Confusion Matrix Generation")
print("=" * 60)

PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/07_01_vit_casme2_mfs_prep"

def find_evaluation_json_files(results_path):
    """Find evaluation JSON files with version detection"""
    json_files = {}
    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        for version in ['v7', 'v8']:
            eval_pattern = f"{eval_dir}/casme2_vit_evaluation_results_{version}.json"
            eval_files = glob.glob(eval_pattern)

            if eval_files:
                json_files[f'main_{version}'] = eval_files[0]
                print(f"Found {version.upper()} evaluation file: {os.path.basename(eval_files[0])}")

            wrong_pattern = f"{eval_dir}/casme2_vit_wrong_predictions_{version}.json"
            wrong_files = glob.glob(wrong_pattern)

            if wrong_files:
                json_files[f'wrong_{version}'] = wrong_files[0]
                print(f"Found {version.upper()} wrong predictions: {os.path.basename(wrong_files[0])}")

        if not json_files:
            print(f"WARNING: No evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (evaluation) has been executed first!")
    else:
        print(f"ERROR: Evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results(json_path):
    """Load and parse evaluation results JSON"""
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1(per_class_performance):
    """Calculate weighted F1 score"""
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy(confusion_matrix):
    """Calculate balanced accuracy handling classes with zero support"""
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []
    classes_with_samples = []

    for i in range(n_classes):
        if cm[i, :].sum() > 0:
            classes_with_samples.append(i)

    for i in classes_with_samples:
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - tp - fn - fp

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color(color_value, threshold=0.5):
    """Determine optimal text color based on background intensity"""
    return 'white' if color_value > threshold else 'black'

def create_confusion_matrix_plot(data, output_path, test_version):
    """Create professional confusion matrix visualization for CASME II micro-expression recognition"""

    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']

    test_config = data.get('test_configuration', {})
    variant = test_config.get('variant', test_version.upper())
    description = test_config.get('description', f'{test_version} preprocessing')
    eval_mode = meta.get('evaluation_mode', 'frame_level')

    print(f"Processing confusion matrix for {variant} ({test_version})")
    print(f"Dataset: {description}")
    print(f"Evaluation mode: {eval_mode}")
    print(f"Confusion matrix shape: {cm.shape}")

    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1(per_class)
    balanced_acc = calculate_balanced_accuracy(cm)

    print(f"Calculated metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Balanced Acc: {balanced_acc:.4f}, Accuracy: {accuracy:.4f}")

    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    fig, ax = plt.subplots(figsize=(12, 10))

    cmap = 'Blues'
    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"

            cell_value = cm_pct[i, j]
            text_color = determine_text_color(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    preprocessing_note = f"Preprocessing: {description}\n"
    preprocessing_note += f"Dataset: {test_version}\n"
    preprocessing_note += f"Evaluation: {eval_mode.replace('_', ' ').title()}"

    if 'kfs_late_fusion_info' in data:
        fusion_info = data['kfs_late_fusion_info']
        preprocessing_note += f"\nFrames: {fusion_info['total_frames']}, Videos: {fusion_info['total_videos']}"

    missing_classes = meta.get('missing_classes', [])
    if missing_classes:
        preprocessing_note += f"\nMissing: {', '.join(missing_classes)}"

    ax.text(0.02, 0.98, preprocessing_note, transform=ax.transAxes, fontsize=8,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    title = f"CASME II {variant} Micro-Expression Recognition - ViT\n"
    title += f"Acc: {accuracy:.4f}  |  Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes
    }

def generate_performance_summary(evaluation_data, wrong_predictions_data=None):
    """Generate comprehensive performance summary"""

    print("\n" + "=" * 60)
    print("CASME II MICRO-EXPRESSION RECOGNITION PERFORMANCE SUMMARY")
    print("=" * 60)

    overall = evaluation_data['overall_performance']
    meta = evaluation_data['evaluation_metadata']
    test_config = evaluation_data.get('test_configuration', {})

    variant = test_config.get('variant', 'N/A')

    print(f"Dataset: {meta['dataset']}")
    print(f"Variant: {variant}")
    print(f"Dataset version: {test_config.get('version', 'N/A')}")
    print(f"Preprocessing: {test_config.get('description', 'N/A')}")
    print(f"Test samples: {meta['test_samples']}")
    print(f"Model: {meta['model_type']}")
    print(f"Evaluation date: {meta['evaluation_timestamp']}")

    if 'kfs_late_fusion_info' in evaluation_data:
        fusion_info = evaluation_data['kfs_late_fusion_info']
        print(f"\nLate Fusion Information:")
        print(f"  Total frames: {fusion_info['total_frames']}")
        print(f"  Video predictions: {fusion_info['total_videos']}")
        print(f"  Aggregation: {fusion_info['aggregation_method']}")

    print(f"\nOverall Performance:")
    print(f"  Accuracy:         {overall['accuracy']:.4f}")
    print(f"  Macro Precision:  {overall['macro_precision']:.4f}")
    print(f"  Macro Recall:     {overall['macro_recall']:.4f}")
    print(f"  Macro F1:         {overall['macro_f1']:.4f}")
    print(f"  Macro AUC:        {overall['macro_auc']:.4f}")

    print(f"\nPer-Class Performance:")
    per_class = evaluation_data['per_class_performance']

    print(f"{'Class':<12} {'F1':<8} {'Precision':<10} {'Recall':<8} {'AUC':<8} {'Support':<8} {'In Test'}")
    print("-" * 65)

    for class_name, metrics in per_class.items():
        in_test = "Yes" if metrics['in_test_set'] else "No"
        print(f"{class_name:<12} {metrics['f1_score']:<8.4f} {metrics['precision']:<10.4f} "
              f"{metrics['recall']:<8.4f} {metrics['auc']:<8.4f} {metrics['support']:<8} {in_test}")

    if 'training_information' in evaluation_data:
        training = evaluation_data['training_information']
        print(f"\nTraining vs Test Comparison:")
        print(f"  Training Val F1:  {training['best_val_f1']:.4f}")
        print(f"  Test F1:          {overall['macro_f1']:.4f}")
        print(f"  Performance Gap:  {training['best_val_f1'] - overall['macro_f1']:+.4f}")
        print(f"  Best Epoch:       {training['best_epoch']}")

    missing_classes = meta.get('missing_classes', [])
    available_classes = meta.get('available_classes', [])

    print(f"\nClass Availability Analysis:")
    print(f"  Available classes: {len(available_classes)}/7")
    print(f"  Missing classes: {missing_classes if missing_classes else 'None'}")

    if wrong_predictions_data:
        wrong_meta = wrong_predictions_data['analysis_metadata']
        print(f"\nError Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']}/{wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        patterns = wrong_predictions_data.get('confusion_patterns', {})
        if patterns:
            print(f"\nTop Confusion Patterns:")
            sorted_patterns = sorted(patterns.items(), key=lambda x: x[1], reverse=True)[:3]
            for pattern, count in sorted_patterns:
                print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    inference = evaluation_data['inference_performance']
    print(f"  Total time: {inference['total_time_seconds']:.2f}s")
    print(f"  Speed: {inference['average_time_ms_per_sample']:.1f} ms/sample")

json_files = find_evaluation_json_files(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (evaluation) has been executed first!")
else:
    print(f"\nFound {len([k for k in json_files.keys() if k.startswith('main_')])} evaluation result(s)")

output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

results_summary = {}
generated_files = []

for version in ['v7', 'v8']:
    main_key = f'main_{version}'
    wrong_key = f'wrong_{version}'

    if main_key in json_files:
        print(f"\n{'='*60}")
        print(f"Processing {version.upper()} Evaluation Results")
        print(f"{'='*60}")

        eval_data = load_evaluation_results(json_files[main_key])

        wrong_data = None
        if wrong_key in json_files:
            wrong_data = load_evaluation_results(json_files[wrong_key])

        if eval_data is not None:
            try:
                cm_output_path = os.path.join(output_dir, f"confusion_matrix_CASME2_ViT_{version}.png")
                metrics = create_confusion_matrix_plot(eval_data, cm_output_path, version)
                generated_files.append(cm_output_path)

                results_summary[version] = metrics

                print(f"\nSUCCESS: {version.upper()} confusion matrix generated successfully")
                print(f"Output file: {os.path.basename(cm_output_path)}")

                print(f"\nPerformance Metrics Summary:")
                print(f"  Accuracy:        {metrics['accuracy']:.4f}")
                print(f"  Macro F1:        {metrics['macro_f1']:.4f}")
                print(f"  Weighted F1:     {metrics['weighted_f1']:.4f}")
                print(f"  Balanced Acc:    {metrics['balanced_accuracy']:.4f}")

                if metrics['missing_classes']:
                    print(f"  Missing classes: {metrics['missing_classes']}")

            except Exception as e:
                print(f"ERROR: Failed to generate {version.upper()} confusion matrix: {str(e)}")
                import traceback
                traceback.print_exc()

            generate_performance_summary(eval_data, wrong_data)
        else:
            print(f"ERROR: Could not load {version.upper()} evaluation data")

if generated_files:
    print(f"\n" + "=" * 60)
    print("CASME II CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 60)

    print(f"\nGenerated visualization files:")
    for file_path in generated_files:
        filename = os.path.basename(file_path)
        print(f"  {filename}")

    for version in ['v7', 'v8']:
        if version in results_summary:
            variant = 'AF' if version == 'v7' else 'KFS'
            print(f"\n{variant} ({version.upper()}) Performance Summary:")
            metrics = results_summary[version]
            print(f"  Accuracy:       {metrics['accuracy']:.4f}")
            print(f"  Macro F1:       {metrics['macro_f1']:.4f}")
            print(f"  Weighted F1:    {metrics['weighted_f1']:.4f}")
            print(f"  Balanced Acc:   {metrics['balanced_accuracy']:.4f}")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No visualizations were generated")
    print("Please check:")
    print("1. Cell 3 evaluation results exist")
    print("2. JSON file structure is correct")
    print("3. No file permission issues")

print("\nCell 4 completed - CASME II confusion matrix analysis generated")

CASME II ViT Confusion Matrix Generation
Found V7 evaluation file: casme2_vit_evaluation_results_v7.json
Found V7 wrong predictions: casme2_vit_wrong_predictions_v7.json
Found V8 evaluation file: casme2_vit_evaluation_results_v8.json
Found V8 wrong predictions: casme2_vit_wrong_predictions_v8.json

Found 2 evaluation result(s)

Processing V7 Evaluation Results
Successfully loaded: casme2_vit_evaluation_results_v7.json
Successfully loaded: casme2_vit_wrong_predictions_v7.json
Processing confusion matrix for AF (v7)
Dataset: Apex Frame with Face-Aware Preprocessing
Evaluation mode: frame_level
Confusion matrix shape: (6, 6)
Calculated metrics - Macro F1: 0.3393, Weighted F1: 0.4694, Balanced Acc: 0.6085, Accuracy: 0.5000
Confusion matrix saved to: confusion_matrix_CASME2_ViT_v7.png

SUCCESS: V7 confusion matrix generated successfully
Output file: confusion_matrix_CASME2_ViT_v7.png

Performance Metrics Summary:
  Accuracy:        0.5000
  Macro F1:        0.3393
  Weighted F1:     0.4694
