In [1]:
# @title Cell 1: CASME II MobileNetV3-Small M1 MFS Infrastructure Configuration

# File: 08_01_MobileNet_CASME2_M1_MFS_Cell1.py
# Location: experiments/08_01_MobileNet_CASME2_M1_MFS.ipynb
# Purpose: MobileNetV3-Small for CASME II micro-expression recognition with M1 MFS methodology

from google.colab import drive
print("=" * 60)
print("CASME II CNN BASELINE - MobileNetV3-Small M1 MFS")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import timm
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/data_split_v3"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/08_01_mobilenet_casme2_mfs"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/08_01_mobilenet_casme2_mfs"

METADATA_TRAIN = f"{DATASET_ROOT}/split_metadata_v3.json"
PROCESSING_SUMMARY = f"{DATASET_ROOT}/processing_summary_v3.json"

print("\nCASME II MobileNetV3-Small M1 MFS Baseline - Infrastructure Configuration")
print("=" * 60)

if not os.path.exists(METADATA_TRAIN):
    raise FileNotFoundError(f"Phase 3 metadata not found: {METADATA_TRAIN}")
if not os.path.exists(PROCESSING_SUMMARY):
    raise FileNotFoundError(f"Phase 3 processing summary not found: {PROCESSING_SUMMARY}")

print("Loading CASME II Phase 3 dataset metadata...")
with open(METADATA_TRAIN, 'r') as f:
    casme2_metadata = json.load(f)

with open(PROCESSING_SUMMARY, 'r') as f:
    processing_info = json.load(f)

print(f"Dataset: {processing_info['dataset']}")
print(f"Phase: {processing_info['phase']}")
print(f"Total images: {processing_info['total_images_copied']}")
print(f"Extraction strategy: {processing_info.get('extraction_strategy', {})}")

USE_FOCAL_LOSS = True
FOCAL_LOSS_GAMMA = 2.5

CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.25, 1.76, 1.91, 1.99, 3.76, 7.04]
FOCAL_LOSS_ALPHA_WEIGHTS = [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]

MOBILENET_MODEL_NAME = 'mobilenetv3_small_100'

print("\n" + "=" * 50)
print("EXPERIMENT CONFIGURATION - CNN M1 MFS")
print("=" * 50)
print(f"Model: MobileNetV3-Small (TIMM)")
print(f"Methodology: M1 (Raw Images)")
print(f"Input Resolution: 640x480 RGB")
print(f"Preprocessing: None (raw CASME II resolution)")
print(f"Loss Function: Focal Loss")
print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
print(f"  Alpha Weights: {FOCAL_LOSS_ALPHA_WEIGHTS}")
print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
print("=" * 50)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

if 'A100' in gpu_name:
    BATCH_SIZE = 24
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("A100: Optimized batch size for 640x480 RGB input")
elif 'L4' in gpu_name:
    BATCH_SIZE = 16
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("L4: Balanced performance configuration")
else:
    BATCH_SIZE = 8
    NUM_WORKERS = 8
    print("Default GPU: Conservative settings for large input size")

RAM_PRELOAD_WORKERS = 32
print(f"RAM preload workers: {RAM_PRELOAD_WORKERS}")

CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

print("\nLoading class distribution...")
try:
    if 'splits' in casme2_metadata:
        train_dist = casme2_metadata['splits']['train']['class_distribution']
        val_dist = casme2_metadata['splits']['val']['class_distribution']
        test_dist = casme2_metadata['splits']['test']['class_distribution']
        print("Using class distribution from split_metadata (v3 format)")
    elif 'train' in casme2_metadata and 'class_distribution' in casme2_metadata['train']:
        train_dist = casme2_metadata['train']['class_distribution']
        val_dist = casme2_metadata['val']['class_distribution']
        test_dist = casme2_metadata['test']['class_distribution']
        print("Using class distribution from split_metadata (v1 format)")
    else:
        train_dist = processing_info['class_preservation']['train']
        val_dist = processing_info['class_preservation']['val']
        test_dist = processing_info['class_preservation']['test']
        print("Using class distribution from processing_summary (v2 format)")
except KeyError as e:
    raise KeyError(f"Could not load class distribution from metadata. Missing key: {e}")

print(f"\nTrain distribution: {train_dist}")
print(f"Validation distribution: {val_dist}")
print(f"Test distribution: {test_dist}")

if USE_FOCAL_LOSS:
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

CASME2_MOBILENET_CONFIG = {
    'model_name': MOBILENET_MODEL_NAME,
    'input_size': (640, 480),
    'num_classes': 7,
    'dropout_rate': 0.3,

    'learning_rate': 5e-5,
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'device': device,

    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_factor': 0.5,
    'scheduler_patience': 5,
    'scheduler_min_lr': 1e-6,
    'scheduler_monitor': 'val_f1_macro',

    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS,
    'class_weights': class_weights,

    'use_macro_avg': True,
    'early_stopping': False,
    'save_best_f1': True,
    'save_strategy': 'best_only',

    'dataset_phase': 'v3',
    'methodology': 'M1',
    'preprocessing': 'raw_images',
    'frame_strategy': 'multi_frame_sampling',
    'train_augmentation': 'temporal_windows',
    'frame_types': ['onset', 'apex', 'offset'],
    'extraction_strategy': processing_info.get('extraction_strategy', {}),
    'copy_statistics': processing_info.get('copy_statistics', {})
}

print(f"\nMobileNetV3 Configuration Summary:")
print(f"  Model: {CASME2_MOBILENET_CONFIG['model_name']}")
print(f"  Input size: {CASME2_MOBILENET_CONFIG['input_size'][0]}x{CASME2_MOBILENET_CONFIG['input_size'][1]} RGB")
print(f"  Methodology: {CASME2_MOBILENET_CONFIG['methodology']} (raw images)")
print(f"  Learning rate: {CASME2_MOBILENET_CONFIG['learning_rate']}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Dataset phase: {CASME2_MOBILENET_CONFIG['dataset_phase']}")
print(f"  Frame strategy: {CASME2_MOBILENET_CONFIG['frame_strategy']}")

class OptimizedFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()

        if alpha is not None:
            if isinstance(alpha, list):
                self.alpha = torch.tensor(alpha, dtype=torch.float32)
            else:
                self.alpha = alpha

            alpha_sum = self.alpha.sum().item()
            if abs(alpha_sum - 1.0) > 0.01:
                print(f"Warning: Alpha weights sum to {alpha_sum:.3f}, expected 1.0")
        else:
            self.alpha = None

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)

        if self.alpha is not None:
            if self.alpha.device != targets.device:
                self.alpha = self.alpha.to(targets.device)
            alpha_t = self.alpha.gather(0, targets)
        else:
            alpha_t = 1.0

        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class MobileNetCASME2Baseline(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.2):
        super(MobileNetCASME2Baseline, self).__init__()

        self.mobilenet = timm.create_model(
            MOBILENET_MODEL_NAME,
            pretrained=True,
            num_classes=0,
            global_pool='avg'
        )

        for param in self.mobilenet.parameters():
            param.requires_grad = True

        with torch.no_grad():
            test_input = torch.randn(1, 3, 480, 640)
            test_output = self.mobilenet(test_input)
            self.mobilenet_feature_dim = test_output.shape[1]

        print(f"MobileNetV3-Small feature dimension: {self.mobilenet_feature_dim}")

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.mobilenet_feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
        )

        self.classifier = nn.Linear(128, num_classes)

        print(f"MobileNet CASME II: {self.mobilenet_feature_dim} -> 512 -> 128 -> {num_classes}")

    def forward(self, x):
        features = self.mobilenet(x)
        processed_features = self.classifier_layers(features)
        output = self.classifier(processed_features)
        return output

def create_optimizer_scheduler_casme2(model, config):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

print("\nSetting up transforms for M1 methodology (raw 640x480 RGB)...")

mobilenet_transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

mobilenet_transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("M1 transforms configured: raw 640x480 RGB with ImageNet normalization")

class CASME2Dataset(Dataset):
    def __init__(self, split_metadata, dataset_root, transform=None, split='train'):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split

        print(f"Loaded {len(self.metadata)} samples for {split} split")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        sample = self.metadata[idx]

        image_path = os.path.join(self.dataset_root, self.split, sample['image_filename'])
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        emotion = sample['emotion']
        label = CLASS_TO_IDX[emotion]

        return image, label, sample['sample_id']

os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

TRAIN_PATH = f"{DATASET_ROOT}/train"
VAL_PATH = f"{DATASET_ROOT}/val"
TEST_PATH = f"{DATASET_ROOT}/test"

print(f"\nDataset paths:")
print(f"Train: {TRAIN_PATH}")
print(f"Validation: {VAL_PATH}")
print(f"Test: {TEST_PATH}")

print("\nMobileNetV3 CASME II architecture validation...")

try:
    test_model = MobileNetCASME2Baseline(num_classes=7, dropout_rate=0.2).to(device)
    test_input = torch.randn(1, 3, 480, 640).to(device)
    test_output = test_model(test_input)

    print(f"Validation successful: Output shape {test_output.shape}")
    print(f"Expected output shape: [1, 7] for CASME II 7 classes")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")

def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': mobilenet_transform_train,
    'transform_val': mobilenet_transform_val,
    'mobilenet_config': CASME2_MOBILENET_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'train_path': TRAIN_PATH,
    'val_path': VAL_PATH,
    'test_path': TEST_PATH,
    'metadata': casme2_metadata,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

print("\n" + "=" * 60)
print("CASME II MOBILENETV3-SMALL M1 MFS BASELINE CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: MobileNetV3-Small")
print(f"  Parameters: ~2.5M")
print(f"  Input Resolution: 640x480 RGB (raw)")
print(f"  Methodology: M1 (No preprocessing)")

print(f"\nDataset Configuration:")
print(f"  Phase: {CASME2_MOBILENET_CONFIG['dataset_phase']}")
print(f"  Frame strategy: {CASME2_MOBILENET_CONFIG['frame_strategy']}")
print(f"  Train augmentation: {CASME2_MOBILENET_CONFIG['train_augmentation']}")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Train samples: {processing_info.get('copy_statistics', {}).get('train', {}).get('total_images', 2061)}")

print("\nNext: Cell 2 - Dataset Loading and Training Pipeline")

CASME II CNN BASELINE - MobileNetV3-Small M1 MFS

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...

CASME II MobileNetV3-Small M1 MFS Baseline - Infrastructure Configuration
Loading CASME II Phase 3 dataset metadata...
Dataset: CASME2_MultiFrameSampling
Phase: Phase 3
Total images: 2774
Extraction strategy: {'train': 'multi_frame_windows_with_fallback', 'val': 'key_frames_only', 'test': 'key_frames_only', 'fallback_method': 'nearest_frame_duplication'}

EXPERIMENT CONFIGURATION - CNN M1 MFS
Model: MobileNetV3-Small (TIMM)
Methodology: M1 (Raw Images)
Input Resolution: 640x480 RGB
Preprocessing: None (raw CASME II resolution)
Loss Function: Focal Loss
  Gamma: 2.5
  Alpha Weights: [0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285]
  Alpha Sum: 1.000

Device: cuda
GPU: NVIDIA L4 (23.8 GB)
L4: Balanced performance con

model.safetensors:   0%|          | 0.00/10.2M [00:00<?, ?B/s]

MobileNetV3-Small feature dimension: 1024
MobileNet CASME II: 1024 -> 512 -> 128 -> 7
Validation failed: Expected more than 1 value per channel when training, got input size torch.Size([1, 512])

CASME II MOBILENETV3-SMALL M1 MFS BASELINE CONFIGURATION COMPLETE
Loss Configuration:
  Function: Optimized Focal Loss
  Gamma: 2.5
  Per-class Alpha: [0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285]
  Alpha Sum: 1.000

Model Configuration:
  Architecture: MobileNetV3-Small
  Parameters: ~2.5M
  Input Resolution: 640x480 RGB (raw)
  Methodology: M1 (No preprocessing)

Dataset Configuration:
  Phase: v3
  Frame strategy: multi_frame_sampling
  Train augmentation: temporal_windows
  Classes: 7
  Train samples: 2613

Next: Cell 2 - Dataset Loading and Training Pipeline


In [2]:
# @title Cell 2: CASME II MobileNetV3-Small Training Pipeline

