In [None]:
# @title Cell 1: CASME II Swin Transformer Infrastructure Configuration

# File: 02_02_SwinT_Direct_Optimized_Baseline_Cell1.py
# Location: experiments/02_02_SwinT_Direct_Baseline.ipynb
# Purpose: Optimized Swin Transformer for CASME II micro-expression recognition with hierarchical vision transformer and advanced class weight optimization

# Mount Google Drive
from google.colab import drive
print("=" * 60)
print("CASME II SWIN TRANSFORMER OPTIMIZED BASELINE INFRASTRUCTURE")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import timm
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Project paths configuration - updated for Swin Transformer
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/data_split_v1"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/02_02_swint_casme2-af"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/02_02_swint_casme2-af"

# Load CASME II dataset metadata - preserved existing paths
METADATA_TRAIN = f"{DATASET_ROOT}/split_metadata.json"
PROCESSING_SUMMARY = f"{DATASET_ROOT}/processing_summary.json"

print("CASME II Swin Transformer Optimized Baseline - Infrastructure Configuration")
print("=========================================================================")

# Load dataset metadata
print("Loading CASME II dataset metadata...")
with open(METADATA_TRAIN, 'r') as f:
    casme2_metadata = json.load(f)

with open(PROCESSING_SUMMARY, 'r') as f:
    processing_info = json.load(f)

print(f"Dataset: {processing_info['dataset']}")
print(f"Total samples: {processing_info['total_samples']}")
print(f"Split strategy: {processing_info['split_strategy']}")

# =====================================================
# ADVANCED EXPERIMENT CONFIGURATION - Swin Transformer Optimized Parameters
# =====================================================

# FOCAL LOSS CONFIGURATION - Toggle and Advanced Parameters
USE_FOCAL_LOSS = True  # Set True to enable Focal Loss, False for CrossEntropy
FOCAL_LOSS_GAMMA = 2.0  # Focal loss focusing parameter (typically 1.0 - 3.0)

# OPTIMIZED CLASS WEIGHTS CONFIGURATION - Inverse Square Root Frequency Approach
# CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
# Train distribution: [99, 63, 32, 27, 25, 7, 2] - inverse sqrt frequency approach

# CrossEntropy Loss - Optimized inverse square root frequency weights
CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.25, 1.76, 1.91, 1.99, 3.76, 7.04]

# Focal Loss - Normalized per-class alpha values (sum = 1.0)
FOCAL_LOSS_ALPHA_WEIGHTS = [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]

# SWIN TRANSFORMER MODEL CONFIGURATION - Support Tiny and Base variants
SWIN_MODEL_VARIANT = 'base'  # Options: 'tiny' or 'base'

# Dynamic Swin model selection based on variant
if SWIN_MODEL_VARIANT == 'tiny':
    SWIN_MODEL_NAME = 'microsoft/swin-tiny-patch4-window7-224'
    EXPECTED_HIDDEN_DIM = 768
    WINDOW_SIZE = 7
    PATCH_SIZE = 4
    print("Using Swin-Tiny for efficient hierarchical micro-expression analysis")
elif SWIN_MODEL_VARIANT == 'base':
    SWIN_MODEL_NAME = 'microsoft/swin-base-patch4-window7-224'
    EXPECTED_HIDDEN_DIM = 1024
    WINDOW_SIZE = 7
    PATCH_SIZE = 4
    print("Using Swin-Base for advanced hierarchical micro-expression recognition")
else:
    raise ValueError(f"Unsupported SWIN_MODEL_VARIANT: {SWIN_MODEL_VARIANT}")

# Display experiment configuration
print("\n" + "=" * 50)
print("OPTIMIZED EXPERIMENT CONFIGURATION")
print("=" * 50)
print(f"Loss Function: {'Focal Loss' if USE_FOCAL_LOSS else 'CrossEntropy Loss'}")
if USE_FOCAL_LOSS:
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Alpha Weights (per-class): {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum Validation: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Class Weights (inverse sqrt freq): {CROSSENTROPY_CLASS_WEIGHTS}")
print(f"Swin Model: {SWIN_MODEL_NAME}")
print(f"Window Size: {WINDOW_SIZE}x{WINDOW_SIZE}")
print(f"Patch Size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"Expected Hidden Dim: {EXPECTED_HIDDEN_DIM}")
print("=" * 50)

# Enhanced GPU configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

# Hardware-optimized batch size for 384px input with Swin Transformer
if 'A100' in gpu_name:
    BATCH_SIZE = 20
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("A100: Optimized batch size for Swin Transformer 384px")
elif 'L4' in gpu_name:
    BATCH_SIZE = 14
    NUM_WORKERS = 6
    torch.backends.cudnn.benchmark = True
    print("L4: Balanced performance configuration for Swin Transformer")
else:
    BATCH_SIZE = 8
    NUM_WORKERS = 4
    print("Default GPU: Conservative settings for Swin Transformer")

# CASME II class mapping and analysis - preserved existing structure
CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

# Analyze class distribution from metadata
train_dist = casme2_metadata['train']['class_distribution']
val_dist = casme2_metadata['val']['class_distribution']
test_dist = casme2_metadata['test']['class_distribution']

print(f"\nTrain distribution: {train_dist}")
print(f"Validation distribution: {val_dist}")
print(f"Test distribution: {test_dist}")

# Apply optimized class weights based on loss function selection
if USE_FOCAL_LOSS:
    # For Focal Loss - use normalized alpha weights (per-class importance)
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    # For CrossEntropy - use inverse sqrt frequency weights
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

# CASME II Swin Transformer Optimized Configuration
CASME2_SWIN_CONFIG = {
    # Architecture configuration - Swin Transformer specific
    'swin_model': SWIN_MODEL_NAME,
    'swin_variant': SWIN_MODEL_VARIANT,
    'window_size': WINDOW_SIZE,
    'patch_size': PATCH_SIZE,
    'input_size': 384,
    'num_classes': 7,
    'dropout_rate': 0.2,
    'expected_hidden_dim': EXPECTED_HIDDEN_DIM,
    'interpolate_pos_encoding': True,

    # Training configuration (proven effective from medical imaging)
    'learning_rate': 1e-5,
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'device': device,

    # Scheduler configuration
    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_factor': 0.5,
    'scheduler_patience': 3,
    'scheduler_min_lr': 1e-6,
    'scheduler_monitor': 'val_f1_macro',

    # Optimized loss configuration
    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS,
    'class_weights': class_weights,

    # Evaluation configuration
    'use_macro_avg': True,
    'early_stopping': False,
    'save_best_f1': True,
    'save_strategy': 'best_only'
}

print(f"\nSwin Transformer Configuration Summary:")
print(f"  Model: {CASME2_SWIN_CONFIG['swin_model']}")
print(f"  Variant: {CASME2_SWIN_CONFIG['swin_variant']}")
print(f"  Input size: {CASME2_SWIN_CONFIG['input_size']}px")
print(f"  Window size: {CASME2_SWIN_CONFIG['window_size']}")
print(f"  Learning rate: {CASME2_SWIN_CONFIG['learning_rate']}")
print(f"  Batch size: {BATCH_SIZE}")

# =====================================================
# ADVANCED FOCAL LOSS IMPLEMENTATION - Per-Class Alpha Support (Preserved from ViT)
# =====================================================

class OptimizedFocalLoss(nn.Module):
    """
    Advanced Focal Loss implementation with per-class alpha support
    Paper: "Focal Loss for Dense Object Detection" (Lin et al., 2017)

    Enhanced Formula: FL(p_t) = -α_t(1-p_t)^γ log(p_t)

    Args:
        alpha (list/tensor): Per-class alpha weights (must sum to 1.0)
        gamma (float): Focusing parameter for hard examples (default: 2.0)
        reduction (str): Reduction method ('mean', 'sum', 'none')
    """

    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()

        if alpha is not None:
            if isinstance(alpha, list):
                self.alpha = torch.tensor(alpha, dtype=torch.float32)
            else:
                self.alpha = alpha

            # Validation: alpha should sum to 1.0 for proper normalization
            alpha_sum = self.alpha.sum().item()
            if abs(alpha_sum - 1.0) > 0.01:
                print(f"Warning: Alpha weights sum to {alpha_sum:.3f}, expected 1.0")

        else:
            self.alpha = None

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Calculate cross entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')

        # Calculate p_t (probability of true class)
        pt = torch.exp(-ce_loss)

        # Apply per-class alpha if provided
        if self.alpha is not None:
            if self.alpha.device != targets.device:
                self.alpha = self.alpha.to(targets.device)
            alpha_t = self.alpha.gather(0, targets)
        else:
            alpha_t = 1.0

        # Apply focal loss formula: α_t(1-p_t)^γ * CE_loss
        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Swin Transformer Architecture for CASME II - Hierarchical Vision Transformer
class SwinCASME2Baseline(nn.Module):
    """Swin Transformer baseline for CASME II micro-expression recognition with hierarchical features"""

    def __init__(self, num_classes, dropout_rate=0.2):
        super(SwinCASME2Baseline, self).__init__()

        # Hugging Face Swin Transformer model
        from transformers import SwinModel

        self.swin = SwinModel.from_pretrained(
            CASME2_SWIN_CONFIG['swin_model'],
            add_pooling_layer=False
        )

        # Enable fine-tuning for micro-expression domain
        for param in self.swin.parameters():
            param.requires_grad = True

        # Get Swin feature dimensions from hierarchical structure
        # Swin uses embed_dim with hierarchical scaling: embed_dim * (2^(num_stages-1))
        base_embed_dim = self.swin.config.embed_dim
        num_stages = len(self.swin.config.depths)
        self.swin_feature_dim = base_embed_dim * (2 ** (num_stages - 1))

        print(f"Swin feature dimension (final stage): {self.swin_feature_dim}")
        print(f"Base embed_dim: {base_embed_dim}, Stages: {num_stages}")

        # Verify expected dimensions
        if self.swin_feature_dim != CASME2_SWIN_CONFIG['expected_hidden_dim']:
            print(f"Warning: Expected {CASME2_SWIN_CONFIG['expected_hidden_dim']}, got {self.swin_feature_dim}")
            print(f"Note: Swin hierarchical structure: {base_embed_dim} * (2^{num_stages-1}) = {self.swin_feature_dim}")

        # Global Average Pooling for hierarchical features
        self.global_pool = nn.AdaptiveAvgPool1d(1)

        # Classification head with LayerNorm for stability
        self.classifier_layers = nn.Sequential(
            nn.Linear(self.swin_feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(dropout_rate),
        )

        # Final classification layer for CASME II classes
        self.classifier = nn.Linear(128, num_classes)

        print(f"Swin CASME II: {self.swin_feature_dim} -> GAP -> 512 -> 128 -> {num_classes}")

    def forward(self, pixel_values):
        # Swin forward pass with hierarchical feature extraction
        swin_outputs = self.swin(
            pixel_values=pixel_values,
            interpolate_pos_encoding=True
        )

        # Extract hierarchical features from last layer
        # Swin output: [batch_size, sequence_length, hidden_size]
        hierarchical_features = swin_outputs.last_hidden_state

        # Global Average Pooling across spatial dimensions
        # [batch_size, hidden_size, sequence_length] -> [batch_size, hidden_size, 1] -> [batch_size, hidden_size]
        pooled_features = self.global_pool(hierarchical_features.transpose(1, 2)).squeeze(-1)

        # Classification pipeline
        processed_features = self.classifier_layers(pooled_features)
        output = self.classifier(processed_features)

        return output

# Enhanced optimizer and scheduler factory
def create_optimizer_scheduler_casme2(model, config):
    """Create optimizer and scheduler for CASME II Swin Transformer training"""

    # AdamW optimizer with proven configuration
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    # ReduceLROnPlateau scheduler monitoring validation F1
    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

# Swin Transformer Image Processor setup for 384px input
from transformers import AutoImageProcessor

print("\nSetting up Swin Transformer Image Processor for 384px input...")

swin_processor = AutoImageProcessor.from_pretrained(
    CASME2_SWIN_CONFIG['swin_model'],
    do_resize=True,
    size={'height': 384, 'width': 384},
    do_normalize=True,
    do_rescale=True,
    do_center_crop=False
)

# Transform functions for Swin Transformer
def swin_transform_train(image):
    """Training transform with Swin Transformer Image Processor"""
    inputs = swin_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

def swin_transform_val(image):
    """Validation transform with Swin Transformer Image Processor"""
    inputs = swin_processor(image, return_tensors="pt")
    return inputs['pixel_values'].squeeze(0)

print(f"Swin Transformer Image Processor configured for 384px with hierarchical processing")

# Custom Dataset class for CASME II - preserved existing structure
class CASME2Dataset(Dataset):
    """Custom dataset class for CASME II with JSON metadata support"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='train'):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split

        print(f"Loaded {len(self.metadata)} samples for {split} split")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        sample = self.metadata[idx]

        # Load image
        image_path = os.path.join(self.dataset_root, self.split, sample['image_filename'])
        image = Image.open(image_path).convert('RGB')

        # Apply transform
        if self.transform:
            image = self.transform(image)

        # Get class label
        emotion = sample['emotion']
        label = CLASS_TO_IDX[emotion]

        return image, label, sample['sample_id']

# Create directories - updated for Swin Transformer
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

# Dataset paths - preserved existing structure
TRAIN_PATH = f"{DATASET_ROOT}/train"
VAL_PATH = f"{DATASET_ROOT}/val"
TEST_PATH = f"{DATASET_ROOT}/test"

print(f"\nDataset paths:")
print(f"Train: {TRAIN_PATH}")
print(f"Validation: {VAL_PATH}")
print(f"Test: {TEST_PATH}")

# Enhanced architecture validation with hierarchical feature calculation
print("\nSwin CASME II architecture validation...")

try:
    test_model = SwinCASME2Baseline(num_classes=7, dropout_rate=0.2).to(device)
    test_input = torch.randn(1, 3, 384, 384).to(device)
    test_output = test_model(test_input)

    # Calculate expected hierarchical patches
    # Swin uses multi-stage downsampling: stages with different resolutions
    stage1_patches = (384 // CASME2_SWIN_CONFIG['patch_size']) ** 2  # Initial patches
    expected_sequence_length = stage1_patches // (2 ** (len(test_model.swin.config.depths) - 1))  # After downsampling stages

    print(f"Validation successful: Output shape {test_output.shape}")
    print(f"Initial patches (Stage 1): {stage1_patches}")
    print(f"Final sequence length (after stages): Variable due to hierarchical downsampling")
    print(f"Window size: {CASME2_SWIN_CONFIG['window_size']}x{CASME2_SWIN_CONFIG['window_size']}")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")

# Optimized loss function factory with advanced configuration (preserved from ViT)
def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    """
    Optimized factory function to create loss criterion based on advanced configuration

    Args:
        weights (Tensor): Class weights for CrossEntropy (ignored if focal loss used)
        use_focal_loss (bool): Whether to use Focal Loss or CrossEntropy
        alpha_weights (list): Per-class alpha weights for Focal Loss (must sum to 1.0)
        gamma (float): Focal loss gamma parameter

    Returns:
        Loss function (nn.Module)
    """
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

# Global configuration for training pipeline - enhanced for Swin Transformer
GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': swin_transform_train,
    'transform_val': swin_transform_val,
    'swin_config': CASME2_SWIN_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'train_path': TRAIN_PATH,
    'val_path': VAL_PATH,
    'test_path': TEST_PATH,
    'metadata': casme2_metadata,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

# Configuration validation and summary
print("\n" + "=" * 60)
print("CASME II SWIN TRANSFORMER OPTIMIZED BASELINE CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: {SWIN_MODEL_NAME}")
print(f"  Variant: {SWIN_MODEL_VARIANT}")
print(f"  Window Size: {WINDOW_SIZE}x{WINDOW_SIZE}")
print(f"  Patch Size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"  Input Resolution: 384px")
print(f"  Hidden Dimension: {EXPECTED_HIDDEN_DIM}")

print(f"\nDataset Configuration:")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Weight Optimization: {'Per-class Alpha' if USE_FOCAL_LOSS else 'Inverse Sqrt Frequency'}")
print(f"  Hierarchical Processing: Multi-stage feature extraction")

print("\nNext: Cell 2 - Dataset Loading and Swin Transformer Training Pipeline")

CASME II SWIN TRANSFORMER OPTIMIZED BASELINE INFRASTRUCTURE

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...
CASME II Swin Transformer Optimized Baseline - Infrastructure Configuration
Loading CASME II dataset metadata...
Dataset: CASME2
Total samples: 255
Split strategy: stratified_80_10_10
Using Swin-Base for advanced hierarchical micro-expression recognition

OPTIMIZED EXPERIMENT CONFIGURATION
Loss Function: Focal Loss
  Gamma: 2.0
  Alpha Weights (per-class): [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum Validation: 0.999
Swin Model: microsoft/swin-base-patch4-window7-224
Window Size: 7x7
Patch Size: 4x4
Expected Hidden Dim: 1024

Device: cuda
GPU: NVIDIA L4 (23.8 GB)
L4: Balanced performance configuration for Swin Transformer

Train distribution: {'others': 79, 'disgust': 50, 'happiness': 25, 'repression': 21, 'surprise': 20, 'sadness': 5, 'fear': 1}
Validation distribution: {'others': 10

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Swin Transformer Image Processor configured for 384px with hierarchical processing

Dataset paths:
Train: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split/train
Validation: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split/val
Test: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/data_split/test

Swin CASME II architecture validation...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/352M [00:00<?, ?B/s]

Swin feature dimension (final stage): 1024
Base embed_dim: 128, Stages: 4
Swin CASME II: 1024 -> GAP -> 512 -> 128 -> 7
Validation successful: Output shape torch.Size([1, 7])
Initial patches (Stage 1): 9216
Final sequence length (after stages): Variable due to hierarchical downsampling
Window size: 7x7

CASME II SWIN TRANSFORMER OPTIMIZED BASELINE CONFIGURATION COMPLETE
Loss Configuration:
  Function: Optimized Focal Loss
  Gamma: 2.0
  Per-class Alpha: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum: 0.999

Model Configuration:
  Architecture: microsoft/swin-base-patch4-window7-224
  Variant: base
  Window Size: 7x7
  Patch Size: 4x4
  Input Resolution: 384px
  Hidden Dimension: 1024

Dataset Configuration:
  Classes: 7
  Weight Optimization: Per-class Alpha
  Hierarchical Processing: Multi-stage feature extraction

Next: Cell 2 - Dataset Loading and Swin Transformer Training Pipeline


In [None]:
# @title Cell 2: CASME II Swin Transformer Training Pipeline

# File: 02_02_SwinT_Direct_Enhanced_Baseline_Cell2.py
# Location: experiments/02_02_SwinT_Direct_Baseline.ipynb
# Purpose: Training pipeline for CASME II Swin Transformer micro-expression recognition

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp

print("CASME II Swin Transformer Enhanced Training Pipeline with Fixed Checkpoints")
print("=" * 70)
print(f"Loss Function: {'Optimized Focal Loss' if CASME2_SWIN_CONFIG['use_focal_loss'] else 'CrossEntropy Loss'}")
if CASME2_SWIN_CONFIG['use_focal_loss']:
    print(f"Focal Loss Parameters:")
    print(f"  Gamma: {CASME2_SWIN_CONFIG['focal_loss_gamma']}")
    print(f"  Per-class Alpha: {CASME2_SWIN_CONFIG['focal_loss_alpha_weights']}")
    print(f"  Alpha Sum: {sum(CASME2_SWIN_CONFIG['focal_loss_alpha_weights']):.3f}")
else:
    print(f"CrossEntropy Parameters:")
    print(f"  Optimized Class Weights: {CASME2_SWIN_CONFIG['crossentropy_class_weights']}")
print(f"Swin Architecture: {CASME2_SWIN_CONFIG['swin_variant']} variant")
print(f"Window mechanism: {CASME2_SWIN_CONFIG['window_size']}x{CASME2_SWIN_CONFIG['window_size']}")
print(f"Training epochs: {CASME2_SWIN_CONFIG['num_epochs']}")
print(f"Scheduler patience: {CASME2_SWIN_CONFIG['scheduler_patience']}")

# Enhanced CASME II Dataset with optimized RAM caching for Swin Transformer
class CASME2DatasetTrainingSwin(Dataset):
    """Enhanced CASME II dataset for Swin Transformer training with RAM caching optimization"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='train', use_ram_cache=True):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.cached_images = []

        print(f"Loading CASME II {split} dataset for Swin Transformer training...")

        # Process metadata
        for sample in self.metadata:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(CLASS_TO_IDX[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])

        print(f"Loaded {len(self.images)} CASME II {split} samples")
        self._print_distribution()

        # RAM caching for training efficiency
        if self.use_ram_cache:
            self._preload_to_ram()

    def _print_distribution(self):
        """Print class distribution"""
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        """RAM preloading optimized for Swin Transformer training"""
        print(f"Preloading {len(self.images)} {self.split} images to RAM for Swin Transformer...")

        valid_images = 0
        self.cached_images = [None] * len(self.images)

        for i, img_path in enumerate(tqdm(self.images, desc=f"Loading {self.split}")):
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
                self.cached_images[i] = image
                valid_images += 1
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                self.cached_images[i] = Image.new('RGB', (384, 384), (128, 128, 128))

        ram_usage_gb = len(self.cached_images) * 384 * 384 * 3 * 4 / 1e9
        print(f"{self.split.upper()} RAM caching completed: {valid_images} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (384, 384), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.sample_ids[idx]

# Enhanced metrics calculation with comprehensive error handling
def calculate_metrics_safe_robust(outputs, labels, class_names, average='macro'):
    """Calculate metrics with enhanced error handling and validation"""
    try:
        # Validate input tensors
        if outputs.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: outputs {outputs.size(0)} vs labels {labels.size(0)}")

        if isinstance(outputs, torch.Tensor):
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
        else:
            predictions = np.array(outputs)

        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()
        else:
            labels = np.array(labels)

        # Validate predictions are in valid range
        unique_preds = np.unique(predictions)
        unique_labels = np.unique(labels)

        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions,
            average=average,
            zero_division=0,
            labels=list(range(len(class_names)))
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1)
        }
    except Exception as e:
        print(f"Warning: Enhanced metrics calculation error: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }

# Enhanced training epoch function with Swin Transformer hierarchical output validation
def train_epoch_swin(model, dataloader, criterion, optimizer, device, epoch, total_epochs):
    """Enhanced training epoch for Swin Transformer with hierarchical feature handling"""
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    progress_bar = tqdm(dataloader, desc=f"CASME II Swin Training Epoch {epoch+1}/{total_epochs}")

    for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        # Enhanced Swin model output validation - handle hierarchical features
        model_output = model(images)

        # Robust output structure validation for Swin Transformer
        if isinstance(model_output, (tuple, list)):
            outputs = model_output[0]
        elif isinstance(model_output, dict):
            # Handle potential Swin-specific output dictionary
            outputs = model_output.get('logits', model_output.get('prediction', model_output))
        else:
            outputs = model_output

        # Validate output shape for 7 CASME II classes
        if outputs.dim() != 2 or outputs.size(1) != 7:
            raise ValueError(f"Invalid CASME II Swin output shape: {outputs.shape}, expected [batch_size, 7]")

        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping for Swin Transformer stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), CASME2_SWIN_CONFIG['gradient_clip'])

        optimizer.step()
        running_loss += loss.item()

        # Memory optimized: Move to CPU before accumulating
        all_outputs.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())

        # Update progress with Swin-specific information
        if batch_idx % 5 == 0:
            avg_loss = running_loss / (batch_idx + 1)
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'Loss': f'{avg_loss:.4f}',
                'LR': f'{current_lr:.2e}',
                'Swin': CASME2_SWIN_CONFIG['swin_variant']
            })

    # Enhanced metrics calculation with error recovery
    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Swin training metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics

