In [None]:
# @title Cell 1: CASME II DeiT Infrastructure Configuration

# File: 02_04_DeiT_Direct_Optimized_Baseline_Cell1.py
# Location: experiments/02_04_DeiT_Direct_Baseline.ipynb
# Purpose: Optimized DeiT for CASME II micro-expression recognition with distillation token mechanism and advanced class weight optimization

# Mount Google Drive
from google.colab import drive
print("=" * 60)
print("CASME II DEIT TRANSFORMER OPTIMIZED BASELINE INFRASTRUCTURE")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import timm
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Project paths configuration - updated for DeiT
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/data_split_v1"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/02_04_deit_casme2-af"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/02_04_deit_casme2-af"

# Load CASME II dataset metadata - preserved existing paths
METADATA_TRAIN = f"{DATASET_ROOT}/split_metadata.json"
PROCESSING_SUMMARY = f"{DATASET_ROOT}/processing_summary.json"

print("CASME II DeiT Optimized Baseline - Infrastructure Configuration")
print("=" * 70)

# Load dataset metadata
print("Loading CASME II dataset metadata...")
with open(METADATA_TRAIN, 'r') as f:
    casme2_metadata = json.load(f)

with open(PROCESSING_SUMMARY, 'r') as f:
    processing_info = json.load(f)

print(f"Dataset: {processing_info['dataset']}")
print(f"Total samples: {processing_info['total_samples']}")
print(f"Split strategy: {processing_info['split_strategy']}")

# =====================================================
# ADVANCED EXPERIMENT CONFIGURATION - DeiT Transformer Optimized Parameters
# =====================================================

# FOCAL LOSS CONFIGURATION - Toggle and Advanced Parameters
USE_FOCAL_LOSS = True  # Set True to enable Focal Loss, False for CrossEntropy
FOCAL_LOSS_GAMMA = 2.0  # Focal loss focusing parameter (typically 1.0 - 3.0)

# OPTIMIZED CLASS WEIGHTS CONFIGURATION - Inverse Square Root Frequency Approach
# CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
# Train distribution: [99, 63, 32, 27, 25, 7, 2] - inverse sqrt frequency approach

# CrossEntropy Loss - Optimized inverse square root frequency weights
CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.25, 1.76, 1.91, 1.99, 3.76, 7.04]

# Focal Loss - Normalized per-class alpha values (sum = 1.0)
FOCAL_LOSS_ALPHA_WEIGHTS = [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]

# DEIT TRANSFORMER MODEL CONFIGURATION - Support Small and Base distilled variants
DEIT_MODEL_VARIANT = 'base'  # Options: 'small' or 'base'

# Dynamic DeiT model selection based on variant
if DEIT_MODEL_VARIANT == 'small':
    DEIT_MODEL_NAME = 'facebook/deit-small-distilled-patch16-224'
    EXPECTED_HIDDEN_DIM = 384
    PATCH_SIZE = 16
    INPUT_SIZE = 384  # Upscaled for micro-expression detail preservation
    print("Using DeiT-Small-Distilled for efficient micro-expression analysis with distillation token")
elif DEIT_MODEL_VARIANT == 'base':
    DEIT_MODEL_NAME = 'facebook/deit-base-distilled-patch16-224'
    EXPECTED_HIDDEN_DIM = 768
    PATCH_SIZE = 16
    INPUT_SIZE = 384  # Upscaled for micro-expression detail preservation
    print("Using DeiT-Base-Distilled for advanced micro-expression recognition with distillation token")
else:
    raise ValueError(f"Unsupported DEIT_MODEL_VARIANT: {DEIT_MODEL_VARIANT}")

# Display experiment configuration
print("\n" + "=" * 50)
print("OPTIMIZED EXPERIMENT CONFIGURATION")
print("=" * 50)
print(f"Loss Function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
if USE_FOCAL_LOSS:
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Alpha Weights (per-class): {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum Validation: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Class Weights (inverse sqrt freq): {CROSSENTROPY_CLASS_WEIGHTS}")
print(f"DeiT Model: {DEIT_MODEL_NAME}")
print(f"Input Size: {INPUT_SIZE}x{INPUT_SIZE}")
print(f"Patch Size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"Expected Hidden Dim: {EXPECTED_HIDDEN_DIM}")
print(f"Distillation Token: Enabled")
print("=" * 50)

# Enhanced GPU configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

# Hardware-optimized batch size for 384px input with DeiT
if 'A100' in gpu_name:
    BATCH_SIZE = 24
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("A100: Optimized batch size for DeiT 384px")
elif 'L4' in gpu_name:
    BATCH_SIZE = 16
    NUM_WORKERS = 6
    torch.backends.cudnn.benchmark = True
    print("L4: Balanced performance configuration for DeiT")
else:
    BATCH_SIZE = 8
    NUM_WORKERS = 4
    print("Default GPU: Conservative settings for DeiT")

# CASME II class mapping and analysis - preserved existing structure
CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

# Analyze class distribution from metadata
train_dist = casme2_metadata['train']['class_distribution']
val_dist = casme2_metadata['val']['class_distribution']
test_dist = casme2_metadata['test']['class_distribution']

print(f"\nTrain distribution: {train_dist}")
print(f"Validation distribution: {val_dist}")
print(f"Test distribution: {test_dist}")

# Apply optimized class weights based on loss function selection
if USE_FOCAL_LOSS:
    # For Focal Loss - use normalized alpha weights (per-class importance)
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    # For CrossEntropy - use inverse sqrt frequency weights
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

# CASME II DeiT Transformer Optimized Configuration
CASME2_DEIT_CONFIG = {
    # Architecture configuration - DeiT Transformer specific
    'deit_model': DEIT_MODEL_NAME,
    'deit_variant': DEIT_MODEL_VARIANT,
    'input_size': INPUT_SIZE,
    'patch_size': PATCH_SIZE,
    'num_classes': 7,
    'dropout_rate': 0.2,
    'expected_hidden_dim': EXPECTED_HIDDEN_DIM,
    'use_distillation_token': True,
    'distillation_enabled': True,
    'interpolate_pos_encoding': True,

    # Training configuration (proven effective from medical imaging)
    'learning_rate': 1e-5,
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'device': device,

    # Scheduler configuration
    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_factor': 0.5,
    'scheduler_patience': 3,
    'scheduler_min_lr': 1e-6,
    'scheduler_monitor': 'val_f1_macro',

    # Optimized loss configuration
    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS,
    'class_weights': class_weights,

    # Evaluation configuration
    'use_macro_avg': True,
    'early_stopping': False,
    'save_best_f1': True,
    'save_strategy': 'best_only'
}

print(f"\nDeiT Transformer Configuration Summary:")
print(f"  Model: {CASME2_DEIT_CONFIG['deit_model']}")
print(f"  Variant: {CASME2_DEIT_CONFIG['deit_variant']}")
print(f"  Input size: {CASME2_DEIT_CONFIG['input_size']}px")
print(f"  Patch size: {CASME2_DEIT_CONFIG['patch_size']}")
print(f"  Learning rate: {CASME2_DEIT_CONFIG['learning_rate']}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Distillation token: {CASME2_DEIT_CONFIG['use_distillation_token']}")

# =====================================================
# ADVANCED FOCAL LOSS IMPLEMENTATION - Per-Class Alpha Support (Preserved from Swin)
# =====================================================

class OptimizedFocalLoss(nn.Module):
    """
    Advanced Focal Loss implementation with per-class alpha support
    Paper: "Focal Loss for Dense Object Detection" (Lin et al., 2017)

    Enhanced Formula: FL(p_t) = -α_t(1-p_t)^γ log(p_t)

    Args:
        alpha (list/tensor): Per-class alpha weights (must sum to 1.0)
        gamma (float): Focusing parameter for hard examples (default: 2.0)
        reduction (str): Reduction method ('mean', 'sum', 'none')
    """

    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()

        if alpha is not None:
            if isinstance(alpha, list):
                self.alpha = torch.tensor(alpha, dtype=torch.float32)
            else:
                self.alpha = alpha

            # Validation: alpha should sum to 1.0 for proper normalization
            alpha_sum = self.alpha.sum().item()
            if abs(alpha_sum - 1.0) > 0.01:
                print(f"Warning: Alpha weights sum to {alpha_sum:.3f}, expected 1.0")

        else:
            self.alpha = None

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Calculate cross entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')

        # Calculate p_t (probability of true class)
        pt = torch.exp(-ce_loss)

        # Apply per-class alpha if provided
        if self.alpha is not None:
            if self.alpha.device != targets.device:
                self.alpha = self.alpha.to(targets.device)
            alpha_t = self.alpha.gather(0, targets)
        else:
            alpha_t = 1.0

        # Apply focal loss formula: α_t(1-p_t)^γ * CE_loss
        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# FIXED: DeiT Architecture for CASME II with proven medical approach
class DeiTCASME2Baseline(nn.Module):
    """DeiT baseline for CASME II micro-expression recognition with distillation token mechanism - Medical approach"""

    def __init__(self, num_classes, dropout_rate=0.2):
        super(DeiTCASME2Baseline, self).__init__()

        # FIXED: Use DeiTModel + custom head (like medical reference) instead of DeiTForImageClassificationWithTeacher
        from transformers import DeiTModel

        self.deit = DeiTModel.from_pretrained(
            CASME2_DEIT_CONFIG['deit_model'],
            add_pooling_layer=False  # Use CLS token manually
        )

        # Enable fine-tuning for micro-expression domain
        for param in self.deit.parameters():
            param.requires_grad = True

        # Get DeiT feature dimensions
        self.deit_feature_dim = self.deit.config.hidden_size

        print(f"DeiT feature dimension: {self.deit_feature_dim}")
        print(f"DeiT distillation token: Available in model")

        # Verify expected dimensions
        if self.deit_feature_dim != CASME2_DEIT_CONFIG['expected_hidden_dim']:
            print(f"Warning: Expected {CASME2_DEIT_CONFIG['expected_hidden_dim']}, got {self.deit_feature_dim}")
            print(f"Note: DeiT-{CASME2_DEIT_CONFIG['deit_variant']} hidden_size: {self.deit_feature_dim}")

        # MEDICAL PROVEN: Classification head with LayerNorm for stability (same structure as medical reference)
        self.classifier_layers = nn.Sequential(
            nn.Linear(self.deit_feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
        )

        # Final classification layer for CASME II classes
        self.classifier = nn.Linear(128, num_classes)

        print(f"DeiT CASME II: {self.deit_feature_dim} -> 512 -> 128 -> {num_classes}")

    def forward(self, pixel_values):
        # FIXED: DeiT forward pass with position embedding interpolation (medical approach)
        deit_outputs = self.deit(
            pixel_values=pixel_values,
            interpolate_pos_encoding=True  # KEY: Enable position embedding interpolation for 384px
        )

        # MEDICAL PROVEN: Extract CLS token features (first token) from last hidden state
        deit_features = deit_outputs.last_hidden_state[:, 0]  # [batch, hidden_size]

        # Classification pipeline
        processed_features = self.classifier_layers(deit_features)
        output = self.classifier(processed_features)

        return output

# Enhanced optimizer and scheduler factory
def create_optimizer_scheduler_casme2(model, config):
    """Create optimizer and scheduler for CASME II DeiT training"""

    # AdamW optimizer with proven configuration
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    # ReduceLROnPlateau scheduler monitoring validation F1
    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

# FIXED: DeiT Image Processor setup for 384px input (medical approach)
from transformers import DeiTImageProcessor

print("\nSetting up DeiT Image Processor for 384px input...")

deit_processor = DeiTImageProcessor.from_pretrained(
    CASME2_DEIT_CONFIG['deit_model'],
    do_resize=False,     # CRITICAL: Don't resize to 224px - keep 384px (medical approach)
    do_normalize=True,   # Apply ImageNet normalization
    do_rescale=True,     # Rescale pixel values to [0,1]
    do_center_crop=False # No center crop - use full 384px image
)

# Transform functions for DeiT
def deit_transform_train(image):
    """Training transform with DeiT Image Processor"""
    inputs = deit_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

def deit_transform_val(image):
    """Validation transform with DeiT Image Processor"""
    inputs = deit_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

print(f"DeiT Image Processor configured for 384px with interpolate_pos_encoding")

# Custom Dataset class for CASME II - preserved existing structure
class CASME2Dataset(Dataset):
    """Custom dataset class for CASME II with JSON metadata support"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='train'):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split

        print(f"Loaded {len(self.metadata)} samples for {split} split")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        sample = self.metadata[idx]

        # Load image
        image_path = os.path.join(self.dataset_root, self.split, sample['image_filename'])
        image = Image.open(image_path).convert('RGB')

        # Apply transform
        if self.transform:
            image = self.transform(image)

        # Get class label
        emotion = sample['emotion']
        label = CLASS_TO_IDX[emotion]

        return image, label, sample['sample_id']

# Create directories - updated for DeiT
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

# Dataset paths - preserved existing structure
TRAIN_PATH = f"{DATASET_ROOT}/train"
VAL_PATH = f"{DATASET_ROOT}/val"
TEST_PATH = f"{DATASET_ROOT}/test"

print(f"\nDataset paths:")
print(f"Train: {TRAIN_PATH}")
print(f"Validation: {VAL_PATH}")
print(f"Test: {TEST_PATH}")

# FIXED: Enhanced architecture validation with DeiT distillation token
print("\nDeiT CASME II architecture validation...")

try:
    test_model = DeiTCASME2Baseline(num_classes=7, dropout_rate=0.2).to(device)
    test_input = torch.randn(1, 3, 384, 384).to(device)
    test_output = test_model(test_input)

    # Calculate expected patches for DeiT at 384px
    expected_patches = (CASME2_DEIT_CONFIG['input_size'] // CASME2_DEIT_CONFIG['patch_size']) ** 2  # 384/16 = 24, 24^2 = 576
    total_tokens = expected_patches + 2  # +2 for CLS and distillation tokens

    print(f"Validation successful: Output shape {test_output.shape}")
    print(f"Expected patches: {expected_patches}")
    print(f"Total tokens (CLS + DIST + patches): {total_tokens}")
    print(f"Patch size: {CASME2_DEIT_CONFIG['patch_size']}x{CASME2_DEIT_CONFIG['patch_size']}")
    print(f"Input resolution: {CASME2_DEIT_CONFIG['input_size']}x{CASME2_DEIT_CONFIG['input_size']}")
    print(f"Interpolate position encoding: {CASME2_DEIT_CONFIG['interpolate_pos_encoding']}")
    print(f"Distillation token: Active")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")

# Optimized loss function factory with advanced configuration (preserved from Swin)
def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    """
    Optimized factory function to create loss criterion based on advanced configuration

    Args:
        weights (Tensor): Class weights for CrossEntropy (ignored if focal loss used)
        use_focal_loss (bool): Whether to use Focal Loss or CrossEntropy
        alpha_weights (list): Per-class alpha weights for Focal Loss (must sum to 1.0)
        gamma (float): Focal loss gamma parameter

    Returns:
        Loss function (nn.Module)
    """
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

# Global configuration for training pipeline - enhanced for DeiT
GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': deit_transform_train,
    'transform_val': deit_transform_val,
    'deit_config': CASME2_DEIT_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'train_path': TRAIN_PATH,
    'val_path': VAL_PATH,
    'test_path': TEST_PATH,
    'metadata': casme2_metadata,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

# Configuration validation and summary
print("\n" + "=" * 60)
print("CASME II DEIT TRANSFORMER OPTIMIZED BASELINE CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: {DEIT_MODEL_NAME}")
print(f"  Variant: {DEIT_MODEL_VARIANT}")
print(f"  Input Resolution: {INPUT_SIZE}px")
print(f"  Patch Size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"  Hidden Dimension: {EXPECTED_HIDDEN_DIM}")
print(f"  Distillation Token: {CASME2_DEIT_CONFIG['use_distillation_token']}")
print(f"  Position Interpolation: {CASME2_DEIT_CONFIG['interpolate_pos_encoding']}")

print(f"\nDataset Configuration:")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Weight Optimization: {'Per-class Alpha' if USE_FOCAL_LOSS else 'Inverse Sqrt Frequency'}")
print(f"  Token Processing: CLS + Distillation + Patch Tokens")

print(f"\nArchitecture Highlights:")
print(f"  Medical-Proven Approach: DeiTModel + Custom Classification Head")
print(f"  Position Encoding: Interpolate from 224px -> 384px")
print(f"  Distillation Capability: Available through model architecture")
print(f"  Stability Features: LayerNorm + GELU + Proper Dropout")

print("\nNext: Cell 2 - Dataset Loading and DeiT Training Pipeline")

CASME II DEIT TRANSFORMER OPTIMIZED BASELINE INFRASTRUCTURE

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...
CASME II DeiT Optimized Baseline - Infrastructure Configuration
Loading CASME II dataset metadata...
Dataset: CASME2
Total samples: 255
Split strategy: stratified_80_10_10
Using DeiT-Base-Distilled for advanced micro-expression recognition with distillation token

OPTIMIZED EXPERIMENT CONFIGURATION
Loss Function: Focal Loss
  Gamma: 2.0
  Alpha Weights (per-class): [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum Validation: 0.999
DeiT Model: facebook/deit-base-distilled-patch16-224
Input Size: 384x384
Patch Size: 16x16
Expected Hidden Dim: 768
Distillation Token: Enabled

Device: cuda
GPU: NVIDIA L4 (23.8 GB)
L4: Balanced performance configuration for DeiT

Train distribution: {'others': 79, 'disgust': 50, 'happiness': 25, 'repression': 21, 'surprise': 20, 'sadness': 5, 'fear': 1}
Validat

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

DeiT Image Processor configured for 384px with interpolate_pos_encoding

Dataset paths:
Train: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split/train
Validation: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split/val
Test: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split/test

DeiT CASME II architecture validation...


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/349M [00:00<?, ?B/s]

DeiT feature dimension: 768
DeiT distillation token: Available in model
DeiT CASME II: 768 -> 512 -> 128 -> 7
Validation successful: Output shape torch.Size([1, 7])
Expected patches: 576
Total tokens (CLS + DIST + patches): 578
Patch size: 16x16
Input resolution: 384x384
Interpolate position encoding: True
Distillation token: Active

CASME II DEIT TRANSFORMER OPTIMIZED BASELINE CONFIGURATION COMPLETE
Loss Configuration:
  Function: Optimized Focal Loss
  Gamma: 2.0
  Per-class Alpha: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum: 0.999

Model Configuration:
  Architecture: facebook/deit-base-distilled-patch16-224
  Variant: base
  Input Resolution: 384px
  Patch Size: 16x16
  Hidden Dimension: 768
  Distillation Token: True
  Position Interpolation: True

Dataset Configuration:
  Classes: 7
  Weight Optimization: Per-class Alpha
  Token Processing: CLS + Distillation + Patch Tokens

Architecture Highlights:
  Medical-Proven Approach: DeiTModel + Custom Classification He

In [None]:
# @title Cell 2: CASME II DeiT Training Pipeline

# File: 02_04_DeiT_Direct_Enhanced_Baseline_Cell2.py
# Location: experiments/02_04_DeiT_Direct_Baseline.ipynb
# Purpose: Training pipeline for CASME II DeiT micro-expression recognition with distillation token mechanism

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp

print("CASME II DeiT Enhanced Training Pipeline with Fixed Checkpoints")
print("=" * 70)
print(f"Loss Function: {'Optimized Focal Loss' if CASME2_DEIT_CONFIG['use_focal_loss'] else 'CrossEntropy Loss'}")
if CASME2_DEIT_CONFIG['use_focal_loss']:
    print(f"Focal Loss Parameters:")
    print(f"  Gamma: {CASME2_DEIT_CONFIG['focal_loss_gamma']}")
    print(f"  Per-class Alpha: {CASME2_DEIT_CONFIG['focal_loss_alpha_weights']}")
    print(f"  Alpha Sum: {sum(CASME2_DEIT_CONFIG['focal_loss_alpha_weights']):.3f}")
else:
    print(f"CrossEntropy Parameters:")
    print(f"  Optimized Class Weights: {CASME2_DEIT_CONFIG['crossentropy_class_weights']}")
print(f"DeiT Architecture: {CASME2_DEIT_CONFIG['deit_variant']} variant with distillation token")
print(f"Input resolution: {CASME2_DEIT_CONFIG['input_size']}x{CASME2_DEIT_CONFIG['input_size']}")
print(f"Position interpolation: {CASME2_DEIT_CONFIG['interpolate_pos_encoding']}")
print(f"Training epochs: {CASME2_DEIT_CONFIG['num_epochs']}")
print(f"Scheduler patience: {CASME2_DEIT_CONFIG['scheduler_patience']}")

# Enhanced CASME II Dataset with clean RAM caching for DeiT
class CASME2DatasetTrainingDeiT(Dataset):
    """Enhanced CASME II dataset for DeiT training with clean RAM caching optimization"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='train', use_ram_cache=True):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.cached_images = []

        print(f"Loading CASME II {split} dataset for DeiT training...")

        # Process metadata
        for sample in self.metadata:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(CLASS_TO_IDX[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])

        print(f"Loaded {len(self.images)} CASME II {split} samples")
        self._print_distribution()

        # RAM caching for training efficiency
        if self.use_ram_cache:
            self._preload_to_ram()

    def _print_distribution(self):
        """Print class distribution"""
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        """Clean RAM preloading optimized for DeiT training at 384px"""
        print(f"Preloading {len(self.images)} {self.split} images to RAM for DeiT...")

        valid_images = 0
        self.cached_images = [None] * len(self.images)

        # Clean loading without verbose progress
        for i, img_path in enumerate(self.images):
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
                self.cached_images[i] = image
                valid_images += 1
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                self.cached_images[i] = Image.new('RGB', (384, 384), (128, 128, 128))

        ram_usage_gb = len(self.cached_images) * 384 * 384 * 3 * 4 / 1e9
        print(f"{self.split.upper()} RAM caching completed: {valid_images} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (384, 384), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.sample_ids[idx]

# Enhanced metrics calculation with comprehensive error handling
def calculate_metrics_safe_robust(outputs, labels, class_names, average='macro'):
    """Calculate metrics with enhanced error handling and validation"""
    try:
        # Validate input tensors
        if outputs.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: outputs {outputs.size(0)} vs labels {labels.size(0)}")

        if isinstance(outputs, torch.Tensor):
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
        else:
            predictions = np.array(outputs)

        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()
        else:
            labels = np.array(labels)

        # Validate predictions are in valid range
        unique_preds = np.unique(predictions)
        unique_labels = np.unique(labels)

        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions,
            average=average,
            zero_division=0,
            labels=list(range(len(class_names)))
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1)
        }
    except Exception as e:
        print(f"Warning: Enhanced metrics calculation error: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }

# FIXED: Enhanced training epoch function with DeiT fixed architecture compatibility
def train_epoch_deit(model, dataloader, criterion, optimizer, device, epoch, total_epochs):
    """Enhanced training epoch for DeiT with fixed architecture compatibility"""
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    progress_bar = tqdm(dataloader, desc=f"CASME II DeiT Training Epoch {epoch+1}/{total_epochs}")

    for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # FIXED: DeiT model output - now works with DeiTModel + custom head architecture
        outputs = model(images)

        # Validate output shape for 7 CASME II classes
        if outputs.dim() != 2 or outputs.size(1) != 7:
            raise ValueError(f"Invalid CASME II DeiT output shape: {outputs.shape}, expected [batch_size, 7]")

        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping for DeiT stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), CASME2_DEIT_CONFIG['gradient_clip'])

        optimizer.step()
        running_loss += loss.item()

        # Memory optimized: Move to CPU before accumulating
        all_outputs.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())

        # Update progress with DeiT-specific information
        if batch_idx % 5 == 0:
            avg_loss = running_loss / (batch_idx + 1)
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'Loss': f'{avg_loss:.4f}',
                'LR': f'{current_lr:.2e}',
                'DeiT': CASME2_DEIT_CONFIG['deit_variant']
            })

    # Enhanced metrics calculation with error recovery
    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: DeiT training metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics

# FIXED: Enhanced validation epoch function with DeiT fixed architecture compatibility
def validate_epoch_deit(model, dataloader, criterion, device, epoch, total_epochs):
    """Enhanced validation epoch for DeiT with fixed architecture compatibility"""
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []
    all_sample_ids = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f"CASME II DeiT Validation Epoch {epoch+1}/{total_epochs}")

        for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)

            # FIXED: DeiT model output - now works with DeiTModel + custom head architecture
            outputs = model(images)

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Memory optimized: Move to CPU before accumulating
            all_outputs.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sample_ids.extend(sample_ids)

            if batch_idx % 3 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                progress_bar.set_postfix({
                    'Val Loss': f'{avg_loss:.4f}',
                    'DeiT': CASME2_DEIT_CONFIG['deit_variant']
                })

    # Enhanced metrics calculation with error recovery
    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: DeiT validation metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics, all_sample_ids

# FIXED: Enhanced checkpoint saving function with complete device migration
def save_checkpoint_robust_fixed(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                                checkpoint_dir, best_metrics, config, max_retries=3):
    """FIXED: Enhanced checkpoint saving with complete device migration and error recovery"""

    print(f"Saving checkpoint with enhanced device migration...")

    # CRITICAL FIX: Complete device migration to CPU before serialization
    try:
        # 1. Move model state dict to CPU with proper tensor handling
        print("  Migrating model state dict to CPU...")
        model_state_cpu = {}
        for k, v in model.state_dict().items():
            if isinstance(v, torch.Tensor):
                model_state_cpu[k] = v.cpu().clone()
            else:
                model_state_cpu[k] = v

        # 2. Move optimizer state dict to CPU with nested tensor handling
        print("  Migrating optimizer state dict to CPU...")
        optimizer_state_cpu = {}
        for k, v in optimizer.state_dict().items():
            if isinstance(v, torch.Tensor):
                optimizer_state_cpu[k] = v.cpu().clone()
            elif isinstance(v, dict):
                # Handle nested dictionaries in optimizer state
                optimizer_state_cpu[k] = {}
                for nested_k, nested_v in v.items():
                    if isinstance(nested_v, torch.Tensor):
                        optimizer_state_cpu[k][nested_k] = nested_v.cpu().clone()
                    elif isinstance(nested_v, dict):
                        # Handle double-nested dictionaries (param_groups, etc.)
                        optimizer_state_cpu[k][nested_k] = {}
                        for deep_k, deep_v in nested_v.items():
                            if isinstance(deep_v, torch.Tensor):
                                optimizer_state_cpu[k][nested_k][deep_k] = deep_v.cpu().clone()
                            else:
                                optimizer_state_cpu[k][nested_k][deep_k] = deep_v
                    else:
                        optimizer_state_cpu[k][nested_k] = nested_v
            else:
                optimizer_state_cpu[k] = v

        # 3. Move scheduler state dict to CPU if exists
        scheduler_state_cpu = None
        if scheduler:
            print("  Migrating scheduler state dict to CPU...")
            scheduler_state_cpu = {}
            for k, v in scheduler.state_dict().items():
                if isinstance(v, torch.Tensor):
                    scheduler_state_cpu[k] = v.cpu().clone()
                else:
                    scheduler_state_cpu[k] = v

        # 4. Force GPU memory cleanup before serialization
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

        print("  Device migration completed successfully")

    except Exception as e:
        print(f"ERROR: Device migration failed: {e}")
        return None

    # Convert all metrics to serializable format
    def make_serializable_enhanced(obj):
        """Enhanced serialization with complete tensor handling"""
        if isinstance(obj, torch.Tensor):
            return obj.cpu().item() if obj.numel() == 1 else obj.cpu().tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.integer, np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, dict):
            return {k: make_serializable_enhanced(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [make_serializable_enhanced(item) for item in obj]
        else:
            try:
                return float(obj) if isinstance(obj, (int, float)) else str(obj)
            except:
                return str(obj)

    # Create checkpoint with CPU-migrated state dicts
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state_cpu,  # Now on CPU
        'optimizer_state_dict': optimizer_state_cpu,  # Now on CPU
        'scheduler_state_dict': scheduler_state_cpu,  # Now on CPU or None
        'train_metrics': make_serializable_enhanced(train_metrics),
        'val_metrics': make_serializable_enhanced(val_metrics),
        'casme2_deit_config': make_serializable_enhanced(config),
        'best_f1': float(best_metrics['f1']),
        'best_loss': float(best_metrics['loss']),
        'best_acc': float(best_metrics['accuracy']),
        'class_names': CASME2_CLASSES,
        'num_classes': 7,
        'deit_variant': config['deit_variant'],
        'deit_model': config['deit_model']
    }

    best_path = f"{checkpoint_dir}/casme2_deit_direct_best_f1.pth"

    # Enhanced save with retry logic and proper file handling
    for attempt in range(max_retries):
        try:
            # Create temporary file first
            temp_path = f"{best_path}.tmp"

            print(f"  Attempt {attempt + 1}: Saving to temporary file...")
            torch.save(checkpoint, temp_path)

            # Move temporary file to final location (atomic operation)
            import shutil
            shutil.move(temp_path, best_path)

            print(f"Checkpoint saved successfully: {os.path.basename(best_path)}")
            print(f"  DeiT variant: {config['deit_variant']}")
            print(f"  Model: {config['deit_model']}")
            return best_path

        except Exception as e:
            print(f"Checkpoint save attempt {attempt + 1} failed: {e}")

            # Clean up temporary file if it exists
            if os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except:
                    pass

            if attempt < max_retries - 1:
                print(f"  Retrying in 2 seconds...")
                time.sleep(2)  # Brief pause before retry

                # Additional memory cleanup before retry
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()
                continue
            else:
                print(f"All {max_retries} checkpoint save attempts failed")
                return None

    return None

# Safe JSON serialization function - enhanced for DeiT
def safe_json_serialize_deit(obj):
    """Convert objects to JSON-serializable format with DeiT-specific handling"""
    if isinstance(obj, torch.Tensor):
        return obj.cpu().item() if obj.numel() == 1 else obj.cpu().numpy().tolist()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: safe_json_serialize_deit(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [safe_json_serialize_deit(item) for item in obj]
    elif hasattr(obj, '__dict__'):
        return safe_json_serialize_deit(obj.__dict__)
    else:
        try:
            return float(obj) if isinstance(obj, (int, float)) else str(obj)
        except:
            return str(obj)

# Create enhanced datasets for DeiT
print("\nCreating CASME II DeiT training datasets...")

train_dataset = CASME2DatasetTrainingDeiT(
    split_metadata=GLOBAL_CONFIG_CASME2['metadata'],
    dataset_root=GLOBAL_CONFIG_CASME2['train_path'].replace('/train', ''),
    transform=GLOBAL_CONFIG_CASME2['transform_train'],
    split='train',
    use_ram_cache=True
)

val_dataset = CASME2DatasetTrainingDeiT(
    split_metadata=GLOBAL_CONFIG_CASME2['metadata'],
    dataset_root=GLOBAL_CONFIG_CASME2['val_path'].replace('/val', ''),
    transform=GLOBAL_CONFIG_CASME2['transform_val'],
    split='val',
    use_ram_cache=True
)

# Create data loaders with DeiT-optimized settings
train_loader = DataLoader(
    train_dataset,
    batch_size=CASME2_DEIT_CONFIG['batch_size'],
    shuffle=True,
    num_workers=CASME2_DEIT_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CASME2_DEIT_CONFIG['batch_size'],
    shuffle=False,
    num_workers=CASME2_DEIT_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

print(f"Training batches: {len(train_loader)} (samples: {len(train_dataset)})")
print(f"Validation batches: {len(val_loader)} (samples: {len(val_dataset)})")

# Initialize DeiT model, criterion, optimizer, scheduler
print("\nInitializing CASME II DeiT enhanced model...")
model = DeiTCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_DEIT_CONFIG['dropout_rate']
).to(GLOBAL_CONFIG_CASME2['device'])

# Enhanced criterion creation using configurable factory function
if CASME2_DEIT_CONFIG['use_focal_loss']:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=True,
        alpha_weights=CASME2_DEIT_CONFIG['focal_loss_alpha_weights'],
        gamma=CASME2_DEIT_CONFIG['focal_loss_gamma']
    )
else:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=False,
        alpha_weights=None,
        gamma=2.0
    )

optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_DEIT_CONFIG
)

print(f"Optimizer: AdamW (LR={CASME2_DEIT_CONFIG['learning_rate']})")
print(f"Scheduler: ReduceLROnPlateau (patience={CASME2_DEIT_CONFIG['scheduler_patience']})")
print(f"Criterion: {'Optimized Focal Loss' if CASME2_DEIT_CONFIG['use_focal_loss'] else 'CrossEntropy'}")
print(f"DeiT Architecture: {CASME2_DEIT_CONFIG['deit_variant']} variant with position interpolation")

# Training history tracking for DeiT
training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

# Enhanced best metrics tracking for multi-criteria checkpoint saving
best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

print("\nStarting CASME II DeiT enhanced training with fixed checkpoints...")
print(f"Training configuration: {CASME2_DEIT_CONFIG['num_epochs']} epochs")
print(f"Position interpolation: 224px -> 384px with DeiTModel + custom head")
print("=" * 70)

# Main training loop with enhanced checkpoint reliability for DeiT
start_time = time.time()

for epoch in range(CASME2_DEIT_CONFIG['num_epochs']):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{CASME2_DEIT_CONFIG['num_epochs']}")

    # Training phase with DeiT-specific adaptations
    train_loss, train_metrics = train_epoch_deit(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_DEIT_CONFIG['num_epochs']
    )

    # Validation phase with DeiT-specific adaptations
    val_loss, val_metrics, val_sample_ids = validate_epoch_deit(
        model, val_loader, criterion,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_DEIT_CONFIG['num_epochs']
    )

    # Update scheduler
    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    # Record training history
    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    # Print epoch summary
    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    # Enhanced multi-criteria checkpoint saving logic
    save_model = False
    improvement_reason = ""

    # Multi-criteria evaluation hierarchy
    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        # Use FIXED checkpoint saving function
        best_model_path = save_checkpoint_robust_fixed(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics, GLOBAL_CONFIG_CASME2['checkpoint_root'],
            best_metrics, CASME2_DEIT_CONFIG
        )

        if best_model_path:
            print(f"New best DeiT model: {improvement_reason} - F1: {best_metrics['f1']:.4f}")
        else:
            print(f"Warning: Failed to save DeiT checkpoint despite improvement")

    # Progress tracking
    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * CASME2_DEIT_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / CASME2_DEIT_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min | DeiT-{CASME2_DEIT_CONFIG['deit_variant']}")

# Training completion
total_time = time.time() - start_time
actual_epochs = CASME2_DEIT_CONFIG['num_epochs']

print("\n" + "=" * 70)
print("CASME II DEIT TRANSFORMER ENHANCED BASELINE TRAINING COMPLETED")
print("=" * 70)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {actual_epochs}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")
print(f"DeiT Architecture: {CASME2_DEIT_CONFIG['deit_variant']} variant with position interpolation")

# Enhanced training documentation export for DeiT
results_dir = GLOBAL_CONFIG_CASME2['results_root']
os.makedirs(f"{results_dir}/training_logs", exist_ok=True)

training_history_path = f"{results_dir}/training_logs/casme2_deit_direct_training_history.json"

print("\nExporting enhanced DeiT training documentation...")

try:
    # Create comprehensive training summary with DeiT-specific configuration
    training_summary = {
        'experiment_type': 'CASME2_DeiT_Enhanced_Baseline',
        'experiment_configuration': {
            'loss_function': 'Optimized Focal Loss' if CASME2_DEIT_CONFIG['use_focal_loss'] else 'CrossEntropy',
            'weight_approach': 'Per-class Alpha (sum=1.0)' if CASME2_DEIT_CONFIG['use_focal_loss'] else 'Inverse Sqrt Frequency',
            'focal_loss_gamma': CASME2_DEIT_CONFIG['focal_loss_gamma'],
            'focal_loss_alpha_weights': CASME2_DEIT_CONFIG['focal_loss_alpha_weights'],
            'crossentropy_class_weights': CASME2_DEIT_CONFIG['crossentropy_class_weights'],
            'deit_model': CASME2_DEIT_CONFIG['deit_model'],
            'deit_variant': CASME2_DEIT_CONFIG['deit_variant'],
            'input_size': CASME2_DEIT_CONFIG['input_size'],
            'patch_size': CASME2_DEIT_CONFIG['patch_size'],
            'distillation_token': CASME2_DEIT_CONFIG['use_distillation_token'],
            'interpolate_pos_encoding': CASME2_DEIT_CONFIG['interpolate_pos_encoding']
        },
        'training_history': safe_json_serialize_deit(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(actual_epochs),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'config': safe_json_serialize_deit(CASME2_DEIT_CONFIG),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'model_checkpoint': 'casme2_deit_direct_best_f1.pth',
        'dataset_info': {
            'name': 'CASME_II',
            'total_samples': 255,
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': 7,
            'class_names': CASME2_CLASSES
        },
        'architecture_info': {
            'model_type': 'DeiTCASME2Baseline',
            'backbone': CASME2_DEIT_CONFIG['deit_model'],
            'variant': CASME2_DEIT_CONFIG['deit_variant'],
            'input_size': f"{CASME2_DEIT_CONFIG['input_size']}x{CASME2_DEIT_CONFIG['input_size']}",
            'patch_size': f"{CASME2_DEIT_CONFIG['patch_size']}x{CASME2_DEIT_CONFIG['patch_size']}",
            'hidden_dim': CASME2_DEIT_CONFIG['expected_hidden_dim'],
            'distillation_token': CASME2_DEIT_CONFIG['use_distillation_token'],
            'position_interpolation': CASME2_DEIT_CONFIG['interpolate_pos_encoding'],
            'architecture_approach': 'DeiTModel + Custom Classification Head (Medical Proven)',
            'total_tokens': f"CLS + Distillation + {(CASME2_DEIT_CONFIG['input_size']//CASME2_DEIT_CONFIG['patch_size'])**2} patches"
        },
        'enhanced_features': {
            'medical_proven_architecture': True,
            'position_encoding_interpolation': True,
            'fixed_checkpoint_saving': True,
            'device_migration_complete': True,
            'robust_error_handling': True,
            'multi_criteria_checkpoint_logic': True,
            'memory_optimized_training': True,
            'clean_dataset_loading': True,
            'retry_checkpoint_logic': True
        }
    }

    # Save with proper JSON serialization
    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Enhanced DeiT training documentation saved: {training_history_path}")
    print(f"Experiment details: {training_summary['experiment_configuration']['loss_function']} loss")
    if CASME2_DEIT_CONFIG['use_focal_loss']:
        print(f"  Gamma: {CASME2_DEIT_CONFIG['focal_loss_gamma']}, Alpha Sum: {sum(CASME2_DEIT_CONFIG['focal_loss_alpha_weights']):.3f}")
    print(f"Model variant: {CASME2_DEIT_CONFIG['deit_model']} ({CASME2_DEIT_CONFIG['deit_variant']})")
    print(f"Architecture approach: DeiTModel + Custom Head (Medical Proven)")
    print(f"Position interpolation: {CASME2_DEIT_CONFIG['interpolate_pos_encoding']}")
    print(f"Checkpoint saving: FIXED with complete device migration")

except Exception as e:
    print(f"Warning: Could not save DeiT training documentation: {e}")
    print("Training completed successfully, but documentation export failed")

# Enhanced memory cleanup for DeiT
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\nNext: Cell 3 - CASME II DeiT Enhanced Evaluation")
print("Enhanced DeiT training pipeline completed successfully!")

CASME II DeiT Enhanced Training Pipeline with Fixed Checkpoints
Loss Function: Optimized Focal Loss
Focal Loss Parameters:
  Gamma: 2.0
  Per-class Alpha: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum: 0.999
DeiT Architecture: base variant with distillation token
Input resolution: 384x384
Position interpolation: True
Training epochs: 50
Scheduler patience: 3

Creating CASME II DeiT training datasets...
Loading CASME II train dataset for DeiT training...
Loaded 201 CASME II train samples
  others: 79 samples (39.3%)
  disgust: 50 samples (24.9%)
  happiness: 25 samples (12.4%)
  repression: 21 samples (10.4%)
  surprise: 20 samples (10.0%)
  sadness: 5 samples (2.5%)
  fear: 1 samples (0.5%)
Preloading 201 train images to RAM for DeiT...


model.safetensors:   0%|          | 0.00/349M [00:00<?, ?B/s]

TRAIN RAM caching completed: 201 images, ~0.36GB
Loading CASME II val dataset for DeiT training...
Loaded 26 CASME II val samples
  others: 10 samples (38.5%)
  disgust: 6 samples (23.1%)
  happiness: 3 samples (11.5%)
  repression: 3 samples (11.5%)
  surprise: 2 samples (7.7%)
  sadness: 1 samples (3.8%)
  fear: 1 samples (3.8%)
Preloading 26 val images to RAM for DeiT...
VAL RAM caching completed: 26 images, ~0.05GB
Training batches: 13 (samples: 201)
Validation batches: 2 (samples: 26)

Initializing CASME II DeiT enhanced model...
DeiT feature dimension: 768
DeiT distillation token: Available in model
DeiT CASME II: 768 -> 512 -> 128 -> 7
Using Optimized Focal Loss with gamma=2.0
Per-class alpha weights: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
Alpha sum: 0.999
Scheduler: ReduceLROnPlateau monitoring val_f1_macro
Optimizer: AdamW (LR=1e-05)
Scheduler: ReduceLROnPlateau (patience=3)
Criterion: Optimized Focal Loss
DeiT Architecture: base variant with position interpolation


CASME II DeiT Training Epoch 1/50: 100%|██████████| 13/13 [00:10<00:00,  1.23it/s, Loss=0.1076, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 1/50: 100%|██████████| 2/2 [00:00<00:00,  2.71it/s, Val Loss=0.0475, DeiT=base]


Train - Loss: 0.1044, F1: 0.1245, Acc: 0.2388
Val   - Loss: 0.1552, F1: 0.1076, Acc: 0.3077
Time  - Epoch: 11.4s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_deit_direct_best_f1.pth
  DeiT variant: base
  Model: facebook/deit-base-distilled-patch16-224
New best DeiT model: Higher F1 - F1: 0.1076
Progress: 2.0% | Best F1: 0.1076 | ETA: 11.3min | DeiT-base

Epoch 2/50


CASME II DeiT Training Epoch 2/50: 100%|██████████| 13/13 [00:09<00:00,  1.35it/s, Loss=0.0808, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 2/50: 100%|██████████| 2/2 [00:00<00:00,  2.43it/s, Val Loss=0.0476, DeiT=base]


Train - Loss: 0.0866, F1: 0.1755, Acc: 0.3980
Val   - Loss: 0.1510, F1: 0.2459, Acc: 0.3846
Time  - Epoch: 10.5s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_deit_direct_best_f1.pth
  DeiT variant: base
  Model: facebook/deit-base-distilled-patch16-224
New best DeiT model: Higher F1 - F1: 0.2459
Progress: 4.0% | Best F1: 0.2459 | ETA: 10.7min | DeiT-base

Epoch 3/50


CASME II DeiT Training Epoch 3/50: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s, Loss=0.0768, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 3/50: 100%|██████████| 2/2 [00:00<00:00,  2.37it/s, Val Loss=0.0496, DeiT=base]


Train - Loss: 0.0753, F1: 0.2678, Acc: 0.4179
Val   - Loss: 0.1435, F1: 0.2022, Acc: 0.3462
Time  - Epoch: 10.7s, LR: 1.00e-05
Progress: 6.0% | Best F1: 0.2459 | ETA: 9.8min | DeiT-base

Epoch 4/50


CASME II DeiT Training Epoch 4/50: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s, Loss=0.0685, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 4/50: 100%|██████████| 2/2 [00:00<00:00,  2.37it/s, Val Loss=0.0461, DeiT=base]


Train - Loss: 0.0672, F1: 0.3470, Acc: 0.4876
Val   - Loss: 0.1479, F1: 0.2173, Acc: 0.3462
Time  - Epoch: 10.7s, LR: 1.00e-05
Progress: 8.0% | Best F1: 0.2459 | ETA: 9.3min | DeiT-base

Epoch 5/50


CASME II DeiT Training Epoch 5/50: 100%|██████████| 13/13 [00:10<00:00,  1.29it/s, Loss=0.0590, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 5/50: 100%|██████████| 2/2 [00:00<00:00,  2.35it/s, Val Loss=0.0436, DeiT=base]


Train - Loss: 0.0592, F1: 0.4234, Acc: 0.5423
Val   - Loss: 0.1467, F1: 0.2007, Acc: 0.3846
Time  - Epoch: 10.9s, LR: 1.00e-05
Progress: 10.0% | Best F1: 0.2459 | ETA: 8.9min | DeiT-base

Epoch 6/50


CASME II DeiT Training Epoch 6/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0526, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 6/50: 100%|██████████| 2/2 [00:00<00:00,  2.34it/s, Val Loss=0.0425, DeiT=base]


Train - Loss: 0.0524, F1: 0.4825, Acc: 0.6269
Val   - Loss: 0.1464, F1: 0.2667, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_deit_direct_best_f1.pth
  DeiT variant: base
  Model: facebook/deit-base-distilled-patch16-224
New best DeiT model: Higher F1 - F1: 0.2667
Progress: 12.0% | Best F1: 0.2667 | ETA: 9.0min | DeiT-base

Epoch 7/50


CASME II DeiT Training Epoch 7/50: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s, Loss=0.0479, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 7/50: 100%|██████████| 2/2 [00:00<00:00,  2.27it/s, Val Loss=0.0422, DeiT=base]


Train - Loss: 0.0502, F1: 0.5003, Acc: 0.6269
Val   - Loss: 0.1466, F1: 0.3030, Acc: 0.4231
Time  - Epoch: 10.8s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_deit_direct_best_f1.pth
  DeiT variant: base
  Model: facebook/deit-base-distilled-patch16-224
New best DeiT model: Higher F1 - F1: 0.3030
Progress: 14.0% | Best F1: 0.3030 | ETA: 8.9min | DeiT-base

Epoch 8/50


CASME II DeiT Training Epoch 8/50: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s, Loss=0.0430, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 8/50: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s, Val Loss=0.0411, DeiT=base]


Train - Loss: 0.0434, F1: 0.5417, Acc: 0.6816
Val   - Loss: 0.1508, F1: 0.2060, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-05
Progress: 16.0% | Best F1: 0.3030 | ETA: 8.5min | DeiT-base

Epoch 9/50


CASME II DeiT Training Epoch 9/50: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s, Loss=0.0391, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 9/50: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s, Val Loss=0.0420, DeiT=base]


Train - Loss: 0.0378, F1: 0.7425, Acc: 0.7313
Val   - Loss: 0.1500, F1: 0.1857, Acc: 0.3077
Time  - Epoch: 10.9s, LR: 1.00e-05
Progress: 18.0% | Best F1: 0.3030 | ETA: 8.2min | DeiT-base

Epoch 10/50


CASME II DeiT Training Epoch 10/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0349, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 10/50: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s, Val Loss=0.0394, DeiT=base]


Train - Loss: 0.0347, F1: 0.6428, Acc: 0.7711
Val   - Loss: 0.1541, F1: 0.1887, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-05
Progress: 20.0% | Best F1: 0.3030 | ETA: 8.0min | DeiT-base

Epoch 11/50


CASME II DeiT Training Epoch 11/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0280, LR=1.00e-05, DeiT=base]
CASME II DeiT Validation Epoch 11/50: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s, Val Loss=0.0399, DeiT=base]