# File: 08_01_MobileNet_CASME2_MFS_Cell2.py
# Location: experiments/08_01_MobileNet_CASME2-MFS.ipynb
# Purpose: Enhanced training pipeline for CASME II MobileNetV3-Small with temporal window augmentation

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import shutil
import tempfile

print("CASME II MobileNetV3-Small Training Pipeline")
print("=" * 70)
print(f"Model: MobileNetV3-Small")
print(f"Methodology: M1 (Raw 640x480 RGB)")
print(f"Loss Function: Focal Loss")
print(f"  Gamma: {CASME2_MOBILENET_CONFIG['focal_loss_gamma']}")
print(f"  Per-class Alpha: {CASME2_MOBILENET_CONFIG['focal_loss_alpha_weights']}")
print(f"  Alpha Sum: {sum(CASME2_MOBILENET_CONFIG['focal_loss_alpha_weights']):.3f}")
print(f"Dataset Phase: {CASME2_MOBILENET_CONFIG['dataset_phase']}")
print(f"Frame Strategy: {CASME2_MOBILENET_CONFIG['frame_strategy']}")
print(f"Training epochs: {CASME2_MOBILENET_CONFIG['num_epochs']}")
print(f"Scheduler patience: {CASME2_MOBILENET_CONFIG['scheduler_patience']}")

def normalize_metadata_structure(metadata):
    if 'splits' in metadata:
        print("Detected v2/v3 metadata format (with 'splits' key)")
        return metadata['splits']
    elif 'train' in metadata:
        print("Detected v1 metadata format (direct split keys)")
        return metadata
    else:
        raise ValueError("Unknown metadata format: missing both 'splits' and 'train' keys")

class CASME2DatasetTraining(Dataset):
    def __init__(self, split_metadata, dataset_root, transform=None, split='train', use_ram_cache=True):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.cached_images = []

        print(f"Loading CASME II {split} dataset for training...")

        for sample in self.metadata:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(CLASS_TO_IDX[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])

        print(f"Loaded {len(self.images)} CASME II {split} samples")
        self._print_distribution()

        if self.use_ram_cache:
            self._preload_to_ram()

    def _print_distribution(self):
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        print(f"Preloading {len(self.images)} {self.split} images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (640, 480):
                    image = image.resize((640, 480), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (640, 480), (128, 128, 128)), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc=f"Loading {self.split} to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * 640 * 480 * 3 * 4 / 1e9
        print(f"{self.split.upper()} RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (640, 480):
                    image = image.resize((640, 480), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (640, 480), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.sample_ids[idx]

def calculate_metrics_safe_robust(outputs, labels, class_names, average='macro'):
    try:
        if outputs.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: outputs {outputs.size(0)} vs labels {labels.size(0)}")

        if isinstance(outputs, torch.Tensor):
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
        else:
            predictions = np.array(outputs)

        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()
        else:
            labels = np.array(labels)

        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions,
            average=average,
            zero_division=0,
            labels=list(range(len(class_names)))
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1)
        }
    except Exception as e:
        print(f"Warning: Enhanced metrics calculation error: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }

def train_epoch(model, dataloader, criterion, optimizer, device, epoch, total_epochs):
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    progress_bar = tqdm(dataloader, desc=f"CASME II Training Epoch {epoch+1}/{total_epochs}")

    for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        model_output = model(images)

        if isinstance(model_output, (tuple, list)):
            outputs = model_output[0]
        elif isinstance(model_output, dict):
            outputs = model_output.get('logits', model_output.get('prediction', model_output))
        else:
            outputs = model_output

        if outputs.dim() != 2 or outputs.size(1) != 7:
            raise ValueError(f"Invalid CASME II output shape: {outputs.shape}, expected [batch_size, 7]")

        loss = criterion(outputs, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), CASME2_MOBILENET_CONFIG['gradient_clip'])

        optimizer.step()
        running_loss += loss.item()

        all_outputs.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())

        if batch_idx % 5 == 0:
            avg_loss = running_loss / (batch_idx + 1)
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'Loss': f'{avg_loss:.4f}',
                'LR': f'{current_lr:.2e}'
            })

    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Training metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics

def validate_epoch(model, dataloader, criterion, device, epoch, total_epochs):
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []
    all_sample_ids = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f"CASME II Validation Epoch {epoch+1}/{total_epochs}")

        for batch_idx, (images, labels, sample_ids) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)

            model_output = model(images)

            if isinstance(model_output, (tuple, list)):
                outputs = model_output[0]
            elif isinstance(model_output, dict):
                outputs = model_output.get('logits', model_output.get('prediction', model_output))
            else:
                outputs = model_output

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            all_outputs.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_sample_ids.extend(sample_ids)

            if batch_idx % 3 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                progress_bar.set_postfix({'Val Loss': f'{avg_loss:.4f}'})

    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Validation metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics, all_sample_ids

def save_checkpoint_robust(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                         checkpoint_dir, best_metrics, config, max_retries=3):
    def make_serializable_cpu(obj):
        if isinstance(obj, torch.Tensor):
            cpu_obj = obj.detach().cpu()
            return cpu_obj.item() if cpu_obj.numel() == 1 else cpu_obj.tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, dict):
            return {k: make_serializable_cpu(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [make_serializable_cpu(item) for item in obj]
        else:
            return obj

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_metrics': make_serializable_cpu(train_metrics),
        'val_metrics': make_serializable_cpu(val_metrics),
        'casme2_config': make_serializable_cpu(config),
        'best_f1': float(best_metrics['f1']),
        'best_loss': float(best_metrics['loss']),
        'best_acc': float(best_metrics['accuracy']),
        'class_names': CASME2_CLASSES,
        'num_classes': 7
    }

    final_path = f"{checkpoint_dir}/casme2_mobilenet_mfs_best_f1.pth"

    for attempt in range(max_retries):
        try:
            temp_fd, temp_path = tempfile.mkstemp(dir=checkpoint_dir, suffix='.pth.tmp')
            os.close(temp_fd)

            print(f"Attempt {attempt + 1}: Saving checkpoint to temporary file...")
            torch.save(checkpoint, temp_path)

            print("Validating checkpoint integrity...")
            validation_checkpoint = torch.load(temp_path, map_location='cpu')

            required_keys = ['model_state_dict', 'epoch', 'best_f1', 'num_classes']
            for key in required_keys:
                if key not in validation_checkpoint:
                    raise ValueError(f"Checkpoint validation failed: missing key '{key}'")

            if validation_checkpoint['epoch'] != epoch:
                raise ValueError(f"Checkpoint epoch mismatch: saved {epoch}, loaded {validation_checkpoint['epoch']}")

            print("Checkpoint validation passed")

            print(f"Moving validated checkpoint to final location...")
            shutil.move(temp_path, final_path)

            print(f"Checkpoint saved and validated successfully: {os.path.basename(final_path)}")
            print(f"  Epoch: {epoch + 1}")
            print(f"  Val F1: {best_metrics['f1']:.4f}")
            print(f"  Val Loss: {best_metrics['loss']:.4f}")
            print(f"  Val Acc: {best_metrics['accuracy']:.4f}")

            return final_path

        except Exception as e:
            print(f"Checkpoint save attempt {attempt + 1}/{max_retries} failed: {e}")

            if os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except:
                    pass

            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                print(f"Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print(f"All {max_retries} checkpoint save attempts failed")
                return None

    return None

def safe_json_serialize(obj):
    if isinstance(obj, torch.Tensor):
        return obj.cpu().item() if obj.numel() == 1 else obj.cpu().numpy().tolist()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: safe_json_serialize(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [safe_json_serialize(item) for item in obj]
    elif hasattr(obj, '__dict__'):
        return safe_json_serialize(obj.__dict__)
    else:
        try:
            return float(obj) if isinstance(obj, (int, float)) else str(obj)
        except:
            return str(obj)

print("\nCreating CASME II MobileNetV3-Small training datasets...")

normalized_metadata = normalize_metadata_structure(GLOBAL_CONFIG_CASME2['metadata'])

train_dataset = CASME2DatasetTraining(
    split_metadata=normalized_metadata,
    dataset_root=GLOBAL_CONFIG_CASME2['train_path'].replace('/train', ''),
    transform=GLOBAL_CONFIG_CASME2['transform_train'],
    split='train',
    use_ram_cache=True
)

val_dataset = CASME2DatasetTraining(
    split_metadata=normalized_metadata,
    dataset_root=GLOBAL_CONFIG_CASME2['val_path'].replace('/val', ''),
    transform=GLOBAL_CONFIG_CASME2['transform_val'],
    split='val',
    use_ram_cache=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CASME2_MOBILENET_CONFIG['batch_size'],
    shuffle=True,
    num_workers=CASME2_MOBILENET_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CASME2_MOBILENET_CONFIG['batch_size'],
    shuffle=False,
    num_workers=CASME2_MOBILENET_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

print(f"Training batches: {len(train_loader)} (samples: {len(train_dataset)})")
print(f"Validation batches: {len(val_loader)} (samples: {len(val_dataset)})")

print("\nInitializing CASME II MobileNetV3-Small model...")
model = MobileNetCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_MOBILENET_CONFIG['dropout_rate']
).to(GLOBAL_CONFIG_CASME2['device'])

if CASME2_MOBILENET_CONFIG['use_focal_loss']:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=True,
        alpha_weights=CASME2_MOBILENET_CONFIG['focal_loss_alpha_weights'],
        gamma=CASME2_MOBILENET_CONFIG['focal_loss_gamma']
    )
else:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=False,
        alpha_weights=None,
        gamma=2.0
    )

optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_MOBILENET_CONFIG
)

print(f"Optimizer: AdamW (LR={CASME2_MOBILENET_CONFIG['learning_rate']})")
print(f"Scheduler: ReduceLROnPlateau (patience={CASME2_MOBILENET_CONFIG['scheduler_patience']})")
print(f"Criterion: Optimized Focal Loss")

training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

print("\nStarting CASME II MobileNetV3-Small training...")
print(f"Training configuration: {CASME2_MOBILENET_CONFIG['num_epochs']} epochs")
print(f"Input resolution: 640x480 RGB (M1 methodology)")
print("=" * 70)

start_time = time.time()

for epoch in range(CASME2_MOBILENET_CONFIG['num_epochs']):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{CASME2_MOBILENET_CONFIG['num_epochs']}")

    train_loss, train_metrics = train_epoch(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_MOBILENET_CONFIG['num_epochs']
    )

    val_loss, val_metrics, val_sample_ids = validate_epoch(
        model, val_loader, criterion,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_MOBILENET_CONFIG['num_epochs']
    )

    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    save_model = False
    improvement_reason = ""

    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        best_model_path = save_checkpoint_robust(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics, GLOBAL_CONFIG_CASME2['checkpoint_root'],
            best_metrics, CASME2_MOBILENET_CONFIG
        )

        if best_model_path:
            print(f"New best model: {improvement_reason} - F1: {best_metrics['f1']:.4f}")
        else:
            print(f"Warning: Failed to save checkpoint despite improvement")

    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * CASME2_MOBILENET_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / CASME2_MOBILENET_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min")

total_time = time.time() - start_time
actual_epochs = CASME2_MOBILENET_CONFIG['num_epochs']

print("\n" + "=" * 70)
print("CASME II MOBILENETV3-SMALL BASELINE TRAINING COMPLETED")
print("=" * 70)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {actual_epochs}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")

results_dir = GLOBAL_CONFIG_CASME2['results_root']
os.makedirs(f"{results_dir}/training_logs", exist_ok=True)

training_history_path = f"{results_dir}/training_logs/casme2_mobilenet_mfs_training_history.json"

print("\nExporting enhanced training documentation...")