# Enhanced validation epoch function with Swin Transformer hierarchical output validation
def validate_epoch_swin(model, dataloader, criterion, device, epoch, total_epochs):
    """Enhanced validation epoch for Swin Transformer with hierarchical feature handling"""
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []
    all_sample_ids = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f"CASME II Swin Validation Epoch {epoch+1}/{total_epochs}")

        for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)

            # Enhanced Swin model output validation - handle hierarchical features
            model_output = model(images)

            # Robust output structure validation for Swin Transformer
            if isinstance(model_output, (tuple, list)):
                outputs = model_output[0]
            elif isinstance(model_output, dict):
                outputs = model_output.get('logits', model_output.get('prediction', model_output))
            else:
                outputs = model_output

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Memory optimized: Move to CPU before accumulating
            all_outputs.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sample_ids.extend(sample_ids)

            if batch_idx % 3 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                progress_bar.set_postfix({
                    'Val Loss': f'{avg_loss:.4f}',
                    'Swin': CASME2_SWIN_CONFIG['swin_variant']
                })

    # Enhanced metrics calculation with error recovery
    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Swin validation metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics, all_sample_ids

# FIXED: Enhanced checkpoint saving function with complete device migration
def save_checkpoint_robust_fixed(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                                checkpoint_dir, best_metrics, config, max_retries=3):
    """FIXED: Enhanced checkpoint saving with complete device migration and error recovery"""

    print(f"Saving checkpoint with enhanced device migration...")

    # CRITICAL FIX: Complete device migration to CPU before serialization
    try:
        # 1. Move model state dict to CPU with proper tensor handling
        print("  Migrating model state dict to CPU...")
        model_state_cpu = {}
        for k, v in model.state_dict().items():
            if isinstance(v, torch.Tensor):
                model_state_cpu[k] = v.cpu().clone()
            else:
                model_state_cpu[k] = v

        # 2. Move optimizer state dict to CPU with nested tensor handling
        print("  Migrating optimizer state dict to CPU...")
        optimizer_state_cpu = {}
        for k, v in optimizer.state_dict().items():
            if isinstance(v, torch.Tensor):
                optimizer_state_cpu[k] = v.cpu().clone()
            elif isinstance(v, dict):
                # Handle nested dictionaries in optimizer state
                optimizer_state_cpu[k] = {}
                for nested_k, nested_v in v.items():
                    if isinstance(nested_v, torch.Tensor):
                        optimizer_state_cpu[k][nested_k] = nested_v.cpu().clone()
                    elif isinstance(nested_v, dict):
                        # Handle double-nested dictionaries (param_groups, etc.)
                        optimizer_state_cpu[k][nested_k] = {}
                        for deep_k, deep_v in nested_v.items():
                            if isinstance(deep_v, torch.Tensor):
                                optimizer_state_cpu[k][nested_k][deep_k] = deep_v.cpu().clone()
                            else:
                                optimizer_state_cpu[k][nested_k][deep_k] = deep_v
                    else:
                        optimizer_state_cpu[k][nested_k] = nested_v
            else:
                optimizer_state_cpu[k] = v

        # 3. Move scheduler state dict to CPU if exists
        scheduler_state_cpu = None
        if scheduler:
            print("  Migrating scheduler state dict to CPU...")
            scheduler_state_cpu = {}
            for k, v in scheduler.state_dict().items():
                if isinstance(v, torch.Tensor):
                    scheduler_state_cpu[k] = v.cpu().clone()
                else:
                    scheduler_state_cpu[k] = v

        # 4. Force GPU memory cleanup before serialization
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

        print("  Device migration completed successfully")

    except Exception as e:
        print(f"ERROR: Device migration failed: {e}")
        return None

    # Convert all metrics to serializable format
    def make_serializable_enhanced(obj):
        """Enhanced serialization with complete tensor handling"""
        if isinstance(obj, torch.Tensor):
            return obj.cpu().item() if obj.numel() == 1 else obj.cpu().tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.integer, np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, dict):
            return {k: make_serializable_enhanced(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [make_serializable_enhanced(item) for item in obj]
        else:
            try:
                return float(obj) if isinstance(obj, (int, float)) else str(obj)
            except:
                return str(obj)

    # Create checkpoint with CPU-migrated state dicts
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state_cpu,  # Now on CPU
        'optimizer_state_dict': optimizer_state_cpu,  # Now on CPU
        'scheduler_state_dict': scheduler_state_cpu,  # Now on CPU or None
        'train_metrics': make_serializable_enhanced(train_metrics),
        'val_metrics': make_serializable_enhanced(val_metrics),
        'casme2_swin_config': make_serializable_enhanced(config),
        'best_f1': float(best_metrics['f1']),
        'best_loss': float(best_metrics['loss']),
        'best_acc': float(best_metrics['accuracy']),
        'class_names': CASME2_CLASSES,
        'num_classes': 7,
        'swin_variant': config['swin_variant'],
        'swin_model': config['swin_model']
    }

    best_path = f"{checkpoint_dir}/casme2_swint_direct_best_f1.pth"

    # Enhanced save with retry logic and proper file handling
    for attempt in range(max_retries):
        try:
            # Create temporary file first
            temp_path = f"{best_path}.tmp"

            print(f"  Attempt {attempt + 1}: Saving to temporary file...")
            torch.save(checkpoint, temp_path)

            # Move temporary file to final location (atomic operation)
            import shutil
            shutil.move(temp_path, best_path)

            print(f"Checkpoint saved successfully: {os.path.basename(best_path)}")
            print(f"  Swin variant: {config['swin_variant']}")
            print(f"  Model: {config['swin_model']}")
            return best_path

        except Exception as e:
            print(f"Checkpoint save attempt {attempt + 1} failed: {e}")

            # Clean up temporary file if it exists
            if os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except:
                    pass

            if attempt < max_retries - 1:
                print(f"  Retrying in 2 seconds...")
                time.sleep(2)  # Brief pause before retry

                # Additional memory cleanup before retry
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()
                continue
            else:
                print(f"All {max_retries} checkpoint save attempts failed")
                return None

    return None

# Safe JSON serialization function - enhanced for Swin Transformer
def safe_json_serialize_swin(obj):
    """Convert objects to JSON-serializable format with Swin-specific handling"""
    if isinstance(obj, torch.Tensor):
        return obj.cpu().item() if obj.numel() == 1 else obj.cpu().numpy().tolist()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: safe_json_serialize_swin(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [safe_json_serialize_swin(item) for item in obj]
    elif hasattr(obj, '__dict__'):
        return safe_json_serialize_swin(obj.__dict__)
    else:
        try:
            return float(obj) if isinstance(obj, (int, float)) else str(obj)
        except:
            return str(obj)

# Create enhanced datasets for Swin Transformer
print("\nCreating CASME II Swin Transformer training datasets...")

train_dataset = CASME2DatasetTrainingSwin(
    split_metadata=GLOBAL_CONFIG_CASME2['metadata'],
    dataset_root=GLOBAL_CONFIG_CASME2['train_path'].replace('/train', ''),
    transform=GLOBAL_CONFIG_CASME2['transform_train'],
    split='train',
    use_ram_cache=True
)

val_dataset = CASME2DatasetTrainingSwin(
    split_metadata=GLOBAL_CONFIG_CASME2['metadata'],
    dataset_root=GLOBAL_CONFIG_CASME2['val_path'].replace('/val', ''),
    transform=GLOBAL_CONFIG_CASME2['transform_val'],
    split='val',
    use_ram_cache=True
)

# Create data loaders with Swin-optimized settings
train_loader = DataLoader(
    train_dataset,
    batch_size=CASME2_SWIN_CONFIG['batch_size'],
    shuffle=True,
    num_workers=CASME2_SWIN_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CASME2_SWIN_CONFIG['batch_size'],
    shuffle=False,
    num_workers=CASME2_SWIN_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

print(f"Training batches: {len(train_loader)} (samples: {len(train_dataset)})")
print(f"Validation batches: {len(val_loader)} (samples: {len(val_dataset)})")

# Initialize Swin Transformer model, criterion, optimizer, scheduler
print("\nInitializing CASME II Swin Transformer enhanced model...")
model = SwinCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_SWIN_CONFIG['dropout_rate']
).to(GLOBAL_CONFIG_CASME2['device'])

# Enhanced criterion creation using configurable factory function
if CASME2_SWIN_CONFIG['use_focal_loss']:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=True,
        alpha_weights=CASME2_SWIN_CONFIG['focal_loss_alpha_weights'],
        gamma=CASME2_SWIN_CONFIG['focal_loss_gamma']
    )
else:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=False,
        alpha_weights=None,
        gamma=2.0
    )

optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_SWIN_CONFIG
)

print(f"Optimizer: AdamW (LR={CASME2_SWIN_CONFIG['learning_rate']})")
print(f"Scheduler: ReduceLROnPlateau (patience={CASME2_SWIN_CONFIG['scheduler_patience']})")
print(f"Criterion: {'Optimized Focal Loss' if CASME2_SWIN_CONFIG['use_focal_loss'] else 'CrossEntropy'}")
print(f"Swin Architecture: {CASME2_SWIN_CONFIG['swin_variant']} variant")

# Training history tracking for Swin Transformer
training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

# Enhanced best metrics tracking for multi-criteria checkpoint saving
best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

print("\nStarting CASME II Swin Transformer enhanced training with fixed checkpoints...")
print(f"Training configuration: {CASME2_SWIN_CONFIG['num_epochs']} epochs")
print(f"Hierarchical processing: {CASME2_SWIN_CONFIG['window_size']}x{CASME2_SWIN_CONFIG['window_size']} windows")
print("=" * 70)

# Main training loop with enhanced checkpoint reliability for Swin Transformer
start_time = time.time()

for epoch in range(CASME2_SWIN_CONFIG['num_epochs']):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{CASME2_SWIN_CONFIG['num_epochs']}")

    # Training phase with Swin-specific adaptations
    train_loss, train_metrics = train_epoch_swin(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_SWIN_CONFIG['num_epochs']
    )

    # Validation phase with Swin-specific adaptations
    val_loss, val_metrics, val_sample_ids = validate_epoch_swin(
        model, val_loader, criterion,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_SWIN_CONFIG['num_epochs']
    )

    # Update scheduler
    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    # Record training history
    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    # Print epoch summary
    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    # Enhanced multi-criteria checkpoint saving logic
    save_model = False
    improvement_reason = ""

    # Multi-criteria evaluation hierarchy
    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        # Use FIXED checkpoint saving function
        best_model_path = save_checkpoint_robust_fixed(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics, GLOBAL_CONFIG_CASME2['checkpoint_root'],
            best_metrics, CASME2_SWIN_CONFIG
        )

        if best_model_path:
            print(f"New best Swin model: {improvement_reason} - F1: {best_metrics['f1']:.4f}")
        else:
            print(f"Warning: Failed to save Swin checkpoint despite improvement")

    # Progress tracking
    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * CASME2_SWIN_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / CASME2_SWIN_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min | Swin-{CASME2_SWIN_CONFIG['swin_variant']}")

# Training completion
total_time = time.time() - start_time
actual_epochs = CASME2_SWIN_CONFIG['num_epochs']

print("\n" + "=" * 70)
print("CASME II SWIN TRANSFORMER ENHANCED BASELINE TRAINING COMPLETED")
print("=" * 70)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {actual_epochs}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")
print(f"Swin Architecture: {CASME2_SWIN_CONFIG['swin_variant']} variant")

# Enhanced training documentation export for Swin Transformer
results_dir = GLOBAL_CONFIG_CASME2['results_root']
os.makedirs(f"{results_dir}/training_logs", exist_ok=True)

training_history_path = f"{results_dir}/training_logs/casme2_swint_direct_training_history.json"

print("\nExporting enhanced Swin Transformer training documentation...")