Train - Loss: 0.0287, F1: 0.8594, Acc: 0.8358
Val   - Loss: 0.1528, F1: 0.1857, Acc: 0.3077
Time  - Epoch: 10.9s, LR: 5.00e-06
Progress: 22.0% | Best F1: 0.3030 | ETA: 7.7min | DeiT-base

Epoch 12/50


CASME II DeiT Training Epoch 12/50: 100%|██████████| 13/13 [00:10<00:00,  1.30it/s, Loss=0.0252, LR=5.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 12/50: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s, Val Loss=0.0400, DeiT=base]


Train - Loss: 0.0259, F1: 0.8746, Acc: 0.8756
Val   - Loss: 0.1550, F1: 0.1857, Acc: 0.3077
Time  - Epoch: 11.0s, LR: 5.00e-06
Progress: 24.0% | Best F1: 0.3030 | ETA: 7.5min | DeiT-base

Epoch 13/50


CASME II DeiT Training Epoch 13/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0229, LR=5.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 13/50: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s, Val Loss=0.0410, DeiT=base]


Train - Loss: 0.0228, F1: 0.8788, Acc: 0.8706
Val   - Loss: 0.1514, F1: 0.1755, Acc: 0.3077
Time  - Epoch: 10.9s, LR: 5.00e-06
Progress: 26.0% | Best F1: 0.3030 | ETA: 7.2min | DeiT-base

Epoch 14/50


CASME II DeiT Training Epoch 14/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0212, LR=5.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 14/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0401, DeiT=base]


Train - Loss: 0.0211, F1: 0.9182, Acc: 0.9154
Val   - Loss: 0.1527, F1: 0.1857, Acc: 0.3077
Time  - Epoch: 11.0s, LR: 5.00e-06
Progress: 28.0% | Best F1: 0.3030 | ETA: 7.0min | DeiT-base

Epoch 15/50


CASME II DeiT Training Epoch 15/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0195, LR=5.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 15/50: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s, Val Loss=0.0399, DeiT=base]


Train - Loss: 0.0196, F1: 0.9074, Acc: 0.8955
Val   - Loss: 0.1542, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 11.0s, LR: 2.50e-06
Progress: 30.0% | Best F1: 0.3030 | ETA: 6.8min | DeiT-base

Epoch 16/50


CASME II DeiT Training Epoch 16/50: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s, Loss=0.0173, LR=2.50e-06, DeiT=base]
CASME II DeiT Validation Epoch 16/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0399, DeiT=base]


Train - Loss: 0.0172, F1: 0.9158, Acc: 0.9254
Val   - Loss: 0.1515, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 2.50e-06
Progress: 32.0% | Best F1: 0.3030 | ETA: 6.6min | DeiT-base

Epoch 17/50


CASME II DeiT Training Epoch 17/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0160, LR=2.50e-06, DeiT=base]
CASME II DeiT Validation Epoch 17/50: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s, Val Loss=0.0383, DeiT=base]


Train - Loss: 0.0167, F1: 0.9345, Acc: 0.9353
Val   - Loss: 0.1554, F1: 0.1755, Acc: 0.3077
Time  - Epoch: 11.0s, LR: 2.50e-06
Progress: 34.0% | Best F1: 0.3030 | ETA: 6.3min | DeiT-base

Epoch 18/50


CASME II DeiT Training Epoch 18/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0147, LR=2.50e-06, DeiT=base]
CASME II DeiT Validation Epoch 18/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0392, DeiT=base]


Train - Loss: 0.0156, F1: 0.9699, Acc: 0.9701
Val   - Loss: 0.1543, F1: 0.1911, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 2.50e-06
Progress: 36.0% | Best F1: 0.3030 | ETA: 6.1min | DeiT-base

Epoch 19/50


CASME II DeiT Training Epoch 19/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0139, LR=2.50e-06, DeiT=base]
CASME II DeiT Validation Epoch 19/50: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s, Val Loss=0.0393, DeiT=base]


Train - Loss: 0.0150, F1: 0.9835, Acc: 0.9701
Val   - Loss: 0.1530, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.25e-06
Progress: 38.0% | Best F1: 0.3030 | ETA: 5.9min | DeiT-base

Epoch 20/50


CASME II DeiT Training Epoch 20/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0146, LR=1.25e-06, DeiT=base]
CASME II DeiT Validation Epoch 20/50: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s, Val Loss=0.0388, DeiT=base]


Train - Loss: 0.0148, F1: 0.9770, Acc: 0.9652
Val   - Loss: 0.1542, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.25e-06
Progress: 40.0% | Best F1: 0.3030 | ETA: 5.7min | DeiT-base

Epoch 21/50


CASME II DeiT Training Epoch 21/50: 100%|██████████| 13/13 [00:10<00:00,  1.30it/s, Loss=0.0144, LR=1.25e-06, DeiT=base]
CASME II DeiT Validation Epoch 21/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0396, DeiT=base]


Train - Loss: 0.0140, F1: 0.9793, Acc: 0.9701
Val   - Loss: 0.1551, F1: 0.1911, Acc: 0.3462
Time  - Epoch: 11.0s, LR: 1.25e-06
Progress: 42.0% | Best F1: 0.3030 | ETA: 5.5min | DeiT-base

Epoch 22/50


CASME II DeiT Training Epoch 22/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0137, LR=1.25e-06, DeiT=base]
CASME II DeiT Validation Epoch 22/50: 100%|██████████| 2/2 [00:00<00:00,  2.11it/s, Val Loss=0.0391, DeiT=base]


Train - Loss: 0.0137, F1: 0.9839, Acc: 0.9701
Val   - Loss: 0.1552, F1: 0.1911, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.25e-06
Progress: 44.0% | Best F1: 0.3030 | ETA: 5.3min | DeiT-base

Epoch 23/50


CASME II DeiT Training Epoch 23/50: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s, Loss=0.0128, LR=1.25e-06, DeiT=base]
CASME II DeiT Validation Epoch 23/50: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s, Val Loss=0.0395, DeiT=base]


Train - Loss: 0.0130, F1: 0.9743, Acc: 0.9751
Val   - Loss: 0.1552, F1: 0.1911, Acc: 0.3462
Time  - Epoch: 10.8s, LR: 1.00e-06
Progress: 46.0% | Best F1: 0.3030 | ETA: 5.1min | DeiT-base

Epoch 24/50


CASME II DeiT Training Epoch 24/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0129, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 24/50: 100%|██████████| 2/2 [00:00<00:00,  2.02it/s, Val Loss=0.0395, DeiT=base]


Train - Loss: 0.0131, F1: 0.9915, Acc: 0.9851
Val   - Loss: 0.1551, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 48.0% | Best F1: 0.3030 | ETA: 4.9min | DeiT-base

Epoch 25/50


CASME II DeiT Training Epoch 25/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0126, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 25/50: 100%|██████████| 2/2 [00:00<00:00,  2.02it/s, Val Loss=0.0399, DeiT=base]


Train - Loss: 0.0129, F1: 0.9790, Acc: 0.9851
Val   - Loss: 0.1550, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 50.0% | Best F1: 0.3030 | ETA: 4.7min | DeiT-base

Epoch 26/50


CASME II DeiT Training Epoch 26/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0116, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 26/50: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s, Val Loss=0.0393, DeiT=base]


Train - Loss: 0.0117, F1: 0.9916, Acc: 0.9851
Val   - Loss: 0.1541, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 52.0% | Best F1: 0.3030 | ETA: 4.5min | DeiT-base

Epoch 27/50


CASME II DeiT Training Epoch 27/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0137, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 27/50: 100%|██████████| 2/2 [00:00<00:00,  2.11it/s, Val Loss=0.0385, DeiT=base]


Train - Loss: 0.0142, F1: 0.9856, Acc: 0.9751
Val   - Loss: 0.1555, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 54.0% | Best F1: 0.3030 | ETA: 4.3min | DeiT-base

Epoch 28/50