try:
    training_summary = {
        'experiment_type': 'CASME2_MobileNetV3Small_MFS_Baseline',
        'experiment_configuration': {
            'model_architecture': 'MobileNetV3-Small',
            'model_parameters': '2.5M',
            'dataset_phase': CASME2_MOBILENET_CONFIG['dataset_phase'],
            'methodology': CASME2_MOBILENET_CONFIG['methodology'],
            'preprocessing': CASME2_MOBILENET_CONFIG['preprocessing'],
            'input_resolution': '640x480 RGB',
            'frame_strategy': CASME2_MOBILENET_CONFIG['frame_strategy'],
            'train_augmentation': CASME2_MOBILENET_CONFIG['train_augmentation'],
            'frame_types': CASME2_MOBILENET_CONFIG['frame_types'],
            'loss_function': 'Optimized Focal Loss',
            'focal_loss_gamma': CASME2_MOBILENET_CONFIG['focal_loss_gamma'],
            'focal_loss_alpha_weights': CASME2_MOBILENET_CONFIG['focal_loss_alpha_weights'],
            'model_name': CASME2_MOBILENET_CONFIG['model_name']
        },
        'training_history': safe_json_serialize(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(actual_epochs),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'config': safe_json_serialize(CASME2_MOBILENET_CONFIG),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'model_checkpoint': 'casme2_mobilenet_mfs_best_f1.pth',
        'dataset_info': {
            'name': 'CASME_II',
            'phase': CASME2_MOBILENET_CONFIG['dataset_phase'],
            'methodology': CASME2_MOBILENET_CONFIG['methodology'],
            'input_resolution': '640x480 RGB',
            'frame_strategy': CASME2_MOBILENET_CONFIG['frame_strategy'],
            'train_augmentation': CASME2_MOBILENET_CONFIG['train_augmentation'],
            'frame_types': CASME2_MOBILENET_CONFIG['frame_types'],
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': 7,
            'class_names': CASME2_CLASSES
        },
        'architecture_info': {
            'model_type': 'MobileNetCASME2Baseline',
            'backbone': CASME2_MOBILENET_CONFIG['model_name'],
            'input_size': '640x480 RGB',
            'classification_head': '576->512->128->7'
        },
        'enhanced_features': {
            'hardened_checkpoint_system': True,
            'atomic_checkpoint_save': True,
            'checkpoint_validation': True,
            'model_output_validation': True,
            'enhanced_error_handling': True,
            'multi_criteria_checkpoint_logic': True,
            'memory_optimized_training': True,
            'retry_with_backoff': True,
            'multi_frame_temporal_windows': True
        }
    }

    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Enhanced training documentation saved successfully: {training_history_path}")
    print(f"Model: {training_summary['experiment_configuration']['model_architecture']}")
    print(f"Methodology: {training_summary['experiment_configuration']['methodology']}")
    print(f"Input resolution: {training_summary['experiment_configuration']['input_resolution']}")
    print(f"Loss function: {training_summary['experiment_configuration']['loss_function']}")

except Exception as e:
    print(f"Warning: Could not save training documentation: {e}")
    print("Training completed successfully, but documentation export failed")

if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\nNext: Cell 3 - CASME II MobileNetV3-Small Evaluation")
print("Enhanced training pipeline with multi-frame temporal windows completed successfully!")

CASME II MobileNetV3-Small Training Pipeline
Model: MobileNetV3-Small
Methodology: M1 (Raw 640x480 RGB)
Loss Function: Focal Loss
  Gamma: 2.5
  Per-class Alpha: [0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285]
  Alpha Sum: 1.000
Dataset Phase: v3
Frame Strategy: multi_frame_sampling
Training epochs: 50
Scheduler patience: 5

Creating CASME II MobileNetV3-Small training datasets...
Detected v2/v3 metadata format (with 'splits' key)
Loading CASME II train dataset for training...
Loaded 2613 CASME II train samples
  others: 1027 samples (39.3%)
  disgust: 650 samples (24.9%)
  happiness: 325 samples (12.4%)
  repression: 273 samples (10.4%)
  surprise: 260 samples (10.0%)
  sadness: 65 samples (2.5%)
  fear: 13 samples (0.5%)
Preloading 2613 train images to RAM with 32 workers...


Loading train to RAM: 100%|██████████| 2613/2613 [00:29<00:00, 87.98it/s]


TRAIN RAM caching completed: 2613/2613 images, ~9.63GB
Loading CASME II val dataset for training...
Loaded 78 CASME II val samples
  others: 30 samples (38.5%)
  disgust: 18 samples (23.1%)
  happiness: 9 samples (11.5%)
  repression: 9 samples (11.5%)
  surprise: 6 samples (7.7%)
  sadness: 3 samples (3.8%)
  fear: 3 samples (3.8%)
Preloading 78 val images to RAM with 32 workers...


Loading val to RAM: 100%|██████████| 78/78 [00:01<00:00, 73.14it/s]


VAL RAM caching completed: 78/78 images, ~0.29GB
Training batches: 164 (samples: 2613)
Validation batches: 5 (samples: 78)

Initializing CASME II MobileNetV3-Small model...
MobileNetV3-Small feature dimension: 1024
MobileNet CASME II: 1024 -> 512 -> 128 -> 7
Using Optimized Focal Loss with gamma=2.5
Per-class alpha weights: [0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285, 0.14285714285714285]
Alpha sum: 1.000
Scheduler: ReduceLROnPlateau monitoring val_f1_macro
Optimizer: AdamW (LR=5e-05)
Scheduler: ReduceLROnPlateau (patience=5)
Criterion: Optimized Focal Loss

Starting CASME II MobileNetV3-Small training...
Training configuration: 50 epochs
Input resolution: 640x480 RGB (M1 methodology)

Epoch 1/50


CASME II Training Epoch 1/50: 100%|██████████| 164/164 [00:37<00:00,  4.34it/s, Loss=0.1406, LR=5.00e-05]
CASME II Validation Epoch 1/50: 100%|██████████| 5/5 [00:14<00:00,  2.84s/it, Val Loss=0.1610]


Train - Loss: 0.1404, F1: 0.3477, Acc: 0.4259
Val   - Loss: 0.1741, F1: 0.1649, Acc: 0.2949
Time  - Epoch: 52.0s, LR: 5.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_mobilenet_mfs_best_f1.pth
  Epoch: 1
  Val F1: 0.1649
  Val Loss: 0.1741
  Val Acc: 0.2949
New best model: Higher F1 - F1: 0.1649
Progress: 2.0% | Best F1: 0.1649 | ETA: 42.7min

Epoch 2/50


CASME II Training Epoch 2/50: 100%|██████████| 164/164 [00:09<00:00, 17.95it/s, Loss=0.0645, LR=5.00e-05]
CASME II Validation Epoch 2/50: 100%|██████████| 5/5 [00:00<00:00,  6.27it/s, Val Loss=0.1316]


Train - Loss: 0.0640, F1: 0.6330, Acc: 0.7298
Val   - Loss: 0.1453, F1: 0.2466, Acc: 0.3974
Time  - Epoch: 9.9s, LR: 5.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_mobilenet_mfs_best_f1.pth
  Epoch: 2
  Val F1: 0.2466
  Val Loss: 0.1453
  Val Acc: 0.3974
New best model: Higher F1 - F1: 0.2466
Progress: 4.0% | Best F1: 0.2466 | ETA: 25.0min

Epoch 3/50


CASME II Training Epoch 3/50: 100%|██████████| 164/164 [00:09<00:00, 17.67it/s, Loss=0.0253, LR=5.00e-05]
CASME II Validation Epoch 3/50: 100%|██████████| 5/5 [00:00<00:00,  6.23it/s, Val Loss=0.1705]


Train - Loss: 0.0252, F1: 0.8007, Acc: 0.9112
Val   - Loss: 0.1951, F1: 0.1744, Acc: 0.2949
Time  - Epoch: 10.1s, LR: 5.00e-05
Progress: 6.0% | Best F1: 0.2466 | ETA: 18.9min

Epoch 4/50


CASME II Training Epoch 4/50: 100%|██████████| 164/164 [00:09<00:00, 17.79it/s, Loss=0.0103, LR=5.00e-05]
CASME II Validation Epoch 4/50: 100%|██████████| 5/5 [00:00<00:00,  6.20it/s, Val Loss=0.1643]


Train - Loss: 0.0104, F1: 0.9007, Acc: 0.9709
Val   - Loss: 0.1872, F1: 0.1944, Acc: 0.3462
Time  - Epoch: 10.0s, LR: 5.00e-05
Progress: 8.0% | Best F1: 0.2466 | ETA: 15.8min

Epoch 5/50


CASME II Training Epoch 5/50: 100%|██████████| 164/164 [00:09<00:00, 17.98it/s, Loss=0.0060, LR=5.00e-05]
CASME II Validation Epoch 5/50: 100%|██████████| 5/5 [00:00<00:00,  6.53it/s, Val Loss=0.1365]


Train - Loss: 0.0064, F1: 0.9765, Acc: 0.9820
Val   - Loss: 0.1817, F1: 0.1412, Acc: 0.3718
Time  - Epoch: 9.9s, LR: 5.00e-05
Progress: 10.0% | Best F1: 0.2466 | ETA: 13.9min

Epoch 6/50


CASME II Training Epoch 6/50: 100%|██████████| 164/164 [00:09<00:00, 17.44it/s, Loss=0.0031, LR=5.00e-05]
CASME II Validation Epoch 6/50: 100%|██████████| 5/5 [00:00<00:00,  6.55it/s, Val Loss=0.1676]


Train - Loss: 0.0032, F1: 0.9902, Acc: 0.9904
Val   - Loss: 0.1942, F1: 0.2892, Acc: 0.4231
Time  - Epoch: 10.2s, LR: 5.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_mobilenet_mfs_best_f1.pth
  Epoch: 6
  Val F1: 0.2892
  Val Loss: 0.1942
  Val Acc: 0.4231
New best model: Higher F1 - F1: 0.2892
Progress: 12.0% | Best F1: 0.2892 | ETA: 12.6min

Epoch 7/50


CASME II Training Epoch 7/50: 100%|██████████| 164/164 [00:09<00:00, 17.80it/s, Loss=0.0020, LR=5.00e-05]
CASME II Validation Epoch 7/50: 100%|██████████| 5/5 [00:00<00:00,  6.49it/s, Val Loss=0.1507]


Train - Loss: 0.0020, F1: 0.9963, Acc: 0.9958
Val   - Loss: 0.1977, F1: 0.2694, Acc: 0.4487
Time  - Epoch: 10.0s, LR: 5.00e-05
Progress: 14.0% | Best F1: 0.2892 | ETA: 11.6min

Epoch 8/50


CASME II Training Epoch 8/50: 100%|██████████| 164/164 [00:09<00:00, 17.90it/s, Loss=0.0018, LR=5.00e-05]
CASME II Validation Epoch 8/50: 100%|██████████| 5/5 [00:00<00:00,  6.35it/s, Val Loss=0.1892]


Train - Loss: 0.0019, F1: 0.9947, Acc: 0.9954
Val   - Loss: 0.2298, F1: 0.1589, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 5.00e-05
Progress: 16.0% | Best F1: 0.2892 | ETA: 10.8min

Epoch 9/50


CASME II Training Epoch 9/50: 100%|██████████| 164/164 [00:09<00:00, 17.77it/s, Loss=0.0015, LR=5.00e-05]
CASME II Validation Epoch 9/50: 100%|██████████| 5/5 [00:00<00:00,  6.25it/s, Val Loss=0.1870]


Train - Loss: 0.0015, F1: 0.9943, Acc: 0.9943
Val   - Loss: 0.2300, F1: 0.2199, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 5.00e-05
Progress: 18.0% | Best F1: 0.2892 | ETA: 10.1min

Epoch 10/50


CASME II Training Epoch 10/50: 100%|██████████| 164/164 [00:09<00:00, 17.47it/s, Loss=0.0013, LR=5.00e-05]
CASME II Validation Epoch 10/50: 100%|██████████| 5/5 [00:00<00:00,  6.79it/s, Val Loss=0.1980]


Train - Loss: 0.0013, F1: 0.9972, Acc: 0.9962
Val   - Loss: 0.2568, F1: 0.2022, Acc: 0.3462
Time  - Epoch: 10.1s, LR: 5.00e-05
Progress: 20.0% | Best F1: 0.2892 | ETA: 9.5min

Epoch 11/50


CASME II Training Epoch 11/50: 100%|██████████| 164/164 [00:09<00:00, 17.73it/s, Loss=0.0007, LR=5.00e-05]
CASME II Validation Epoch 11/50: 100%|██████████| 5/5 [00:00<00:00,  6.33it/s, Val Loss=0.1999]


Train - Loss: 0.0012, F1: 0.9952, Acc: 0.9958
Val   - Loss: 0.2441, F1: 0.1664, Acc: 0.3462
Time  - Epoch: 10.1s, LR: 5.00e-05
Progress: 22.0% | Best F1: 0.2892 | ETA: 9.0min

Epoch 12/50


CASME II Training Epoch 12/50: 100%|██████████| 164/164 [00:09<00:00, 17.48it/s, Loss=0.0005, LR=5.00e-05]
CASME II Validation Epoch 12/50: 100%|██████████| 5/5 [00:00<00:00,  6.18it/s, Val Loss=0.1476]


Train - Loss: 0.0005, F1: 0.9990, Acc: 0.9985
Val   - Loss: 0.2210, F1: 0.2338, Acc: 0.4103
Time  - Epoch: 10.2s, LR: 2.50e-05
Progress: 24.0% | Best F1: 0.2892 | ETA: 8.6min

Epoch 13/50


CASME II Training Epoch 13/50: 100%|██████████| 164/164 [00:09<00:00, 17.78it/s, Loss=0.0003, LR=2.50e-05]
CASME II Validation Epoch 13/50: 100%|██████████| 5/5 [00:00<00:00,  6.48it/s, Val Loss=0.1832]


Train - Loss: 0.0003, F1: 0.9995, Acc: 0.9996
Val   - Loss: 0.2419, F1: 0.2017, Acc: 0.3974
Time  - Epoch: 10.0s, LR: 2.50e-05
Progress: 26.0% | Best F1: 0.2892 | ETA: 8.2min

Epoch 14/50


CASME II Training Epoch 14/50: 100%|██████████| 164/164 [00:09<00:00, 18.01it/s, Loss=0.0004, LR=2.50e-05]
CASME II Validation Epoch 14/50: 100%|██████████| 5/5 [00:00<00:00,  6.59it/s, Val Loss=0.1794]


Train - Loss: 0.0004, F1: 0.9993, Acc: 0.9992
Val   - Loss: 0.2438, F1: 0.1813, Acc: 0.3718
Time  - Epoch: 9.9s, LR: 2.50e-05
Progress: 28.0% | Best F1: 0.2892 | ETA: 7.9min

Epoch 15/50


CASME II Training Epoch 15/50: 100%|██████████| 164/164 [00:09<00:00, 17.92it/s, Loss=0.0002, LR=2.50e-05]
CASME II Validation Epoch 15/50: 100%|██████████| 5/5 [00:00<00:00,  6.33it/s, Val Loss=0.1974]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2577, F1: 0.1866, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 2.50e-05
Progress: 30.0% | Best F1: 0.2892 | ETA: 7.5min

Epoch 16/50


CASME II Training Epoch 16/50: 100%|██████████| 164/164 [00:09<00:00, 17.94it/s, Loss=0.0002, LR=2.50e-05]
CASME II Validation Epoch 16/50: 100%|██████████| 5/5 [00:00<00:00,  6.33it/s, Val Loss=0.1983]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2519, F1: 0.1992, Acc: 0.3718
Time  - Epoch: 9.9s, LR: 2.50e-05
Progress: 32.0% | Best F1: 0.2892 | ETA: 7.2min

Epoch 17/50


CASME II Training Epoch 17/50: 100%|██████████| 164/164 [00:09<00:00, 18.15it/s, Loss=0.0002, LR=2.50e-05]
CASME II Validation Epoch 17/50: 100%|██████████| 5/5 [00:00<00:00,  6.65it/s, Val Loss=0.1875]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2564, F1: 0.1774, Acc: 0.3590
Time  - Epoch: 9.8s, LR: 2.50e-05
Progress: 34.0% | Best F1: 0.2892 | ETA: 6.9min

Epoch 18/50


CASME II Training Epoch 18/50: 100%|██████████| 164/164 [00:09<00:00, 17.72it/s, Loss=0.0002, LR=2.50e-05]
CASME II Validation Epoch 18/50: 100%|██████████| 5/5 [00:00<00:00,  6.52it/s, Val Loss=0.1885]


Train - Loss: 0.0002, F1: 0.9998, Acc: 0.9996
Val   - Loss: 0.2602, F1: 0.1794, Acc: 0.3590
Time  - Epoch: 10.0s, LR: 1.25e-05
Progress: 36.0% | Best F1: 0.2892 | ETA: 6.6min

Epoch 19/50


CASME II Training Epoch 19/50: 100%|██████████| 164/164 [00:09<00:00, 17.96it/s, Loss=0.0005, LR=1.25e-05]
CASME II Validation Epoch 19/50: 100%|██████████| 5/5 [00:00<00:00,  6.71it/s, Val Loss=0.2019]


Train - Loss: 0.0004, F1: 0.9994, Acc: 0.9989
Val   - Loss: 0.2656, F1: 0.1894, Acc: 0.3846
Time  - Epoch: 9.9s, LR: 1.25e-05
Progress: 38.0% | Best F1: 0.2892 | ETA: 6.3min

Epoch 20/50


CASME II Training Epoch 20/50: 100%|██████████| 164/164 [00:09<00:00, 17.54it/s, Loss=0.0001, LR=1.25e-05]
CASME II Validation Epoch 20/50: 100%|██████████| 5/5 [00:00<00:00,  6.47it/s, Val Loss=0.2042]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2731, F1: 0.1765, Acc: 0.3590
Time  - Epoch: 10.1s, LR: 1.25e-05
Progress: 40.0% | Best F1: 0.2892 | ETA: 6.1min

Epoch 21/50


CASME II Training Epoch 21/50: 100%|██████████| 164/164 [00:09<00:00, 17.75it/s, Loss=0.0002, LR=1.25e-05]
CASME II Validation Epoch 21/50: 100%|██████████| 5/5 [00:00<00:00,  6.49it/s, Val Loss=0.1958]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2643, F1: 0.1902, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 1.25e-05
Progress: 42.0% | Best F1: 0.2892 | ETA: 5.8min

Epoch 22/50


CASME II Training Epoch 22/50: 100%|██████████| 164/164 [00:09<00:00, 17.64it/s, Loss=0.0001, LR=1.25e-05]
CASME II Validation Epoch 22/50: 100%|██████████| 5/5 [00:00<00:00,  6.29it/s, Val Loss=0.2296]


Train - Loss: 0.0005, F1: 0.9994, Acc: 0.9992
Val   - Loss: 0.2933, F1: 0.1211, Acc: 0.3205
Time  - Epoch: 10.1s, LR: 1.25e-05
Progress: 44.0% | Best F1: 0.2892 | ETA: 5.6min

Epoch 23/50


CASME II Training Epoch 23/50: 100%|██████████| 164/164 [00:09<00:00, 17.88it/s, Loss=0.0001, LR=1.25e-05]
CASME II Validation Epoch 23/50: 100%|██████████| 5/5 [00:00<00:00,  6.55it/s, Val Loss=0.2065]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2748, F1: 0.1605, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 1.25e-05
Progress: 46.0% | Best F1: 0.2892 | ETA: 5.3min

Epoch 24/50


CASME II Training Epoch 24/50: 100%|██████████| 164/164 [00:09<00:00, 17.96it/s, Loss=0.0001, LR=1.25e-05]
CASME II Validation Epoch 24/50: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s, Val Loss=0.1973]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2643, F1: 0.1707, Acc: 0.3333
Time  - Epoch: 9.9s, LR: 6.25e-06
Progress: 48.0% | Best F1: 0.2892 | ETA: 5.1min