try:
    # Create comprehensive training summary with Swin-specific configuration
    training_summary = {
        'experiment_type': 'CASME2_SwinT_Enhanced_Baseline',
        'experiment_configuration': {
            'loss_function': 'Optimized Focal Loss' if CASME2_SWIN_CONFIG['use_focal_loss'] else 'CrossEntropy',
            'weight_approach': 'Per-class Alpha (sum=1.0)' if CASME2_SWIN_CONFIG['use_focal_loss'] else 'Inverse Sqrt Frequency',
            'focal_loss_gamma': CASME2_SWIN_CONFIG['focal_loss_gamma'],
            'focal_loss_alpha_weights': CASME2_SWIN_CONFIG['focal_loss_alpha_weights'],
            'crossentropy_class_weights': CASME2_SWIN_CONFIG['crossentropy_class_weights'],
            'swin_model': CASME2_SWIN_CONFIG['swin_model'],
            'swin_variant': CASME2_SWIN_CONFIG['swin_variant'],
            'window_size': CASME2_SWIN_CONFIG['window_size'],
            'patch_size': CASME2_SWIN_CONFIG['patch_size']
        },
        'training_history': safe_json_serialize_swin(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(actual_epochs),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'config': safe_json_serialize_swin(CASME2_SWIN_CONFIG),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'model_checkpoint': 'casme2_swint_direct_best_f1.pth',
        'dataset_info': {
            'name': 'CASME_II',
            'total_samples': 255,
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': 7,
            'class_names': CASME2_CLASSES
        },
        'architecture_info': {
            'model_type': 'SwinCASME2Baseline',
            'backbone': CASME2_SWIN_CONFIG['swin_model'],
            'variant': CASME2_SWIN_CONFIG['swin_variant'],
            'input_size': f"{CASME2_SWIN_CONFIG['input_size']}x{CASME2_SWIN_CONFIG['input_size']}",
            'window_size': f"{CASME2_SWIN_CONFIG['window_size']}x{CASME2_SWIN_CONFIG['window_size']}",
            'patch_size': f"{CASME2_SWIN_CONFIG['patch_size']}x{CASME2_SWIN_CONFIG['patch_size']}",
            'hidden_dim': CASME2_SWIN_CONFIG['expected_hidden_dim'],
            'classification_head': f'{CASME2_SWIN_CONFIG["expected_hidden_dim"]}->GAP->512->128->7'
        },
        'enhanced_features': {
            'hierarchical_processing': True,
            'shifted_window_attention': True,
            'fixed_checkpoint_saving': True,
            'device_migration_complete': True,
            'robust_error_handling': True,
            'multi_criteria_checkpoint_logic': True,
            'memory_optimized_training': True,
            'retry_checkpoint_logic': True
        }
    }

    # Save with proper JSON serialization
    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Enhanced Swin Transformer training documentation saved: {training_history_path}")
    print(f"Experiment details: {training_summary['experiment_configuration']['loss_function']} loss")
    if CASME2_SWIN_CONFIG['use_focal_loss']:
        print(f"  Gamma: {CASME2_SWIN_CONFIG['focal_loss_gamma']}, Alpha Sum: {sum(CASME2_SWIN_CONFIG['focal_loss_alpha_weights']):.3f}")
    print(f"Model variant: {CASME2_SWIN_CONFIG['swin_model']} ({CASME2_SWIN_CONFIG['swin_variant']})")
    print(f"Checkpoint saving: FIXED with complete device migration")

except Exception as e:
    print(f"Warning: Could not save Swin training documentation: {e}")
    print("Training completed successfully, but documentation export failed")

# Enhanced memory cleanup for Swin Transformer
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\nNext: Cell 3 - CASME II Swin Transformer Enhanced Evaluation")
print("Enhanced Swin Transformer training pipeline completed successfully!")

CASME II Swin Transformer Enhanced Training Pipeline with Fixed Checkpoints
Loss Function: Optimized Focal Loss
Focal Loss Parameters:
  Gamma: 2.0
  Per-class Alpha: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum: 0.999
Swin Architecture: base variant
Window mechanism: 7x7
Training epochs: 50
Scheduler patience: 3

Creating CASME II Swin Transformer training datasets...
Loading CASME II train dataset for Swin Transformer training...
Loaded 201 CASME II train samples
  others: 79 samples (39.3%)
  disgust: 50 samples (24.9%)
  happiness: 25 samples (12.4%)
  repression: 21 samples (10.4%)
  surprise: 20 samples (10.0%)
  sadness: 5 samples (2.5%)
  fear: 1 samples (0.5%)
Preloading 201 train images to RAM for Swin Transformer...


Loading train: 100%|██████████| 201/201 [01:48<00:00,  1.85it/s]


TRAIN RAM caching completed: 201 images, ~0.36GB
Loading CASME II val dataset for Swin Transformer training...
Loaded 26 CASME II val samples
  others: 10 samples (38.5%)
  disgust: 6 samples (23.1%)
  happiness: 3 samples (11.5%)
  repression: 3 samples (11.5%)
  surprise: 2 samples (7.7%)
  sadness: 1 samples (3.8%)
  fear: 1 samples (3.8%)
Preloading 26 val images to RAM for Swin Transformer...


Loading val: 100%|██████████| 26/26 [00:10<00:00,  2.45it/s]


VAL RAM caching completed: 26 images, ~0.05GB
Training batches: 15 (samples: 201)
Validation batches: 2 (samples: 26)

Initializing CASME II Swin Transformer enhanced model...
Swin feature dimension (final stage): 1024
Base embed_dim: 128, Stages: 4
Swin CASME II: 1024 -> GAP -> 512 -> 128 -> 7
Using Optimized Focal Loss with gamma=2.0
Per-class alpha weights: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
Alpha sum: 0.999
Scheduler: ReduceLROnPlateau monitoring val_f1_macro
Optimizer: AdamW (LR=1e-05)
Scheduler: ReduceLROnPlateau (patience=3)
Criterion: Optimized Focal Loss
Swin Architecture: base variant

Starting CASME II Swin Transformer enhanced training with fixed checkpoints...
Training configuration: 50 epochs
Hierarchical processing: 7x7 windows

Epoch 1/50


CASME II Swin Training Epoch 1/50: 100%|██████████| 15/15 [00:13<00:00,  1.11it/s, Loss=0.1099, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 1/50: 100%|██████████| 2/2 [00:00<00:00,  2.25it/s, Val Loss=0.0462, Swin=base]


Train - Loss: 0.1038, F1: 0.1208, Acc: 0.2736
Val   - Loss: 0.1375, F1: 0.0680, Acc: 0.1923
Time  - Epoch: 14.4s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_swint_direct_best_f1.pth
  Swin variant: base
  Model: microsoft/swin-base-patch4-window7-224
New best Swin model: Higher F1 - F1: 0.0680
Progress: 2.0% | Best F1: 0.0680 | ETA: 13.9min | Swin-base

Epoch 2/50


CASME II Swin Training Epoch 2/50: 100%|██████████| 15/15 [00:12<00:00,  1.22it/s, Loss=0.0995, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 2/50: 100%|██████████| 2/2 [00:00<00:00,  2.15it/s, Val Loss=0.0441, Swin=base]


Train - Loss: 0.0940, F1: 0.1858, Acc: 0.3781
Val   - Loss: 0.1385, F1: 0.0779, Acc: 0.3462
Time  - Epoch: 13.3s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_swint_direct_best_f1.pth
  Swin variant: base
  Model: microsoft/swin-base-patch4-window7-224
New best Swin model: Higher F1 - F1: 0.0779
Progress: 4.0% | Best F1: 0.0779 | ETA: 13.2min | Swin-base

Epoch 3/50


CASME II Swin Training Epoch 3/50: 100%|██████████| 15/15 [00:12<00:00,  1.20it/s, Loss=0.0819, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 3/50: 100%|██████████| 2/2 [00:01<00:00,  1.99it/s, Val Loss=0.0447, Swin=base]


Train - Loss: 0.0872, F1: 0.2553, Acc: 0.4577
Val   - Loss: 0.1422, F1: 0.0756, Acc: 0.3462
Time  - Epoch: 13.5s, LR: 1.00e-05
Progress: 6.0% | Best F1: 0.0779 | ETA: 12.2min | Swin-base

Epoch 4/50


CASME II Swin Training Epoch 4/50: 100%|██████████| 15/15 [00:12<00:00,  1.20it/s, Loss=0.0747, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 4/50: 100%|██████████| 2/2 [00:01<00:00,  1.93it/s, Val Loss=0.0461, Swin=base]


Train - Loss: 0.0783, F1: 0.3095, Acc: 0.4478
Val   - Loss: 0.1426, F1: 0.0952, Acc: 0.2692
Time  - Epoch: 13.6s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_swint_direct_best_f1.pth
  Swin variant: base
  Model: microsoft/swin-base-patch4-window7-224
New best Swin model: Higher F1 - F1: 0.0952
Progress: 8.0% | Best F1: 0.0952 | ETA: 12.5min | Swin-base

Epoch 5/50


CASME II Swin Training Epoch 5/50: 100%|██████████| 15/15 [00:12<00:00,  1.19it/s, Loss=0.0665, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 5/50: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s, Val Loss=0.0437, Swin=base]


Train - Loss: 0.0696, F1: 0.3593, Acc: 0.5025
Val   - Loss: 0.1434, F1: 0.0949, Acc: 0.3077
Time  - Epoch: 13.7s, LR: 1.00e-05
Progress: 10.0% | Best F1: 0.0952 | ETA: 11.8min | Swin-base

Epoch 6/50


CASME II Swin Training Epoch 6/50: 100%|██████████| 15/15 [00:12<00:00,  1.19it/s, Loss=0.0657, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 6/50: 100%|██████████| 2/2 [00:01<00:00,  1.85it/s, Val Loss=0.0444, Swin=base]


Train - Loss: 0.0616, F1: 0.4775, Acc: 0.6318
Val   - Loss: 0.1400, F1: 0.1786, Acc: 0.3077
Time  - Epoch: 13.7s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_swint_direct_best_f1.pth
  Swin variant: base
  Model: microsoft/swin-base-patch4-window7-224
New best Swin model: Higher F1 - F1: 0.1786
Progress: 12.0% | Best F1: 0.1786 | ETA: 11.6min | Swin-base

Epoch 7/50


CASME II Swin Training Epoch 7/50: 100%|██████████| 15/15 [00:12<00:00,  1.18it/s, Loss=0.0566, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 7/50: 100%|██████████| 2/2 [00:01<00:00,  1.83it/s, Val Loss=0.0437, Swin=base]


Train - Loss: 0.0570, F1: 0.4333, Acc: 0.5672
Val   - Loss: 0.1440, F1: 0.1871, Acc: 0.3462
Time  - Epoch: 13.8s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_swint_direct_best_f1.pth
  Swin variant: base
  Model: microsoft/swin-base-patch4-window7-224
New best Swin model: Higher F1 - F1: 0.1871
Progress: 14.0% | Best F1: 0.1871 | ETA: 11.4min | Swin-base

Epoch 8/50


CASME II Swin Training Epoch 8/50: 100%|██████████| 15/15 [00:12<00:00,  1.17it/s, Loss=0.0515, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 8/50: 100%|██████████| 2/2 [00:01<00:00,  1.67it/s, Val Loss=0.0401, Swin=base]


Train - Loss: 0.0517, F1: 0.5620, Acc: 0.7015
Val   - Loss: 0.1405, F1: 0.1840, Acc: 0.3462
Time  - Epoch: 14.1s, LR: 1.00e-05
Progress: 16.0% | Best F1: 0.1871 | ETA: 11.0min | Swin-base

Epoch 9/50


CASME II Swin Training Epoch 9/50: 100%|██████████| 15/15 [00:12<00:00,  1.17it/s, Loss=0.0443, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 9/50: 100%|██████████| 2/2 [00:01<00:00,  1.70it/s, Val Loss=0.0396, Swin=base]


Train - Loss: 0.0460, F1: 0.5566, Acc: 0.6965
Val   - Loss: 0.1377, F1: 0.2435, Acc: 0.3846
Time  - Epoch: 14.1s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_swint_direct_best_f1.pth
  Swin variant: base
  Model: microsoft/swin-base-patch4-window7-224
New best Swin model: Higher F1 - F1: 0.2435
Progress: 18.0% | Best F1: 0.2435 | ETA: 10.8min | Swin-base

Epoch 10/50


CASME II Swin Training Epoch 10/50: 100%|██████████| 15/15 [00:12<00:00,  1.16it/s, Loss=0.0439, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 10/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0371, Swin=base]


Train - Loss: 0.0412, F1: 0.6174, Acc: 0.7662
Val   - Loss: 0.1347, F1: 0.2381, Acc: 0.4231
Time  - Epoch: 14.2s, LR: 1.00e-05
Progress: 20.0% | Best F1: 0.2435 | ETA: 10.4min | Swin-base

Epoch 11/50


CASME II Swin Training Epoch 11/50: 100%|██████████| 15/15 [00:12<00:00,  1.16it/s, Loss=0.0351, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 11/50: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s, Val Loss=0.0421, Swin=base]


Train - Loss: 0.0351, F1: 0.7881, Acc: 0.7761
Val   - Loss: 0.1361, F1: 0.2619, Acc: 0.4231
Time  - Epoch: 14.3s, LR: 1.00e-05
Saving checkpoint with enhanced device migration...
  Migrating model state dict to CPU...
  Migrating optimizer state dict to CPU...
  Migrating scheduler state dict to CPU...
  Device migration completed successfully
  Attempt 1: Saving to temporary file...
Checkpoint saved successfully: casme2_swint_direct_best_f1.pth
  Swin variant: base
  Model: microsoft/swin-base-patch4-window7-224
New best Swin model: Higher F1 - F1: 0.2619
Progress: 22.0% | Best F1: 0.2619 | ETA: 10.3min | Swin-base

Epoch 12/50


CASME II Swin Training Epoch 12/50: 100%|██████████| 15/15 [00:12<00:00,  1.15it/s, Loss=0.0305, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 12/50: 100%|██████████| 2/2 [00:01<00:00,  1.53it/s, Val Loss=0.0351, Swin=base]


Train - Loss: 0.0324, F1: 0.8501, Acc: 0.8507
Val   - Loss: 0.1348, F1: 0.2303, Acc: 0.4231
Time  - Epoch: 14.3s, LR: 1.00e-05
Progress: 24.0% | Best F1: 0.2619 | ETA: 9.9min | Swin-base

Epoch 13/50


CASME II Swin Training Epoch 13/50: 100%|██████████| 15/15 [00:12<00:00,  1.15it/s, Loss=0.0263, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 13/50: 100%|██████████| 2/2 [00:01<00:00,  1.52it/s, Val Loss=0.0360, Swin=base]


Train - Loss: 0.0274, F1: 0.8824, Acc: 0.8706
Val   - Loss: 0.1361, F1: 0.2395, Acc: 0.4615
Time  - Epoch: 14.3s, LR: 1.00e-05
Progress: 26.0% | Best F1: 0.2619 | ETA: 9.6min | Swin-base

Epoch 14/50


CASME II Swin Training Epoch 14/50: 100%|██████████| 15/15 [00:13<00:00,  1.15it/s, Loss=0.0230, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 14/50: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s, Val Loss=0.0315, Swin=base]


Train - Loss: 0.0230, F1: 0.9171, Acc: 0.9005
Val   - Loss: 0.1355, F1: 0.2347, Acc: 0.4615
Time  - Epoch: 14.3s, LR: 1.00e-05
Progress: 28.0% | Best F1: 0.2619 | ETA: 9.3min | Swin-base

Epoch 15/50


CASME II Swin Training Epoch 15/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0230, LR=1.00e-05, Swin=base]
CASME II Swin Validation Epoch 15/50: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s, Val Loss=0.0375, Swin=base]


Train - Loss: 0.0218, F1: 0.9295, Acc: 0.9254
Val   - Loss: 0.1342, F1: 0.2293, Acc: 0.4231
Time  - Epoch: 14.5s, LR: 5.00e-06
Progress: 30.0% | Best F1: 0.2619 | ETA: 9.0min | Swin-base

Epoch 16/50


CASME II Swin Training Epoch 16/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0184, LR=5.00e-06, Swin=base]
CASME II Swin Validation Epoch 16/50: 100%|██████████| 2/2 [00:01<00:00,  1.54it/s, Val Loss=0.0366, Swin=base]


Train - Loss: 0.0192, F1: 0.9296, Acc: 0.9353
Val   - Loss: 0.1344, F1: 0.2245, Acc: 0.4231
Time  - Epoch: 14.4s, LR: 5.00e-06
Progress: 32.0% | Best F1: 0.2619 | ETA: 8.7min | Swin-base

Epoch 17/50


CASME II Swin Training Epoch 17/50: 100%|██████████| 15/15 [00:13<00:00,  1.15it/s, Loss=0.0166, LR=5.00e-06, Swin=base]
CASME II Swin Validation Epoch 17/50: 100%|██████████| 2/2 [00:01<00:00,  1.60it/s, Val Loss=0.0358, Swin=base]


Train - Loss: 0.0159, F1: 0.9418, Acc: 0.9453
Val   - Loss: 0.1372, F1: 0.2208, Acc: 0.4231
Time  - Epoch: 14.3s, LR: 5.00e-06
Progress: 34.0% | Best F1: 0.2619 | ETA: 8.4min | Swin-base

Epoch 18/50


CASME II Swin Training Epoch 18/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0165, LR=5.00e-06, Swin=base]
CASME II Swin Validation Epoch 18/50: 100%|██████████| 2/2 [00:01<00:00,  1.59it/s, Val Loss=0.0366, Swin=base]


Train - Loss: 0.0171, F1: 0.9634, Acc: 0.9602
Val   - Loss: 0.1375, F1: 0.2208, Acc: 0.4231
Time  - Epoch: 14.4s, LR: 5.00e-06
Progress: 36.0% | Best F1: 0.2619 | ETA: 8.1min | Swin-base

Epoch 19/50


CASME II Swin Training Epoch 19/50: 100%|██████████| 15/15 [00:12<00:00,  1.16it/s, Loss=0.0138, LR=5.00e-06, Swin=base]
CASME II Swin Validation Epoch 19/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0390, Swin=base]


Train - Loss: 0.0142, F1: 0.9730, Acc: 0.9701
Val   - Loss: 0.1393, F1: 0.2293, Acc: 0.4231
Time  - Epoch: 14.3s, LR: 2.50e-06
Progress: 38.0% | Best F1: 0.2619 | ETA: 7.8min | Swin-base

Epoch 20/50


CASME II Swin Training Epoch 20/50: 100%|██████████| 15/15 [00:13<00:00,  1.15it/s, Loss=0.0138, LR=2.50e-06, Swin=base]
CASME II Swin Validation Epoch 20/50: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s, Val Loss=0.0377, Swin=base]


Train - Loss: 0.0147, F1: 0.9688, Acc: 0.9701
Val   - Loss: 0.1391, F1: 0.2381, Acc: 0.4231
Time  - Epoch: 14.3s, LR: 2.50e-06
Progress: 40.0% | Best F1: 0.2619 | ETA: 7.6min | Swin-base

Epoch 21/50


CASME II Swin Training Epoch 21/50: 100%|██████████| 15/15 [00:13<00:00,  1.15it/s, Loss=0.0121, LR=2.50e-06, Swin=base]
CASME II Swin Validation Epoch 21/50: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s, Val Loss=0.0371, Swin=base]


Train - Loss: 0.0123, F1: 0.9774, Acc: 0.9851
Val   - Loss: 0.1393, F1: 0.2483, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 2.50e-06
Progress: 42.0% | Best F1: 0.2619 | ETA: 7.3min | Swin-base

Epoch 22/50


CASME II Swin Training Epoch 22/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0130, LR=2.50e-06, Swin=base]
CASME II Swin Validation Epoch 22/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0351, Swin=base]


Train - Loss: 0.0130, F1: 0.9645, Acc: 0.9751
Val   - Loss: 0.1397, F1: 0.2446, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 2.50e-06
Progress: 44.0% | Best F1: 0.2619 | ETA: 7.0min | Swin-base

Epoch 23/50


CASME II Swin Training Epoch 23/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0127, LR=2.50e-06, Swin=base]
CASME II Swin Validation Epoch 23/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0367, Swin=base]


Train - Loss: 0.0123, F1: 0.9852, Acc: 0.9801
Val   - Loss: 0.1402, F1: 0.2347, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 1.25e-06
Progress: 46.0% | Best F1: 0.2619 | ETA: 6.8min | Swin-base

Epoch 24/50


CASME II Swin Training Epoch 24/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0124, LR=1.25e-06, Swin=base]
CASME II Swin Validation Epoch 24/50: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s, Val Loss=0.0367, Swin=base]


Train - Loss: 0.0120, F1: 0.9739, Acc: 0.9751
Val   - Loss: 0.1397, F1: 0.2347, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 1.25e-06
Progress: 48.0% | Best F1: 0.2619 | ETA: 6.5min | Swin-base