CASME II DeiT Training Epoch 28/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0115, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 28/50: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s, Val Loss=0.0386, DeiT=base]


Train - Loss: 0.0116, F1: 0.9977, Acc: 0.9950
Val   - Loss: 0.1553, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 56.0% | Best F1: 0.3030 | ETA: 4.1min | DeiT-base

Epoch 29/50


CASME II DeiT Training Epoch 29/50: 100%|██████████| 13/13 [00:09<00:00,  1.33it/s, Loss=0.0130, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 29/50: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s, Val Loss=0.0389, DeiT=base]


Train - Loss: 0.0131, F1: 0.9893, Acc: 0.9801
Val   - Loss: 0.1559, F1: 0.1911, Acc: 0.3462
Time  - Epoch: 10.8s, LR: 1.00e-06
Progress: 58.0% | Best F1: 0.3030 | ETA: 3.9min | DeiT-base

Epoch 30/50


CASME II DeiT Training Epoch 30/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0125, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 30/50: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s, Val Loss=0.0395, DeiT=base]


Train - Loss: 0.0123, F1: 0.9954, Acc: 0.9950
Val   - Loss: 0.1561, F1: 0.1911, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 60.0% | Best F1: 0.3030 | ETA: 3.8min | DeiT-base

Epoch 31/50


CASME II DeiT Training Epoch 31/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0112, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 31/50: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s, Val Loss=0.0396, DeiT=base]


Train - Loss: 0.0114, F1: 0.9909, Acc: 0.9851
Val   - Loss: 0.1562, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 62.0% | Best F1: 0.3030 | ETA: 3.6min | DeiT-base

Epoch 32/50


CASME II DeiT Training Epoch 32/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0106, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 32/50: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s, Val Loss=0.0393, DeiT=base]


Train - Loss: 0.0108, F1: 0.9909, Acc: 0.9851
Val   - Loss: 0.1553, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 64.0% | Best F1: 0.3030 | ETA: 3.4min | DeiT-base

Epoch 33/50


CASME II DeiT Training Epoch 33/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0107, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 33/50: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s, Val Loss=0.0393, DeiT=base]


Train - Loss: 0.0107, F1: 0.9953, Acc: 0.9900
Val   - Loss: 0.1543, F1: 0.2130, Acc: 0.3846
Time  - Epoch: 11.0s, LR: 1.00e-06
Progress: 66.0% | Best F1: 0.3030 | ETA: 3.2min | DeiT-base

Epoch 34/50


CASME II DeiT Training Epoch 34/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0110, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 34/50: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s, Val Loss=0.0392, DeiT=base]


Train - Loss: 0.0113, F1: 0.9929, Acc: 0.9851
Val   - Loss: 0.1551, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 68.0% | Best F1: 0.3030 | ETA: 3.0min | DeiT-base

Epoch 35/50


CASME II DeiT Training Epoch 35/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0101, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 35/50: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s, Val Loss=0.0388, DeiT=base]


Train - Loss: 0.0100, F1: 0.9793, Acc: 0.9900
Val   - Loss: 0.1554, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 70.0% | Best F1: 0.3030 | ETA: 2.8min | DeiT-base

Epoch 36/50


CASME II DeiT Training Epoch 36/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0109, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 36/50: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s, Val Loss=0.0383, DeiT=base]


Train - Loss: 0.0105, F1: 0.9930, Acc: 0.9851
Val   - Loss: 0.1552, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 72.0% | Best F1: 0.3030 | ETA: 2.6min | DeiT-base

Epoch 37/50


CASME II DeiT Training Epoch 37/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0092, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 37/50: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s, Val Loss=0.0390, DeiT=base]


Train - Loss: 0.0096, F1: 0.9977, Acc: 0.9950
Val   - Loss: 0.1549, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 11.0s, LR: 1.00e-06
Progress: 74.0% | Best F1: 0.3030 | ETA: 2.4min | DeiT-base

Epoch 38/50


CASME II DeiT Training Epoch 38/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0098, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 38/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0390, DeiT=base]


Train - Loss: 0.0111, F1: 0.9903, Acc: 0.9851
Val   - Loss: 0.1554, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 76.0% | Best F1: 0.3030 | ETA: 2.2min | DeiT-base

Epoch 39/50


CASME II DeiT Training Epoch 39/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0101, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 39/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0385, DeiT=base]


Train - Loss: 0.0100, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1564, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 78.0% | Best F1: 0.3030 | ETA: 2.1min | DeiT-base

Epoch 40/50


CASME II DeiT Training Epoch 40/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0099, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 40/50: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s, Val Loss=0.0388, DeiT=base]


Train - Loss: 0.0099, F1: 0.9953, Acc: 0.9900
Val   - Loss: 0.1557, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 80.0% | Best F1: 0.3030 | ETA: 1.9min | DeiT-base

Epoch 41/50


CASME II DeiT Training Epoch 41/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0093, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 41/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0394, DeiT=base]


Train - Loss: 0.0094, F1: 0.9954, Acc: 0.9900
Val   - Loss: 0.1555, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 82.0% | Best F1: 0.3030 | ETA: 1.7min | DeiT-base

Epoch 42/50


CASME II DeiT Training Epoch 42/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0093, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 42/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0392, DeiT=base]


Train - Loss: 0.0093, F1: 0.9892, Acc: 0.9801
Val   - Loss: 0.1557, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 11.0s, LR: 1.00e-06
Progress: 84.0% | Best F1: 0.3030 | ETA: 1.5min | DeiT-base

Epoch 43/50


CASME II DeiT Training Epoch 43/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0092, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 43/50: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s, Val Loss=0.0386, DeiT=base]


Train - Loss: 0.0090, F1: 0.9954, Acc: 0.9950
Val   - Loss: 0.1558, F1: 0.2130, Acc: 0.3846
Time  - Epoch: 11.0s, LR: 1.00e-06
Progress: 86.0% | Best F1: 0.3030 | ETA: 1.3min | DeiT-base

Epoch 44/50


CASME II DeiT Training Epoch 44/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0095, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 44/50: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s, Val Loss=0.0388, DeiT=base]


Train - Loss: 0.0093, F1: 0.9977, Acc: 0.9950
Val   - Loss: 0.1558, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 88.0% | Best F1: 0.3030 | ETA: 1.1min | DeiT-base

Epoch 45/50


CASME II DeiT Training Epoch 45/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0088, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 45/50: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s, Val Loss=0.0391, DeiT=base]


Train - Loss: 0.0088, F1: 0.9977, Acc: 0.9950
Val   - Loss: 0.1557, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 90.0% | Best F1: 0.3030 | ETA: 0.9min | DeiT-base

Epoch 46/50


CASME II DeiT Training Epoch 46/50: 100%|██████████| 13/13 [00:09<00:00,  1.30it/s, Loss=0.0089, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 46/50: 100%|██████████| 2/2 [00:00<00:00,  2.09it/s, Val Loss=0.0393, DeiT=base]


Train - Loss: 0.0086, F1: 0.9977, Acc: 0.9950
Val   - Loss: 0.1549, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 92.0% | Best F1: 0.3030 | ETA: 0.7min | DeiT-base

Epoch 47/50


CASME II DeiT Training Epoch 47/50: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s, Loss=0.0080, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 47/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0388, DeiT=base]


Train - Loss: 0.0080, F1: 0.9977, Acc: 0.9950
Val   - Loss: 0.1562, F1: 0.1966, Acc: 0.3462
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 94.0% | Best F1: 0.3030 | ETA: 0.6min | DeiT-base

Epoch 48/50


CASME II DeiT Training Epoch 48/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0083, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 48/50: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s, Val Loss=0.0384, DeiT=base]


Train - Loss: 0.0083, F1: 0.9856, Acc: 0.9950
Val   - Loss: 0.1561, F1: 0.2130, Acc: 0.3846
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 96.0% | Best F1: 0.3030 | ETA: 0.4min | DeiT-base

Epoch 49/50


CASME II DeiT Training Epoch 49/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0095, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 49/50: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s, Val Loss=0.0387, DeiT=base]


Train - Loss: 0.0093, F1: 0.9953, Acc: 0.9900
Val   - Loss: 0.1557, F1: 0.2130, Acc: 0.3846
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 98.0% | Best F1: 0.3030 | ETA: 0.2min | DeiT-base

Epoch 50/50


CASME II DeiT Training Epoch 50/50: 100%|██████████| 13/13 [00:09<00:00,  1.31it/s, Loss=0.0079, LR=1.00e-06, DeiT=base]
CASME II DeiT Validation Epoch 50/50: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s, Val Loss=0.0386, DeiT=base]

Train - Loss: 0.0081, F1: 0.9977, Acc: 0.9950
Val   - Loss: 0.1563, F1: 0.2130, Acc: 0.3846
Time  - Epoch: 10.9s, LR: 1.00e-06
Progress: 100.0% | Best F1: 0.3030 | ETA: 0.0min | DeiT-base

CASME II DEIT TRANSFORMER ENHANCED BASELINE TRAINING COMPLETED
Training time: 9.3 minutes
Epochs completed: 50
Best validation F1: 0.3030 (epoch 7)
Final train F1: 0.9977
Final validation F1: 0.2130
DeiT Architecture: base variant with position interpolation

Exporting enhanced DeiT training documentation...
Enhanced DeiT training documentation saved: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/02_04_deit_casme2-af/training_logs/casme2_deit_direct_training_history.json
Experiment details: Optimized Focal Loss loss
  Gamma: 2.0, Alpha Sum: 0.999
Model variant: facebook/deit-base-distilled-patch16-224 (base)
Architecture approach: DeiTModel + Custom Head (Medical Proven)
Position interpolation: True
Checkpoint saving: FIXED with complete device migration

Next: 




In [None]:
# @title Cell 3: CASME II DeiT Direct Baseline Evaluation

# File: 02_04_DeiT_Direct_Baseline_Cell3.py
# Location: experiments/02_04_DeiT_Direct_Baseline.ipynb
# Purpose: Evaluation framework for trained CASME II DeiT micro-expression recognition model with distillation token

import os
import time
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime

# Evaluation specific imports
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import pickle
import warnings
warnings.filterwarnings('ignore')