Epoch 25/50


CASME II Training Epoch 25/50: 100%|██████████| 164/164 [00:09<00:00, 17.89it/s, Loss=0.0001, LR=6.25e-06]
CASME II Validation Epoch 25/50: 100%|██████████| 5/5 [00:00<00:00,  6.41it/s, Val Loss=0.2077]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2760, F1: 0.1783, Acc: 0.3462
Time  - Epoch: 10.0s, LR: 6.25e-06
Progress: 50.0% | Best F1: 0.2892 | ETA: 4.9min

Epoch 26/50


CASME II Training Epoch 26/50: 100%|██████████| 164/164 [00:09<00:00, 17.66it/s, Loss=0.0001, LR=6.25e-06]
CASME II Validation Epoch 26/50: 100%|██████████| 5/5 [00:00<00:00,  6.34it/s, Val Loss=0.2015]


Train - Loss: 0.0002, F1: 0.9988, Acc: 0.9996
Val   - Loss: 0.2649, F1: 0.1963, Acc: 0.3974
Time  - Epoch: 10.1s, LR: 6.25e-06
Progress: 52.0% | Best F1: 0.2892 | ETA: 4.7min

Epoch 27/50


CASME II Training Epoch 27/50: 100%|██████████| 164/164 [00:09<00:00, 17.80it/s, Loss=0.0003, LR=6.25e-06]
CASME II Validation Epoch 27/50: 100%|██████████| 5/5 [00:00<00:00,  6.06it/s, Val Loss=0.2114]


Train - Loss: 0.0003, F1: 0.9985, Acc: 0.9989
Val   - Loss: 0.2800, F1: 0.1761, Acc: 0.3590
Time  - Epoch: 10.1s, LR: 6.25e-06
Progress: 54.0% | Best F1: 0.2892 | ETA: 4.4min

Epoch 28/50


CASME II Training Epoch 28/50: 100%|██████████| 164/164 [00:09<00:00, 17.61it/s, Loss=0.0001, LR=6.25e-06]
CASME II Validation Epoch 28/50: 100%|██████████| 5/5 [00:00<00:00,  6.34it/s, Val Loss=0.2180]


Train - Loss: 0.0002, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2795, F1: 0.1824, Acc: 0.3590
Time  - Epoch: 10.1s, LR: 6.25e-06
Progress: 56.0% | Best F1: 0.2892 | ETA: 4.2min

Epoch 29/50


CASME II Training Epoch 29/50: 100%|██████████| 164/164 [00:09<00:00, 17.59it/s, Loss=0.0001, LR=6.25e-06]
CASME II Validation Epoch 29/50: 100%|██████████| 5/5 [00:00<00:00,  6.27it/s, Val Loss=0.1985]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2662, F1: 0.1904, Acc: 0.3846
Time  - Epoch: 10.1s, LR: 6.25e-06
Progress: 58.0% | Best F1: 0.2892 | ETA: 4.0min

Epoch 30/50


CASME II Training Epoch 30/50: 100%|██████████| 164/164 [00:09<00:00, 17.77it/s, Loss=0.0001, LR=6.25e-06]
CASME II Validation Epoch 30/50: 100%|██████████| 5/5 [00:00<00:00,  6.39it/s, Val Loss=0.2065]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2660, F1: 0.1956, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 3.13e-06
Progress: 60.0% | Best F1: 0.2892 | ETA: 3.8min

Epoch 31/50


CASME II Training Epoch 31/50: 100%|██████████| 164/164 [00:09<00:00, 17.77it/s, Loss=0.0001, LR=3.13e-06]
CASME II Validation Epoch 31/50: 100%|██████████| 5/5 [00:00<00:00,  6.41it/s, Val Loss=0.1992]


Train - Loss: 0.0001, F1: 0.9998, Acc: 0.9996
Val   - Loss: 0.2664, F1: 0.1839, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 3.13e-06
Progress: 62.0% | Best F1: 0.2892 | ETA: 3.6min

Epoch 32/50


CASME II Training Epoch 32/50: 100%|██████████| 164/164 [00:09<00:00, 18.04it/s, Loss=0.0001, LR=3.13e-06]
CASME II Validation Epoch 32/50: 100%|██████████| 5/5 [00:00<00:00,  6.32it/s, Val Loss=0.2067]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2765, F1: 0.1918, Acc: 0.3718
Time  - Epoch: 9.9s, LR: 3.13e-06
Progress: 64.0% | Best F1: 0.2892 | ETA: 3.4min

Epoch 33/50


CASME II Training Epoch 33/50: 100%|██████████| 164/164 [00:09<00:00, 17.92it/s, Loss=0.0001, LR=3.13e-06]
CASME II Validation Epoch 33/50: 100%|██████████| 5/5 [00:00<00:00,  6.46it/s, Val Loss=0.2017]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2662, F1: 0.1920, Acc: 0.3718
Time  - Epoch: 9.9s, LR: 3.13e-06
Progress: 66.0% | Best F1: 0.2892 | ETA: 3.2min

Epoch 34/50


CASME II Training Epoch 34/50: 100%|██████████| 164/164 [00:09<00:00, 17.78it/s, Loss=0.0001, LR=3.13e-06]
CASME II Validation Epoch 34/50: 100%|██████████| 5/5 [00:00<00:00,  6.40it/s, Val Loss=0.2100]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2821, F1: 0.1891, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 3.13e-06
Progress: 68.0% | Best F1: 0.2892 | ETA: 3.0min