Epoch 25/50


CASME II Swin Training Epoch 25/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0107, LR=1.25e-06, Swin=base]
CASME II Swin Validation Epoch 25/50: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s, Val Loss=0.0367, Swin=base]


Train - Loss: 0.0105, F1: 0.9705, Acc: 0.9851
Val   - Loss: 0.1388, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.25e-06
Progress: 50.0% | Best F1: 0.2619 | ETA: 6.2min | Swin-base

Epoch 26/50


CASME II Swin Training Epoch 26/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0105, LR=1.25e-06, Swin=base]
CASME II Swin Validation Epoch 26/50: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s, Val Loss=0.0367, Swin=base]


Train - Loss: 0.0126, F1: 0.9620, Acc: 0.9652
Val   - Loss: 0.1387, F1: 0.2310, Acc: 0.4231
Time  - Epoch: 14.5s, LR: 1.25e-06
Progress: 52.0% | Best F1: 0.2619 | ETA: 6.0min | Swin-base

Epoch 27/50


CASME II Swin Training Epoch 27/50: 100%|██████████| 15/15 [00:13<00:00,  1.13it/s, Loss=0.0105, LR=1.25e-06, Swin=base]
CASME II Swin Validation Epoch 27/50: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s, Val Loss=0.0382, Swin=base]


Train - Loss: 0.0110, F1: 0.9895, Acc: 0.9851
Val   - Loss: 0.1387, F1: 0.2395, Acc: 0.4615
Time  - Epoch: 14.6s, LR: 1.00e-06
Progress: 54.0% | Best F1: 0.2619 | ETA: 5.7min | Swin-base

Epoch 28/50


CASME II Swin Training Epoch 28/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0101, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 28/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0375, Swin=base]


Train - Loss: 0.0104, F1: 0.9901, Acc: 0.9851
Val   - Loss: 0.1394, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 1.00e-06
Progress: 56.0% | Best F1: 0.2619 | ETA: 5.5min | Swin-base

Epoch 29/50


CASME II Swin Training Epoch 29/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0105, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 29/50: 100%|██████████| 2/2 [00:01<00:00,  1.54it/s, Val Loss=0.0379, Swin=base]


Train - Loss: 0.0106, F1: 0.9839, Acc: 0.9801
Val   - Loss: 0.1398, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 58.0% | Best F1: 0.2619 | ETA: 5.2min | Swin-base

Epoch 30/50


CASME II Swin Training Epoch 30/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0116, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 30/50: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s, Val Loss=0.0390, Swin=base]


Train - Loss: 0.0114, F1: 0.9730, Acc: 0.9701
Val   - Loss: 0.1397, F1: 0.2395, Acc: 0.4231
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 60.0% | Best F1: 0.2619 | ETA: 5.0min | Swin-base

Epoch 31/50


CASME II Swin Training Epoch 31/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0099, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 31/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0376, Swin=base]


Train - Loss: 0.0098, F1: 0.9753, Acc: 0.9751
Val   - Loss: 0.1406, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 1.00e-06
Progress: 62.0% | Best F1: 0.2619 | ETA: 4.7min | Swin-base

Epoch 32/50


CASME II Swin Training Epoch 32/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0094, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 32/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0376, Swin=base]


Train - Loss: 0.0096, F1: 0.9882, Acc: 0.9851
Val   - Loss: 0.1409, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 64.0% | Best F1: 0.2619 | ETA: 4.5min | Swin-base

Epoch 33/50


CASME II Swin Training Epoch 33/50: 100%|██████████| 15/15 [00:13<00:00,  1.13it/s, Loss=0.0097, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 33/50: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s, Val Loss=0.0379, Swin=base]


Train - Loss: 0.0092, F1: 0.9777, Acc: 0.9851
Val   - Loss: 0.1406, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 66.0% | Best F1: 0.2619 | ETA: 4.2min | Swin-base

Epoch 34/50


CASME II Swin Training Epoch 34/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0087, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 34/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0378, Swin=base]


Train - Loss: 0.0087, F1: 0.9838, Acc: 0.9900
Val   - Loss: 0.1401, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 1.00e-06
Progress: 68.0% | Best F1: 0.2619 | ETA: 4.0min | Swin-base

Epoch 35/50


CASME II Swin Training Epoch 35/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0100, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 35/50: 100%|██████████| 2/2 [00:01<00:00,  1.59it/s, Val Loss=0.0375, Swin=base]


Train - Loss: 0.0096, F1: 0.9863, Acc: 0.9851
Val   - Loss: 0.1405, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 1.00e-06
Progress: 70.0% | Best F1: 0.2619 | ETA: 3.7min | Swin-base

Epoch 36/50


CASME II Swin Training Epoch 36/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0099, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 36/50: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s, Val Loss=0.0374, Swin=base]


Train - Loss: 0.0102, F1: 0.9774, Acc: 0.9652
Val   - Loss: 0.1415, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 72.0% | Best F1: 0.2619 | ETA: 3.5min | Swin-base

Epoch 37/50


CASME II Swin Training Epoch 37/50: 100%|██████████| 15/15 [00:13<00:00,  1.13it/s, Loss=0.0095, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 37/50: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s, Val Loss=0.0379, Swin=base]


Train - Loss: 0.0090, F1: 0.9800, Acc: 0.9851
Val   - Loss: 0.1417, F1: 0.2395, Acc: 0.4615
Time  - Epoch: 14.6s, LR: 1.00e-06
Progress: 74.0% | Best F1: 0.2619 | ETA: 3.2min | Swin-base

Epoch 38/50


CASME II Swin Training Epoch 38/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0093, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 38/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0386, Swin=base]


Train - Loss: 0.0102, F1: 0.9823, Acc: 0.9900
Val   - Loss: 0.1413, F1: 0.2395, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 76.0% | Best F1: 0.2619 | ETA: 3.0min | Swin-base

Epoch 39/50


CASME II Swin Training Epoch 39/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0098, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 39/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0379, Swin=base]


Train - Loss: 0.0101, F1: 0.9911, Acc: 0.9851
Val   - Loss: 0.1418, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 1.00e-06
Progress: 78.0% | Best F1: 0.2619 | ETA: 2.7min | Swin-base

Epoch 40/50


CASME II Swin Training Epoch 40/50: 100%|██████████| 15/15 [00:13<00:00,  1.13it/s, Loss=0.0089, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 40/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0382, Swin=base]


Train - Loss: 0.0096, F1: 0.9916, Acc: 0.9851
Val   - Loss: 0.1412, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.6s, LR: 1.00e-06
Progress: 80.0% | Best F1: 0.2619 | ETA: 2.5min | Swin-base

Epoch 41/50


CASME II Swin Training Epoch 41/50: 100%|██████████| 15/15 [00:13<00:00,  1.13it/s, Loss=0.0082, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 41/50: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s, Val Loss=0.0385, Swin=base]


Train - Loss: 0.0083, F1: 0.9962, Acc: 0.9950
Val   - Loss: 0.1409, F1: 0.2395, Acc: 0.4231
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 82.0% | Best F1: 0.2619 | ETA: 2.2min | Swin-base

Epoch 42/50


CASME II Swin Training Epoch 42/50: 100%|██████████| 15/15 [00:13<00:00,  1.13it/s, Loss=0.0093, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 42/50: 100%|██████████| 2/2 [00:01<00:00,  1.53it/s, Val Loss=0.0379, Swin=base]


Train - Loss: 0.0093, F1: 0.9680, Acc: 0.9851
Val   - Loss: 0.1410, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.6s, LR: 1.00e-06
Progress: 84.0% | Best F1: 0.2619 | ETA: 2.0min | Swin-base

Epoch 43/50


CASME II Swin Training Epoch 43/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0097, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 43/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0381, Swin=base]


Train - Loss: 0.0093, F1: 0.9962, Acc: 0.9950
Val   - Loss: 0.1412, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 86.0% | Best F1: 0.2619 | ETA: 1.7min | Swin-base

Epoch 44/50


CASME II Swin Training Epoch 44/50: 100%|██████████| 15/15 [00:13<00:00,  1.15it/s, Loss=0.0088, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 44/50: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s, Val Loss=0.0379, Swin=base]


Train - Loss: 0.0092, F1: 0.9823, Acc: 0.9900
Val   - Loss: 0.1415, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.4s, LR: 1.00e-06
Progress: 88.0% | Best F1: 0.2619 | ETA: 1.5min | Swin-base

Epoch 45/50


CASME II Swin Training Epoch 45/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0082, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 45/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0382, Swin=base]


Train - Loss: 0.0094, F1: 0.9744, Acc: 0.9900
Val   - Loss: 0.1415, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 90.0% | Best F1: 0.2619 | ETA: 1.2min | Swin-base

Epoch 46/50


CASME II Swin Training Epoch 46/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0080, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 46/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0385, Swin=base]


Train - Loss: 0.0082, F1: 0.9931, Acc: 0.9900
Val   - Loss: 0.1421, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 92.0% | Best F1: 0.2619 | ETA: 1.0min | Swin-base

Epoch 47/50


CASME II Swin Training Epoch 47/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0069, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 47/50: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s, Val Loss=0.0388, Swin=base]


Train - Loss: 0.0072, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.1424, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 94.0% | Best F1: 0.2619 | ETA: 0.7min | Swin-base

Epoch 48/50


CASME II Swin Training Epoch 48/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0075, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 48/50: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s, Val Loss=0.0390, Swin=base]


Train - Loss: 0.0080, F1: 0.9701, Acc: 0.9851
Val   - Loss: 0.1431, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 96.0% | Best F1: 0.2619 | ETA: 0.5min | Swin-base

Epoch 49/50


CASME II Swin Training Epoch 49/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0079, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 49/50: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s, Val Loss=0.0390, Swin=base]


Train - Loss: 0.0080, F1: 0.9911, Acc: 0.9851
Val   - Loss: 0.1429, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 98.0% | Best F1: 0.2619 | ETA: 0.2min | Swin-base

Epoch 50/50


CASME II Swin Training Epoch 50/50: 100%|██████████| 15/15 [00:13<00:00,  1.14it/s, Loss=0.0080, LR=1.00e-06, Swin=base]
CASME II Swin Validation Epoch 50/50: 100%|██████████| 2/2 [00:01<00:00,  1.56it/s, Val Loss=0.0389, Swin=base]


Train - Loss: 0.0079, F1: 0.9861, Acc: 0.9950
Val   - Loss: 0.1434, F1: 0.2494, Acc: 0.4615
Time  - Epoch: 14.5s, LR: 1.00e-06
Progress: 100.0% | Best F1: 0.2619 | ETA: 0.0min | Swin-base

CASME II SWIN TRANSFORMER ENHANCED BASELINE TRAINING COMPLETED
Training time: 12.3 minutes
Epochs completed: 50
Best validation F1: 0.2619 (epoch 11)
Final train F1: 0.9861
Final validation F1: 0.2494
Swin Architecture: base variant

Exporting enhanced Swin Transformer training documentation...
Enhanced Swin Transformer training documentation saved: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/02_02_swint_direct/training_logs/casme2_swint_direct_training_history.json
Experiment details: Optimized Focal Loss loss
  Gamma: 2.0, Alpha Sum: 0.999
Model variant: microsoft/swin-base-patch4-window7-224 (base)
Checkpoint saving: FIXED with complete device migration

Next: Cell 3 - CASME II Swin Transformer Enhanced Evaluation
Enhanced Swin Transformer training pipeline

In [None]:
# @title Cell 3: CASME II Swin Transformer Direct Baseline Evaluation

# File: 02_02_SwinT_Direct_Baseline_Cell3.py
# Location: experiments/02_02_SwinT_Direct_Baseline.ipynb
# Purpose: Evaluation framework for trained CASME II Swin Transformer micro-expression recognition model

import os
import time
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime

# Evaluation specific imports
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import pickle
import warnings
warnings.filterwarnings('ignore')