# Enhanced test dataset for CASME II DeiT evaluation with clean loading
class CASME2DatasetEvaluationDeiT(Dataset):
    """Enhanced CASME II test dataset with clean evaluation support for DeiT"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='test', use_ram_cache=True):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.emotions = []
        self.subjects = []
        self.cached_images = []

        print(f"Loading CASME II {split} dataset for DeiT evaluation...")

        # Process metadata for evaluation
        for sample in self.metadata:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(CLASS_TO_IDX[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])
            self.emotions.append(sample['emotion'])
            self.subjects.append(sample['subject'])

        print(f"Loaded {len(self.images)} CASME II {split} samples for DeiT evaluation")
        self._print_evaluation_distribution()

        # RAM caching for fast evaluation
        if self.use_ram_cache:
            self._preload_to_ram_evaluation()

    def _print_evaluation_distribution(self):
        """Print comprehensive distribution for evaluation analysis"""
        if len(self.labels) == 0:
            print("No test samples found!")
            return

        label_counts = {}
        subject_counts = {}

        for label, subject in zip(self.labels, self.subjects):
            label_counts[label] = label_counts.get(label, 0) + 1
            subject_counts[subject] = subject_counts.get(subject, 0) + 1

        print("Test set class distribution:")
        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

        print(f"Test set covers {len(subject_counts)} subjects")

        # Check for missing classes
        missing_classes = []
        for i, class_name in enumerate(CASME2_CLASSES):
            if i not in label_counts:
                missing_classes.append(class_name)

        if missing_classes:
            print(f"Missing classes in test set: {missing_classes}")

    def _preload_to_ram_evaluation(self):
        """Clean RAM preloading optimized for DeiT evaluation at 384px"""
        print(f"Preloading {len(self.images)} test images to RAM for DeiT evaluation...")

        valid_images = 0
        self.cached_images = [None] * len(self.images)

        # Clean loading for evaluation stability
        for i, img_path in enumerate(self.images):
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
                self.cached_images[i] = image
                valid_images += 1
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                # Create neutral placeholder
                self.cached_images[i] = Image.new('RGB', (384, 384), (128, 128, 128))

        ram_usage_gb = len(self.cached_images) * 384 * 384 * 3 * 4 / 1e9
        print(f"Test RAM caching completed: {valid_images} valid images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (384, 384), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return (image, self.labels[idx], self.sample_ids[idx],
                self.emotions[idx], self.subjects[idx], os.path.basename(self.images[idx]))

# CASME II DeiT evaluation configuration
EVALUATION_CONFIG_CASME2_DEIT = {
    'model_type': 'DeiT_CASME2_Direct_Baseline',
    'task_type': 'micro_expression_recognition',
    'num_classes': 7,
    'class_names': CASME2_CLASSES,
    'checkpoint_file': 'casme2_deit_direct_best_f1.pth',
    'dataset_name': 'CASME_II',
    'input_size': '384x384',
    'evaluation_protocol': 'stratified_split',
    'architecture': 'deit_transformer_distillation'
}

print("CASME II DeiT Direct Baseline Evaluation Framework")
print("=" * 65)
print(f"Model: {EVALUATION_CONFIG_CASME2_DEIT['model_type']}")
print(f"Task: {EVALUATION_CONFIG_CASME2_DEIT['task_type']}")
print(f"Classes: {EVALUATION_CONFIG_CASME2_DEIT['class_names']}")
print(f"Input size: {EVALUATION_CONFIG_CASME2_DEIT['input_size']}")
print(f"Architecture: {EVALUATION_CONFIG_CASME2_DEIT['architecture']}")

# FIXED: Simplified logits extraction for DeiT fixed architecture
def extract_logits_safe_casme2_deit(outputs_all):
    """Simplified logits extraction for CASME II DeiT fixed architecture (DeiTModel + custom head)"""
    # FIXED: With DeiTModel + custom head, output is directly a tensor
    if isinstance(outputs_all, torch.Tensor):
        return outputs_all
    # Fallback for other structures
    if isinstance(outputs_all, (tuple, list)):
        return outputs_all[0]
    if isinstance(outputs_all, dict):
        for key in ('logits', 'prediction', 'outputs'):
            value = outputs_all.get(key)
            if isinstance(value, torch.Tensor):
                return value
    raise RuntimeError("Unable to extract tensor logits from CASME II DeiT model output")

def load_trained_model_casme2_deit(checkpoint_path, device):
    """Load trained CASME II DeiT model with fixed architecture compatibility"""
    print(f"Loading trained CASME II DeiT model from: {checkpoint_path}")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"DeiT checkpoint not found: {checkpoint_path}")

    # Multiple loading approaches for maximum compatibility
    checkpoint = None
    loading_method = "unknown"

    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        loading_method = "standard"
    except Exception as e1:
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            loading_method = "weights_only_false"
        except Exception as e2:
            try:
                import pickle
                with open(checkpoint_path, 'rb') as f:
                    checkpoint = pickle.load(f)
                loading_method = "pickle"
            except Exception as e3:
                raise RuntimeError(f"All DeiT loading methods failed: {e1}, {e2}, {e3}")

    print(f"DeiT checkpoint loaded using: {loading_method}")

    # Extract DeiT-specific configuration
    deit_config = checkpoint.get('casme2_deit_config', {})
    deit_variant = checkpoint.get('deit_variant', 'small')
    deit_model_name = checkpoint.get('deit_model', 'facebook/deit-small-distilled-patch16-224')

    print(f"Detected DeiT variant: {deit_variant}")
    print(f"Detected DeiT model: {deit_model_name}")
    print(f"Architecture approach: DeiTModel + Custom Head (Medical Proven)")

    # FIXED: Initialize CASME II DeiT model with fixed architecture
    model = DeiTCASME2Baseline(
        num_classes=EVALUATION_CONFIG_CASME2_DEIT['num_classes'],
        dropout_rate=deit_config.get('dropout_rate', 0.2)
    ).to(device)

    # Load state dict with fallback approaches
    state_dict = checkpoint.get('model_state_dict', checkpoint)

    try:
        model.load_state_dict(state_dict, strict=True)
        print("DeiT model state loaded with strict=True")
    except Exception as e:
        print(f"Strict loading failed, trying non-strict: {str(e)[:100]}...")
        try:
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            if missing_keys or unexpected_keys:
                print(f"Non-strict loading: Missing {len(missing_keys)}, Unexpected {len(unexpected_keys)}")
            else:
                print("DeiT model state loaded with strict=False (no key mismatches)")
        except Exception as e2:
            raise RuntimeError(f"Both DeiT loading approaches failed: {e2}")

    model.eval()

    # Extract training information
    training_info = {
        'best_val_f1': float(checkpoint.get('best_f1', 0.0)),
        'best_val_loss': float(checkpoint.get('best_loss', float('inf'))),
        'best_val_accuracy': float(checkpoint.get('best_acc', 0.0)),
        'best_epoch': int(checkpoint.get('epoch', 0)) + 1,
        'model_checkpoint': EVALUATION_CONFIG_CASME2_DEIT['checkpoint_file'],
        'num_classes': EVALUATION_CONFIG_CASME2_DEIT['num_classes'],
        'deit_variant': deit_variant,
        'deit_model': deit_model_name,
        'architecture_approach': 'DeiTModel + Custom Classification Head',
        'position_interpolation': deit_config.get('interpolate_pos_encoding', True),
        'config': checkpoint.get('casme2_deit_config', {})
    }

    print(f"DeiT model loaded successfully:")
    print(f"  Best validation F1: {training_info['best_val_f1']:.4f}")
    print(f"  Best validation accuracy: {training_info['best_val_accuracy']:.4f}")
    print(f"  Best epoch: {training_info['best_epoch']}")
    print(f"  DeiT variant: {training_info['deit_variant']}")
    print(f"  Position interpolation: {training_info['position_interpolation']}")
    print(f"  Model classes: {EVALUATION_CONFIG_CASME2_DEIT['num_classes']}")

    return model, training_info

def run_model_inference_casme2_deit(model, test_loader, device):
    """Run CASME II DeiT model inference with fixed architecture compatibility"""
    print("Running CASME II DeiT model inference on test set...")

    model.eval()
    all_predictions = []
    all_probabilities = []
    all_labels = []
    all_sample_ids = []
    all_emotions = []
    all_subjects = []
    all_filenames = []

    inference_start = time.time()

    with torch.no_grad():
        for batch_idx, (images, labels, sample_ids, emotions, subjects, filenames) in enumerate(
            tqdm(test_loader, desc="CASME II DeiT Inference")):

            images = images.to(device)

            # FIXED: Forward pass dengan simplified output extraction
            try:
                outputs_raw = model(images)
                outputs = extract_logits_safe_casme2_deit(outputs_raw)
            except Exception as e:
                print(f"Error in DeiT model forward pass: {e}")
                # Fallback - should not be needed with fixed architecture
                outputs = model(images)

            # Validate output shape for 7 CASME II classes
            if outputs.shape[1] != 7:
                print(f"Warning: Expected 7 classes output, got {outputs.shape[1]}")

            # Get probabilities and predictions
            probabilities = torch.softmax(outputs, dim=1)
            predictions = torch.argmax(probabilities, dim=1)

            # Store results (CPU for memory efficiency)
            all_predictions.extend(predictions.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_sample_ids.extend(sample_ids)
            all_emotions.extend(emotions)
            all_subjects.extend(subjects)
            all_filenames.extend(filenames)

    inference_time = time.time() - inference_start

    # Convert to arrays
    predictions_array = np.array(all_predictions)
    probabilities_array = np.array(all_probabilities)
    labels_array = np.array(all_labels)

    print(f"CASME II DeiT inference completed: {len(predictions_array)} samples in {inference_time:.2f}s")

    # Analyze prediction distribution
    unique_predictions, pred_counts = np.unique(predictions_array, return_counts=True)
    print(f"Predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    unique_labels, label_counts = np.unique(labels_array, return_counts=True)
    print(f"True classes in test: {[CASME2_CLASSES[i] for i in unique_labels]}")

    return {
        'predictions': predictions_array,
        'probabilities': probabilities_array,
        'labels': labels_array,
        'sample_ids': all_sample_ids,
        'emotions': all_emotions,
        'subjects': all_subjects,
        'filenames': all_filenames,
        'inference_time': inference_time,
        'samples_count': len(predictions_array)
    }

def analyze_wrong_predictions_casme2_deit(inference_results):
    """Comprehensive wrong predictions analysis for CASME II DeiT"""
    print("Analyzing wrong predictions for CASME II DeiT micro-expression recognition...")

    predictions = inference_results['predictions']
    labels = inference_results['labels']
    sample_ids = inference_results['sample_ids']
    emotions = inference_results['emotions']
    subjects = inference_results['subjects']
    filenames = inference_results['filenames']

    # Find wrong predictions
    wrong_mask = predictions != labels
    wrong_indices = np.where(wrong_mask)[0]

    # Organize by true emotion class
    wrong_predictions_by_class = {}
    subject_error_analysis = {}

    for class_name in CASME2_CLASSES:
        wrong_predictions_by_class[class_name] = []

    # Analyze wrong predictions
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        sample_id = sample_ids[idx]
        emotion = emotions[idx]
        subject = subjects[idx]
        filename = filenames[idx]

        true_class = CASME2_CLASSES[true_label]
        pred_class = CASME2_CLASSES[pred_label]

        wrong_info = {
            'sample_id': sample_id,
            'filename': filename,
            'subject': subject,
            'true_label': int(true_label),
            'true_class': true_class,
            'predicted_label': int(pred_label),
            'predicted_class': pred_class,
            'emotion': emotion
        }

        wrong_predictions_by_class[true_class].append(wrong_info)

        # Subject error tracking
        if subject not in subject_error_analysis:
            subject_error_analysis[subject] = {'total': 0, 'wrong': 0, 'errors': []}
        subject_error_analysis[subject]['wrong'] += 1
        subject_error_analysis[subject]['errors'].append(wrong_info)

    # Count total samples per subject
    for subject in subjects:
        if subject in subject_error_analysis:
            subject_error_analysis[subject]['total'] += 1
        else:
            subject_error_analysis[subject] = {'total': 1, 'wrong': 0, 'errors': []}

    # Calculate error rates per subject
    for subject in subject_error_analysis:
        total = subject_error_analysis[subject]['total']
        wrong = subject_error_analysis[subject]['wrong']
        subject_error_analysis[subject]['error_rate'] = wrong / total if total > 0 else 0.0

    # Summary statistics
    total_wrong = len(wrong_indices)
    total_samples = len(predictions)
    error_rate = (total_wrong / total_samples) * 100

    # Confusion patterns analysis
    confusion_patterns = {}
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        pattern = f"{CASME2_CLASSES[true_label]}_to_{CASME2_CLASSES[pred_label]}"
        confusion_patterns[pattern] = confusion_patterns.get(pattern, 0) + 1

    analysis_results = {
        'analysis_metadata': {
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'model_type': EVALUATION_CONFIG_CASME2_DEIT['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2_DEIT['dataset_name'],
            'architecture': EVALUATION_CONFIG_CASME2_DEIT['architecture'],
            'total_samples': int(total_samples),
            'total_wrong_predictions': int(total_wrong),
            'overall_error_rate': float(error_rate)
        },
        'wrong_predictions_by_class': wrong_predictions_by_class,
        'subject_error_analysis': subject_error_analysis,
        'confusion_patterns': confusion_patterns,
        'error_summary': {
            class_name: len(wrong_predictions_by_class[class_name])
            for class_name in CASME2_CLASSES
        }
    }

    return analysis_results

def calculate_comprehensive_metrics_casme2_deit(inference_results):
    """Calculate comprehensive evaluation metrics for CASME II DeiT micro-expression recognition"""
    print("Calculating comprehensive metrics for CASME II DeiT micro-expression recognition...")

    predictions = inference_results['predictions']
    probabilities = inference_results['probabilities']
    labels = inference_results['labels']

    if len(predictions) == 0:
        raise ValueError("No predictions to evaluate!")

    # Identify available classes in test set
    unique_test_labels = sorted(np.unique(labels))
    unique_predictions = sorted(np.unique(predictions))

    print(f"Test set contains labels: {[CASME2_CLASSES[i] for i in unique_test_labels]}")
    print(f"DeiT model predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    # Basic metrics
    accuracy = accuracy_score(labels, predictions)

    # Macro metrics (only for available classes)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions, labels=unique_test_labels, average='macro', zero_division=0
    )

    print(f"Macro F1 (available classes): {f1:.4f}")

    # Per-class metrics (all 7 classes)
    precision_per_class, recall_per_class, f1_per_class, support_per_class = precision_recall_fscore_support(
        labels, predictions, labels=range(7), average=None, zero_division=0
    )

    # Confusion matrix
    cm = confusion_matrix(labels, predictions, labels=range(7))

    # Multi-class AUC (only for classes with test samples)
    auc_scores = {}
    fpr_dict = {}
    tpr_dict = {}

    try:
        labels_binarized = label_binarize(labels, classes=range(7))

        for i, class_name in enumerate(CASME2_CLASSES):
            if i in unique_test_labels and len(np.unique(labels_binarized[:, i])) > 1:
                fpr, tpr, _ = roc_curve(labels_binarized[:, i], probabilities[:, i])
                auc_score = auc(fpr, tpr)
                auc_scores[class_name] = float(auc_score)
                fpr_dict[class_name] = fpr.tolist()
                tpr_dict[class_name] = tpr.tolist()
            else:
                auc_scores[class_name] = 0.0
                fpr_dict[class_name] = [0.0, 1.0]
                tpr_dict[class_name] = [0.0, 0.0]

        # Macro AUC for available classes
        available_auc_scores = [auc_scores[CASME2_CLASSES[i]] for i in unique_test_labels]
        macro_auc = float(np.mean(available_auc_scores)) if available_auc_scores else 0.0

    except Exception as e:
        print(f"Warning: AUC calculation failed: {e}")
        auc_scores = {class_name: 0.0 for class_name in CASME2_CLASSES}
        macro_auc = 0.0

    # Subject-level analysis
    subjects = inference_results['subjects']
    subject_performance = {}

    for subject in set(subjects):
        subject_mask = [s == subject for s in subjects]
        subject_predictions = predictions[subject_mask]
        subject_labels = labels[subject_mask]

        if len(subject_predictions) > 0:
            subject_acc = accuracy_score(subject_labels, subject_predictions)
            subject_performance[subject] = {
                'accuracy': float(subject_acc),
                'samples': int(len(subject_predictions)),
                'correct': int(np.sum(subject_predictions == subject_labels))
            }

    # Comprehensive results
    comprehensive_results = {
        'evaluation_metadata': {
            'model_type': EVALUATION_CONFIG_CASME2_DEIT['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2_DEIT['dataset_name'],
            'architecture': EVALUATION_CONFIG_CASME2_DEIT['architecture'],
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'num_classes': EVALUATION_CONFIG_CASME2_DEIT['num_classes'],
            'class_names': EVALUATION_CONFIG_CASME2_DEIT['class_names'],
            'test_samples': int(len(labels)),
            'available_classes': [CASME2_CLASSES[i] for i in unique_test_labels],
            'missing_classes': [CASME2_CLASSES[i] for i in range(7) if i not in unique_test_labels]
        },

        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_precision': float(precision),
            'macro_recall': float(recall),
            'macro_f1': float(f1),
            'macro_auc': macro_auc
        },

        'per_class_performance': {},

        'confusion_matrix': cm.tolist(),

        'subject_level_performance': subject_performance,

        'roc_analysis': {
            'auc_scores': auc_scores,
            'fpr_curves': fpr_dict,
            'tpr_curves': tpr_dict
        },

        'inference_performance': {
            'total_time_seconds': float(inference_results['inference_time']),
            'average_time_ms_per_sample': float(inference_results['inference_time'] * 1000 / len(labels))
        }
    }

    # Per-class performance details
    for i, class_name in enumerate(CASME2_CLASSES):
        comprehensive_results['per_class_performance'][class_name] = {
            'precision': float(precision_per_class[i]),
            'recall': float(recall_per_class[i]),
            'f1_score': float(f1_per_class[i]),
            'support': int(support_per_class[i]),
            'auc': auc_scores[class_name],
            'in_test_set': i in unique_test_labels
        }

    return comprehensive_results

def save_evaluation_results_casme2_deit(evaluation_results, wrong_predictions_results, results_dir):
    """Save comprehensive evaluation results for CASME II DeiT"""
    os.makedirs(results_dir, exist_ok=True)

    # Save main evaluation results
    results_file = f"{results_dir}/casme2_deit_direct_evaluation_results.json"
    with open(results_file, 'w') as f:
        json.dump(evaluation_results, f, indent=2, default=str)

    # Save wrong predictions analysis
    wrong_predictions_file = f"{results_dir}/casme2_deit_direct_wrong_predictions.json"
    with open(wrong_predictions_file, 'w') as f:
        json.dump(wrong_predictions_results, f, indent=2, default=str)

    print(f"DeiT evaluation results saved:")
    print(f"  Main results: {results_file}")
    print(f"  Wrong predictions: {wrong_predictions_file}")

    return results_file, wrong_predictions_file

# Main evaluation execution for DeiT
try:
    print("Starting CASME II DeiT Direct Baseline comprehensive evaluation...")

    # Create test dataset
    print("Creating CASME II test dataset for DeiT...")
    casme2_test_dataset = CASME2DatasetEvaluationDeiT(
        split_metadata=GLOBAL_CONFIG_CASME2['metadata'],
        dataset_root=GLOBAL_CONFIG_CASME2['test_path'].replace('/test', ''),
        transform=GLOBAL_CONFIG_CASME2['transform_val'],
        split='test',
        use_ram_cache=True
    )

    if len(casme2_test_dataset) == 0:
        raise ValueError("No test samples found! Check test data path.")

    casme2_test_loader = DataLoader(
        casme2_test_dataset,
        batch_size=CASME2_DEIT_CONFIG['batch_size'],
        shuffle=False,
        num_workers=CASME2_DEIT_CONFIG['num_workers'],
        pin_memory=True
    )

    # Load trained DeiT model
    checkpoint_path = f"{GLOBAL_CONFIG_CASME2['checkpoint_root']}/{EVALUATION_CONFIG_CASME2_DEIT['checkpoint_file']}"
    casme2_deit_model, training_info = load_trained_model_casme2_deit(checkpoint_path, GLOBAL_CONFIG_CASME2['device'])

    # Run DeiT inference
    inference_results = run_model_inference_casme2_deit(casme2_deit_model, casme2_test_loader, GLOBAL_CONFIG_CASME2['device'])

    # Calculate comprehensive metrics
    evaluation_results = calculate_comprehensive_metrics_casme2_deit(inference_results)

    # Analyze wrong predictions
    wrong_predictions_results = analyze_wrong_predictions_casme2_deit(inference_results)

    # Add training information
    evaluation_results['training_information'] = training_info

    # Save results
    results_dir = f"{GLOBAL_CONFIG_CASME2['results_root']}/evaluation_results"
    results_file, wrong_file = save_evaluation_results_casme2_deit(
        evaluation_results, wrong_predictions_results, results_dir
    )

    # Display comprehensive results
    print("\n" + "=" * 65)
    print("CASME II DEIT TRANSFORMER DIRECT BASELINE EVALUATION RESULTS")
    print("=" * 65)

    # Overall performance
    overall = evaluation_results['overall_performance']
    print(f"Overall Performance (Macro - Available Classes):")
    print(f"  Accuracy:  {overall['accuracy']:.4f}")
    print(f"  Precision: {overall['macro_precision']:.4f}")
    print(f"  Recall:    {overall['macro_recall']:.4f}")
    print(f"  F1 Score:  {overall['macro_f1']:.4f}")
    print(f"  AUC:       {overall['macro_auc']:.4f}")

    # Per-class performance
    print(f"\nPer-Class Performance:")
    for class_name, metrics in evaluation_results['per_class_performance'].items():
        in_test = "Present" if metrics['in_test_set'] else "Missing"
        print(f"  {class_name} [{in_test}]: F1={metrics['f1_score']:.4f}, "
              f"AUC={metrics['auc']:.4f}, Support={metrics['support']}")

    # Training vs test comparison
    print(f"\nTraining vs Test Performance:")
    training_f1 = training_info['best_val_f1']
    training_acc = training_info['best_val_accuracy']
    test_f1 = overall['macro_f1']
    test_acc = overall['accuracy']

    print(f"  Training Val F1:  {training_f1:.4f}")
    print(f"  Test F1:          {test_f1:.4f}")
    print(f"  F1 Difference:    {training_f1 - test_f1:+.4f}")
    print(f"  Training Val Acc: {training_acc:.4f}")
    print(f"  Test Accuracy:    {test_acc:.4f}")
    print(f"  Acc Difference:   {training_acc - test_acc:+.4f}")
    print(f"  Best Epoch:       {training_info['best_epoch']}")
    print(f"  DeiT Variant:     {training_info['deit_variant']}")
    print(f"  Architecture:     {training_info['architecture_approach']}")

    # Wrong predictions summary
    print(f"\n" + "=" * 40)
    print("WRONG PREDICTIONS ANALYSIS")
    print("=" * 40)

    wrong_meta = wrong_predictions_results['analysis_metadata']
    print(f"Total wrong predictions: {wrong_meta['total_wrong_predictions']} / {wrong_meta['total_samples']}")
    print(f"Overall error rate: {wrong_meta['overall_error_rate']:.2f}%")

    print(f"\nErrors by True Class:")
    for class_name, error_count in wrong_predictions_results['error_summary'].items():
        if error_count > 0:
            wrong_samples = wrong_predictions_results['wrong_predictions_by_class'][class_name]
            print(f"  {class_name}: {error_count} errors")
            for sample in wrong_samples[:3]:  # Show first 3
                print(f"    - {sample['filename']} -> predicted as {sample['predicted_class']}")
            if len(wrong_samples) > 3:
                print(f"    ... and {len(wrong_samples) - 3} more")

    # Subject-level analysis
    print(f"\nSubject-Level Performance:")
    subject_perfs = list(evaluation_results['subject_level_performance'].items())
    subject_perfs.sort(key=lambda x: x[1]['accuracy'], reverse=True)
    for subject, perf in subject_perfs[:5]:  # Show top 5
        print(f"  {subject}: {perf['accuracy']:.3f} ({perf['correct']}/{perf['samples']})")

    # Most common confusion patterns
    print(f"\nMost Common Confusion Patterns:")
    patterns = sorted(wrong_predictions_results['confusion_patterns'].items(),
                     key=lambda x: x[1], reverse=True)
    for pattern, count in patterns[:3]:
        print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    print(f"  Total time: {inference_results['inference_time']:.2f}s")
    print(f"  Speed: {evaluation_results['inference_performance']['average_time_ms_per_sample']:.1f} ms/sample")

    print(f"\nMissing Classes: {evaluation_results['evaluation_metadata']['missing_classes']}")

    print(f"\nArchitecture Features:")
    print(f"  Fixed Architecture: DeiTModel + Custom Classification Head")
    print(f"  Position Interpolation: {training_info['position_interpolation']}")
    print(f"  Medical-Proven Approach: Verified compatibility with 384px input")

    print("\n" + "=" * 65)
    print("CASME II DEIT TRANSFORMER DIRECT BASELINE EVALUATION COMPLETED")
    print("=" * 65)

except Exception as e:
    print(f"DeiT evaluation failed: {e}")
    import traceback
    traceback.print_exc()

finally:
    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

print("Next: Cell 4 - DeiT Confusion Matrix Analysis and Visualization")

CASME II DeiT Direct Baseline Evaluation Framework
Model: DeiT_CASME2_Direct_Baseline
Task: micro_expression_recognition
Classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
Input size: 384x384
Architecture: deit_transformer_distillation
Starting CASME II DeiT Direct Baseline comprehensive evaluation...
Creating CASME II test dataset for DeiT...
Loading CASME II test dataset for DeiT evaluation...
Loaded 28 CASME II test samples for DeiT evaluation
Test set class distribution:
  others: 10 samples (35.7%)
  disgust: 7 samples (25.0%)
  happiness: 4 samples (14.3%)
  repression: 3 samples (10.7%)
  surprise: 3 samples (10.7%)
  sadness: 1 samples (3.6%)
Test set covers 16 subjects
Missing classes in test set: ['fear']
Preloading 28 test images to RAM for DeiT evaluation...
Test RAM caching completed: 28 valid images, ~0.05GB
Loading trained CASME II DeiT model from: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/02

CASME II DeiT Inference: 100%|██████████| 2/2 [00:01<00:00,  1.88it/s]

CASME II DeiT inference completed: 28 samples in 1.06s
Predicted classes: ['others', 'disgust', 'happiness', 'repression', 'sadness']
True classes in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Calculating comprehensive metrics for CASME II DeiT micro-expression recognition...
Test set contains labels: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
DeiT model predicted classes: ['others', 'disgust', 'happiness', 'repression', 'sadness']
Macro F1 (available classes): 0.2554
Analyzing wrong predictions for CASME II DeiT micro-expression recognition...
DeiT evaluation results saved:
  Main results: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/02_04_deit_casme2-af/evaluation_results/casme2_deit_direct_evaluation_results.json
  Wrong predictions: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/02_04_deit_casme2-af/evaluation_results/casme2_deit_direct_wrong




In [None]:
# @title Cell 4: CASME II DeiT Direct Baseline Confusion Matrix Generation

# File: 02_04_DeiT_Direct_Baseline_Cell4.py
# Location: experiments/02_04_DeiT_Direct_Baseline.ipynb
# Purpose: Generate professional confusion matrix and comprehensive analysis for CASME II DeiT micro-expression recognition
# Dependencies: Trained model evaluation results from Cell 3

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II DeiT Direct Baseline Confusion Matrix Generation")
print("=" * 65)

# Project paths configuration - DeiT specific
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/02_04_deit_casme2-af"

def find_evaluation_json_files_casme2_deit(results_path):
    """Find CASME II DeiT evaluation JSON files"""
    json_files = {}

    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        # Look for main DeiT evaluation results
        eval_files = glob.glob(f"{eval_dir}/casme2_deit_direct_evaluation_results.json")
        if eval_files:
            json_files['main'] = eval_files[0]
            print(f"Found CASME II DeiT evaluation file: {os.path.basename(eval_files[0])}")

        # Look for DeiT wrong predictions analysis
        wrong_files = glob.glob(f"{eval_dir}/casme2_deit_direct_wrong_predictions.json")
        if wrong_files:
            json_files['wrong_predictions'] = wrong_files[0]
            print(f"Found DeiT wrong predictions file: {os.path.basename(wrong_files[0])}")

        if not json_files:
            print(f"WARNING: No DeiT evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (DeiT evaluation) has been executed first!")
    else:
        print(f"ERROR: DeiT evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results_casme2_deit(json_path):
    """Load and parse CASME II DeiT evaluation results JSON"""
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded DeiT evaluation results from: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1_casme2_deit(per_class_performance):
    """Calculate weighted F1 score for CASME II DeiT micro-expression classes"""
    # Only count classes that have test samples (support > 0)
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:  # Only include classes with test samples
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy_casme2_deit(confusion_matrix):
    """
    Calculate balanced accuracy for CASME II DeiT 7-class micro-expression recognition
    Handles classes with zero support (missing in test set)
    """
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []

    # Find classes with actual test samples
    classes_with_samples = []
    for i in range(n_classes):
        if cm[i, :].sum() > 0:  # Class has test samples
            classes_with_samples.append(i)

    for i in classes_with_samples:
        # True positives, false negatives, false positives, true negatives for class i
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp  # Sum of row i minus diagonal
        fp = cm[:, i].sum() - tp  # Sum of column i minus diagonal
        tn = cm.sum() - tp - fn - fp  # Total minus TP, FN, FP

        # Calculate sensitivity and specificity for class i
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        # Per-class balanced accuracy
        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    # Overall balanced accuracy (mean of available classes)
    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color_casme2_deit(color_value, threshold=0.5):
    """Determine optimal text color based on background intensity"""
    return 'white' if color_value > threshold else 'black'

def analyze_missing_classes_casme2_deit(data):
    """Analyze missing classes in CASME II DeiT test set"""
    meta = data['evaluation_metadata']
    available_classes = meta.get('available_classes', [])
    missing_classes = meta.get('missing_classes', [])
    training_info = data.get('training_information', {})
    deit_variant = training_info.get('deit_variant', 'unknown')
    distillation_token = training_info.get('distillation_token', False)

    print(f"DeiT Analysis:")
    print(f"  Variant: {deit_variant}")
    print(f"  Distillation Token: {distillation_token}")
    print(f"  Available in test: {available_classes}")
    print(f"  Missing from test: {missing_classes}")

    return {
        'available': available_classes,
        'missing': missing_classes,
        'total_classes': len(meta['class_names']),
        'deit_variant': deit_variant,
        'distillation_token': distillation_token
    }

def create_confusion_matrix_plot_casme2_deit(data, output_path):
    """Create professional confusion matrix visualization for CASME II DeiT micro-expression recognition"""

    # Extract data
    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']
    training_info = data.get('training_information', {})
    deit_variant = training_info.get('deit_variant', 'unknown')
    distillation_token = training_info.get('distillation_token', False)

    print(f"Processing DeiT confusion matrix for CASME II classes: {class_names}")
    print(f"DeiT variant: {deit_variant}")
    print(f"Distillation token: {distillation_token}")
    print(f"Confusion matrix shape: {cm.shape}")

    # Calculate comprehensive metrics
    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1_casme2_deit(per_class)
    balanced_acc = calculate_balanced_accuracy_casme2_deit(cm)

    print(f"Calculated metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Balanced Acc: {balanced_acc:.4f}, Accuracy: {accuracy:.4f}")

    # Row-wise normalization for percentage display
    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    # Create visualization with appropriate size for 7 classes
    fig, ax = plt.subplots(figsize=(12, 10))

    # Color scheme optimized for DeiT micro-expression research
    cmap = 'Blues'

    # Create heatmap with improved color scaling
    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    # Annotate cells with count and percentage
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            # Handle percentage calculation for classes with 0 samples
            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"  # Class has no test samples

            # Determine text color based on cell intensity
            cell_value = cm_pct[i, j]
            text_color = determine_text_color_casme2_deit(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    # Configure axes with micro-expression class names
    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    # Add note about missing classes and DeiT architecture
    missing_classes = meta.get('missing_classes', [])
    note_lines = []
    if missing_classes:
        note_lines.append(f"Missing classes: {', '.join(missing_classes)}")

    # DeiT-specific architecture note
    distillation_text = "With Distillation Token" if distillation_token else "No Distillation Token"
    note_lines.append(f"DeiT-{deit_variant.capitalize()} | {distillation_text}")

    if note_lines:
        note_text = "\n".join(note_lines)
        ax.text(0.02, 0.98, note_text, transform=ax.transAxes, fontsize=9,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    # Create comprehensive title for DeiT micro-expression research
    title = f"CASME II DeiT ({deit_variant.capitalize()}) Micro-Expression Recognition\n"
    title += f"Acc: {accuracy:.4f}  |  Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"DeiT confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes,
        'deit_variant': deit_variant,
        'distillation_token': distillation_token
    }

def create_per_class_performance_chart_casme2_deit(data, output_path):
    """Create per-class performance visualization for CASME II DeiT"""
    per_class = data['per_class_performance']
    class_names = data['evaluation_metadata']['class_names']
    training_info = data.get('training_information', {})
    deit_variant = training_info.get('deit_variant', 'unknown')
    distillation_token = training_info.get('distillation_token', False)

    # Extract metrics for each class
    classes = []
    f1_scores = []
    precisions = []
    recalls = []
    supports = []
    in_test_flags = []

    for class_name in class_names:
        class_data = per_class[class_name]
        classes.append(class_name)
        f1_scores.append(class_data['f1_score'])
        precisions.append(class_data['precision'])
        recalls.append(class_data['recall'])
        supports.append(class_data['support'])
        in_test_flags.append(class_data['in_test_set'])

    # Create grouped bar chart
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

    x = np.arange(len(classes))
    width = 0.25

    # Top plot: F1, Precision, Recall
    bars1 = ax1.bar(x - width, f1_scores, width, label='F1 Score', alpha=0.8, color='steelblue')
    bars2 = ax1.bar(x, precisions, width, label='Precision', alpha=0.8, color='orange')
    bars3 = ax1.bar(x + width, recalls, width, label='Recall', alpha=0.8, color='green')

    ax1.set_xlabel('Emotion Classes', fontweight='bold')
    ax1.set_ylabel('Score', fontweight='bold')

    # DeiT-specific title
    distillation_text = "with Distillation Token" if distillation_token else "without Distillation Token"
    title = f'CASME II DeiT ({deit_variant.capitalize()}) Per-Class Performance - {distillation_text}'
    ax1.set_title(title, fontweight='bold', pad=20)

    ax1.set_xticks(x)
    ax1.set_xticklabels(classes, rotation=45, ha='right')
    ax1.legend()
    ax1.grid(axis='y', alpha=0.3)
    ax1.set_ylim(0, 1.0)

    # Add value labels on bars
    for bars in [bars1, bars2, bars3]:
        for bar, in_test in zip(bars, in_test_flags):
            height = bar.get_height()
            if in_test:
                ax1.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)
            else:
                # Mark missing classes
                ax1.annotate('N/A', xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3), textcoords="offset points", ha='center', va='bottom',
                           fontsize=8, color='red', fontweight='bold')

    # Bottom plot: Support (sample count)
    bars4 = ax2.bar(x, supports, color='purple', alpha=0.7)
    ax2.set_xlabel('Emotion Classes', fontweight='bold')
    ax2.set_ylabel('Number of Test Samples', fontweight='bold')
    ax2.set_title('CASME II Test Set Class Distribution (DeiT)', fontweight='bold')
    ax2.set_xticks(x)
    ax2.set_xticklabels(classes, rotation=45, ha='right')
    ax2.grid(axis='y', alpha=0.3)

    # Add value labels on support bars
    for bar, support in zip(bars4, supports):
        height = bar.get_height()
        ax2.annotate(f'{int(support)}', xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=10)

    # Add architecture note with DeiT-specific information
    distillation_info = "Teacher-Student Knowledge Transfer" if distillation_token else "Standard Vision Transformer"
    fig.text(0.02, 0.02, f"Architecture: DeiT-{deit_variant.capitalize()} | {distillation_info}",
             fontsize=9, style='italic', alpha=0.7)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"DeiT per-class performance chart saved to: {os.path.basename(output_path)}")

def generate_performance_summary_casme2_deit(evaluation_data, wrong_predictions_data=None):
    """Generate comprehensive performance summary for CASME II DeiT"""

    print("\n" + "=" * 65)
    print("CASME II DEIT TRANSFORMER MICRO-EXPRESSION RECOGNITION PERFORMANCE SUMMARY")
    print("=" * 65)

    # Overall performance
    overall = evaluation_data['overall_performance']
    meta = evaluation_data['evaluation_metadata']
    training_info = evaluation_data.get('training_information', {})
    deit_variant = training_info.get('deit_variant', 'unknown')
    distillation_token = training_info.get('distillation_token', False)

    print(f"Dataset: {meta['dataset']}")
    print(f"Test samples: {meta['test_samples']}")
    print(f"Model: {meta['model_type']}")
    print(f"Architecture: DeiT-{deit_variant.capitalize()} with {'Distillation Token' if distillation_token else 'Standard Processing'}")
    print(f"Evaluation date: {meta['evaluation_timestamp']}")

    print(f"\nOverall Performance:")
    print(f"  Accuracy:         {overall['accuracy']:.4f}")
    print(f"  Macro Precision:  {overall['macro_precision']:.4f}")
    print(f"  Macro Recall:     {overall['macro_recall']:.4f}")
    print(f"  Macro F1:         {overall['macro_f1']:.4f}")
    print(f"  Macro AUC:        {overall['macro_auc']:.4f}")

    # Per-class performance
    print(f"\nPer-Class Performance:")
    per_class = evaluation_data['per_class_performance']

    print(f"{'Class':<12} {'F1':<8} {'Precision':<10} {'Recall':<8} {'AUC':<8} {'Support':<8} {'In Test'}")
    print("-" * 65)

    for class_name, metrics in per_class.items():
        in_test = "Yes" if metrics['in_test_set'] else "No"
        print(f"{class_name:<12} {metrics['f1_score']:<8.4f} {metrics['precision']:<10.4f} "
              f"{metrics['recall']:<8.4f} {metrics['auc']:<8.4f} {metrics['support']:<8} {in_test}")

    # Training vs test performance
    if 'training_information' in evaluation_data:
        training = evaluation_data['training_information']
        print(f"\nTraining vs Test Comparison:")
        print(f"  Training Val F1:  {training['best_val_f1']:.4f}")
        print(f"  Test F1:          {overall['macro_f1']:.4f}")
        print(f"  Performance Gap:  {training['best_val_f1'] - overall['macro_f1']:+.4f}")
        print(f"  Best Epoch:       {training['best_epoch']}")
        print(f"  DeiT Variant:     {training['deit_variant']}")

    # Class imbalance analysis
    missing_classes = meta.get('missing_classes', [])
    available_classes = meta.get('available_classes', [])

    print(f"\nClass Availability Analysis:")
    print(f"  Available classes: {len(available_classes)}/7")
    print(f"  Missing classes: {missing_classes if missing_classes else 'None'}")

    # Architecture-specific information
    print(f"\nDeiT Architecture Details:")
    print(f"  Variant: {deit_variant}")
    print(f"  Distillation Token: {'Enabled' if distillation_token else 'Disabled'}")
    if distillation_token:
        print(f"  Knowledge Transfer: CNN Teacher -> DeiT Student")
        print(f"  Token Processing: CLS + Distillation + Patch Tokens")
        print(f"  Training Strategy: Data-Efficient Learning")
    else:
        print(f"  Token Processing: Standard CLS + Patch Tokens")
        print(f"  Training Strategy: Standard Vision Transformer")

    # Wrong predictions summary if available
    if wrong_predictions_data:
        wrong_meta = wrong_predictions_data['analysis_metadata']
        print(f"\nError Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']}/{wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        # Top confusion patterns
        patterns = wrong_predictions_data.get('confusion_patterns', {})
        if patterns:
            print(f"\nTop Confusion Patterns:")
            sorted_patterns = sorted(patterns.items(), key=lambda x: x[1], reverse=True)[:3]
            for pattern, count in sorted_patterns:
                print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    inference = evaluation_data['inference_performance']
    print(f"  Total time: {inference['total_time_seconds']:.2f}s")
    print(f"  Speed: {inference['average_time_ms_per_sample']:.1f} ms/sample")

    # Distillation-specific insights
    if distillation_token:
        print(f"\nDistillation Token Analysis:")
        print(f"  Teacher-Student Knowledge Transfer: Active")
        print(f"  Dual-Head Classification: CLS + Distillation tokens")
        print(f"  Data Efficiency: Enhanced through distillation")

# Find DeiT evaluation JSON files
json_files = find_evaluation_json_files_casme2_deit(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No DeiT evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (DeiT evaluation) has been executed first!")
else:
    print(f"Found {len(json_files)} DeiT evaluation file(s)")

# Create output directory
output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Process DeiT evaluation results
results_summary = {}
generated_files = []

if 'main' in json_files:
    # Load main evaluation data
    eval_data = load_evaluation_results_casme2_deit(json_files['main'])

    # Load wrong predictions data if available
    wrong_data = None
    if 'wrong_predictions' in json_files:
        wrong_data = load_evaluation_results_casme2_deit(json_files['wrong_predictions'])

    if eval_data is not None:
        try:
            # Analyze missing classes for DeiT
            class_analysis = analyze_missing_classes_casme2_deit(eval_data)

            # Generate DeiT confusion matrix
            cm_output_path = os.path.join(output_dir, "confusion_matrix_CASME2_DeiT_Direct.png")
            metrics = create_confusion_matrix_plot_casme2_deit(eval_data, cm_output_path)
            generated_files.append(cm_output_path)

            # Generate DeiT per-class performance chart
            perf_output_path = os.path.join(output_dir, "per_class_performance_CASME2_DeiT_Direct.png")
            create_per_class_performance_chart_casme2_deit(eval_data, perf_output_path)
            generated_files.append(perf_output_path)

            results_summary['casme2_deit'] = metrics
            results_summary['casme2_deit']['class_analysis'] = class_analysis

            print(f"SUCCESS: DeiT visualization files generated successfully")

        except Exception as e:
            print(f"ERROR: Failed to generate DeiT visualizations: {str(e)}")
            import traceback
            traceback.print_exc()

        # Generate comprehensive DeiT summary
        generate_performance_summary_casme2_deit(eval_data, wrong_data)

    else:
        print("ERROR: Could not load DeiT evaluation data")
else:
    print("ERROR: No main DeiT evaluation results found")

# Final summary
if generated_files:
    print(f"\n" + "=" * 65)
    print("CASME II DEIT TRANSFORMER CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 65)

    print(f"Generated visualization files:")
    for file_path in generated_files:
        filename = os.path.basename(file_path)
        print(f"  {filename}")

    if 'casme2_deit' in results_summary:
        casme2_deit = results_summary['casme2_deit']
        print(f"\nFinal Performance Summary:")
        print(f"  Architecture:      DeiT-{casme2_deit['deit_variant'].capitalize()}")
        print(f"  Distillation:      {'Enabled' if casme2_deit['distillation_token'] else 'Disabled'}")
        print(f"  Accuracy:          {casme2_deit['accuracy']:.4f}")
        print(f"  Macro F1:          {casme2_deit['macro_f1']:.4f}")
        print(f"  Weighted F1:       {casme2_deit['weighted_f1']:.4f}")
        print(f"  Balanced Acc:      {casme2_deit['balanced_accuracy']:.4f}")

        if 'class_analysis' in casme2_deit:
            analysis = casme2_deit['class_analysis']
            print(f"  Available classes: {len(analysis['available'])}/{analysis['total_classes']}")
            print(f"  Missing classes:   {len(analysis['missing'])}")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No DeiT visualizations were generated")
    print("Please check:")
    print("1. Cell 3 DeiT evaluation results exist")
    print("2. JSON file structure is correct")
    print("3. No file permission issues")

print("\nCell 4 completed - CASME II DeiT confusion matrix analysis generated")
print("Ready for multi-architecture comparative analysis (ViT vs Swin vs EfficientViT vs DeiT)")

CASME II DeiT Direct Baseline Confusion Matrix Generation
Found CASME II DeiT evaluation file: casme2_deit_direct_evaluation_results.json
Found DeiT wrong predictions file: casme2_deit_direct_wrong_predictions.json
Found 2 DeiT evaluation file(s)
Successfully loaded DeiT evaluation results from: casme2_deit_direct_evaluation_results.json
Successfully loaded DeiT evaluation results from: casme2_deit_direct_wrong_predictions.json
DeiT Analysis:
  Variant: base
  Distillation Token: False
  Available in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
  Missing from test: ['fear']
Processing DeiT confusion matrix for CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
DeiT variant: base
Distillation token: False
Confusion matrix shape: (7, 7)
Calculated metrics - Macro F1: 0.2554, Weighted F1: 0.3719, Balanced Acc: 0.5721, Accuracy: 0.3929
DeiT confusion matrix saved to: confusion_matrix_CASME2_DeiT_Direct.png
DeiT