Epoch 35/50


CASME II Training Epoch 35/50: 100%|██████████| 164/164 [00:09<00:00, 17.75it/s, Loss=0.0001, LR=3.13e-06]
CASME II Validation Epoch 35/50: 100%|██████████| 5/5 [00:00<00:00,  6.35it/s, Val Loss=0.2170]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2850, F1: 0.1880, Acc: 0.3590
Time  - Epoch: 10.0s, LR: 3.13e-06
Progress: 70.0% | Best F1: 0.2892 | ETA: 2.8min

Epoch 36/50


CASME II Training Epoch 36/50: 100%|██████████| 164/164 [00:09<00:00, 17.85it/s, Loss=0.0001, LR=3.13e-06]
CASME II Validation Epoch 36/50: 100%|██████████| 5/5 [00:00<00:00,  6.26it/s, Val Loss=0.2064]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2767, F1: 0.1945, Acc: 0.3846
Time  - Epoch: 10.0s, LR: 1.56e-06
Progress: 72.0% | Best F1: 0.2892 | ETA: 2.6min

Epoch 37/50


CASME II Training Epoch 37/50: 100%|██████████| 164/164 [00:09<00:00, 17.75it/s, Loss=0.0001, LR=1.56e-06]
CASME II Validation Epoch 37/50: 100%|██████████| 5/5 [00:00<00:00,  6.20it/s, Val Loss=0.2002]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2729, F1: 0.1958, Acc: 0.3974
Time  - Epoch: 10.1s, LR: 1.56e-06
Progress: 74.0% | Best F1: 0.2892 | ETA: 2.4min

Epoch 38/50


CASME II Training Epoch 38/50: 100%|██████████| 164/164 [00:09<00:00, 17.95it/s, Loss=0.0001, LR=1.56e-06]
CASME II Validation Epoch 38/50: 100%|██████████| 5/5 [00:00<00:00,  6.20it/s, Val Loss=0.2147]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2908, F1: 0.1857, Acc: 0.3974
Time  - Epoch: 10.0s, LR: 1.56e-06
Progress: 76.0% | Best F1: 0.2892 | ETA: 2.2min

Epoch 39/50


CASME II Training Epoch 39/50: 100%|██████████| 164/164 [00:09<00:00, 17.55it/s, Loss=0.0001, LR=1.56e-06]
CASME II Validation Epoch 39/50: 100%|██████████| 5/5 [00:00<00:00,  6.57it/s, Val Loss=0.2176]


Train - Loss: 0.0002, F1: 0.9994, Acc: 0.9992
Val   - Loss: 0.2841, F1: 0.1853, Acc: 0.3590
Time  - Epoch: 10.1s, LR: 1.56e-06
Progress: 78.0% | Best F1: 0.2892 | ETA: 2.0min

Epoch 40/50


CASME II Training Epoch 40/50: 100%|██████████| 164/164 [00:09<00:00, 17.56it/s, Loss=0.0001, LR=1.56e-06]
CASME II Validation Epoch 40/50: 100%|██████████| 5/5 [00:00<00:00,  6.47it/s, Val Loss=0.2127]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2898, F1: 0.1921, Acc: 0.3846
Time  - Epoch: 10.1s, LR: 1.56e-06
Progress: 80.0% | Best F1: 0.2892 | ETA: 1.8min

Epoch 41/50


CASME II Training Epoch 41/50: 100%|██████████| 164/164 [00:09<00:00, 18.14it/s, Loss=0.0001, LR=1.56e-06]
CASME II Validation Epoch 41/50: 100%|██████████| 5/5 [00:00<00:00,  6.13it/s, Val Loss=0.2085]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2869, F1: 0.1786, Acc: 0.3718
Time  - Epoch: 9.9s, LR: 1.56e-06
Progress: 82.0% | Best F1: 0.2892 | ETA: 1.7min

Epoch 42/50


CASME II Training Epoch 42/50: 100%|██████████| 164/164 [00:09<00:00, 17.83it/s, Loss=0.0002, LR=1.56e-06]
CASME II Validation Epoch 42/50: 100%|██████████| 5/5 [00:00<00:00,  6.34it/s, Val Loss=0.2163]


Train - Loss: 0.0002, F1: 0.9997, Acc: 0.9996
Val   - Loss: 0.2893, F1: 0.1828, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 1.00e-06
Progress: 84.0% | Best F1: 0.2892 | ETA: 1.5min

Epoch 43/50


CASME II Training Epoch 43/50: 100%|██████████| 164/164 [00:09<00:00, 17.98it/s, Loss=0.0001, LR=1.00e-06]
CASME II Validation Epoch 43/50: 100%|██████████| 5/5 [00:00<00:00,  6.56it/s, Val Loss=0.2056]


Train - Loss: 0.0001, F1: 0.9997, Acc: 0.9996
Val   - Loss: 0.2796, F1: 0.1913, Acc: 0.3846
Time  - Epoch: 9.9s, LR: 1.00e-06
Progress: 86.0% | Best F1: 0.2892 | ETA: 1.3min

Epoch 44/50


CASME II Training Epoch 44/50: 100%|██████████| 164/164 [00:09<00:00, 17.79it/s, Loss=0.0001, LR=1.00e-06]
CASME II Validation Epoch 44/50: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s, Val Loss=0.2073]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2774, F1: 0.1828, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 1.00e-06
Progress: 88.0% | Best F1: 0.2892 | ETA: 1.1min

Epoch 45/50


CASME II Training Epoch 45/50: 100%|██████████| 164/164 [00:09<00:00, 17.86it/s, Loss=0.0001, LR=1.00e-06]
CASME II Validation Epoch 45/50: 100%|██████████| 5/5 [00:00<00:00,  6.43it/s, Val Loss=0.2116]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2826, F1: 0.1771, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 1.00e-06
Progress: 90.0% | Best F1: 0.2892 | ETA: 0.9min

Epoch 46/50


CASME II Training Epoch 46/50: 100%|██████████| 164/164 [00:09<00:00, 17.79it/s, Loss=0.0001, LR=1.00e-06]
CASME II Validation Epoch 46/50: 100%|██████████| 5/5 [00:00<00:00,  6.52it/s, Val Loss=0.2192]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2888, F1: 0.1866, Acc: 0.3718
Time  - Epoch: 10.0s, LR: 1.00e-06
Progress: 92.0% | Best F1: 0.2892 | ETA: 0.7min

Epoch 47/50


CASME II Training Epoch 47/50: 100%|██████████| 164/164 [00:09<00:00, 17.86it/s, Loss=0.0001, LR=1.00e-06]
CASME II Validation Epoch 47/50: 100%|██████████| 5/5 [00:00<00:00,  6.07it/s, Val Loss=0.2234]


Train - Loss: 0.0003, F1: 0.9997, Acc: 0.9996
Val   - Loss: 0.2909, F1: 0.1976, Acc: 0.3846
Time  - Epoch: 10.0s, LR: 1.00e-06
Progress: 94.0% | Best F1: 0.2892 | ETA: 0.5min

Epoch 48/50


CASME II Training Epoch 48/50: 100%|██████████| 164/164 [00:09<00:00, 17.65it/s, Loss=0.0001, LR=1.00e-06]
CASME II Validation Epoch 48/50: 100%|██████████| 5/5 [00:00<00:00,  6.17it/s, Val Loss=0.2077]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2796, F1: 0.1803, Acc: 0.3718
Time  - Epoch: 10.1s, LR: 1.00e-06
Progress: 96.0% | Best F1: 0.2892 | ETA: 0.4min

Epoch 49/50


CASME II Training Epoch 49/50: 100%|██████████| 164/164 [00:09<00:00, 17.54it/s, Loss=0.0000, LR=1.00e-06]
CASME II Validation Epoch 49/50: 100%|██████████| 5/5 [00:00<00:00,  6.36it/s, Val Loss=0.2186]


Train - Loss: 0.0006, F1: 0.9994, Acc: 0.9992
Val   - Loss: 0.2904, F1: 0.1931, Acc: 0.3846
Time  - Epoch: 10.2s, LR: 1.00e-06
Progress: 98.0% | Best F1: 0.2892 | ETA: 0.2min

Epoch 50/50


CASME II Training Epoch 50/50: 100%|██████████| 164/164 [00:09<00:00, 17.65it/s, Loss=0.0001, LR=1.00e-06]
CASME II Validation Epoch 50/50: 100%|██████████| 5/5 [00:00<00:00,  6.19it/s, Val Loss=0.2169]


Train - Loss: 0.0001, F1: 1.0000, Acc: 1.0000
Val   - Loss: 0.2913, F1: 0.1924, Acc: 0.3846
Time  - Epoch: 10.1s, LR: 1.00e-06
Progress: 100.0% | Best F1: 0.2892 | ETA: 0.0min

CASME II MOBILENETV3-SMALL BASELINE TRAINING COMPLETED
Training time: 9.1 minutes
Epochs completed: 50
Best validation F1: 0.2892 (epoch 6)
Final train F1: 1.0000
Final validation F1: 0.1924

Exporting enhanced training documentation...
Enhanced training documentation saved successfully: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/08_01_mobilenet_casme2_mfs/training_logs/casme2_mobilenet_mfs_training_history.json
Model: MobileNetV3-Small
Methodology: M1
Input resolution: 640x480 RGB
Loss function: Optimized Focal Loss

Next: Cell 3 - CASME II MobileNetV3-Small Evaluation
Enhanced training pipeline with multi-frame temporal windows completed successfully!


In [3]:
# @title Cell 3: CASME II MobileNetV3-Small Evaluation (Configurable)

# File: 08_01_MobileNet_CASME2_MFS_Cell3.py
# Location: experiments/08_01_MobileNet_CASME2-MFS.ipynb
# Purpose: Configurable evaluation framework for KFS and AF test sets

# CONFIGURATION: Choose which test version to evaluate
# Options: 'kfs', 'af', or 'both'
# - 'kfs': Evaluate Key-Frame Sampling test set only (84 samples)
# - 'af': Evaluate Apex-Frame test set only (28 samples)
# - 'both': Evaluate both KFS and AF sequentially

import os
import time
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime

from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import pickle
import warnings
warnings.filterwarnings('ignore')

print("CASME II MobileNetV3-Small Evaluation Framework")
print("=" * 60)

TEST_VERSION_TO_EVALUATE = 'both'

print(f"Evaluation Configuration: {TEST_VERSION_TO_EVALUATE.upper()}")
if TEST_VERSION_TO_EVALUATE == 'both':
    print("  Strategy: Sequential KFS → AF")
    print("  Expected: ~84 samples (KFS) + ~28 samples (AF)")
elif TEST_VERSION_TO_EVALUATE == 'kfs':
    print("  Strategy: KFS only")
    print("  Expected: ~84 samples (Key-Frame Sampling)")
elif TEST_VERSION_TO_EVALUATE == 'af':
    print("  Strategy: AF only")
    print("  Expected: ~28 samples (Apex-Frame)")
else:
    raise ValueError(f"Invalid TEST_VERSION_TO_EVALUATE: {TEST_VERSION_TO_EVALUATE}. Must be 'kfs', 'af', or 'both'")
print("=" * 60)

def get_test_dataset_config(version, project_root):
    if version == 'kfs':
        config = {
            'version': 'kfs',
            'phase': 'Phase 2',
            'dataset_path': f"{project_root}/datasets/processed_casme2/data_split_v2",
            'metadata_file': 'split_metadata_v2.json',
            'processing_summary': 'processing_summary_v2.json',
            'description': 'Key-frames (onset, apex, offset)',
            'expected_samples': 84,
            'frame_types': ['onset', 'apex', 'offset']
        }
    elif version == 'af':
        config = {
            'version': 'af',
            'phase': 'Phase 1',
            'dataset_path': f"{project_root}/datasets/processed_casme2/data_split_v1",
            'metadata_file': 'split_metadata.json',
            'processing_summary': 'processing_summary.json',
            'description': 'Apex-only frames',
            'expected_samples': 28,
            'frame_types': ['apex']
        }
    else:
        raise ValueError(f"Invalid version: {version}. Must be 'kfs' or 'af'")

    return config