# Enhanced test dataset for CASME II Swin Transformer evaluation
class CASME2DatasetEvaluationSwin(Dataset):
    """Enhanced CASME II test dataset with comprehensive evaluation support for Swin Transformer"""

    def __init__(self, split_metadata, dataset_root, transform=None, split='test', use_ram_cache=True):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.emotions = []
        self.subjects = []
        self.cached_images = []

        print(f"Loading CASME II {split} dataset for Swin Transformer evaluation...")

        # Process metadata for evaluation
        for sample in self.metadata:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(CLASS_TO_IDX[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])
            self.emotions.append(sample['emotion'])
            self.subjects.append(sample['subject'])

        print(f"Loaded {len(self.images)} CASME II {split} samples for Swin evaluation")
        self._print_evaluation_distribution()

        # RAM caching for fast evaluation
        if self.use_ram_cache:
            self._preload_to_ram_evaluation()

    def _print_evaluation_distribution(self):
        """Print comprehensive distribution for evaluation analysis"""
        if len(self.labels) == 0:
            print("No test samples found!")
            return

        label_counts = {}
        subject_counts = {}

        for label, subject in zip(self.labels, self.subjects):
            label_counts[label] = label_counts.get(label, 0) + 1
            subject_counts[subject] = subject_counts.get(subject, 0) + 1

        print("Test set class distribution:")
        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

        print(f"Test set covers {len(subject_counts)} subjects")

        # Check for missing classes
        missing_classes = []
        for i, class_name in enumerate(CASME2_CLASSES):
            if i not in label_counts:
                missing_classes.append(class_name)

        if missing_classes:
            print(f"Missing classes in test set: {missing_classes}")

    def _preload_to_ram_evaluation(self):
        """RAM preloading optimized for Swin Transformer evaluation"""
        print(f"Preloading {len(self.images)} test images to RAM for Swin evaluation...")

        valid_images = 0
        self.cached_images = [None] * len(self.images)

        # Simple loading for evaluation stability
        for i, img_path in enumerate(tqdm(self.images, desc="Loading test images")):
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
                self.cached_images[i] = image
                valid_images += 1
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                # Create neutral placeholder
                self.cached_images[i] = Image.new('RGB', (384, 384), (128, 128, 128))

        ram_usage_gb = len(self.cached_images) * 384 * 384 * 3 * 4 / 1e9
        print(f"Test RAM caching completed: {valid_images} valid images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (384, 384):
                    image = image.resize((384, 384), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (384, 384), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return (image, self.labels[idx], self.sample_ids[idx],
                self.emotions[idx], self.subjects[idx], os.path.basename(self.images[idx]))

# CASME II Swin Transformer evaluation configuration
EVALUATION_CONFIG_CASME2_SWIN = {
    'model_type': 'SwinT_CASME2_Direct_Baseline',
    'task_type': 'micro_expression_recognition',
    'num_classes': 7,
    'class_names': CASME2_CLASSES,
    'checkpoint_file': 'casme2_swint_direct_best_f1.pth',
    'dataset_name': 'CASME_II',
    'input_size': '384x384',
    'evaluation_protocol': 'stratified_split',
    'architecture': 'swin_transformer_hierarchical'
}

print("CASME II Swin Transformer Direct Baseline Evaluation Framework")
print("=" * 65)
print(f"Model: {EVALUATION_CONFIG_CASME2_SWIN['model_type']}")
print(f"Task: {EVALUATION_CONFIG_CASME2_SWIN['task_type']}")
print(f"Classes: {EVALUATION_CONFIG_CASME2_SWIN['class_names']}")
print(f"Input size: {EVALUATION_CONFIG_CASME2_SWIN['input_size']}")
print(f"Architecture: {EVALUATION_CONFIG_CASME2_SWIN['architecture']}")

def extract_logits_safe_casme2_swin(outputs_all):
    """Robust logits extraction for CASME II Swin Transformer model with hierarchical processing"""
    if isinstance(outputs_all, torch.Tensor):
        return outputs_all
    if isinstance(outputs_all, (tuple, list)):
        for item in outputs_all:
            if isinstance(item, torch.Tensor):
                return item
    if isinstance(outputs_all, dict):
        # Handle Swin-specific output keys
        for key in ('logits', 'logit', 'predictions', 'outputs', 'scores', 'pooler_output', 'last_hidden_state'):
            value = outputs_all.get(key)
            if isinstance(value, torch.Tensor):
                return value
        # Fallback to first tensor value
        for value in outputs_all.values():
            if isinstance(value, torch.Tensor):
                return value
    raise RuntimeError("Unable to extract tensor logits from CASME II Swin Transformer model output")

def load_trained_model_casme2_swin(checkpoint_path, device):
    """Load trained CASME II Swin Transformer model with comprehensive compatibility"""
    print(f"Loading trained CASME II Swin Transformer model from: {checkpoint_path}")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Swin checkpoint not found: {checkpoint_path}")

    # Multiple loading approaches for maximum compatibility
    checkpoint = None
    loading_method = "unknown"

    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        loading_method = "standard"
    except Exception as e1:
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            loading_method = "weights_only_false"
        except Exception as e2:
            try:
                import pickle
                with open(checkpoint_path, 'rb') as f:
                    checkpoint = pickle.load(f)
                loading_method = "pickle"
            except Exception as e3:
                raise RuntimeError(f"All Swin loading methods failed: {e1}, {e2}, {e3}")

    print(f"Swin checkpoint loaded using: {loading_method}")

    # Extract Swin-specific configuration
    swin_config = checkpoint.get('casme2_swin_config', {})
    swin_variant = checkpoint.get('swin_variant', 'tiny')
    swin_model_name = checkpoint.get('swin_model', 'microsoft/swin-tiny-patch4-window7-224')

    print(f"Detected Swin variant: {swin_variant}")
    print(f"Detected Swin model: {swin_model_name}")

    # Initialize CASME II Swin Transformer model with detected configuration
    model = SwinCASME2Baseline(
        num_classes=EVALUATION_CONFIG_CASME2_SWIN['num_classes'],
        dropout_rate=swin_config.get('dropout_rate', 0.2)
    ).to(device)

    # Load state dict with fallback approaches
    state_dict = checkpoint.get('model_state_dict', checkpoint)

    try:
        model.load_state_dict(state_dict, strict=True)
        print("Swin model state loaded with strict=True")
    except Exception as e:
        print(f"Strict loading failed, trying non-strict: {str(e)[:100]}...")
        try:
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            if missing_keys or unexpected_keys:
                print(f"Non-strict loading: Missing {len(missing_keys)}, Unexpected {len(unexpected_keys)}")
            else:
                print("Swin model state loaded with strict=False (no key mismatches)")
        except Exception as e2:
            raise RuntimeError(f"Both Swin loading approaches failed: {e2}")

    model.eval()

    # Extract training information
    training_info = {
        'best_val_f1': float(checkpoint.get('best_f1', 0.0)),
        'best_val_loss': float(checkpoint.get('best_loss', float('inf'))),
        'best_val_accuracy': float(checkpoint.get('best_acc', 0.0)),
        'best_epoch': int(checkpoint.get('epoch', 0)) + 1,
        'model_checkpoint': EVALUATION_CONFIG_CASME2_SWIN['checkpoint_file'],
        'num_classes': EVALUATION_CONFIG_CASME2_SWIN['num_classes'],
        'swin_variant': swin_variant,
        'swin_model': swin_model_name,
        'config': checkpoint.get('casme2_swin_config', {})
    }

    print(f"Swin Transformer model loaded successfully:")
    print(f"  Best validation F1: {training_info['best_val_f1']:.4f}")
    print(f"  Best validation accuracy: {training_info['best_val_accuracy']:.4f}")
    print(f"  Best epoch: {training_info['best_epoch']}")
    print(f"  Swin variant: {training_info['swin_variant']}")
    print(f"  Model classes: {EVALUATION_CONFIG_CASME2_SWIN['num_classes']}")

    return model, training_info

def run_model_inference_casme2_swin(model, test_loader, device):
    """Run CASME II Swin Transformer model inference with hierarchical feature tracking"""
    print("Running CASME II Swin Transformer model inference on test set...")

    model.eval()
    all_predictions = []
    all_probabilities = []
    all_labels = []
    all_sample_ids = []
    all_emotions = []
    all_subjects = []
    all_filenames = []

    inference_start = time.time()

    with torch.no_grad():
        for batch_idx, (images, labels, sample_ids, emotions, subjects, filenames) in enumerate(
            tqdm(test_loader, desc="CASME II Swin Inference")):

            images = images.to(device)

            # Forward pass with robust Swin hierarchical output extraction
            try:
                outputs_raw = model(images)
                outputs = extract_logits_safe_casme2_swin(outputs_raw)
            except Exception as e:
                print(f"Error in Swin model forward pass: {e}")
                outputs = model(images)
                if not isinstance(outputs, torch.Tensor):
                    outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs

            # Validate output shape for 7 CASME II classes
            if outputs.shape[1] != 7:
                print(f"Warning: Expected 7 classes output, got {outputs.shape[1]}")

            # Get probabilities and predictions
            probabilities = torch.softmax(outputs, dim=1)
            predictions = torch.argmax(probabilities, dim=1)

            # Store results (CPU for memory efficiency)
            all_predictions.extend(predictions.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_sample_ids.extend(sample_ids)
            all_emotions.extend(emotions)
            all_subjects.extend(subjects)
            all_filenames.extend(filenames)

    inference_time = time.time() - inference_start

    # Convert to arrays
    predictions_array = np.array(all_predictions)
    probabilities_array = np.array(all_probabilities)
    labels_array = np.array(all_labels)

    print(f"CASME II Swin inference completed: {len(predictions_array)} samples in {inference_time:.2f}s")

    # Analyze prediction distribution
    unique_predictions, pred_counts = np.unique(predictions_array, return_counts=True)
    print(f"Predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    unique_labels, label_counts = np.unique(labels_array, return_counts=True)
    print(f"True classes in test: {[CASME2_CLASSES[i] for i in unique_labels]}")

    return {
        'predictions': predictions_array,
        'probabilities': probabilities_array,
        'labels': labels_array,
        'sample_ids': all_sample_ids,
        'emotions': all_emotions,
        'subjects': all_subjects,
        'filenames': all_filenames,
        'inference_time': inference_time,
        'samples_count': len(predictions_array)
    }

def analyze_wrong_predictions_casme2_swin(inference_results):
    """Comprehensive wrong predictions analysis for CASME II Swin Transformer"""
    print("Analyzing wrong predictions for CASME II Swin Transformer micro-expression recognition...")

    predictions = inference_results['predictions']
    labels = inference_results['labels']
    sample_ids = inference_results['sample_ids']
    emotions = inference_results['emotions']
    subjects = inference_results['subjects']
    filenames = inference_results['filenames']

    # Find wrong predictions
    wrong_mask = predictions != labels
    wrong_indices = np.where(wrong_mask)[0]

    # Organize by true emotion class
    wrong_predictions_by_class = {}
    subject_error_analysis = {}

    for class_name in CASME2_CLASSES:
        wrong_predictions_by_class[class_name] = []

    # Analyze wrong predictions
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        sample_id = sample_ids[idx]
        emotion = emotions[idx]
        subject = subjects[idx]
        filename = filenames[idx]

        true_class = CASME2_CLASSES[true_label]
        pred_class = CASME2_CLASSES[pred_label]

        wrong_info = {
            'sample_id': sample_id,
            'filename': filename,
            'subject': subject,
            'true_label': int(true_label),
            'true_class': true_class,
            'predicted_label': int(pred_label),
            'predicted_class': pred_class,
            'emotion': emotion
        }

        wrong_predictions_by_class[true_class].append(wrong_info)

        # Subject error tracking
        if subject not in subject_error_analysis:
            subject_error_analysis[subject] = {'total': 0, 'wrong': 0, 'errors': []}
        subject_error_analysis[subject]['wrong'] += 1
        subject_error_analysis[subject]['errors'].append(wrong_info)

    # Count total samples per subject
    for subject in subjects:
        if subject in subject_error_analysis:
            subject_error_analysis[subject]['total'] += 1
        else:
            subject_error_analysis[subject] = {'total': 1, 'wrong': 0, 'errors': []}

    # Calculate error rates per subject
    for subject in subject_error_analysis:
        total = subject_error_analysis[subject]['total']
        wrong = subject_error_analysis[subject]['wrong']
        subject_error_analysis[subject]['error_rate'] = wrong / total if total > 0 else 0.0

    # Summary statistics
    total_wrong = len(wrong_indices)
    total_samples = len(predictions)
    error_rate = (total_wrong / total_samples) * 100

    # Confusion patterns analysis
    confusion_patterns = {}
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        pattern = f"{CASME2_CLASSES[true_label]}_to_{CASME2_CLASSES[pred_label]}"
        confusion_patterns[pattern] = confusion_patterns.get(pattern, 0) + 1

    analysis_results = {
        'analysis_metadata': {
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'model_type': EVALUATION_CONFIG_CASME2_SWIN['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2_SWIN['dataset_name'],
            'architecture': EVALUATION_CONFIG_CASME2_SWIN['architecture'],
            'total_samples': int(total_samples),
            'total_wrong_predictions': int(total_wrong),
            'overall_error_rate': float(error_rate)
        },
        'wrong_predictions_by_class': wrong_predictions_by_class,
        'subject_error_analysis': subject_error_analysis,
        'confusion_patterns': confusion_patterns,
        'error_summary': {
            class_name: len(wrong_predictions_by_class[class_name])
            for class_name in CASME2_CLASSES
        }
    }

    return analysis_results

def calculate_comprehensive_metrics_casme2_swin(inference_results):
    """Calculate comprehensive evaluation metrics for CASME II Swin Transformer micro-expression recognition"""
    print("Calculating comprehensive metrics for CASME II Swin Transformer micro-expression recognition...")

    predictions = inference_results['predictions']
    probabilities = inference_results['probabilities']
    labels = inference_results['labels']

    if len(predictions) == 0:
        raise ValueError("No predictions to evaluate!")

    # Identify available classes in test set
    unique_test_labels = sorted(np.unique(labels))
    unique_predictions = sorted(np.unique(predictions))

    print(f"Test set contains labels: {[CASME2_CLASSES[i] for i in unique_test_labels]}")
    print(f"Swin model predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    # Basic metrics
    accuracy = accuracy_score(labels, predictions)

    # Macro metrics (only for available classes)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions, labels=unique_test_labels, average='macro', zero_division=0
    )

    print(f"Macro F1 (available classes): {f1:.4f}")

    # Per-class metrics (all 7 classes)
    precision_per_class, recall_per_class, f1_per_class, support_per_class = precision_recall_fscore_support(
        labels, predictions, labels=range(7), average=None, zero_division=0
    )

    # Confusion matrix
    cm = confusion_matrix(labels, predictions, labels=range(7))

    # Multi-class AUC (only for classes with test samples)
    auc_scores = {}
    fpr_dict = {}
    tpr_dict = {}

    try:
        labels_binarized = label_binarize(labels, classes=range(7))

        for i, class_name in enumerate(CASME2_CLASSES):
            if i in unique_test_labels and len(np.unique(labels_binarized[:, i])) > 1:
                fpr, tpr, _ = roc_curve(labels_binarized[:, i], probabilities[:, i])
                auc_score = auc(fpr, tpr)
                auc_scores[class_name] = float(auc_score)
                fpr_dict[class_name] = fpr.tolist()
                tpr_dict[class_name] = tpr.tolist()
            else:
                auc_scores[class_name] = 0.0
                fpr_dict[class_name] = [0.0, 1.0]
                tpr_dict[class_name] = [0.0, 0.0]

        # Macro AUC for available classes
        available_auc_scores = [auc_scores[CASME2_CLASSES[i]] for i in unique_test_labels]
        macro_auc = float(np.mean(available_auc_scores)) if available_auc_scores else 0.0

    except Exception as e:
        print(f"Warning: AUC calculation failed: {e}")
        auc_scores = {class_name: 0.0 for class_name in CASME2_CLASSES}
        macro_auc = 0.0

    # Subject-level analysis
    subjects = inference_results['subjects']
    subject_performance = {}

    for subject in set(subjects):
        subject_mask = [s == subject for s in subjects]
        subject_predictions = predictions[subject_mask]
        subject_labels = labels[subject_mask]

        if len(subject_predictions) > 0:
            subject_acc = accuracy_score(subject_labels, subject_predictions)
            subject_performance[subject] = {
                'accuracy': float(subject_acc),
                'samples': int(len(subject_predictions)),
                'correct': int(np.sum(subject_predictions == subject_labels))
            }

    # Comprehensive results
    comprehensive_results = {
        'evaluation_metadata': {
            'model_type': EVALUATION_CONFIG_CASME2_SWIN['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2_SWIN['dataset_name'],
            'architecture': EVALUATION_CONFIG_CASME2_SWIN['architecture'],
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'num_classes': EVALUATION_CONFIG_CASME2_SWIN['num_classes'],
            'class_names': EVALUATION_CONFIG_CASME2_SWIN['class_names'],
            'test_samples': int(len(labels)),
            'available_classes': [CASME2_CLASSES[i] for i in unique_test_labels],
            'missing_classes': [CASME2_CLASSES[i] for i in range(7) if i not in unique_test_labels]
        },

        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_precision': float(precision),
            'macro_recall': float(recall),
            'macro_f1': float(f1),
            'macro_auc': macro_auc
        },

        'per_class_performance': {},

        'confusion_matrix': cm.tolist(),

        'subject_level_performance': subject_performance,

        'roc_analysis': {
            'auc_scores': auc_scores,
            'fpr_curves': fpr_dict,
            'tpr_curves': tpr_dict
        },

        'inference_performance': {
            'total_time_seconds': float(inference_results['inference_time']),
            'average_time_ms_per_sample': float(inference_results['inference_time'] * 1000 / len(labels))
        }
    }

    # Per-class performance details
    for i, class_name in enumerate(CASME2_CLASSES):
        comprehensive_results['per_class_performance'][class_name] = {
            'precision': float(precision_per_class[i]),
            'recall': float(recall_per_class[i]),
            'f1_score': float(f1_per_class[i]),
            'support': int(support_per_class[i]),
            'auc': auc_scores[class_name],
            'in_test_set': i in unique_test_labels
        }

    return comprehensive_results

def save_evaluation_results_casme2_swin(evaluation_results, wrong_predictions_results, results_dir):
    """Save comprehensive evaluation results for CASME II Swin Transformer"""
    os.makedirs(results_dir, exist_ok=True)

    # Save main evaluation results
    results_file = f"{results_dir}/casme2_swint_direct_evaluation_results.json"
    with open(results_file, 'w') as f:
        json.dump(evaluation_results, f, indent=2, default=str)

    # Save wrong predictions analysis
    wrong_predictions_file = f"{results_dir}/casme2_swint_direct_wrong_predictions.json"
    with open(wrong_predictions_file, 'w') as f:
        json.dump(wrong_predictions_results, f, indent=2, default=str)

    print(f"Swin Transformer evaluation results saved:")
    print(f"  Main results: {results_file}")
    print(f"  Wrong predictions: {wrong_predictions_file}")

    return results_file, wrong_predictions_file

# Main evaluation execution for Swin Transformer
try:
    print("Starting CASME II Swin Transformer Direct Baseline comprehensive evaluation...")

    # Create test dataset
    print("Creating CASME II test dataset for Swin Transformer...")
    casme2_test_dataset = CASME2DatasetEvaluationSwin(
        split_metadata=GLOBAL_CONFIG_CASME2['metadata'],
        dataset_root=GLOBAL_CONFIG_CASME2['test_path'].replace('/test', ''),
        transform=GLOBAL_CONFIG_CASME2['transform_val'],
        split='test',
        use_ram_cache=True
    )

    if len(casme2_test_dataset) == 0:
        raise ValueError("No test samples found! Check test data path.")

    casme2_test_loader = DataLoader(
        casme2_test_dataset,
        batch_size=CASME2_SWIN_CONFIG['batch_size'],
        shuffle=False,
        num_workers=CASME2_SWIN_CONFIG['num_workers'],
        pin_memory=True
    )

    # Load trained Swin Transformer model
    checkpoint_path = f"{GLOBAL_CONFIG_CASME2['checkpoint_root']}/{EVALUATION_CONFIG_CASME2_SWIN['checkpoint_file']}"
    casme2_swin_model, training_info = load_trained_model_casme2_swin(checkpoint_path, GLOBAL_CONFIG_CASME2['device'])

    # Run Swin Transformer inference
    inference_results = run_model_inference_casme2_swin(casme2_swin_model, casme2_test_loader, GLOBAL_CONFIG_CASME2['device'])

    # Calculate comprehensive metrics
    evaluation_results = calculate_comprehensive_metrics_casme2_swin(inference_results)

    # Analyze wrong predictions
    wrong_predictions_results = analyze_wrong_predictions_casme2_swin(inference_results)

    # Add training information
    evaluation_results['training_information'] = training_info

    # Save results
    results_dir = f"{GLOBAL_CONFIG_CASME2['results_root']}/evaluation_results"
    results_file, wrong_file = save_evaluation_results_casme2_swin(
        evaluation_results, wrong_predictions_results, results_dir
    )

    # Display comprehensive results
    print("\n" + "=" * 65)
    print("CASME II SWIN TRANSFORMER DIRECT BASELINE EVALUATION RESULTS")
    print("=" * 65)

    # Overall performance
    overall = evaluation_results['overall_performance']
    print(f"Overall Performance (Macro - Available Classes):")
    print(f"  Accuracy:  {overall['accuracy']:.4f}")
    print(f"  Precision: {overall['macro_precision']:.4f}")
    print(f"  Recall:    {overall['macro_recall']:.4f}")
    print(f"  F1 Score:  {overall['macro_f1']:.4f}")
    print(f"  AUC:       {overall['macro_auc']:.4f}")

    # Per-class performance
    print(f"\nPer-Class Performance:")
    for class_name, metrics in evaluation_results['per_class_performance'].items():
        in_test = "Present" if metrics['in_test_set'] else "Missing"
        print(f"  {class_name} [{in_test}]: F1={metrics['f1_score']:.4f}, "
              f"AUC={metrics['auc']:.4f}, Support={metrics['support']}")

    # Training vs test comparison
    print(f"\nTraining vs Test Performance:")
    training_f1 = training_info['best_val_f1']
    training_acc = training_info['best_val_accuracy']
    test_f1 = overall['macro_f1']
    test_acc = overall['accuracy']

    print(f"  Training Val F1:  {training_f1:.4f}")
    print(f"  Test F1:          {test_f1:.4f}")
    print(f"  F1 Difference:    {training_f1 - test_f1:+.4f}")
    print(f"  Training Val Acc: {training_acc:.4f}")
    print(f"  Test Accuracy:    {test_acc:.4f}")
    print(f"  Acc Difference:   {training_acc - test_acc:+.4f}")
    print(f"  Best Epoch:       {training_info['best_epoch']}")
    print(f"  Swin Variant:     {training_info['swin_variant']}")

    # Wrong predictions summary
    print(f"\n" + "=" * 40)
    print("WRONG PREDICTIONS ANALYSIS")
    print("=" * 40)

    wrong_meta = wrong_predictions_results['analysis_metadata']
    print(f"Total wrong predictions: {wrong_meta['total_wrong_predictions']} / {wrong_meta['total_samples']}")
    print(f"Overall error rate: {wrong_meta['overall_error_rate']:.2f}%")

    print(f"\nErrors by True Class:")
    for class_name, error_count in wrong_predictions_results['error_summary'].items():
        if error_count > 0:
            wrong_samples = wrong_predictions_results['wrong_predictions_by_class'][class_name]
            print(f"  {class_name}: {error_count} errors")
            for sample in wrong_samples[:3]:  # Show first 3
                print(f"    - {sample['filename']} -> predicted as {sample['predicted_class']}")
            if len(wrong_samples) > 3:
                print(f"    ... and {len(wrong_samples) - 3} more")

    # Subject-level analysis
    print(f"\nSubject-Level Performance:")
    subject_perfs = list(evaluation_results['subject_level_performance'].items())
    subject_perfs.sort(key=lambda x: x[1]['accuracy'], reverse=True)
    for subject, perf in subject_perfs[:5]:  # Show top 5
        print(f"  {subject}: {perf['accuracy']:.3f} ({perf['correct']}/{perf['samples']})")

    # Most common confusion patterns
    print(f"\nMost Common Confusion Patterns:")
    patterns = sorted(wrong_predictions_results['confusion_patterns'].items(),
                     key=lambda x: x[1], reverse=True)
    for pattern, count in patterns[:3]:
        print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    print(f"  Total time: {inference_results['inference_time']:.2f}s")
    print(f"  Speed: {evaluation_results['inference_performance']['average_time_ms_per_sample']:.1f} ms/sample")

    print(f"\nMissing Classes: {evaluation_results['evaluation_metadata']['missing_classes']}")

    print("\n" + "=" * 65)
    print("CASME II SWIN TRANSFORMER DIRECT BASELINE EVALUATION COMPLETED")
    print("=" * 65)

except Exception as e:
    print(f"Swin Transformer evaluation failed: {e}")
    import traceback
    traceback.print_exc()

finally:
    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

print("Next: Cell 4 - Swin Transformer Confusion Matrix Analysis and Visualization")

CASME II Swin Transformer Direct Baseline Evaluation Framework
Model: SwinT_CASME2_Direct_Baseline
Task: micro_expression_recognition
Classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
Input size: 384x384
Architecture: swin_transformer_hierarchical
Starting CASME II Swin Transformer Direct Baseline comprehensive evaluation...
Creating CASME II test dataset for Swin Transformer...
Loading CASME II test dataset for Swin Transformer evaluation...
Loaded 28 CASME II test samples for Swin evaluation
Test set class distribution:
  others: 10 samples (35.7%)
  disgust: 7 samples (25.0%)
  happiness: 4 samples (14.3%)
  repression: 3 samples (10.7%)
  surprise: 3 samples (10.7%)
  sadness: 1 samples (3.6%)
Test set covers 16 subjects
Missing classes in test set: ['fear']
Preloading 28 test images to RAM for Swin evaluation...


Loading test images: 100%|██████████| 28/28 [00:10<00:00,  2.76it/s]


Test RAM caching completed: 28 valid images, ~0.05GB
Loading trained CASME II Swin Transformer model from: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/02_02_swint_direct/casme2_swint_direct_best_f1.pth
Swin checkpoint loaded using: standard
Detected Swin variant: base
Detected Swin model: microsoft/swin-base-patch4-window7-224
Swin feature dimension (final stage): 1024
Base embed_dim: 128, Stages: 4
Swin CASME II: 1024 -> GAP -> 512 -> 128 -> 7
Swin model state loaded with strict=True
Swin Transformer model loaded successfully:
  Best validation F1: 0.2619
  Best validation accuracy: 0.4231
  Best epoch: 11
  Swin variant: base
  Model classes: 7
Running CASME II Swin Transformer model inference on test set...


CASME II Swin Inference: 100%|██████████| 2/2 [00:01<00:00,  1.49it/s]

CASME II Swin inference completed: 28 samples in 1.35s
Predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
True classes in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Calculating comprehensive metrics for CASME II Swin Transformer micro-expression recognition...
Test set contains labels: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Swin model predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Macro F1 (available classes): 0.4103
Analyzing wrong predictions for CASME II Swin Transformer micro-expression recognition...
Swin Transformer evaluation results saved:
  Main results: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/02_02_swint_direct/evaluation_results/casme2_swint_direct_evaluation_results.json
  Wrong predictions: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/02_02




In [None]:
# @title Cell 4: CASME II Swin Transformer Direct Baseline Confusion Matrix Generation

# File: 02_02_SwinT_Direct_Baseline_Cell4.py
# Location: experiments/02_02_SwinT_Direct_Baseline.ipynb
# Purpose: Generate professional confusion matrix and comprehensive analysis for CASME II Swin Transformer micro-expression recognition
# Dependencies: Trained model evaluation results from Cell 3

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II Swin Transformer Direct Baseline Confusion Matrix Generation")
print("=" * 65)

# Project paths configuration - Swin Transformer specific
PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/02_02_swint_direct"

def find_evaluation_json_files_casme2_swin(results_path):
    """Find CASME II Swin Transformer evaluation JSON files"""
    json_files = {}

    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        # Look for main Swin Transformer evaluation results
        eval_files = glob.glob(f"{eval_dir}/casme2_swint_direct_evaluation_results.json")
        if eval_files:
            json_files['main'] = eval_files[0]
            print(f"Found CASME II Swin evaluation file: {os.path.basename(eval_files[0])}")

        # Look for Swin Transformer wrong predictions analysis
        wrong_files = glob.glob(f"{eval_dir}/casme2_swint_direct_wrong_predictions.json")
        if wrong_files:
            json_files['wrong_predictions'] = wrong_files[0]
            print(f"Found Swin wrong predictions file: {os.path.basename(wrong_files[0])}")

        if not json_files:
            print(f"WARNING: No Swin Transformer evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (Swin evaluation) has been executed first!")
    else:
        print(f"ERROR: Swin evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results_casme2_swin(json_path):
    """Load and parse CASME II Swin Transformer evaluation results JSON"""
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded Swin evaluation results from: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1_casme2_swin(per_class_performance):
    """Calculate weighted F1 score for CASME II Swin Transformer micro-expression classes"""
    # Only count classes that have test samples (support > 0)
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:  # Only include classes with test samples
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy_casme2_swin(confusion_matrix):
    """
    Calculate balanced accuracy for CASME II Swin Transformer 7-class micro-expression recognition
    Handles classes with zero support (missing in test set)
    """
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []

    # Find classes with actual test samples
    classes_with_samples = []
    for i in range(n_classes):
        if cm[i, :].sum() > 0:  # Class has test samples
            classes_with_samples.append(i)

    for i in classes_with_samples:
        # True positives, false negatives, false positives, true negatives for class i
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp  # Sum of row i minus diagonal
        fp = cm[:, i].sum() - tp  # Sum of column i minus diagonal
        tn = cm.sum() - tp - fn - fp  # Total minus TP, FN, FP

        # Calculate sensitivity and specificity for class i
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        # Per-class balanced accuracy
        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    # Overall balanced accuracy (mean of available classes)
    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color_casme2_swin(color_value, threshold=0.5):
    """Determine optimal text color based on background intensity"""
    return 'white' if color_value > threshold else 'black'

def analyze_missing_classes_casme2_swin(data):
    """Analyze missing classes in CASME II Swin Transformer test set"""
    meta = data['evaluation_metadata']
    available_classes = meta.get('available_classes', [])
    missing_classes = meta.get('missing_classes', [])
    swin_variant = data.get('training_information', {}).get('swin_variant', 'unknown')

    print(f"Swin Transformer Class Analysis:")
    print(f"  Variant: {swin_variant}")
    print(f"  Available in test: {available_classes}")
    print(f"  Missing from test: {missing_classes}")

    return {
        'available': available_classes,
        'missing': missing_classes,
        'total_classes': len(meta['class_names']),
        'swin_variant': swin_variant
    }

def create_confusion_matrix_plot_casme2_swin(data, output_path):
    """Create professional confusion matrix visualization for CASME II Swin Transformer micro-expression recognition"""

    # Extract data
    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']
    training_info = data.get('training_information', {})
    swin_variant = training_info.get('swin_variant', 'unknown')

    print(f"Processing Swin Transformer confusion matrix for CASME II classes: {class_names}")
    print(f"Swin variant: {swin_variant}")
    print(f"Confusion matrix shape: {cm.shape}")

    # Calculate comprehensive metrics
    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1_casme2_swin(per_class)
    balanced_acc = calculate_balanced_accuracy_casme2_swin(cm)

    print(f"Calculated metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Balanced Acc: {balanced_acc:.4f}, Accuracy: {accuracy:.4f}")

    # Row-wise normalization for percentage display
    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    # Create visualization with appropriate size for 7 classes
    fig, ax = plt.subplots(figsize=(12, 10))

    # Color scheme optimized for Swin Transformer micro-expression research
    cmap = 'Blues'

    # Create heatmap with improved color scaling
    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    # Annotate cells with count and percentage
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            # Handle percentage calculation for classes with 0 samples
            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"  # Class has no test samples

            # Determine text color based on cell intensity
            cell_value = cm_pct[i, j]
            text_color = determine_text_color_casme2_swin(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    # Configure axes with micro-expression class names
    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    # Add note about missing classes and Swin architecture
    missing_classes = meta.get('missing_classes', [])
    note_lines = []
    if missing_classes:
        note_lines.append(f"Missing classes: {', '.join(missing_classes)}")
    note_lines.append(f"Swin-{swin_variant.capitalize()} | Hierarchical Processing")

    if note_lines:
        note_text = "\n".join(note_lines)
        ax.text(0.02, 0.98, note_text, transform=ax.transAxes, fontsize=9,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    # Create comprehensive title for Swin Transformer micro-expression research
    title = f"CASME II Swin Transformer ({swin_variant.capitalize()}) Micro-Expression Recognition\n"
    title += f"Acc: {accuracy:.4f}  |  Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Swin Transformer confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes,
        'swin_variant': swin_variant
    }

def create_per_class_performance_chart_casme2_swin(data, output_path):
    """Create per-class performance visualization for CASME II Swin Transformer"""
    per_class = data['per_class_performance']
    class_names = data['evaluation_metadata']['class_names']
    training_info = data.get('training_information', {})
    swin_variant = training_info.get('swin_variant', 'unknown')

    # Extract metrics for each class
    classes = []
    f1_scores = []
    precisions = []
    recalls = []
    supports = []
    in_test_flags = []

    for class_name in class_names:
        class_data = per_class[class_name]
        classes.append(class_name)
        f1_scores.append(class_data['f1_score'])
        precisions.append(class_data['precision'])
        recalls.append(class_data['recall'])
        supports.append(class_data['support'])
        in_test_flags.append(class_data['in_test_set'])

    # Create grouped bar chart
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

    x = np.arange(len(classes))
    width = 0.25

    # Top plot: F1, Precision, Recall
    bars1 = ax1.bar(x - width, f1_scores, width, label='F1 Score', alpha=0.8, color='steelblue')
    bars2 = ax1.bar(x, precisions, width, label='Precision', alpha=0.8, color='orange')
    bars3 = ax1.bar(x + width, recalls, width, label='Recall', alpha=0.8, color='green')

    ax1.set_xlabel('Emotion Classes', fontweight='bold')
    ax1.set_ylabel('Score', fontweight='bold')
    ax1.set_title(f'CASME II Swin Transformer ({swin_variant.capitalize()}) Per-Class Performance Metrics',
                 fontweight='bold', pad=20)
    ax1.set_xticks(x)
    ax1.set_xticklabels(classes, rotation=45, ha='right')
    ax1.legend()
    ax1.grid(axis='y', alpha=0.3)
    ax1.set_ylim(0, 1.0)

    # Add value labels on bars
    for bars in [bars1, bars2, bars3]:
        for bar, in_test in zip(bars, in_test_flags):
            height = bar.get_height()
            if in_test:
                ax1.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)
            else:
                # Mark missing classes
                ax1.annotate('N/A', xy=(bar.get_x() + bar.get_width() / 2, height),
                           xytext=(0, 3), textcoords="offset points", ha='center', va='bottom',
                           fontsize=8, color='red', fontweight='bold')

    # Bottom plot: Support (sample count)
    bars4 = ax2.bar(x, supports, color='purple', alpha=0.7)
    ax2.set_xlabel('Emotion Classes', fontweight='bold')
    ax2.set_ylabel('Number of Test Samples', fontweight='bold')
    ax2.set_title('CASME II Test Set Class Distribution (Swin Transformer)', fontweight='bold')
    ax2.set_xticks(x)
    ax2.set_xticklabels(classes, rotation=45, ha='right')
    ax2.grid(axis='y', alpha=0.3)

    # Add value labels on support bars
    for bar, support in zip(bars4, supports):
        height = bar.get_height()
        ax2.annotate(f'{int(support)}', xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=10)

    # Add architecture note
    fig.text(0.02, 0.02, f"Architecture: Swin-{swin_variant.capitalize()} Hierarchical Processing",
             fontsize=9, style='italic', alpha=0.7)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Swin Transformer per-class performance chart saved to: {os.path.basename(output_path)}")

def generate_performance_summary_casme2_swin(evaluation_data, wrong_predictions_data=None):
    """Generate comprehensive performance summary for CASME II Swin Transformer"""

    print("\n" + "=" * 65)
    print("CASME II SWIN TRANSFORMER MICRO-EXPRESSION RECOGNITION PERFORMANCE SUMMARY")
    print("=" * 65)

    # Overall performance
    overall = evaluation_data['overall_performance']
    meta = evaluation_data['evaluation_metadata']
    training_info = evaluation_data.get('training_information', {})
    swin_variant = training_info.get('swin_variant', 'unknown')

    print(f"Dataset: {meta['dataset']}")
    print(f"Test samples: {meta['test_samples']}")
    print(f"Model: {meta['model_type']}")
    print(f"Architecture: Swin-{swin_variant.capitalize()} Hierarchical Transformer")
    print(f"Evaluation date: {meta['evaluation_timestamp']}")

    print(f"\nOverall Performance:")
    print(f"  Accuracy:         {overall['accuracy']:.4f}")
    print(f"  Macro Precision:  {overall['macro_precision']:.4f}")
    print(f"  Macro Recall:     {overall['macro_recall']:.4f}")
    print(f"  Macro F1:         {overall['macro_f1']:.4f}")
    print(f"  Macro AUC:        {overall['macro_auc']:.4f}")

    # Per-class performance
    print(f"\nPer-Class Performance:")
    per_class = evaluation_data['per_class_performance']

    print(f"{'Class':<12} {'F1':<8} {'Precision':<10} {'Recall':<8} {'AUC':<8} {'Support':<8} {'In Test'}")
    print("-" * 65)

    for class_name, metrics in per_class.items():
        in_test = "Yes" if metrics['in_test_set'] else "No"
        print(f"{class_name:<12} {metrics['f1_score']:<8.4f} {metrics['precision']:<10.4f} "
              f"{metrics['recall']:<8.4f} {metrics['auc']:<8.4f} {metrics['support']:<8} {in_test}")

    # Training vs test performance
    if 'training_information' in evaluation_data:
        training = evaluation_data['training_information']
        print(f"\nTraining vs Test Comparison:")
        print(f"  Training Val F1:  {training['best_val_f1']:.4f}")
        print(f"  Test F1:          {overall['macro_f1']:.4f}")
        print(f"  Performance Gap:  {training['best_val_f1'] - overall['macro_f1']:+.4f}")
        print(f"  Best Epoch:       {training['best_epoch']}")
        print(f"  Swin Variant:     {training['swin_variant']}")

    # Class imbalance analysis
    missing_classes = meta.get('missing_classes', [])
    available_classes = meta.get('available_classes', [])

    print(f"\nClass Availability Analysis:")
    print(f"  Available classes: {len(available_classes)}/7")
    print(f"  Missing classes: {missing_classes if missing_classes else 'None'}")

    # Architecture-specific information
    print(f"\nSwin Transformer Architecture Details:")
    print(f"  Variant: {swin_variant}")
    print(f"  Hierarchical processing: Multi-stage feature extraction")
    print(f"  Window-based attention: Shifted window mechanism")
    print(f"  Global pooling: Adaptive average pooling")

    # Wrong predictions summary if available
    if wrong_predictions_data:
        wrong_meta = wrong_predictions_data['analysis_metadata']
        print(f"\nError Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']}/{wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        # Top confusion patterns
        patterns = wrong_predictions_data.get('confusion_patterns', {})
        if patterns:
            print(f"\nTop Confusion Patterns:")
            sorted_patterns = sorted(patterns.items(), key=lambda x: x[1], reverse=True)[:3]
            for pattern, count in sorted_patterns:
                print(f"  {pattern}: {count} cases")

    print(f"\nInference Performance:")
    inference = evaluation_data['inference_performance']
    print(f"  Total time: {inference['total_time_seconds']:.2f}s")
    print(f"  Speed: {inference['average_time_ms_per_sample']:.1f} ms/sample")

# Find Swin Transformer evaluation JSON files
json_files = find_evaluation_json_files_casme2_swin(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No Swin Transformer evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (Swin evaluation) has been executed first!")
else:
    print(f"Found {len(json_files)} Swin evaluation file(s)")

# Create output directory - Fixed path as requested
output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

# Process Swin Transformer evaluation results
results_summary = {}
generated_files = []

if 'main' in json_files:
    # Load main evaluation data
    eval_data = load_evaluation_results_casme2_swin(json_files['main'])

    # Load wrong predictions data if available
    wrong_data = None
    if 'wrong_predictions' in json_files:
        wrong_data = load_evaluation_results_casme2_swin(json_files['wrong_predictions'])

    if eval_data is not None:
        try:
            # Analyze missing classes for Swin Transformer
            class_analysis = analyze_missing_classes_casme2_swin(eval_data)

            # Generate Swin Transformer confusion matrix
            cm_output_path = os.path.join(output_dir, "confusion_matrix_CASME2_SwinT_Direct.png")
            metrics = create_confusion_matrix_plot_casme2_swin(eval_data, cm_output_path)
            generated_files.append(cm_output_path)

            # Generate Swin Transformer per-class performance chart
            perf_output_path = os.path.join(output_dir, "per_class_performance_CASME2_SwinT_Direct.png")
            create_per_class_performance_chart_casme2_swin(eval_data, perf_output_path)
            generated_files.append(perf_output_path)

            results_summary['casme2_swin'] = metrics
            results_summary['casme2_swin']['class_analysis'] = class_analysis

            print(f"SUCCESS: Swin Transformer visualization files generated successfully")

        except Exception as e:
            print(f"ERROR: Failed to generate Swin visualizations: {str(e)}")
            import traceback
            traceback.print_exc()

        # Generate comprehensive Swin Transformer summary
        generate_performance_summary_casme2_swin(eval_data, wrong_data)

    else:
        print("ERROR: Could not load Swin Transformer evaluation data")
else:
    print("ERROR: No main Swin Transformer evaluation results found")

# Final summary
if generated_files:
    print(f"\n" + "=" * 65)
    print("CASME II SWIN TRANSFORMER CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 65)

    print(f"Generated visualization files:")
    for file_path in generated_files:
        filename = os.path.basename(file_path)
        print(f"  {filename}")

    if 'casme2_swin' in results_summary:
        casme2_swin = results_summary['casme2_swin']
        print(f"\nFinal Performance Summary:")
        print(f"  Architecture:     Swin-{casme2_swin['swin_variant'].capitalize()}")
        print(f"  Accuracy:         {casme2_swin['accuracy']:.4f}")
        print(f"  Macro F1:         {casme2_swin['macro_f1']:.4f}")
        print(f"  Weighted F1:      {casme2_swin['weighted_f1']:.4f}")
        print(f"  Balanced Acc:     {casme2_swin['balanced_accuracy']:.4f}")

        if 'class_analysis' in casme2_swin:
            analysis = casme2_swin['class_analysis']
            print(f"  Available classes: {len(analysis['available'])}/{analysis['total_classes']}")
            print(f"  Missing classes: {len(analysis['missing'])}")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No Swin Transformer visualizations were generated")
    print("Please check:")
    print("1. Cell 3 Swin evaluation results exist")
    print("2. JSON file structure is correct")
    print("3. No file permission issues")

print("\nCell 4 completed - CASME II Swin Transformer confusion matrix analysis generated")
print("Ready for ViT vs Swin Transformer comparative analysis")

CASME II Swin Transformer Direct Baseline Confusion Matrix Generation
Found CASME II Swin evaluation file: casme2_swint_direct_evaluation_results.json
Found Swin wrong predictions file: casme2_swint_direct_wrong_predictions.json
Found 2 Swin evaluation file(s)
Successfully loaded Swin evaluation results from: casme2_swint_direct_evaluation_results.json
Successfully loaded Swin evaluation results from: casme2_swint_direct_wrong_predictions.json
Swin Transformer Class Analysis:
  Variant: base
  Available in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
  Missing from test: ['fear']
Processing Swin Transformer confusion matrix for CASME II classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
Swin variant: base
Confusion matrix shape: (7, 7)
Calculated metrics - Macro F1: 0.4103, Weighted F1: 0.4930, Balanced Acc: 0.6461, Accuracy: 0.5000
Swin Transformer confusion matrix saved to: confusion_matrix_CASME2_SwinT_Direct.png