class CASME2DatasetEvaluation(Dataset):
    def __init__(self, split_metadata, dataset_root, transform=None, split='test', use_ram_cache=True):
        self.metadata = split_metadata[split]['samples']
        self.dataset_root = dataset_root
        self.transform = transform
        self.split = split
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.sample_ids = []
        self.emotions = []
        self.subjects = []
        self.cached_images = []

        print(f"Loading CASME II {split} dataset for evaluation...")

        for sample in self.metadata:
            image_path = os.path.join(dataset_root, split, sample['image_filename'])
            self.images.append(image_path)
            self.labels.append(CLASS_TO_IDX[sample['emotion']])
            self.sample_ids.append(sample['sample_id'])
            self.emotions.append(sample['emotion'])
            self.subjects.append(sample['subject'])

        print(f"Loaded {len(self.images)} CASME II {split} samples for evaluation")
        self._print_evaluation_distribution()

        if self.use_ram_cache:
            self._preload_to_ram_evaluation()

    def _print_evaluation_distribution(self):
        if len(self.labels) == 0:
            print("No test samples found!")
            return

        label_counts = {}
        subject_counts = {}

        for label, subject in zip(self.labels, self.subjects):
            label_counts[label] = label_counts.get(label, 0) + 1
            subject_counts[subject] = subject_counts.get(subject, 0) + 1

        print("Test set class distribution:")
        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

        print(f"Test set covers {len(subject_counts)} subjects")

        missing_classes = []
        for i, class_name in enumerate(CASME2_CLASSES):
            if i not in label_counts:
                missing_classes.append(class_name)

        if missing_classes:
            print(f"Missing classes in test set: {missing_classes}")

    def _preload_to_ram_evaluation(self):
        print(f"Preloading {len(self.images)} test images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            try:
                image = Image.open(img_path).convert('RGB')
                if image.size != (640, 480):
                    image = image.resize((640, 480), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('RGB', (640, 480), (128, 128, 128)), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc="Loading test images to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * 640 * 480 * 3 * 4 / 1e9
        print(f"Test RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx]).convert('RGB')
                if image.size != (640, 480):
                    image = image.resize((640, 480), Image.Resampling.LANCZOS)
            except:
                image = Image.new('RGB', (640, 480), (128, 128, 128))

        if self.transform:
            image = self.transform(image)

        return (image, self.labels[idx], self.sample_ids[idx],
                self.emotions[idx], self.subjects[idx], os.path.basename(self.images[idx]))

EVALUATION_CONFIG_CASME2 = {
    'model_type': 'MobileNetV3Small_CASME2_MFS_Baseline',
    'task_type': 'micro_expression_recognition',
    'num_classes': 7,
    'class_names': CASME2_CLASSES,
    'checkpoint_file': 'casme2_mobilenet_mfs_best_f1.pth',
    'dataset_name': 'CASME_II',
    'methodology': 'M1',
    'input_resolution': '640x480 RGB',
    'evaluation_protocol': 'stratified_split'
}

print(f"\nCASME II MobileNetV3 Evaluation Configuration:")
print(f"  Model: {EVALUATION_CONFIG_CASME2['model_type']}")
print(f"  Methodology: {EVALUATION_CONFIG_CASME2['methodology']}")
print(f"  Input resolution: {EVALUATION_CONFIG_CASME2['input_resolution']}")
print(f"  Classes: {EVALUATION_CONFIG_CASME2['class_names']}")

def extract_logits_safe_casme2(outputs_all):
    if isinstance(outputs_all, torch.Tensor):
        return outputs_all
    if isinstance(outputs_all, (tuple, list)):
        for item in outputs_all:
            if isinstance(item, torch.Tensor):
                return item
    if isinstance(outputs_all, dict):
        for key in ('logits', 'logit', 'predictions', 'outputs', 'scores'):
            value = outputs_all.get(key)
            if isinstance(value, torch.Tensor):
                return value
        for value in outputs_all.values():
            if isinstance(value, torch.Tensor):
                return value
    raise RuntimeError("Unable to extract tensor logits from model output")

def load_trained_model_casme2(checkpoint_path, device):
    print(f"Loading trained CASME II MobileNetV3-Small model from: {checkpoint_path}")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    checkpoint = None
    loading_method = "unknown"

    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        loading_method = "standard"
    except Exception as e1:
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            loading_method = "weights_only_false"
        except Exception as e2:
            try:
                import pickle
                with open(checkpoint_path, 'rb') as f:
                    checkpoint = pickle.load(f)
                loading_method = "pickle"
            except Exception as e3:
                raise RuntimeError(f"All loading methods failed: {e1}, {e2}, {e3}")

    print(f"Checkpoint loaded using: {loading_method}")

    model = MobileNetCASME2Baseline(
        num_classes=EVALUATION_CONFIG_CASME2['num_classes'],
        dropout_rate=CASME2_MOBILENET_CONFIG['dropout_rate']
    ).to(device)

    state_dict = checkpoint.get('model_state_dict', checkpoint)

    try:
        model.load_state_dict(state_dict, strict=True)
        print("Model state loaded with strict=True")
    except Exception as e:
        print(f"Strict loading failed, trying non-strict: {str(e)[:100]}...")
        try:
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            if missing_keys or unexpected_keys:
                print(f"Non-strict loading: Missing {len(missing_keys)}, Unexpected {len(unexpected_keys)}")
            else:
                print("Model state loaded with strict=False (no key mismatches)")
        except Exception as e2:
            raise RuntimeError(f"Both loading approaches failed: {e2}")

    model.eval()

    training_info = {
        'best_val_f1': float(checkpoint.get('best_f1', 0.0)),
        'best_val_loss': float(checkpoint.get('best_loss', float('inf'))),
        'best_val_accuracy': float(checkpoint.get('best_acc', 0.0)),
        'best_epoch': int(checkpoint.get('epoch', 0)) + 1,
        'model_checkpoint': EVALUATION_CONFIG_CASME2['checkpoint_file'],
        'num_classes': EVALUATION_CONFIG_CASME2['num_classes'],
        'config': checkpoint.get('casme2_config', {})
    }

    print(f"Model loaded successfully:")
    print(f"  Best validation F1: {training_info['best_val_f1']:.4f}")
    print(f"  Best validation accuracy: {training_info['best_val_accuracy']:.4f}")
    print(f"  Best epoch: {training_info['best_epoch']}")

    return model, training_info

def run_model_inference_casme2(model, test_loader, device, test_version):
    print(f"Running CASME II MobileNetV3 inference on {test_version.upper()} test set...")

    model.eval()
    all_predictions = []
    all_probabilities = []
    all_labels = []
    all_sample_ids = []
    all_emotions = []
    all_subjects = []
    all_filenames = []

    inference_start = time.time()

    with torch.no_grad():
        for batch_idx, (images, labels, sample_ids, emotions, subjects, filenames) in enumerate(
            tqdm(test_loader, desc=f"CASME II Inference ({test_version.upper()})")):

            images = images.to(device)

            try:
                outputs_raw = model(images)
                outputs = extract_logits_safe_casme2(outputs_raw)
            except Exception as e:
                print(f"Error in model forward pass: {e}")
                outputs = model(images)
                if not isinstance(outputs, torch.Tensor):
                    outputs = outputs[0] if isinstance(outputs, (list, tuple)) else outputs

            if outputs.shape[1] != 7:
                print(f"Warning: Expected 7 classes output, got {outputs.shape[1]}")

            probabilities = torch.softmax(outputs, dim=1)
            predictions = torch.argmax(probabilities, dim=1)

            all_predictions.extend(predictions.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_sample_ids.extend(sample_ids)
            all_emotions.extend(emotions)
            all_subjects.extend(subjects)
            all_filenames.extend(filenames)

    inference_time = time.time() - inference_start

    print(f"CASME II inference completed: {len(all_predictions)} samples in {inference_time:.2f}s")

    predictions_array = np.array(all_predictions)
    labels_array = np.array(all_labels)

    unique_predictions, pred_counts = np.unique(predictions_array, return_counts=True)
    print(f"Predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    unique_labels, label_counts = np.unique(labels_array, return_counts=True)
    print(f"True classes in test: {[CASME2_CLASSES[i] for i in unique_labels]}")

    return {
        'predictions': predictions_array,
        'probabilities': np.array(all_probabilities),
        'labels': labels_array,
        'sample_ids': all_sample_ids,
        'emotions': all_emotions,
        'subjects': all_subjects,
        'filenames': all_filenames,
        'inference_time': inference_time,
        'samples_count': len(predictions_array)
    }

def analyze_wrong_predictions_casme2(inference_results, test_version):
    print(f"Analyzing wrong predictions for {test_version.upper()}...")

    predictions = inference_results['predictions']
    labels = inference_results['labels']
    sample_ids = inference_results['sample_ids']
    emotions = inference_results['emotions']
    subjects = inference_results['subjects']
    filenames = inference_results['filenames']

    wrong_mask = predictions != labels
    wrong_indices = np.where(wrong_mask)[0]

    wrong_predictions_by_class = {}
    subject_error_analysis = {}

    for class_name in CASME2_CLASSES:
        wrong_predictions_by_class[class_name] = []

    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        sample_id = sample_ids[idx]
        emotion = emotions[idx]
        subject = subjects[idx]
        filename = filenames[idx]

        true_class = CASME2_CLASSES[true_label]
        pred_class = CASME2_CLASSES[pred_label]

        wrong_info = {
            'sample_id': sample_id,
            'filename': filename,
            'subject': subject,
            'true_label': int(true_label),
            'true_class': true_class,
            'predicted_label': int(pred_label),
            'predicted_class': pred_class,
            'emotion': emotion
        }

        wrong_predictions_by_class[true_class].append(wrong_info)

        if subject not in subject_error_analysis:
            subject_error_analysis[subject] = {'total': 0, 'wrong': 0, 'errors': []}
        subject_error_analysis[subject]['wrong'] += 1
        subject_error_analysis[subject]['errors'].append(wrong_info)

    for subject in subjects:
        if subject in subject_error_analysis:
            subject_error_analysis[subject]['total'] += 1
        else:
            subject_error_analysis[subject] = {'total': 1, 'wrong': 0, 'errors': []}

    for subject in subject_error_analysis:
        total = subject_error_analysis[subject]['total']
        wrong = subject_error_analysis[subject]['wrong']
        subject_error_analysis[subject]['error_rate'] = wrong / total if total > 0 else 0.0

    total_wrong = len(wrong_indices)
    total_samples = len(predictions)
    error_rate = (total_wrong / total_samples) * 100

    confusion_patterns = {}
    for idx in wrong_indices:
        true_label = labels[idx]
        pred_label = predictions[idx]
        pattern = f"{CASME2_CLASSES[true_label]}_to_{CASME2_CLASSES[pred_label]}"
        confusion_patterns[pattern] = confusion_patterns.get(pattern, 0) + 1

    analysis_results = {
        'analysis_metadata': {
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'model_type': EVALUATION_CONFIG_CASME2['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2['dataset_name'],
            'test_version': test_version,
            'methodology': EVALUATION_CONFIG_CASME2['methodology'],
            'total_samples': int(total_samples),
            'total_wrong_predictions': int(total_wrong),
            'overall_error_rate': float(error_rate)
        },
        'wrong_predictions_by_class': wrong_predictions_by_class,
        'subject_error_analysis': subject_error_analysis,
        'confusion_patterns': confusion_patterns,
        'error_summary': {
            class_name: len(wrong_predictions_by_class[class_name])
            for class_name in CASME2_CLASSES
        }
    }

    return analysis_results

def calculate_comprehensive_metrics_casme2(inference_results, test_version, test_config):
    print(f"Calculating comprehensive metrics for {test_version.upper()}...")

    predictions = inference_results['predictions']
    probabilities = inference_results['probabilities']
    labels = inference_results['labels']

    if len(predictions) == 0:
        raise ValueError("No predictions to evaluate!")

    unique_test_labels = sorted(np.unique(labels))
    unique_predictions = sorted(np.unique(predictions))

    print(f"Test set contains labels: {[CASME2_CLASSES[i] for i in unique_test_labels]}")
    print(f"Model predicted classes: {[CASME2_CLASSES[i] for i in unique_predictions]}")

    accuracy = accuracy_score(labels, predictions)

    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions, labels=unique_test_labels, average='macro', zero_division=0
    )

    print(f"Macro F1 (available classes): {f1:.4f}")

    precision_per_class, recall_per_class, f1_per_class, support_per_class = precision_recall_fscore_support(
        labels, predictions, labels=range(7), average=None, zero_division=0
    )

    cm = confusion_matrix(labels, predictions, labels=range(7))

    auc_scores = {}
    fpr_dict = {}
    tpr_dict = {}

    try:
        labels_binarized = label_binarize(labels, classes=range(7))

        for i, class_name in enumerate(CASME2_CLASSES):
            if i in unique_test_labels and len(np.unique(labels_binarized[:, i])) > 1:
                fpr, tpr, _ = roc_curve(labels_binarized[:, i], probabilities[:, i])
                auc_score = auc(fpr, tpr)
                auc_scores[class_name] = float(auc_score)
                fpr_dict[class_name] = fpr.tolist()
                tpr_dict[class_name] = tpr.tolist()
            else:
                auc_scores[class_name] = 0.0
                fpr_dict[class_name] = [0.0, 1.0]
                tpr_dict[class_name] = [0.0, 0.0]

        available_auc_scores = [auc_scores[CASME2_CLASSES[i]] for i in unique_test_labels]
        macro_auc = float(np.mean(available_auc_scores)) if available_auc_scores else 0.0

    except Exception as e:
        print(f"Warning: AUC calculation failed: {e}")
        auc_scores = {class_name: 0.0 for class_name in CASME2_CLASSES}
        macro_auc = 0.0

    subjects = inference_results['subjects']
    subject_performance = {}

    for subject in set(subjects):
        subject_mask = [s == subject for s in subjects]
        subject_predictions = predictions[subject_mask]
        subject_labels = labels[subject_mask]

        if len(subject_predictions) > 0:
            subject_acc = accuracy_score(subject_labels, subject_predictions)
            subject_performance[subject] = {
                'accuracy': float(subject_acc),
                'samples': int(len(subject_predictions)),
                'correct': int(np.sum(subject_predictions == subject_labels))
            }

    comprehensive_results = {
        'evaluation_metadata': {
            'model_type': EVALUATION_CONFIG_CASME2['model_type'],
            'dataset': EVALUATION_CONFIG_CASME2['dataset_name'],
            'methodology': EVALUATION_CONFIG_CASME2['methodology'],
            'input_resolution': EVALUATION_CONFIG_CASME2['input_resolution'],
            'test_version': test_version,
            'test_phase': test_config['phase'],
            'test_description': test_config['description'],
            'test_frame_types': test_config['frame_types'],
            'evaluation_timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'num_classes': EVALUATION_CONFIG_CASME2['num_classes'],
            'class_names': EVALUATION_CONFIG_CASME2['class_names'],
            'test_samples': int(len(labels)),
            'available_classes': [CASME2_CLASSES[i] for i in unique_test_labels],
            'missing_classes': [CASME2_CLASSES[i] for i in range(7) if i not in unique_test_labels]
        },

        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_precision': float(precision),
            'macro_recall': float(recall),
            'macro_f1': float(f1),
            'macro_auc': macro_auc
        },

        'per_class_performance': {},

        'confusion_matrix': cm.tolist(),

        'subject_level_performance': subject_performance,

        'roc_analysis': {
            'auc_scores': auc_scores,
            'fpr_curves': fpr_dict,
            'tpr_curves': tpr_dict
        },

        'inference_performance': {
            'total_time_seconds': float(inference_results['inference_time']),
            'average_time_ms_per_sample': float(inference_results['inference_time'] * 1000 / len(labels))
        }
    }

    for i, class_name in enumerate(CASME2_CLASSES):
        comprehensive_results['per_class_performance'][class_name] = {
            'precision': float(precision_per_class[i]),
            'recall': float(recall_per_class[i]),
            'f1_score': float(f1_per_class[i]),
            'support': int(support_per_class[i]),
            'auc': auc_scores[class_name],
            'in_test_set': i in unique_test_labels
        }

    return comprehensive_results

def save_evaluation_results_casme2(evaluation_results, wrong_predictions_results, results_dir, test_version):
    os.makedirs(results_dir, exist_ok=True)

    results_file = f"{results_dir}/casme2_mobilenet_mfs_evaluation_results_{test_version}.json"
    with open(results_file, 'w') as f:
        json.dump(evaluation_results, f, indent=2, default=str)

    wrong_predictions_file = f"{results_dir}/casme2_mobilenet_mfs_wrong_predictions_{test_version}.json"
    with open(wrong_predictions_file, 'w') as f:
        json.dump(wrong_predictions_results, f, indent=2, default=str)

    print(f"Evaluation results saved:")
    print(f"  Main results: {os.path.basename(results_file)}")
    print(f"  Wrong predictions: {os.path.basename(wrong_predictions_file)}")

    return results_file, wrong_predictions_file

def normalize_metadata_structure(metadata):
    if 'splits' in metadata:
        return metadata['splits']
    elif 'train' in metadata or 'test' in metadata:
        return metadata
    else:
        raise ValueError("Unknown metadata format")

def display_evaluation_summary(evaluation_results, wrong_predictions_results, training_info, test_version):
    print("\n" + "=" * 60)
    print(f"CASME II MOBILENETV3-SMALL {test_version.upper()} EVALUATION RESULTS")
    print("=" * 60)

    overall = evaluation_results['overall_performance']
    meta = evaluation_results['evaluation_metadata']

    print(f"Test Dataset: {meta['test_description']} ({test_version.upper()})")
    print(f"Methodology: {meta['methodology']}")
    print(f"Input Resolution: {meta['input_resolution']}")

    print(f"\nOverall Performance (Macro - Available Classes):")
    print(f"  Accuracy:  {overall['accuracy']:.4f}")
    print(f"  Precision: {overall['macro_precision']:.4f}")
    print(f"  Recall:    {overall['macro_recall']:.4f}")
    print(f"  F1 Score:  {overall['macro_f1']:.4f}")
    print(f"  AUC:       {overall['macro_auc']:.4f}")

    print(f"\nPer-Class Performance:")
    for class_name, metrics in evaluation_results['per_class_performance'].items():
        in_test = "Present" if metrics['in_test_set'] else "Missing"
        print(f"  {class_name} [{in_test}]: F1={metrics['f1_score']:.4f}, "
              f"AUC={metrics['auc']:.4f}, Support={metrics['support']}")

    print(f"\nTraining vs Test Performance:")
    training_f1 = training_info['best_val_f1']
    training_acc = training_info['best_val_accuracy']
    test_f1 = overall['macro_f1']
    test_acc = overall['accuracy']

    print(f"  Training Val F1:  {training_f1:.4f}")
    print(f"  Test F1:          {test_f1:.4f}")
    print(f"  F1 Difference:    {training_f1 - test_f1:+.4f}")
    print(f"  Training Val Acc: {training_acc:.4f}")
    print(f"  Test Accuracy:    {test_acc:.4f}")
    print(f"  Acc Difference:   {training_acc - test_acc:+.4f}")

    if wrong_predictions_results:
        wrong_meta = wrong_predictions_results['analysis_metadata']
        print(f"\nWrong Predictions Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']} / {wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        print(f"\n  Errors by True Class:")
        for class_name, error_count in wrong_predictions_results['error_summary'].items():
            if error_count > 0:
                print(f"    {class_name}: {error_count} errors")

        patterns = sorted(wrong_predictions_results['confusion_patterns'].items(),
                         key=lambda x: x[1], reverse=True)
        if patterns:
            print(f"\n  Top Confusion Patterns:")
            for pattern, count in patterns[:3]:
                print(f"    {pattern}: {count} cases")

    print(f"\nInference Performance:")
    print(f"  Total time: {evaluation_results['inference_performance']['total_time_seconds']:.2f}s")
    print(f"  Speed: {evaluation_results['inference_performance']['average_time_ms_per_sample']:.1f} ms/sample")

    print(f"\nTest Dataset Info:")
    print(f"  Missing classes: {meta['missing_classes']}")

print("\n" + "=" * 60)
if TEST_VERSION_TO_EVALUATE == 'both':
    print("SEQUENTIAL EVALUATION: KFS → AF")
else:
    print(f"SINGLE EVALUATION: {TEST_VERSION_TO_EVALUATE.upper()}")
print("=" * 60)

checkpoint_path = f"{GLOBAL_CONFIG_CASME2['checkpoint_root']}/{EVALUATION_CONFIG_CASME2['checkpoint_file']}"
casme2_model, training_info = load_trained_model_casme2(checkpoint_path, GLOBAL_CONFIG_CASME2['device'])

results_dir = f"{GLOBAL_CONFIG_CASME2['results_root']}/evaluation_results"
all_results = {}

if TEST_VERSION_TO_EVALUATE == 'both':
    versions_to_evaluate = ['kfs', 'af']
elif TEST_VERSION_TO_EVALUATE in ['kfs', 'af']:
    versions_to_evaluate = [TEST_VERSION_TO_EVALUATE]
else:
    raise ValueError(f"Invalid TEST_VERSION_TO_EVALUATE: {TEST_VERSION_TO_EVALUATE}")

print(f"\nVersions to evaluate: {[v.upper() for v in versions_to_evaluate]}")

for test_version in versions_to_evaluate:
    print("\n" + "=" * 60)
    print(f"EVALUATING: {test_version.upper()}")
    print("=" * 60)

    try:
        test_config = get_test_dataset_config(test_version, PROJECT_ROOT)

        print(f"\nTest Dataset Configuration:")
        print(f"  Version: {test_config['version']}")
        print(f"  Phase: {test_config['phase']}")
        print(f"  Description: {test_config['description']}")
        print(f"  Expected samples: {test_config['expected_samples']}")
        print(f"  Frame types: {test_config['frame_types']}")

        test_metadata_path = f"{test_config['dataset_path']}/{test_config['metadata_file']}"

        if not os.path.exists(test_metadata_path):
            print(f"WARNING: Test metadata not found: {test_metadata_path}")
            print(f"Skipping {test_version.upper()} evaluation")
            continue

        with open(test_metadata_path, 'r') as f:
            test_metadata = json.load(f)

        normalized_test_metadata = normalize_metadata_structure(test_metadata)

        if 'test' not in normalized_test_metadata:
            print(f"WARNING: Test split not found in metadata for {test_version.upper()}")
            print(f"Skipping {test_version.upper()} evaluation")
            continue

        actual_test_samples = len(normalized_test_metadata['test']['samples'])
        print(f"Loaded {actual_test_samples} test samples (expected: {test_config['expected_samples']})")

        casme2_test_dataset = CASME2DatasetEvaluation(
            split_metadata=normalized_test_metadata,
            dataset_root=test_config['dataset_path'],
            transform=GLOBAL_CONFIG_CASME2['transform_val'],
            split='test',
            use_ram_cache=True
        )

        if len(casme2_test_dataset) == 0:
            print(f"WARNING: No test samples found for {test_version.upper()}")
            print(f"Skipping {test_version.upper()} evaluation")
            continue

        casme2_test_loader = DataLoader(
            casme2_test_dataset,
            batch_size=CASME2_MOBILENET_CONFIG['batch_size'],
            shuffle=False,
            num_workers=CASME2_MOBILENET_CONFIG['num_workers'],
            pin_memory=True
        )

        inference_results = run_model_inference_casme2(
            casme2_model, casme2_test_loader, GLOBAL_CONFIG_CASME2['device'], test_version
        )

        evaluation_results = calculate_comprehensive_metrics_casme2(
            inference_results, test_version, test_config
        )

        wrong_predictions_results = analyze_wrong_predictions_casme2(
            inference_results, test_version
        )

        evaluation_results['training_information'] = training_info

        results_file, wrong_file = save_evaluation_results_casme2(
            evaluation_results, wrong_predictions_results, results_dir, test_version
        )

        display_evaluation_summary(
            evaluation_results, wrong_predictions_results, training_info, test_version
        )

        all_results[test_version] = {
            'evaluation': evaluation_results,
            'wrong_predictions': wrong_predictions_results,
            'files': {'main': results_file, 'wrong': wrong_file}
        }

        print(f"\n{test_version.upper()} evaluation completed successfully!")

    except Exception as e:
        print(f"ERROR in {test_version.upper()} evaluation: {e}")
        import traceback
        traceback.print_exc()
        print(f"Continuing to next evaluation...")

print("\n" + "=" * 60)
if TEST_VERSION_TO_EVALUATE == 'both':
    print("SEQUENTIAL EVALUATION COMPLETED")
else:
    print(f"{TEST_VERSION_TO_EVALUATE.upper()} EVALUATION COMPLETED")
print("=" * 60)

if all_results:
    print(f"\nEvaluated datasets: {[v.upper() for v in all_results.keys()]}")

    print(f"\nPerformance Summary:")
    for version, results in all_results.items():
        overall = results['evaluation']['overall_performance']
        print(f"\n{version.upper()}:")
        print(f"  Accuracy:  {overall['accuracy']:.4f}")
        print(f"  Macro F1:  {overall['macro_f1']:.4f}")
        print(f"  Macro AUC: {overall['macro_auc']:.4f}")

    if len(all_results) == 2:
        print(f"\nComparative Analysis:")
        kfs_f1 = all_results['kfs']['evaluation']['overall_performance']['macro_f1']
        af_f1 = all_results['af']['evaluation']['overall_performance']['macro_f1']
        delta_f1 = kfs_f1 - af_f1

        print(f"  KFS Macro F1: {kfs_f1:.4f}")
        print(f"  AF Macro F1:  {af_f1:.4f}")
        print(f"  Delta (KFS - AF): {delta_f1:+.4f}")

        if delta_f1 > 0:
            improvement_pct = (delta_f1 / af_f1) * 100
            print(f"  KFS improves by {improvement_pct:.1f}% over AF")
        else:
            degradation_pct = (abs(delta_f1) / kfs_f1) * 100
            print(f"  KFS degrades by {degradation_pct:.1f}% from AF")

    print(f"\nAll evaluation files saved in: {results_dir}")
else:
    print("\nWARNING: No evaluations completed successfully")
    print("Please check:")
    print("  1. Model checkpoint exists")
    print("  2. Test dataset paths are correct")
    print("  3. Metadata files are present")

if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print(f"\nEvaluation strategy used: {TEST_VERSION_TO_EVALUATE.upper()}")
print("Next: Cell 4 - Generate confusion matrices")
print("Evaluation completed successfully!")

CASME II MobileNetV3-Small Evaluation Framework
Evaluation Configuration: BOTH
  Strategy: Sequential KFS → AF
  Expected: ~84 samples (KFS) + ~28 samples (AF)

CASME II MobileNetV3 Evaluation Configuration:
  Model: MobileNetV3Small_CASME2_MFS_Baseline
  Methodology: M1
  Input resolution: 640x480 RGB
  Classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']

SEQUENTIAL EVALUATION: KFS → AF
Loading trained CASME II MobileNetV3-Small model from: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/08_01_mobilenet_casme2_mfs/casme2_mobilenet_mfs_best_f1.pth
Checkpoint loaded using: standard
MobileNetV3-Small feature dimension: 1024
MobileNet CASME II: 1024 -> 512 -> 128 -> 7
Model state loaded with strict=True
Model loaded successfully:
  Best validation F1: 0.2892
  Best validation accuracy: 0.4231
  Best epoch: 6

Versions to evaluate: ['KFS', 'AF']

EVALUATING: KFS

Test Dataset Configuration:
  Version: kfs
  Phase: Pha

Loading test images to RAM: 100%|██████████| 84/84 [00:01<00:00, 53.98it/s]


Test RAM caching completed: 84/84 images, ~0.31GB
Running CASME II MobileNetV3 inference on KFS test set...


CASME II Inference (KFS): 100%|██████████| 6/6 [00:14<00:00,  2.38s/it]


CASME II inference completed: 84 samples in 14.30s
Predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
True classes in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Calculating comprehensive metrics for KFS...
Test set contains labels: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Model predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Macro F1 (available classes): 0.3504
Analyzing wrong predictions for KFS...
Evaluation results saved:
  Main results: casme2_mobilenet_mfs_evaluation_results_kfs.json
  Wrong predictions: casme2_mobilenet_mfs_wrong_predictions_kfs.json

CASME II MOBILENETV3-SMALL KFS EVALUATION RESULTS
Test Dataset: Key-frames (onset, apex, offset) (KFS)
Methodology: M1
Input Resolution: 640x480 RGB

Overall Performance (Macro - Available Classes):
  Accuracy:  0.4881
  Precision: 0.3586
  Recall:    0.3533
  F1 Score:  0.3504
  AUC:  

Loading test images to RAM: 100%|██████████| 28/28 [00:00<00:00, 47.30it/s]


Test RAM caching completed: 28/28 images, ~0.10GB
Running CASME II MobileNetV3 inference on AF test set...


CASME II Inference (AF): 100%|██████████| 2/2 [00:14<00:00,  7.05s/it]

CASME II inference completed: 28 samples in 14.09s
Predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise']
True classes in test: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Calculating comprehensive metrics for AF...
Test set contains labels: ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness']
Model predicted classes: ['others', 'disgust', 'happiness', 'repression', 'surprise']
Macro F1 (available classes): 0.3016
Analyzing wrong predictions for AF...
Evaluation results saved:
  Main results: casme2_mobilenet_mfs_evaluation_results_af.json
  Wrong predictions: casme2_mobilenet_mfs_wrong_predictions_af.json

CASME II MOBILENETV3-SMALL AF EVALUATION RESULTS
Test Dataset: Apex-only frames (AF)
Methodology: M1
Input Resolution: 640x480 RGB

Overall Performance (Macro - Available Classes):
  Accuracy:  0.4643
  Precision: 0.2992
  Recall:    0.3163
  F1 Score:  0.3016
  AUC:       0.6492

Per-Class Performance:
  others




In [4]:
# @title Cell 4: CASME II MobileNetV3-Small Confusion Matrix Generation

# File: 08_01_MobileNet_CASME2_MFS_Cell4.py
# Location: experiments/08_01_MobileNet_CASME2-MFS.ipynb
# Purpose: Generate professional confusion matrix visualizations for KFS and AF test sets

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II MobileNetV3-Small Confusion Matrix Generation")
print("=" * 60)

PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/08_01_mobilenet_casme2_mfs"

def find_evaluation_json_files_casme2(results_path):
    json_files = {}

    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        for version in ['kfs', 'af']:
            eval_pattern = f"{eval_dir}/casme2_mobilenet_mfs_evaluation_results_{version}.json"
            eval_files = glob.glob(eval_pattern)

            if eval_files:
                json_files[f'main_{version}'] = eval_files[0]
                print(f"Found {version.upper()} evaluation file: {os.path.basename(eval_files[0])}")

        if not json_files:
            print(f"WARNING: No evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (evaluation) has been executed first!")
    else:
        print(f"ERROR: Evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results_casme2(json_path):
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1_casme2(per_class_performance):
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy_casme2(confusion_matrix):
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []

    classes_with_samples = []
    for i in range(n_classes):
        if cm[i, :].sum() > 0:
            classes_with_samples.append(i)

    for i in classes_with_samples:
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - tp - fn - fp

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color_casme2(color_value, threshold=0.5):
    return 'white' if color_value > threshold else 'black'

def create_confusion_matrix_plot_casme2(data, output_path, test_version):
    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']
    test_desc = meta.get('test_description', test_version)
    methodology = meta.get('methodology', 'M1')
    input_res = meta.get('input_resolution', '640x480 RGB')

    print(f"Processing confusion matrix for {test_version.upper()}")
    print(f"Confusion matrix shape: {cm.shape}")

    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1_casme2(per_class)
    balanced_acc = calculate_balanced_accuracy_casme2(cm)

    print(f"Metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Acc: {accuracy:.4f}, Balanced Acc: {balanced_acc:.4f}")

    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    fig, ax = plt.subplots(figsize=(12, 10))

    cmap = 'Blues'

    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"

            cell_value = cm_pct[i, j]
            text_color = determine_text_color_casme2(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    missing_classes = meta.get('missing_classes', [])
    note_text = f"Test: {test_desc} ({test_version.upper()})\n{methodology} | {input_res}"
    if missing_classes:
        note_text += f"\nMissing: {', '.join(missing_classes)}"

    ax.text(0.02, 0.98, note_text, transform=ax.transAxes, fontsize=9,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    title = f"CASME II MobileNetV3-Small Micro-Expression Recognition - {test_version.upper()}\n"
    title += f"Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Acc: {accuracy:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes,
        'test_version': test_version
    }

json_files = find_evaluation_json_files_casme2(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (evaluation) has been executed first!")
else:
    print(f"\nFound {len([k for k in json_files.keys() if k.startswith('main_')])} evaluation result(s)")

output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

results_summary = {}
generated_files = []

for version in ['kfs', 'af']:
    main_key = f'main_{version}'

    if main_key in json_files:
        print(f"\n{'='*60}")
        print(f"Processing {version.upper()} Confusion Matrix")
        print(f"{'='*60}")

        eval_data = load_evaluation_results_casme2(json_files[main_key])

        if eval_data is not None:
            try:
                cm_output_path = os.path.join(output_dir, f"confusion_matrix_CASME2_MobileNet_MFS_{version.upper()}.png")
                metrics = create_confusion_matrix_plot_casme2(eval_data, cm_output_path, version)
                generated_files.append(cm_output_path)

                results_summary[version] = metrics

                print(f"SUCCESS: {version.upper()} confusion matrix generated")

            except Exception as e:
                print(f"ERROR: Failed to generate {version.upper()} confusion matrix: {str(e)}")
                import traceback
                traceback.print_exc()
        else:
            print(f"ERROR: Could not load {version.upper()} evaluation data")

if generated_files:
    print(f"\n" + "=" * 60)
    print("CASME II MOBILENETV3-SMALL CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 60)

    print(f"\nGenerated confusion matrix files:")
    for file_path in generated_files:
        filename = os.path.basename(file_path)
        print(f"  {filename}")

    print(f"\nPerformance Summary:")
    for version in ['kfs', 'af']:
        if version in results_summary:
            metrics = results_summary[version]
            print(f"\n{version.upper()}:")
            print(f"  Macro F1:       {metrics['macro_f1']:.4f}")
            print(f"  Weighted F1:    {metrics['weighted_f1']:.4f}")
            print(f"  Accuracy:       {metrics['accuracy']:.4f}")
            print(f"  Balanced Acc:   {metrics['balanced_accuracy']:.4f}")

            if metrics['missing_classes']:
                print(f"  Missing classes: {len(metrics['missing_classes'])}")

    if len(results_summary) == 2:
        print(f"\nComparative Analysis:")
        kfs_f1 = results_summary['kfs']['macro_f1']
        af_f1 = results_summary['af']['macro_f1']
        delta_f1 = kfs_f1 - af_f1

        print(f"  KFS vs AF (Macro F1): {kfs_f1:.4f} vs {af_f1:.4f}")
        print(f"  Delta (KFS - AF): {delta_f1:+.4f}")

        if delta_f1 > 0:
            improvement_pct = (delta_f1 / af_f1) * 100
            print(f"  KFS improves by {improvement_pct:.1f}% over AF")
        else:
            degradation_pct = (abs(delta_f1) / kfs_f1) * 100
            print(f"  KFS degrades by {degradation_pct:.1f}% from AF")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No confusion matrices were generated")
    print("Please check:")
    print("1. Cell 3 evaluation results exist")
    print("2. JSON file structure is correct")
    print("3. No file permission issues")

print("\nCell 4 completed - CASME II MobileNetV3-Small confusion matrix analysis generated")

CASME II MobileNetV3-Small Confusion Matrix Generation
Found KFS evaluation file: casme2_mobilenet_mfs_evaluation_results_kfs.json
Found AF evaluation file: casme2_mobilenet_mfs_evaluation_results_af.json

Found 2 evaluation result(s)

Processing KFS Confusion Matrix
Successfully loaded: casme2_mobilenet_mfs_evaluation_results_kfs.json
Processing confusion matrix for KFS
Confusion matrix shape: (7, 7)
Metrics - Macro F1: 0.3504, Weighted F1: 0.4703, Acc: 0.4881, Balanced Acc: 0.6204
Confusion matrix saved to: confusion_matrix_CASME2_MobileNet_MFS_KFS.png
SUCCESS: KFS confusion matrix generated

Processing AF Confusion Matrix
Successfully loaded: casme2_mobilenet_mfs_evaluation_results_af.json
Processing confusion matrix for AF
Confusion matrix shape: (7, 7)
Metrics - Macro F1: 0.3016, Weighted F1: 0.4297, Acc: 0.4643, Balanced Acc: 0.5981
Confusion matrix saved to: confusion_matrix_CASME2_MobileNet_MFS_AF.png
SUCCESS: AF confusion matrix generated

CASME II MOBILENETV3-SMALL CONFUSION 