In [1]:
# @title Cell 1: CASME II MobileNetV3-Small M2 MFS-PREP Infrastructure Configuration

# File: 09_01_MobileNet_CASME2_MFS_PREP_Cell1.py
# Location: experiments/09_01_MobileNet_CASME2-MFS-PREP.ipynb
# Purpose: MobileNetV3-Small for CASME II micro-expression recognition with M2 preprocessed methodology

from google.colab import drive
print("=" * 60)
print("CASME II CNN BASELINE - MobileNetV3-Small M2 MFS-PREP")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import timm
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/preprocessed_v9"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/09_01_mobilenet_casme2_mfs_prep"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/09_01_mobilenet_casme2_mfs_prep"

PREPROCESSING_SUMMARY = f"{DATASET_ROOT}/preprocessing_summary.json"

print("CASME II MobileNetV3-Small M2 MFS-PREP - Infrastructure Configuration")
print("=" * 60)

if not os.path.exists(PREPROCESSING_SUMMARY):
    raise FileNotFoundError(f"v9 preprocessing summary not found: {PREPROCESSING_SUMMARY}")

print("Loading CASME II v9 preprocessing metadata...")
with open(PREPROCESSING_SUMMARY, 'r') as f:
    preprocessing_info = json.load(f)

print(f"Dataset variant: {preprocessing_info['variant']}")
print(f"Processing date: {preprocessing_info['processing_date']}")
print(f"Preprocessing method: {preprocessing_info['preprocessing_method']}")
print(f"Total images processed: {preprocessing_info['total_processed']}")
print(f"Face detection rate: {preprocessing_info['face_detection_stats']['detection_rate']:.2%}")

preproc_params = preprocessing_info['preprocessing_parameters']
print(f"Target size: {preproc_params['target_size']}x{preproc_params['target_size']}px")
print(f"BBox expansion: {preproc_params['bbox_expansion']}px")
print(f"Image format: Grayscale (1 channel)")

print(f"\nDataset split information:")
print(f"  Train: {preprocessing_info['splits']['train']['total_images']} frames")
print(f"  Validation: {preprocessing_info['splits']['val']['total_images']} frames")
print(f"  Test: {preprocessing_info['splits']['test']['total_images']} frames")

USE_FOCAL_LOSS = True
FOCAL_LOSS_GAMMA = 2.5

CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.25, 1.76, 1.91, 1.99, 3.76, 7.04]
FOCAL_LOSS_ALPHA_WEIGHTS = [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]

MOBILENET_MODEL_NAME = 'mobilenetv3_small_100'
USE_PURE_GRAYSCALE = True

print("\n" + "=" * 50)
print("EXPERIMENT CONFIGURATION - CNN M2 MFS-PREP")
print("=" * 50)
print(f"Model: MobileNetV3-Small (TIMM)")
print(f"Methodology: M2 (Face-Aware Preprocessing)")
print(f"Input Resolution: 224x224 Pure Grayscale (1 channel)")
print(f"Training Strategy: From Scratch (No Pretrained Weights)")
print(f"Preprocessing: Face detection + crop + grayscale")
print(f"Loss Function: Focal Loss")
print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
print(f"  Alpha Weights: {FOCAL_LOSS_ALPHA_WEIGHTS}")
print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
print("=" * 50)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

if 'A100' in gpu_name:
    BATCH_SIZE = 24
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("A100: Optimized batch size for 224x224 input")
elif 'L4' in gpu_name:
    BATCH_SIZE = 16
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("L4: Balanced performance configuration")
else:
    BATCH_SIZE = 8
    NUM_WORKERS = 8
    print("Default GPU: Conservative settings")

RAM_PRELOAD_WORKERS = 32
print(f"RAM preload workers: {RAM_PRELOAD_WORKERS}")

CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

print("\nLoading v9 class distribution...")
train_dist = preprocessing_info['splits']['train']['emotion_distribution']
val_dist = preprocessing_info['splits']['val']['emotion_distribution']
test_dist = preprocessing_info['splits']['test']['emotion_distribution']

def emotion_dist_to_list(emotion_dict, class_names):
    return [emotion_dict.get(cls, 0) for cls in class_names]

train_dist_list = emotion_dist_to_list(train_dist, CASME2_CLASSES)
val_dist_list = emotion_dist_to_list(val_dist, CASME2_CLASSES)
test_dist_list = emotion_dist_to_list(test_dist, CASME2_CLASSES)

print(f"\nv9 Train distribution: {train_dist_list}")
print(f"v9 Val distribution: {val_dist_list}")
print(f"v9 Test distribution: {test_dist_list}")

if USE_FOCAL_LOSS:
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

CASME2_MOBILENET_CONFIG = {
    'model_name': MOBILENET_MODEL_NAME,
    'input_size': (224, 224),
    'num_classes': 7,
    'dropout_rate': 0.3,

    'learning_rate': 5e-5,
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'device': device,

    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_factor': 0.5,
    'scheduler_patience': 5,
    'scheduler_min_lr': 1e-6,
    'scheduler_monitor': 'val_f1_macro',

    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS,
    'class_weights': class_weights,

    'use_macro_avg': True,
    'early_stopping': False,
    'save_best_f1': True,
    'save_strategy': 'best_only',

    'dataset_version': 'v9',
    'methodology': 'M2',
    'preprocessing': 'face_aware_preprocessing',
    'frame_strategy': 'multi_frame_sampling',
    'train_augmentation': 'frame_level_independent',
    'image_format': 'pure_grayscale_1channel',
    'use_pure_grayscale': USE_PURE_GRAYSCALE,
    'use_pretrained_weights': False,
    'training_strategy': 'from_scratch',
    'preprocessing_details': {
        'face_detection': True,
        'crop_method': 'face_bbox_expansion',
        'target_size': preproc_params['target_size'],
        'bbox_expansion': preproc_params['bbox_expansion'],
        'grayscale_conversion': True,
        'input_channels': 1
    }
}

print(f"\nMobileNetV3 Configuration Summary:")
print(f"  Model: {CASME2_MOBILENET_CONFIG['model_name']}")
print(f"  Input size: {CASME2_MOBILENET_CONFIG['input_size'][0]}x{CASME2_MOBILENET_CONFIG['input_size'][1]} Pure Grayscale (1ch)")
print(f"  Methodology: {CASME2_MOBILENET_CONFIG['methodology']} (face-aware preprocessing)")
print(f"  Training: From Scratch (No Pretrained Weights)")
print(f"  Learning rate: {CASME2_MOBILENET_CONFIG['learning_rate']}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Dataset version: {CASME2_MOBILENET_CONFIG['dataset_version']}")
print(f"  Frame strategy: {CASME2_MOBILENET_CONFIG['frame_strategy']}")

class OptimizedFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()

        if alpha is not None:
            if isinstance(alpha, list):
                self.alpha = torch.tensor(alpha, dtype=torch.float32)
            else:
                self.alpha = alpha

            alpha_sum = self.alpha.sum().item()
            if abs(alpha_sum - 1.0) > 0.01:
                print(f"Warning: Alpha weights sum to {alpha_sum:.3f}, expected 1.0")
        else:
            self.alpha = None

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)

        if self.alpha is not None:
            if self.alpha.device != targets.device:
                self.alpha = self.alpha.to(targets.device)
            alpha_t = self.alpha.gather(0, targets)
        else:
            alpha_t = 1.0

        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class MobileNetCASME2Baseline(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.2, in_channels=1):
        super(MobileNetCASME2Baseline, self).__init__()

        self.mobilenet = timm.create_model(
            MOBILENET_MODEL_NAME,
            pretrained=False,
            num_classes=0,
            global_pool='avg',
            in_chans=in_channels
        )

        for param in self.mobilenet.parameters():
            param.requires_grad = True

        with torch.no_grad():
            test_input = torch.randn(1, in_channels, 224, 224)
            test_output = self.mobilenet(test_input)
            self.mobilenet_feature_dim = test_output.shape[1]

        print(f"MobileNetV3-Small feature dimension: {self.mobilenet_feature_dim}")
        print(f"Training from scratch with {in_channels}-channel input")

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.mobilenet_feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
        )

        self.classifier = nn.Linear(128, num_classes)

        print(f"MobileNet CASME II: {self.mobilenet_feature_dim} -> 512 -> 128 -> {num_classes}")
        print(f"Architecture: Pure grayscale (1ch) from scratch")

    def forward(self, x):
        features = self.mobilenet(x)
        processed_features = self.classifier_layers(features)
        output = self.classifier(processed_features)
        return output

def create_optimizer_scheduler_casme2(model, config):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

print("\nSetting up transforms for M2 methodology (224x224 pure grayscale)...")

mobilenet_transform_train = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

mobilenet_transform_val = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

print("M2 transforms configured: 224x224 pure grayscale (1 channel) with standardization")

class CASME2Dataset(Dataset):
    def __init__(self, dataset_root, split, transform=None):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []
        self.filenames = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading {split} dataset from {split_path}...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        loaded_count = 0

        for filename in sorted(all_files):
            emotion_found = None
            name_without_ext = filename.rsplit('.', 1)[0]

            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                loaded_count += 1

        print(f"Loaded {len(self.images)} samples for {split} split")
        self._print_distribution()

    def _print_distribution(self):
        if len(self.labels) == 0:
            print("No samples to display distribution")
            return

        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path)

        if image.mode != 'L':
            image = image.convert('L')

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx]

os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

TRAIN_PATH = DATASET_ROOT
VAL_PATH = DATASET_ROOT
TEST_PATH = DATASET_ROOT

print(f"\nDataset paths:")
print(f"Root: {DATASET_ROOT}")
print(f"  Train: {TRAIN_PATH}/train")
print(f"  Validation: {VAL_PATH}/val")
print(f"  Test: {TEST_PATH}/test")

print("\nMobileNetV3 CASME II architecture validation...")

try:
    test_model = MobileNetCASME2Baseline(num_classes=7, dropout_rate=0.2, in_channels=1).to(device)
    test_input = torch.randn(1, 1, 224, 224).to(device)
    test_output = test_model(test_input)

    print(f"Validation successful: Output shape {test_output.shape}")
    print(f"Expected output shape: [1, 7] for CASME II 7 classes")
    print(f"Pure grayscale (1 channel) architecture validated")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")

def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': mobilenet_transform_train,
    'transform_val': mobilenet_transform_val,
    'mobilenet_config': CASME2_MOBILENET_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'dataset_root': DATASET_ROOT,
    'train_path': TRAIN_PATH,
    'val_path': VAL_PATH,
    'test_path': TEST_PATH,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

print("\n" + "=" * 60)
print("CASME II MOBILENETV3-SMALL M2 MFS-PREP CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: MobileNetV3-Small")
print(f"  Parameters: ~2.5M")
print(f"  Input Resolution: 224x224 Pure Grayscale (1 channel)")
print(f"  Methodology: M2 (Face-aware preprocessing)")
print(f"  Training Strategy: From Scratch (No Pretrained Weights)")
print(f"  First Conv Layer: Modified to 1-channel input")

print(f"\nDataset Configuration:")
print(f"  Version: {CASME2_MOBILENET_CONFIG['dataset_version']}")
print(f"  Frame strategy: {CASME2_MOBILENET_CONFIG['frame_strategy']}")
print(f"  Train augmentation: {CASME2_MOBILENET_CONFIG['train_augmentation']}")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Train samples: {preprocessing_info['splits']['train']['total_images']} frames")

print("\nNext: Cell 2 - Dataset Loading and Training Pipeline")

CASME II CNN BASELINE - MobileNetV3-Small M2 MFS-PREP

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...
CASME II MobileNetV3-Small M2 MFS-PREP - Infrastructure Configuration
Loading CASME II v9 preprocessing metadata...
Dataset variant: MFS
Processing date: 2025-10-19T08:20:12.098301
Preprocessing method: face_bbox_expansion_all_directions
Total images processed: 2774
Face detection rate: 100.00%
Target size: 224x224px
BBox expansion: 20px
Image format: Grayscale (1 channel)

Dataset split information:
  Train: 2613 frames
  Validation: 78 frames
  Test: 83 frames

EXPERIMENT CONFIGURATION - CNN M2 MFS-PREP
Model: MobileNetV3-Small (TIMM)
Methodology: M2 (Face-Aware Preprocessing)
Input Resolution: 224x224 Pure Grayscale (1 channel)
Training Strategy: From Scratch (No Pretrained Weights)
Preprocessing: Face detection + crop + grayscale
Loss Function: Focal Loss
  Gamma: 2.5
  Alpha Weights: [0.053, 0.067, 0.09

In [2]:
# @title Cell 2: CASME II MobileNetV3-Small M2 MFS-PREP Training Pipeline

# File: 09_01_MobileNet_CASME2_MFS_PREP_Cell2.py
# Location: experiments/09_01_MobileNet_CASME2-MFS-PREP.ipynb
# Purpose: Training pipeline for pure grayscale MobileNetV3-Small from scratch

import os
import time
import json
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from concurrent.futures import ThreadPoolExecutor
import multiprocessing as mp
import shutil
import tempfile

print("CASME II MobileNetV3-Small M2 MFS-PREP Training Pipeline")
print("=" * 70)
print(f"Model: MobileNetV3-Small")
print(f"Methodology: M2 (Face-aware preprocessing)")
print(f"Input: 224x224 Pure Grayscale (1 channel)")
print(f"Training: From Scratch (No Pretrained Weights)")
print(f"Loss Function: Focal Loss")
print(f"  Gamma: {CASME2_MOBILENET_CONFIG['focal_loss_gamma']}")
print(f"  Per-class Alpha: {CASME2_MOBILENET_CONFIG['focal_loss_alpha_weights']}")
print(f"  Alpha Sum: {sum(CASME2_MOBILENET_CONFIG['focal_loss_alpha_weights']):.3f}")
print(f"Dataset Version: {CASME2_MOBILENET_CONFIG['dataset_version']}")
print(f"Frame Strategy: {CASME2_MOBILENET_CONFIG['frame_strategy']}")
print(f"Training epochs: {CASME2_MOBILENET_CONFIG['num_epochs']}")
print(f"Scheduler patience: {CASME2_MOBILENET_CONFIG['scheduler_patience']}")

class CASME2DatasetTraining(Dataset):
    def __init__(self, dataset_root, split, transform=None, use_ram_cache=True):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.filenames = []
        self.cached_images = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading CASME II {split} dataset for training...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]

        for filename in sorted(all_files):
            emotion_found = None
            name_without_ext = filename.rsplit('.', 1)[0]

            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)

        print(f"Loaded {len(self.images)} CASME II {split} samples")
        self._print_distribution()

        if self.use_ram_cache:
            self._preload_to_ram()

    def _print_distribution(self):
        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def _preload_to_ram(self):
        print(f"Preloading {len(self.images)} {self.split} images to RAM with {RAM_PRELOAD_WORKERS} workers...")

        self.cached_images = [None] * len(self.images)
        valid_images = 0

        def load_single_image(idx, img_path):
            try:
                image = Image.open(img_path)
                if image.mode != 'L':
                    image = image.convert('L')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
                return idx, image, True
            except Exception as e:
                return idx, Image.new('L', (224, 224), 128), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc=f"Loading {self.split} to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image
                if success:
                    valid_images += 1

        ram_usage_gb = len(self.cached_images) * 224 * 224 * 1 * 4 / 1e9
        print(f"{self.split.upper()} RAM caching completed: {valid_images}/{len(self.images)} images, ~{ram_usage_gb:.2f}GB")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx])
                if image.mode != 'L':
                    image = image.convert('L')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
            except:
                image = Image.new('L', (224, 224), 128)

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx]

def calculate_metrics_safe_robust(outputs, labels, class_names, average='macro'):
    try:
        if outputs.size(0) != labels.size(0):
            raise ValueError(f"Batch size mismatch: outputs {outputs.size(0)} vs labels {labels.size(0)}")

        if isinstance(outputs, torch.Tensor):
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
        else:
            predictions = np.array(outputs)

        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()
        else:
            labels = np.array(labels)

        accuracy = accuracy_score(labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            labels, predictions,
            average=average,
            zero_division=0,
            labels=list(range(len(class_names)))
        )

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1_score': float(f1)
        }
    except Exception as e:
        print(f"Warning: Metrics calculation error: {e}")
        return {
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1_score': 0.0
        }

def train_epoch(model, dataloader, criterion, optimizer, device, epoch, total_epochs):
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    progress_bar = tqdm(dataloader, desc=f"CASME II Training Epoch {epoch+1}/{total_epochs}")

    for batch_idx, (images, labels, filenames) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        model_output = model(images)

        if isinstance(model_output, (tuple, list)):
            outputs = model_output[0]
        elif isinstance(model_output, dict):
            outputs = model_output.get('logits', model_output.get('prediction', model_output))
        else:
            outputs = model_output

        if outputs.dim() != 2 or outputs.size(1) != 7:
            raise ValueError(f"Invalid CASME II output shape: {outputs.shape}, expected [batch_size, 7]")

        loss = criterion(outputs, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), CASME2_MOBILENET_CONFIG['gradient_clip'])

        optimizer.step()
        running_loss += loss.item()

        all_outputs.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())

        if batch_idx % 5 == 0:
            avg_loss = running_loss / (batch_idx + 1)
            current_lr = optimizer.param_groups[0]['lr']
            progress_bar.set_postfix({
                'Loss': f'{avg_loss:.4f}',
                'LR': f'{current_lr:.2e}'
            })

    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Training metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics

def validate_epoch(model, dataloader, criterion, device, epoch, total_epochs):
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []
    all_filenames = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc=f"CASME II Validation Epoch {epoch+1}/{total_epochs}")

        for batch_idx, (images, labels, filenames) in enumerate(progress_bar):
            images, labels = images.to(device), labels.to(device)

            model_output = model(images)

            if isinstance(model_output, (tuple, list)):
                outputs = model_output[0]
            elif isinstance(model_output, dict):
                outputs = model_output.get('logits', model_output.get('prediction', model_output))
            else:
                outputs = model_output

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            all_outputs.append(outputs.detach().cpu())
            all_labels.append(labels.detach().cpu())
            all_filenames.extend(filenames)

            if batch_idx % 3 == 0:
                avg_loss = running_loss / (batch_idx + 1)
                progress_bar.set_postfix({'Val Loss': f'{avg_loss:.4f}'})

    try:
        epoch_outputs = torch.cat(all_outputs, dim=0)
        epoch_labels = torch.cat(all_labels, dim=0)
        metrics = calculate_metrics_safe_robust(epoch_outputs, epoch_labels, CASME2_CLASSES, average='macro')
    except Exception as e:
        print(f"Warning: Validation metrics calculation failed: {e}")
        metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}

    avg_loss = running_loss / len(dataloader)
    return avg_loss, metrics, all_filenames

def save_checkpoint_robust(model, optimizer, scheduler, epoch, train_metrics, val_metrics,
                         checkpoint_dir, best_metrics, config, max_retries=3):
    def make_serializable_cpu(obj):
        if isinstance(obj, torch.Tensor):
            cpu_obj = obj.detach().cpu()
            return cpu_obj.item() if cpu_obj.numel() == 1 else cpu_obj.tolist()
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, dict):
            return {k: make_serializable_cpu(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [make_serializable_cpu(item) for item in obj]
        else:
            return obj

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_metrics': make_serializable_cpu(train_metrics),
        'val_metrics': make_serializable_cpu(val_metrics),
        'casme2_config': make_serializable_cpu(config),
        'best_f1': float(best_metrics['f1']),
        'best_loss': float(best_metrics['loss']),
        'best_acc': float(best_metrics['accuracy']),
        'class_names': CASME2_CLASSES,
        'num_classes': 7
    }

    final_path = f"{checkpoint_dir}/casme2_mobilenet_mfs_prep_best_f1.pth"

    for attempt in range(max_retries):
        try:
            temp_fd, temp_path = tempfile.mkstemp(dir=checkpoint_dir, suffix='.pth.tmp')
            os.close(temp_fd)

            print(f"Attempt {attempt + 1}: Saving checkpoint to temporary file...")
            torch.save(checkpoint, temp_path)

            print("Validating checkpoint integrity...")
            validation_checkpoint = torch.load(temp_path, map_location='cpu')

            required_keys = ['model_state_dict', 'epoch', 'best_f1', 'num_classes']
            for key in required_keys:
                if key not in validation_checkpoint:
                    raise ValueError(f"Checkpoint validation failed: missing key '{key}'")

            if validation_checkpoint['epoch'] != epoch:
                raise ValueError(f"Checkpoint epoch mismatch: saved {epoch}, loaded {validation_checkpoint['epoch']}")

            print("Checkpoint validation passed")

            print(f"Moving validated checkpoint to final location...")
            shutil.move(temp_path, final_path)

            print(f"Checkpoint saved and validated successfully: {os.path.basename(final_path)}")
            print(f"  Epoch: {epoch + 1}")
            print(f"  Val F1: {best_metrics['f1']:.4f}")
            print(f"  Val Loss: {best_metrics['loss']:.4f}")
            print(f"  Val Acc: {best_metrics['accuracy']:.4f}")

            return final_path

        except Exception as e:
            print(f"Checkpoint save attempt {attempt + 1}/{max_retries} failed: {e}")

            if os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except:
                    pass

            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                print(f"Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print(f"All {max_retries} checkpoint save attempts failed")
                return None

    return None

def safe_json_serialize(obj):
    if isinstance(obj, torch.Tensor):
        return obj.cpu().item() if obj.numel() == 1 else obj.cpu().numpy().tolist()
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float32, np.float64)):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: safe_json_serialize(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [safe_json_serialize(item) for item in obj]
    elif hasattr(obj, '__dict__'):
        return safe_json_serialize(obj.__dict__)
    else:
        try:
            return float(obj) if isinstance(obj, (int, float)) else str(obj)
        except:
            return str(obj)

print("\nCreating CASME II MobileNetV3-Small M2 training datasets...")

train_dataset = CASME2DatasetTraining(
    dataset_root=GLOBAL_CONFIG_CASME2['dataset_root'],
    split='train',
    transform=GLOBAL_CONFIG_CASME2['transform_train'],
    use_ram_cache=True
)

val_dataset = CASME2DatasetTraining(
    dataset_root=GLOBAL_CONFIG_CASME2['dataset_root'],
    split='val',
    transform=GLOBAL_CONFIG_CASME2['transform_val'],
    use_ram_cache=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=CASME2_MOBILENET_CONFIG['batch_size'],
    shuffle=True,
    num_workers=CASME2_MOBILENET_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CASME2_MOBILENET_CONFIG['batch_size'],
    shuffle=False,
    num_workers=CASME2_MOBILENET_CONFIG['num_workers'],
    pin_memory=True,
    prefetch_factor=2
)

print(f"Training batches: {len(train_loader)} (samples: {len(train_dataset)})")
print(f"Validation batches: {len(val_loader)} (samples: {len(val_dataset)})")

print("\nInitializing CASME II MobileNetV3-Small model...")
model = MobileNetCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_MOBILENET_CONFIG['dropout_rate'],
    in_channels=1
).to(GLOBAL_CONFIG_CASME2['device'])

if CASME2_MOBILENET_CONFIG['use_focal_loss']:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=True,
        alpha_weights=CASME2_MOBILENET_CONFIG['focal_loss_alpha_weights'],
        gamma=CASME2_MOBILENET_CONFIG['focal_loss_gamma']
    )
else:
    criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
        weights=GLOBAL_CONFIG_CASME2['class_weights'],
        use_focal_loss=False,
        alpha_weights=None,
        gamma=2.0
    )

optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_MOBILENET_CONFIG
)

print(f"Optimizer: AdamW (LR={CASME2_MOBILENET_CONFIG['learning_rate']})")
print(f"Scheduler: ReduceLROnPlateau (patience={CASME2_MOBILENET_CONFIG['scheduler_patience']})")
print(f"Criterion: Optimized Focal Loss")
print(f"Training: From Scratch (No Pretrained Weights)")

training_history = {
    'train_loss': [],
    'val_loss': [],
    'train_f1': [],
    'val_f1': [],
    'train_acc': [],
    'val_acc': [],
    'learning_rate': [],
    'epoch_time': []
}

best_metrics = {
    'f1': 0.0,
    'loss': float('inf'),
    'accuracy': 0.0,
    'epoch': 0
}

print("\nStarting CASME II MobileNetV3-Small M2 training...")
print(f"Training configuration: {CASME2_MOBILENET_CONFIG['num_epochs']} epochs")
print(f"Input resolution: 224x224 Pure Grayscale (1 channel)")
print(f"Training strategy: From Scratch")
print("=" * 70)

start_time = time.time()

for epoch in range(CASME2_MOBILENET_CONFIG['num_epochs']):
    epoch_start_time = time.time()
    print(f"\nEpoch {epoch+1}/{CASME2_MOBILENET_CONFIG['num_epochs']}")

    train_loss, train_metrics = train_epoch(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_MOBILENET_CONFIG['num_epochs']
    )

    val_loss, val_metrics, val_filenames = validate_epoch(
        model, val_loader, criterion,
        GLOBAL_CONFIG_CASME2['device'], epoch, CASME2_MOBILENET_CONFIG['num_epochs']
    )

    if scheduler:
        scheduler.step(val_metrics['f1_score'])

    epoch_time = time.time() - epoch_start_time
    current_lr = optimizer.param_groups[0]['lr']

    training_history['train_loss'].append(float(train_loss))
    training_history['val_loss'].append(float(val_loss))
    training_history['train_f1'].append(float(train_metrics['f1_score']))
    training_history['val_f1'].append(float(val_metrics['f1_score']))
    training_history['train_acc'].append(float(train_metrics['accuracy']))
    training_history['val_acc'].append(float(val_metrics['accuracy']))
    training_history['learning_rate'].append(float(current_lr))
    training_history['epoch_time'].append(float(epoch_time))

    print(f"Train - Loss: {train_loss:.4f}, F1: {train_metrics['f1_score']:.4f}, Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, F1: {val_metrics['f1_score']:.4f}, Acc: {val_metrics['accuracy']:.4f}")
    print(f"Time  - Epoch: {epoch_time:.1f}s, LR: {current_lr:.2e}")

    save_model = False
    improvement_reason = ""

    if val_metrics['f1_score'] > best_metrics['f1']:
        save_model = True
        improvement_reason = "Higher F1"
    elif val_metrics['f1_score'] == best_metrics['f1']:
        if val_loss < best_metrics['loss']:
            save_model = True
            improvement_reason = "Same F1, Lower Loss"
        elif val_loss == best_metrics['loss'] and val_metrics['accuracy'] > best_metrics['accuracy']:
            save_model = True
            improvement_reason = "Same F1&Loss, Higher Accuracy"

    if save_model:
        best_metrics['f1'] = val_metrics['f1_score']
        best_metrics['loss'] = val_loss
        best_metrics['accuracy'] = val_metrics['accuracy']
        best_metrics['epoch'] = epoch + 1

        best_model_path = save_checkpoint_robust(
            model, optimizer, scheduler, epoch,
            train_metrics, val_metrics, GLOBAL_CONFIG_CASME2['checkpoint_root'],
            best_metrics, CASME2_MOBILENET_CONFIG
        )

        if best_model_path:
            print(f"New best model: {improvement_reason} - F1: {best_metrics['f1']:.4f}")
        else:
            print(f"Warning: Failed to save checkpoint despite improvement")

    elapsed_time = time.time() - start_time
    estimated_total = (elapsed_time / (epoch + 1)) * CASME2_MOBILENET_CONFIG['num_epochs']
    remaining_time = estimated_total - elapsed_time
    progress_pct = ((epoch + 1) / CASME2_MOBILENET_CONFIG['num_epochs']) * 100

    print(f"Progress: {progress_pct:.1f}% | Best F1: {best_metrics['f1']:.4f} | ETA: {remaining_time/60:.1f}min")

total_time = time.time() - start_time
actual_epochs = CASME2_MOBILENET_CONFIG['num_epochs']

print("\n" + "=" * 70)
print("CASME II MOBILENETV3-SMALL M2 MFS-PREP TRAINING COMPLETED")
print("=" * 70)
print(f"Training time: {total_time/60:.1f} minutes")
print(f"Epochs completed: {actual_epochs}")
print(f"Best validation F1: {best_metrics['f1']:.4f} (epoch {best_metrics['epoch']})")
print(f"Final train F1: {training_history['train_f1'][-1]:.4f}")
print(f"Final validation F1: {training_history['val_f1'][-1]:.4f}")

results_dir = GLOBAL_CONFIG_CASME2['results_root']
os.makedirs(f"{results_dir}/training_logs", exist_ok=True)

training_history_path = f"{results_dir}/training_logs/casme2_mobilenet_mfs_prep_training_history.json"

print("\nExporting training documentation...")

try:
    training_summary = {
        'experiment_type': 'CASME2_MobileNetV3Small_MFS_PREP_Baseline',
        'experiment_configuration': {
            'model_architecture': 'MobileNetV3-Small',
            'model_parameters': '2.5M',
            'dataset_version': CASME2_MOBILENET_CONFIG['dataset_version'],
            'methodology': CASME2_MOBILENET_CONFIG['methodology'],
            'preprocessing': CASME2_MOBILENET_CONFIG['preprocessing'],
            'input_resolution': '224x224 Pure Grayscale (1 channel)',
            'training_strategy': CASME2_MOBILENET_CONFIG['training_strategy'],
            'use_pretrained_weights': CASME2_MOBILENET_CONFIG['use_pretrained_weights'],
            'frame_strategy': CASME2_MOBILENET_CONFIG['frame_strategy'],
            'train_augmentation': CASME2_MOBILENET_CONFIG['train_augmentation'],
            'loss_function': 'Optimized Focal Loss',
            'focal_loss_gamma': CASME2_MOBILENET_CONFIG['focal_loss_gamma'],
            'focal_loss_alpha_weights': CASME2_MOBILENET_CONFIG['focal_loss_alpha_weights'],
            'model_name': CASME2_MOBILENET_CONFIG['model_name']
        },
        'training_history': safe_json_serialize(training_history),
        'best_val_f1': float(best_metrics['f1']),
        'best_val_loss': float(best_metrics['loss']),
        'best_val_accuracy': float(best_metrics['accuracy']),
        'best_epoch': int(best_metrics['epoch']),
        'total_epochs': int(actual_epochs),
        'total_time_minutes': float(total_time / 60),
        'average_epoch_time_seconds': float(np.mean(training_history['epoch_time'])),
        'config': safe_json_serialize(CASME2_MOBILENET_CONFIG),
        'final_train_f1': float(training_history['train_f1'][-1]),
        'final_val_f1': float(training_history['val_f1'][-1]),
        'model_checkpoint': 'casme2_mobilenet_mfs_prep_best_f1.pth',
        'dataset_info': {
            'name': 'CASME_II',
            'version': CASME2_MOBILENET_CONFIG['dataset_version'],
            'methodology': CASME2_MOBILENET_CONFIG['methodology'],
            'input_resolution': '224x224 Pure Grayscale (1 channel)',
            'frame_strategy': CASME2_MOBILENET_CONFIG['frame_strategy'],
            'train_augmentation': CASME2_MOBILENET_CONFIG['train_augmentation'],
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'num_classes': 7,
            'class_names': CASME2_CLASSES
        },
        'architecture_info': {
            'model_type': 'MobileNetCASME2Baseline',
            'backbone': CASME2_MOBILENET_CONFIG['model_name'],
            'input_size': '224x224 Pure Grayscale (1 channel)',
            'classification_head': '576->512->128->7',
            'input_channels': 1,
            'training_strategy': 'from_scratch'
        }
    }

    with open(training_history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)

    print(f"Training documentation saved successfully: {training_history_path}")
    print(f"Model: {training_summary['experiment_configuration']['model_architecture']}")
    print(f"Methodology: {training_summary['experiment_configuration']['methodology']}")
    print(f"Input resolution: {training_summary['experiment_configuration']['input_resolution']}")
    print(f"Training strategy: {training_summary['experiment_configuration']['training_strategy']}")

except Exception as e:
    print(f"Warning: Could not save training documentation: {e}")
    print("Training completed successfully, but documentation export failed")

if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\nNext: Cell 3 - CASME II MobileNetV3-Small M2 Evaluation")
print("Training pipeline completed successfully!")

CASME II MobileNetV3-Small M2 MFS-PREP Training Pipeline
Model: MobileNetV3-Small
Methodology: M2 (Face-aware preprocessing)
Input: 224x224 Pure Grayscale (1 channel)
Training: From Scratch (No Pretrained Weights)
Loss Function: Focal Loss
  Gamma: 2.5
  Per-class Alpha: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
  Alpha Sum: 0.999
Dataset Version: v9
Frame Strategy: multi_frame_sampling
Training epochs: 50
Scheduler patience: 5

Creating CASME II MobileNetV3-Small M2 training datasets...
Loading CASME II train dataset for training...
Loaded 2613 CASME II train samples
  others: 1027 samples (39.3%)
  disgust: 650 samples (24.9%)
  happiness: 325 samples (12.4%)
  repression: 273 samples (10.4%)
  surprise: 260 samples (10.0%)
  sadness: 65 samples (2.5%)
  fear: 13 samples (0.5%)
Preloading 2613 train images to RAM with 32 workers...


Loading train to RAM: 100%|██████████| 2613/2613 [00:39<00:00, 65.85it/s] 


TRAIN RAM caching completed: 2613/2613 images, ~0.52GB
Loading CASME II val dataset for training...
Loaded 78 CASME II val samples
  others: 30 samples (38.5%)
  disgust: 18 samples (23.1%)
  happiness: 9 samples (11.5%)
  repression: 9 samples (11.5%)
  surprise: 6 samples (7.7%)
  sadness: 3 samples (3.8%)
  fear: 3 samples (3.8%)
Preloading 78 val images to RAM with 32 workers...


Loading val to RAM: 100%|██████████| 78/78 [00:01<00:00, 54.95it/s]


VAL RAM caching completed: 78/78 images, ~0.02GB
Training batches: 164 (samples: 2613)
Validation batches: 5 (samples: 78)

Initializing CASME II MobileNetV3-Small model...
MobileNetV3-Small feature dimension: 1024
Training from scratch with 1-channel input
MobileNet CASME II: 1024 -> 512 -> 128 -> 7
Architecture: Pure grayscale (1ch) from scratch
Using Optimized Focal Loss with gamma=2.5
Per-class alpha weights: [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]
Alpha sum: 0.999
Scheduler: ReduceLROnPlateau monitoring val_f1_macro
Optimizer: AdamW (LR=5e-05)
Scheduler: ReduceLROnPlateau (patience=5)
Criterion: Optimized Focal Loss
Training: From Scratch (No Pretrained Weights)

Starting CASME II MobileNetV3-Small M2 training...
Training configuration: 50 epochs
Input resolution: 224x224 Pure Grayscale (1 channel)
Training strategy: From Scratch

Epoch 1/50


CASME II Training Epoch 1/50: 100%|██████████| 164/164 [00:32<00:00,  5.07it/s, Loss=0.0904, LR=5.00e-05]
CASME II Validation Epoch 1/50: 100%|██████████| 5/5 [00:13<00:00,  2.71s/it, Val Loss=0.1270]


Train - Loss: 0.0900, F1: 0.2384, Acc: 0.2962
Val   - Loss: 0.1224, F1: 0.1967, Acc: 0.2949
Time  - Epoch: 45.9s, LR: 5.00e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_mobilenet_mfs_prep_best_f1.pth
  Epoch: 1
  Val F1: 0.1967
  Val Loss: 0.1224
  Val Acc: 0.2949
New best model: Higher F1 - F1: 0.1967
Progress: 2.0% | Best F1: 0.1967 | ETA: 37.7min

Epoch 2/50


CASME II Training Epoch 2/50: 100%|██████████| 164/164 [00:04<00:00, 33.89it/s, Loss=0.0603, LR=5.00e-05]
CASME II Validation Epoch 2/50: 100%|██████████| 5/5 [00:00<00:00,  9.50it/s, Val Loss=0.1455]


Train - Loss: 0.0601, F1: 0.4515, Acc: 0.5442
Val   - Loss: 0.1396, F1: 0.1248, Acc: 0.1795
Time  - Epoch: 5.4s, LR: 5.00e-05
Progress: 4.0% | Best F1: 0.1967 | ETA: 20.6min

Epoch 3/50


CASME II Training Epoch 3/50: 100%|██████████| 164/164 [00:04<00:00, 35.17it/s, Loss=0.0377, LR=5.00e-05]
CASME II Validation Epoch 3/50: 100%|██████████| 5/5 [00:00<00:00, 10.46it/s, Val Loss=0.1534]


Train - Loss: 0.0377, F1: 0.6628, Acc: 0.7202
Val   - Loss: 0.1453, F1: 0.1435, Acc: 0.2564
Time  - Epoch: 5.2s, LR: 5.00e-05
Progress: 6.0% | Best F1: 0.1967 | ETA: 14.8min

Epoch 4/50


CASME II Training Epoch 4/50: 100%|██████████| 164/164 [00:04<00:00, 34.15it/s, Loss=0.0219, LR=5.00e-05]
CASME II Validation Epoch 4/50: 100%|██████████| 5/5 [00:00<00:00, 10.55it/s, Val Loss=0.1648]


Train - Loss: 0.0219, F1: 0.8186, Acc: 0.8305
Val   - Loss: 0.1558, F1: 0.1251, Acc: 0.1923
Time  - Epoch: 5.3s, LR: 5.00e-05
Progress: 8.0% | Best F1: 0.1967 | ETA: 11.9min

Epoch 5/50


CASME II Training Epoch 5/50: 100%|██████████| 164/164 [00:04<00:00, 34.66it/s, Loss=0.0143, LR=5.00e-05]
CASME II Validation Epoch 5/50: 100%|██████████| 5/5 [00:00<00:00, 10.17it/s, Val Loss=0.1678]


Train - Loss: 0.0143, F1: 0.9051, Acc: 0.8990
Val   - Loss: 0.1602, F1: 0.1484, Acc: 0.2436
Time  - Epoch: 5.2s, LR: 5.00e-05
Progress: 10.0% | Best F1: 0.1967 | ETA: 10.1min

Epoch 6/50


CASME II Training Epoch 6/50: 100%|██████████| 164/164 [00:04<00:00, 34.66it/s, Loss=0.0100, LR=5.00e-05]
CASME II Validation Epoch 6/50: 100%|██████████| 5/5 [00:00<00:00,  9.47it/s, Val Loss=0.1713]


Train - Loss: 0.0100, F1: 0.9251, Acc: 0.9334
Val   - Loss: 0.1646, F1: 0.1369, Acc: 0.2179
Time  - Epoch: 5.3s, LR: 5.00e-05
Progress: 12.0% | Best F1: 0.1967 | ETA: 8.9min

Epoch 7/50


CASME II Training Epoch 7/50: 100%|██████████| 164/164 [00:04<00:00, 34.40it/s, Loss=0.0073, LR=5.00e-05]
CASME II Validation Epoch 7/50: 100%|██████████| 5/5 [00:00<00:00, 10.75it/s, Val Loss=0.1751]


Train - Loss: 0.0075, F1: 0.9503, Acc: 0.9480
Val   - Loss: 0.1693, F1: 0.1562, Acc: 0.2436
Time  - Epoch: 5.2s, LR: 2.50e-05
Progress: 14.0% | Best F1: 0.1967 | ETA: 8.0min

Epoch 8/50


CASME II Training Epoch 8/50: 100%|██████████| 164/164 [00:04<00:00, 34.64it/s, Loss=0.0058, LR=2.50e-05]
CASME II Validation Epoch 8/50: 100%|██████████| 5/5 [00:00<00:00,  9.90it/s, Val Loss=0.1691]


Train - Loss: 0.0059, F1: 0.9587, Acc: 0.9640
Val   - Loss: 0.1655, F1: 0.1896, Acc: 0.2949
Time  - Epoch: 5.3s, LR: 2.50e-05
Progress: 16.0% | Best F1: 0.1967 | ETA: 7.3min

Epoch 9/50


CASME II Training Epoch 9/50: 100%|██████████| 164/164 [00:04<00:00, 34.22it/s, Loss=0.0046, LR=2.50e-05]
CASME II Validation Epoch 9/50: 100%|██████████| 5/5 [00:00<00:00,  9.90it/s, Val Loss=0.1755]


Train - Loss: 0.0046, F1: 0.9760, Acc: 0.9713
Val   - Loss: 0.1704, F1: 0.1867, Acc: 0.2949
Time  - Epoch: 5.3s, LR: 2.50e-05
Progress: 18.0% | Best F1: 0.1967 | ETA: 6.7min

Epoch 10/50


CASME II Training Epoch 10/50: 100%|██████████| 164/164 [00:04<00:00, 34.38it/s, Loss=0.0042, LR=2.50e-05]
CASME II Validation Epoch 10/50: 100%|██████████| 5/5 [00:00<00:00, 10.70it/s, Val Loss=0.1708]


Train - Loss: 0.0047, F1: 0.9664, Acc: 0.9698
Val   - Loss: 0.1657, F1: 0.1906, Acc: 0.3333
Time  - Epoch: 5.3s, LR: 2.50e-05
Progress: 20.0% | Best F1: 0.1967 | ETA: 6.2min

Epoch 11/50


CASME II Training Epoch 11/50: 100%|██████████| 164/164 [00:04<00:00, 33.33it/s, Loss=0.0032, LR=2.50e-05]
CASME II Validation Epoch 11/50: 100%|██████████| 5/5 [00:00<00:00, 10.85it/s, Val Loss=0.1760]


Train - Loss: 0.0032, F1: 0.9831, Acc: 0.9816
Val   - Loss: 0.1721, F1: 0.1592, Acc: 0.2436
Time  - Epoch: 5.4s, LR: 2.50e-05
Progress: 22.0% | Best F1: 0.1967 | ETA: 5.8min

Epoch 12/50


CASME II Training Epoch 12/50: 100%|██████████| 164/164 [00:04<00:00, 34.37it/s, Loss=0.0034, LR=2.50e-05]
CASME II Validation Epoch 12/50: 100%|██████████| 5/5 [00:00<00:00,  9.79it/s, Val Loss=0.1732]


Train - Loss: 0.0034, F1: 0.9839, Acc: 0.9812
Val   - Loss: 0.1699, F1: 0.2209, Acc: 0.3333
Time  - Epoch: 5.3s, LR: 2.50e-05
Attempt 1: Saving checkpoint to temporary file...
Validating checkpoint integrity...
Checkpoint validation passed
Moving validated checkpoint to final location...
Checkpoint saved and validated successfully: casme2_mobilenet_mfs_prep_best_f1.pth
  Epoch: 12
  Val F1: 0.2209
  Val Loss: 0.1699
  Val Acc: 0.3333
New best model: Higher F1 - F1: 0.2209
Progress: 24.0% | Best F1: 0.2209 | ETA: 5.5min

Epoch 13/50


CASME II Training Epoch 13/50: 100%|██████████| 164/164 [00:04<00:00, 34.25it/s, Loss=0.0028, LR=2.50e-05]
CASME II Validation Epoch 13/50: 100%|██████████| 5/5 [00:00<00:00,  9.54it/s, Val Loss=0.1765]


Train - Loss: 0.0028, F1: 0.9854, Acc: 0.9824
Val   - Loss: 0.1710, F1: 0.1629, Acc: 0.2949
Time  - Epoch: 5.3s, LR: 2.50e-05
Progress: 26.0% | Best F1: 0.2209 | ETA: 5.2min

Epoch 14/50


CASME II Training Epoch 14/50: 100%|██████████| 164/164 [00:04<00:00, 35.01it/s, Loss=0.0027, LR=2.50e-05]
CASME II Validation Epoch 14/50: 100%|██████████| 5/5 [00:00<00:00, 10.65it/s, Val Loss=0.1721]


Train - Loss: 0.0028, F1: 0.9828, Acc: 0.9824
Val   - Loss: 0.1683, F1: 0.1435, Acc: 0.2436
Time  - Epoch: 5.2s, LR: 2.50e-05
Progress: 28.0% | Best F1: 0.2209 | ETA: 4.9min

Epoch 15/50


CASME II Training Epoch 15/50: 100%|██████████| 164/164 [00:04<00:00, 34.52it/s, Loss=0.0025, LR=2.50e-05]
CASME II Validation Epoch 15/50: 100%|██████████| 5/5 [00:00<00:00,  9.81it/s, Val Loss=0.1761]


Train - Loss: 0.0026, F1: 0.9847, Acc: 0.9843
Val   - Loss: 0.1724, F1: 0.1917, Acc: 0.3333
Time  - Epoch: 5.3s, LR: 2.50e-05
Progress: 30.0% | Best F1: 0.2209 | ETA: 4.7min

Epoch 16/50


CASME II Training Epoch 16/50: 100%|██████████| 164/164 [00:04<00:00, 33.48it/s, Loss=0.0022, LR=2.50e-05]
CASME II Validation Epoch 16/50: 100%|██████████| 5/5 [00:00<00:00,  9.54it/s, Val Loss=0.1799]


Train - Loss: 0.0022, F1: 0.9910, Acc: 0.9878
Val   - Loss: 0.1754, F1: 0.2140, Acc: 0.3462
Time  - Epoch: 5.4s, LR: 2.50e-05
Progress: 32.0% | Best F1: 0.2209 | ETA: 4.5min

Epoch 17/50


CASME II Training Epoch 17/50: 100%|██████████| 164/164 [00:04<00:00, 35.59it/s, Loss=0.0023, LR=2.50e-05]
CASME II Validation Epoch 17/50: 100%|██████████| 5/5 [00:00<00:00,  9.78it/s, Val Loss=0.1883]


Train - Loss: 0.0023, F1: 0.9852, Acc: 0.9812
Val   - Loss: 0.1821, F1: 0.1866, Acc: 0.3077
Time  - Epoch: 5.1s, LR: 2.50e-05
Progress: 34.0% | Best F1: 0.2209 | ETA: 4.2min

Epoch 18/50


CASME II Training Epoch 18/50: 100%|██████████| 164/164 [00:04<00:00, 33.92it/s, Loss=0.0018, LR=2.50e-05]
CASME II Validation Epoch 18/50: 100%|██████████| 5/5 [00:00<00:00,  9.61it/s, Val Loss=0.1822]


Train - Loss: 0.0017, F1: 0.9907, Acc: 0.9881
Val   - Loss: 0.1773, F1: 0.1862, Acc: 0.3333
Time  - Epoch: 5.4s, LR: 1.25e-05
Progress: 36.0% | Best F1: 0.2209 | ETA: 4.0min

Epoch 19/50


CASME II Training Epoch 19/50: 100%|██████████| 164/164 [00:04<00:00, 35.83it/s, Loss=0.0013, LR=1.25e-05]
CASME II Validation Epoch 19/50: 100%|██████████| 5/5 [00:00<00:00, 10.73it/s, Val Loss=0.1871]


Train - Loss: 0.0013, F1: 0.9965, Acc: 0.9954
Val   - Loss: 0.1807, F1: 0.1934, Acc: 0.3333
Time  - Epoch: 5.1s, LR: 1.25e-05
Progress: 38.0% | Best F1: 0.2209 | ETA: 3.8min

Epoch 20/50


CASME II Training Epoch 20/50: 100%|██████████| 164/164 [00:04<00:00, 35.35it/s, Loss=0.0016, LR=1.25e-05]
CASME II Validation Epoch 20/50: 100%|██████████| 5/5 [00:00<00:00,  9.50it/s, Val Loss=0.1830]


Train - Loss: 0.0017, F1: 0.9911, Acc: 0.9897
Val   - Loss: 0.1811, F1: 0.1807, Acc: 0.3333
Time  - Epoch: 5.2s, LR: 1.25e-05
Progress: 40.0% | Best F1: 0.2209 | ETA: 3.7min

Epoch 21/50


CASME II Training Epoch 21/50: 100%|██████████| 164/164 [00:04<00:00, 34.31it/s, Loss=0.0013, LR=1.25e-05]
CASME II Validation Epoch 21/50: 100%|██████████| 5/5 [00:00<00:00, 10.66it/s, Val Loss=0.1868]


Train - Loss: 0.0013, F1: 0.9948, Acc: 0.9935
Val   - Loss: 0.1817, F1: 0.1625, Acc: 0.2949
Time  - Epoch: 5.3s, LR: 1.25e-05
Progress: 42.0% | Best F1: 0.2209 | ETA: 3.5min

Epoch 22/50


CASME II Training Epoch 22/50: 100%|██████████| 164/164 [00:04<00:00, 36.26it/s, Loss=0.0014, LR=1.25e-05]
CASME II Validation Epoch 22/50: 100%|██████████| 5/5 [00:00<00:00,  9.97it/s, Val Loss=0.1865]


Train - Loss: 0.0016, F1: 0.9936, Acc: 0.9935
Val   - Loss: 0.1815, F1: 0.1775, Acc: 0.3205
Time  - Epoch: 5.0s, LR: 1.25e-05
Progress: 44.0% | Best F1: 0.2209 | ETA: 3.3min

Epoch 23/50


CASME II Training Epoch 23/50: 100%|██████████| 164/164 [00:04<00:00, 34.19it/s, Loss=0.0014, LR=1.25e-05]
CASME II Validation Epoch 23/50: 100%|██████████| 5/5 [00:00<00:00, 10.68it/s, Val Loss=0.1830]


Train - Loss: 0.0014, F1: 0.9927, Acc: 0.9927
Val   - Loss: 0.1784, F1: 0.1823, Acc: 0.3333
Time  - Epoch: 5.3s, LR: 1.25e-05
Progress: 46.0% | Best F1: 0.2209 | ETA: 3.2min

Epoch 24/50


CASME II Training Epoch 24/50: 100%|██████████| 164/164 [00:04<00:00, 35.96it/s, Loss=0.0012, LR=1.25e-05]
CASME II Validation Epoch 24/50: 100%|██████████| 5/5 [00:00<00:00,  9.44it/s, Val Loss=0.1810]


Train - Loss: 0.0012, F1: 0.9954, Acc: 0.9943
Val   - Loss: 0.1790, F1: 0.1909, Acc: 0.3077
Time  - Epoch: 5.1s, LR: 6.25e-06
Progress: 48.0% | Best F1: 0.2209 | ETA: 3.0min

Epoch 25/50


CASME II Training Epoch 25/50: 100%|██████████| 164/164 [00:04<00:00, 35.23it/s, Loss=0.0012, LR=6.25e-06]
CASME II Validation Epoch 25/50: 100%|██████████| 5/5 [00:00<00:00,  9.89it/s, Val Loss=0.1814]


Train - Loss: 0.0012, F1: 0.9956, Acc: 0.9939
Val   - Loss: 0.1765, F1: 0.1770, Acc: 0.3205
Time  - Epoch: 5.2s, LR: 6.25e-06
Progress: 50.0% | Best F1: 0.2209 | ETA: 2.9min

Epoch 26/50


CASME II Training Epoch 26/50: 100%|██████████| 164/164 [00:04<00:00, 35.36it/s, Loss=0.0011, LR=6.25e-06]
CASME II Validation Epoch 26/50: 100%|██████████| 5/5 [00:00<00:00,  9.57it/s, Val Loss=0.1916]


Train - Loss: 0.0011, F1: 0.9946, Acc: 0.9943
Val   - Loss: 0.1856, F1: 0.1547, Acc: 0.2821
Time  - Epoch: 5.2s, LR: 6.25e-06
Progress: 52.0% | Best F1: 0.2209 | ETA: 2.7min

Epoch 27/50


CASME II Training Epoch 27/50: 100%|██████████| 164/164 [00:05<00:00, 32.14it/s, Loss=0.0011, LR=6.25e-06]
CASME II Validation Epoch 27/50: 100%|██████████| 5/5 [00:00<00:00, 10.54it/s, Val Loss=0.1876]


Train - Loss: 0.0012, F1: 0.9963, Acc: 0.9950
Val   - Loss: 0.1851, F1: 0.2142, Acc: 0.3333
Time  - Epoch: 5.6s, LR: 6.25e-06
Progress: 54.0% | Best F1: 0.2209 | ETA: 2.6min

Epoch 28/50


CASME II Training Epoch 28/50: 100%|██████████| 164/164 [00:04<00:00, 34.87it/s, Loss=0.0012, LR=6.25e-06]
CASME II Validation Epoch 28/50: 100%|██████████| 5/5 [00:00<00:00, 10.69it/s, Val Loss=0.1848]


Train - Loss: 0.0012, F1: 0.9937, Acc: 0.9916
Val   - Loss: 0.1793, F1: 0.2108, Acc: 0.3462
Time  - Epoch: 5.2s, LR: 6.25e-06
Progress: 56.0% | Best F1: 0.2209 | ETA: 2.5min

Epoch 29/50


CASME II Training Epoch 29/50: 100%|██████████| 164/164 [00:04<00:00, 35.42it/s, Loss=0.0012, LR=6.25e-06]
CASME II Validation Epoch 29/50: 100%|██████████| 5/5 [00:00<00:00, 10.02it/s, Val Loss=0.1917]


Train - Loss: 0.0012, F1: 0.9945, Acc: 0.9931
Val   - Loss: 0.1848, F1: 0.1759, Acc: 0.3205
Time  - Epoch: 5.1s, LR: 6.25e-06
Progress: 58.0% | Best F1: 0.2209 | ETA: 2.3min

Epoch 30/50


CASME II Training Epoch 30/50: 100%|██████████| 164/164 [00:04<00:00, 33.47it/s, Loss=0.0010, LR=6.25e-06]
CASME II Validation Epoch 30/50: 100%|██████████| 5/5 [00:00<00:00, 10.84it/s, Val Loss=0.1951]


Train - Loss: 0.0010, F1: 0.9961, Acc: 0.9946
Val   - Loss: 0.1880, F1: 0.1496, Acc: 0.3077
Time  - Epoch: 5.4s, LR: 3.13e-06
Progress: 60.0% | Best F1: 0.2209 | ETA: 2.2min

Epoch 31/50


CASME II Training Epoch 31/50: 100%|██████████| 164/164 [00:04<00:00, 35.34it/s, Loss=0.0012, LR=3.13e-06]
CASME II Validation Epoch 31/50: 100%|██████████| 5/5 [00:00<00:00,  9.73it/s, Val Loss=0.1858]


Train - Loss: 0.0013, F1: 0.9927, Acc: 0.9916
Val   - Loss: 0.1795, F1: 0.1711, Acc: 0.2949
Time  - Epoch: 5.2s, LR: 3.13e-06
Progress: 62.0% | Best F1: 0.2209 | ETA: 2.1min

Epoch 32/50


CASME II Training Epoch 32/50: 100%|██████████| 164/164 [00:04<00:00, 36.01it/s, Loss=0.0008, LR=3.13e-06]
CASME II Validation Epoch 32/50: 100%|██████████| 5/5 [00:00<00:00,  9.04it/s, Val Loss=0.1866]


Train - Loss: 0.0009, F1: 0.9972, Acc: 0.9962
Val   - Loss: 0.1818, F1: 0.2116, Acc: 0.3462
Time  - Epoch: 5.1s, LR: 3.13e-06
Progress: 64.0% | Best F1: 0.2209 | ETA: 2.0min

Epoch 33/50


CASME II Training Epoch 33/50: 100%|██████████| 164/164 [00:04<00:00, 35.72it/s, Loss=0.0012, LR=3.13e-06]
CASME II Validation Epoch 33/50: 100%|██████████| 5/5 [00:00<00:00,  9.90it/s, Val Loss=0.1909]


Train - Loss: 0.0012, F1: 0.9937, Acc: 0.9923
Val   - Loss: 0.1827, F1: 0.1651, Acc: 0.3077
Time  - Epoch: 5.1s, LR: 3.13e-06
Progress: 66.0% | Best F1: 0.2209 | ETA: 1.8min

Epoch 34/50


CASME II Training Epoch 34/50: 100%|██████████| 164/164 [00:04<00:00, 33.83it/s, Loss=0.0010, LR=3.13e-06]
CASME II Validation Epoch 34/50: 100%|██████████| 5/5 [00:00<00:00, 10.65it/s, Val Loss=0.1931]


Train - Loss: 0.0011, F1: 0.9950, Acc: 0.9935
Val   - Loss: 0.1866, F1: 0.1619, Acc: 0.3333
Time  - Epoch: 5.3s, LR: 3.13e-06
Progress: 68.0% | Best F1: 0.2209 | ETA: 1.7min

Epoch 35/50


CASME II Training Epoch 35/50: 100%|██████████| 164/164 [00:04<00:00, 34.12it/s, Loss=0.0007, LR=3.13e-06]
CASME II Validation Epoch 35/50: 100%|██████████| 5/5 [00:00<00:00,  9.72it/s, Val Loss=0.1960]


Train - Loss: 0.0008, F1: 0.9968, Acc: 0.9958
Val   - Loss: 0.1884, F1: 0.1850, Acc: 0.3333
Time  - Epoch: 5.3s, LR: 3.13e-06
Progress: 70.0% | Best F1: 0.2209 | ETA: 1.6min

Epoch 36/50


CASME II Training Epoch 36/50: 100%|██████████| 164/164 [00:04<00:00, 35.62it/s, Loss=0.0009, LR=3.13e-06]
CASME II Validation Epoch 36/50: 100%|██████████| 5/5 [00:00<00:00, 10.83it/s, Val Loss=0.1957]


Train - Loss: 0.0010, F1: 0.9957, Acc: 0.9939
Val   - Loss: 0.1879, F1: 0.1738, Acc: 0.3205
Time  - Epoch: 5.1s, LR: 1.56e-06
Progress: 72.0% | Best F1: 0.2209 | ETA: 1.5min

Epoch 37/50


CASME II Training Epoch 37/50: 100%|██████████| 164/164 [00:04<00:00, 35.23it/s, Loss=0.0009, LR=1.56e-06]
CASME II Validation Epoch 37/50: 100%|██████████| 5/5 [00:00<00:00,  9.20it/s, Val Loss=0.1901]


Train - Loss: 0.0009, F1: 0.9961, Acc: 0.9943
Val   - Loss: 0.1824, F1: 0.1913, Acc: 0.3333
Time  - Epoch: 5.2s, LR: 1.56e-06
Progress: 74.0% | Best F1: 0.2209 | ETA: 1.4min

Epoch 38/50


CASME II Training Epoch 38/50: 100%|██████████| 164/164 [00:04<00:00, 35.61it/s, Loss=0.0009, LR=1.56e-06]
CASME II Validation Epoch 38/50: 100%|██████████| 5/5 [00:00<00:00,  9.53it/s, Val Loss=0.1898]


Train - Loss: 0.0009, F1: 0.9955, Acc: 0.9939
Val   - Loss: 0.1853, F1: 0.2130, Acc: 0.3333
Time  - Epoch: 5.1s, LR: 1.56e-06
Progress: 76.0% | Best F1: 0.2209 | ETA: 1.3min

Epoch 39/50


CASME II Training Epoch 39/50: 100%|██████████| 164/164 [00:04<00:00, 34.93it/s, Loss=0.0010, LR=1.56e-06]
CASME II Validation Epoch 39/50: 100%|██████████| 5/5 [00:00<00:00,  9.57it/s, Val Loss=0.1910]


Train - Loss: 0.0010, F1: 0.9961, Acc: 0.9946
Val   - Loss: 0.1861, F1: 0.1990, Acc: 0.3205
Time  - Epoch: 5.2s, LR: 1.56e-06
Progress: 78.0% | Best F1: 0.2209 | ETA: 1.2min

Epoch 40/50


CASME II Training Epoch 40/50: 100%|██████████| 164/164 [00:04<00:00, 33.92it/s, Loss=0.0010, LR=1.56e-06]
CASME II Validation Epoch 40/50: 100%|██████████| 5/5 [00:00<00:00,  9.63it/s, Val Loss=0.1944]


Train - Loss: 0.0010, F1: 0.9953, Acc: 0.9958
Val   - Loss: 0.1884, F1: 0.1712, Acc: 0.3205
Time  - Epoch: 5.4s, LR: 1.56e-06
Progress: 80.0% | Best F1: 0.2209 | ETA: 1.0min

Epoch 41/50


CASME II Training Epoch 41/50: 100%|██████████| 164/164 [00:04<00:00, 34.99it/s, Loss=0.0009, LR=1.56e-06]
CASME II Validation Epoch 41/50: 100%|██████████| 5/5 [00:00<00:00,  9.07it/s, Val Loss=0.1941]


Train - Loss: 0.0009, F1: 0.9962, Acc: 0.9950
Val   - Loss: 0.1882, F1: 0.1551, Acc: 0.3205
Time  - Epoch: 5.3s, LR: 1.56e-06
Progress: 82.0% | Best F1: 0.2209 | ETA: 0.9min

Epoch 42/50


CASME II Training Epoch 42/50: 100%|██████████| 164/164 [00:05<00:00, 28.54it/s, Loss=0.0011, LR=1.56e-06]
CASME II Validation Epoch 42/50: 100%|██████████| 5/5 [00:00<00:00,  8.94it/s, Val Loss=0.1961]


Train - Loss: 0.0011, F1: 0.9923, Acc: 0.9916
Val   - Loss: 0.1901, F1: 0.1661, Acc: 0.3077
Time  - Epoch: 6.3s, LR: 1.00e-06
Progress: 84.0% | Best F1: 0.2209 | ETA: 0.8min

Epoch 43/50


CASME II Training Epoch 43/50: 100%|██████████| 164/164 [00:04<00:00, 34.09it/s, Loss=0.0008, LR=1.00e-06]
CASME II Validation Epoch 43/50: 100%|██████████| 5/5 [00:00<00:00, 10.40it/s, Val Loss=0.1930]


Train - Loss: 0.0008, F1: 0.9960, Acc: 0.9950
Val   - Loss: 0.1866, F1: 0.1761, Acc: 0.3205
Time  - Epoch: 5.3s, LR: 1.00e-06
Progress: 86.0% | Best F1: 0.2209 | ETA: 0.7min

Epoch 44/50


CASME II Training Epoch 44/50: 100%|██████████| 164/164 [00:04<00:00, 34.54it/s, Loss=0.0006, LR=1.00e-06]
CASME II Validation Epoch 44/50: 100%|██████████| 5/5 [00:00<00:00, 10.69it/s, Val Loss=0.1899]


Train - Loss: 0.0006, F1: 0.9985, Acc: 0.9981
Val   - Loss: 0.1841, F1: 0.1997, Acc: 0.3462
Time  - Epoch: 5.2s, LR: 1.00e-06
Progress: 88.0% | Best F1: 0.2209 | ETA: 0.6min

Epoch 45/50


CASME II Training Epoch 45/50: 100%|██████████| 164/164 [00:04<00:00, 35.32it/s, Loss=0.0012, LR=1.00e-06]
CASME II Validation Epoch 45/50: 100%|██████████| 5/5 [00:00<00:00,  9.53it/s, Val Loss=0.1917]


Train - Loss: 0.0012, F1: 0.9959, Acc: 0.9935
Val   - Loss: 0.1882, F1: 0.1813, Acc: 0.3077
Time  - Epoch: 5.2s, LR: 1.00e-06
Progress: 90.0% | Best F1: 0.2209 | ETA: 0.5min

Epoch 46/50


CASME II Training Epoch 46/50: 100%|██████████| 164/164 [00:04<00:00, 34.85it/s, Loss=0.0009, LR=1.00e-06]
CASME II Validation Epoch 46/50: 100%|██████████| 5/5 [00:00<00:00,  9.45it/s, Val Loss=0.1895]


Train - Loss: 0.0009, F1: 0.9974, Acc: 0.9966
Val   - Loss: 0.1833, F1: 0.1851, Acc: 0.3333
Time  - Epoch: 5.3s, LR: 1.00e-06
Progress: 92.0% | Best F1: 0.2209 | ETA: 0.4min

Epoch 47/50


CASME II Training Epoch 47/50: 100%|██████████| 164/164 [00:04<00:00, 34.77it/s, Loss=0.0011, LR=1.00e-06]
CASME II Validation Epoch 47/50: 100%|██████████| 5/5 [00:00<00:00,  9.67it/s, Val Loss=0.1949]


Train - Loss: 0.0018, F1: 0.9931, Acc: 0.9931
Val   - Loss: 0.1893, F1: 0.1765, Acc: 0.3333
Time  - Epoch: 5.2s, LR: 1.00e-06
Progress: 94.0% | Best F1: 0.2209 | ETA: 0.3min

Epoch 48/50


CASME II Training Epoch 48/50: 100%|██████████| 164/164 [00:04<00:00, 34.91it/s, Loss=0.0009, LR=1.00e-06]
CASME II Validation Epoch 48/50: 100%|██████████| 5/5 [00:00<00:00, 10.64it/s, Val Loss=0.1925]


Train - Loss: 0.0009, F1: 0.9944, Acc: 0.9950
Val   - Loss: 0.1859, F1: 0.2066, Acc: 0.3462
Time  - Epoch: 5.2s, LR: 1.00e-06
Progress: 96.0% | Best F1: 0.2209 | ETA: 0.2min

Epoch 49/50


CASME II Training Epoch 49/50: 100%|██████████| 164/164 [00:04<00:00, 33.61it/s, Loss=0.0008, LR=1.00e-06]
CASME II Validation Epoch 49/50: 100%|██████████| 5/5 [00:00<00:00,  9.37it/s, Val Loss=0.1940]


Train - Loss: 0.0011, F1: 0.9966, Acc: 0.9943
Val   - Loss: 0.1862, F1: 0.1916, Acc: 0.3462
Time  - Epoch: 5.4s, LR: 1.00e-06
Progress: 98.0% | Best F1: 0.2209 | ETA: 0.1min

Epoch 50/50


CASME II Training Epoch 50/50: 100%|██████████| 164/164 [00:04<00:00, 34.81it/s, Loss=0.0009, LR=1.00e-06]
CASME II Validation Epoch 50/50: 100%|██████████| 5/5 [00:00<00:00, 10.57it/s, Val Loss=0.1932]

Train - Loss: 0.0010, F1: 0.9958, Acc: 0.9962
Val   - Loss: 0.1862, F1: 0.1862, Acc: 0.3333
Time  - Epoch: 5.2s, LR: 1.00e-06
Progress: 100.0% | Best F1: 0.2209 | ETA: 0.0min

CASME II MOBILENETV3-SMALL M2 MFS-PREP TRAINING COMPLETED
Training time: 5.1 minutes
Epochs completed: 50
Best validation F1: 0.2209 (epoch 12)
Final train F1: 0.9958
Final validation F1: 0.1862

Exporting training documentation...
Training documentation saved successfully: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/results/09_01_mobilenet_casme2_mfs_prep/training_logs/casme2_mobilenet_mfs_prep_training_history.json
Model: MobileNetV3-Small
Methodology: M2
Input resolution: 224x224 Pure Grayscale (1 channel)
Training strategy: from_scratch

Next: Cell 3 - CASME II MobileNetV3-Small M2 Evaluation
Training pipeline completed successfully!





In [3]:
# @title Cell 3: CASME II MobileNetV3-Small M2 MFS-PREP Evaluation (Dual Dataset)

# File: 09_01_MobileNet_CASME2_MFS_PREP_Cell3.py
# Location: experiments/09_01_MobileNet_CASME2-MFS-PREP.ipynb
# Purpose: Comprehensive evaluation framework with v7 (AF) and v8 (KFS) test datasets

import os
import time
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict

from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    classification_report, confusion_matrix,
    roc_curve, auc
)
from sklearn.preprocessing import label_binarize
from concurrent.futures import ThreadPoolExecutor
import warnings
warnings.filterwarnings('ignore')

EVALUATE_DATASETS = ['v7', 'v8']

print("CASME II MobileNetV3-Small M2 MFS-PREP Evaluation Framework")
print("=" * 60)
print(f"Datasets to evaluate: {EVALUATE_DATASETS}")
print(f"Input: 224x224 Pure Grayscale (1 channel)")
print(f"Model: From Scratch Training")
print("=" * 60)

def get_test_dataset_config(version, project_root):
    if version == 'v7':
        config = {
            'version': 'v7',
            'variant': 'AF',
            'dataset_path': f"{project_root}/datasets/processed_casme2/preprocessed_v7",
            'preprocessing_summary': 'preprocessing_summary.json',
            'description': 'Apex Frame with Face-Aware Preprocessing',
            'expected_samples': 28,
            'frame_strategy': 'apex_frame',
            'evaluation_mode': 'frame_level',
            'aggregation': None
        }
    elif version == 'v8':
        config = {
            'version': 'v8',
            'variant': 'KFS',
            'dataset_path': f"{project_root}/datasets/processed_casme2/preprocessed_v8",
            'preprocessing_summary': 'preprocessing_summary.json',
            'description': 'Key Frame Sequence with Face-Aware Preprocessing',
            'expected_frames': 84,
            'expected_videos': 28,
            'frame_strategy': 'key_frame_sequence',
            'frame_types': ['onset', 'apex', 'offset'],
            'evaluation_mode': 'video_level',
            'aggregation': 'late_fusion'
        }
    else:
        raise ValueError(f"Invalid version: {version}. Must be 'v7' or 'v8'")

    return config

def extract_video_id_from_filename(filename):
    name_without_ext = filename.rsplit('.', 1)[0]

    for frame_type in ['onset', 'apex', 'offset']:
        if name_without_ext.endswith(f'_{frame_type}'):
            video_id = name_without_ext.rsplit(f'_{frame_type}', 1)[0]
            return video_id

    return name_without_ext

class CASME2DatasetEvaluation(Dataset):
    def __init__(self, dataset_root, split, transform=None, use_ram_cache=True):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.use_ram_cache = use_ram_cache
        self.images = []
        self.labels = []
        self.filenames = []
        self.emotions = []
        self.video_ids = []
        self.cached_images = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading CASME II {split} dataset for evaluation...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        loaded_count = 0

        for filename in sorted(all_files):
            emotion_found = None
            name_without_ext = filename.rsplit('.', 1)[0]

            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                video_id = extract_video_id_from_filename(filename)

                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                self.emotions.append(emotion_found)
                self.video_ids.append(video_id)
                loaded_count += 1

        print(f"Loaded {len(self.images)} CASME II {split} samples for evaluation")
        self._print_evaluation_distribution()

        if self.use_ram_cache and len(self.images) > 0:
            self._preload_to_ram_evaluation()

    def _print_evaluation_distribution(self):
        if len(self.labels) == 0:
            print("No test samples found!")
            return

        label_counts = {}
        unique_videos = set(self.video_ids)

        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        print("Test set class distribution:")
        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

        print(f"Unique video IDs: {len(unique_videos)}")

        missing_classes = []
        for i, class_name in enumerate(CASME2_CLASSES):
            if i not in label_counts:
                missing_classes.append(class_name)

        if missing_classes:
            print(f"Missing classes in test set: {missing_classes}")

    def _preload_to_ram_evaluation(self):
        if len(self.images) == 0:
            return

        print(f"Preloading {len(self.images)} test images to RAM...")

        self.cached_images = [None] * len(self.images)

        def load_single_image(idx, img_path):
            try:
                image = Image.open(img_path)
                if image.mode != 'L':
                    image = image.convert('L')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
                return idx, image, True
            except:
                return idx, Image.new('L', (224, 224), 128), False

        with ThreadPoolExecutor(max_workers=RAM_PRELOAD_WORKERS) as executor:
            futures = [executor.submit(load_single_image, i, path)
                      for i, path in enumerate(self.images)]

            for future in tqdm(futures, desc="Loading test to RAM"):
                idx, image, success = future.result()
                self.cached_images[idx] = image

        print(f"RAM caching completed: {len(self.cached_images)} test images")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.use_ram_cache and self.cached_images[idx] is not None:
            image = self.cached_images[idx].copy()
        else:
            try:
                image = Image.open(self.images[idx])
                if image.mode != 'L':
                    image = image.convert('L')
                if image.size != (224, 224):
                    image = image.resize((224, 224), Image.Resampling.LANCZOS)
            except:
                image = Image.new('L', (224, 224), 128)

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx], self.video_ids[idx]

EVALUATION_CONFIG_CASME2 = {
    'model_type': 'MobileNetV3Small_CASME2_MFS_PREP_Baseline',
    'task_type': 'micro_expression_recognition',
    'num_classes': 7,
    'class_names': CASME2_CLASSES,
    'checkpoint_file': 'casme2_mobilenet_mfs_prep_best_f1.pth',
    'dataset_name': 'CASME_II',
    'methodology': 'M2',
    'input_resolution': '224x224 Pure Grayscale (1 channel)',
    'training_strategy': 'from_scratch',
    'evaluation_protocol': 'dual_test_v7_v8'
}

print(f"\nCASME II MobileNetV3 M2 Evaluation Configuration:")
print(f"  Model: {EVALUATION_CONFIG_CASME2['model_type']}")
print(f"  Methodology: {EVALUATION_CONFIG_CASME2['methodology']}")
print(f"  Input resolution: {EVALUATION_CONFIG_CASME2['input_resolution']}")
print(f"  Training strategy: {EVALUATION_CONFIG_CASME2['training_strategy']}")

def load_trained_model_casme2(checkpoint_path, device):
    print(f"Loading trained model from: {checkpoint_path}")

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    checkpoint = None
    loading_method = "unknown"

    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        loading_method = "standard"
    except Exception as e1:
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            loading_method = "weights_only_false"
        except Exception as e2:
            try:
                import pickle
                with open(checkpoint_path, 'rb') as f:
                    checkpoint = pickle.load(f)
                loading_method = "pickle"
            except Exception as e3:
                raise RuntimeError(f"All loading methods failed: {e1}, {e2}, {e3}")

    print(f"Checkpoint loaded using: {loading_method}")

    model = MobileNetCASME2Baseline(
        num_classes=EVALUATION_CONFIG_CASME2['num_classes'],
        dropout_rate=checkpoint['casme2_config']['dropout_rate'],
        in_channels=1
    ).to(device)

    state_dict = checkpoint.get('model_state_dict', checkpoint)

    try:
        model.load_state_dict(state_dict, strict=True)
        print("Model state loaded with strict=True")
    except Exception as e:
        print(f"Strict loading failed, trying non-strict: {str(e)[:100]}...")
        try:
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            if missing_keys or unexpected_keys:
                print(f"Non-strict loading: Missing {len(missing_keys)}, Unexpected {len(unexpected_keys)}")
            else:
                print("Model state loaded with strict=False (no key mismatches)")
        except Exception as e2:
            raise RuntimeError(f"Both loading approaches failed: {e2}")

    model.eval()

    training_info = {
        'best_val_f1': float(checkpoint.get('best_f1', 0.0)),
        'best_val_loss': float(checkpoint.get('best_loss', float('inf'))),
        'best_val_accuracy': float(checkpoint.get('best_acc', 0.0)),
        'best_epoch': int(checkpoint.get('epoch', 0)) + 1,
        'model_checkpoint': EVALUATION_CONFIG_CASME2['checkpoint_file'],
        'num_classes': EVALUATION_CONFIG_CASME2['num_classes'],
        'config': checkpoint.get('casme2_config', {})
    }

    print(f"Model loaded successfully:")
    print(f"  Best validation F1: {training_info['best_val_f1']:.4f}")
    print(f"  Best validation accuracy: {training_info['best_val_accuracy']:.4f}")
    print(f"  Best epoch: {training_info['best_epoch']}")

    return model, training_info

def run_frame_level_inference(model, test_loader, device):
    print("Running frame-level inference...")

    model.eval()
    all_predictions = []
    all_labels = []
    all_filenames = []
    all_probabilities = []

    start_time = time.time()

    with torch.no_grad():
        for images, labels, filenames, video_ids in tqdm(test_loader, desc="Frame-level inference"):
            images = images.to(device)

            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)

            _, predicted = torch.max(outputs, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_filenames.extend(filenames)
            all_probabilities.extend(probabilities.cpu().numpy())

    inference_time = time.time() - start_time

    return {
        'predictions': all_predictions,
        'labels': all_labels,
        'filenames': all_filenames,
        'probabilities': all_probabilities,
        'inference_time': inference_time,
        'evaluation_mode': 'frame_level'
    }

def run_video_level_inference_late_fusion(model, test_loader, device):
    print("Running video-level inference with late fusion...")

    model.eval()

    frame_predictions = []
    frame_labels = []
    frame_filenames = []
    frame_video_ids = []
    frame_probabilities = []

    start_time = time.time()

    with torch.no_grad():
        for images, labels, filenames, video_ids in tqdm(test_loader, desc="Frame inference"):
            images = images.to(device)

            outputs = model(images)
            probabilities = F.softmax(outputs, dim=1)

            _, predicted = torch.max(outputs, 1)

            frame_predictions.extend(predicted.cpu().numpy())
            frame_labels.extend(labels.numpy())
            frame_filenames.extend(filenames)
            frame_video_ids.extend(video_ids)
            frame_probabilities.extend(probabilities.cpu().numpy())

    print("Aggregating frame predictions to video level...")

    video_data = {}
    for pred, label, filename, video_id, prob in zip(
        frame_predictions, frame_labels, frame_filenames, frame_video_ids, frame_probabilities
    ):
        if video_id not in video_data:
            video_data[video_id] = {
                'predictions': [],
                'probabilities': [],
                'label': label,
                'filenames': []
            }

        video_data[video_id]['predictions'].append(pred)
        video_data[video_id]['probabilities'].append(prob)
        video_data[video_id]['filenames'].append(filename)

    video_predictions = []
    video_labels = []
    video_ids_list = []

    for video_id, data in video_data.items():
        avg_prob = np.mean(data['probabilities'], axis=0)
        final_pred = np.argmax(avg_prob)

        video_predictions.append(final_pred)
        video_labels.append(data['label'])
        video_ids_list.append(video_id)

    inference_time = time.time() - start_time

    return {
        'predictions': video_predictions,
        'labels': video_labels,
        'video_ids': video_ids_list,
        'inference_time': inference_time,
        'evaluation_mode': 'video_level',
        'kfs_late_fusion_info': {
            'total_frames': len(frame_predictions),
            'total_videos': len(video_predictions),
            'aggregation_method': 'average_probability'
        }
    }

def calculate_comprehensive_metrics(inference_results):
    predictions = np.array(inference_results['predictions'])
    labels = np.array(inference_results['labels'])

    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, support = precision_recall_fscore_support(
        labels, predictions, average='macro', zero_division=0
    )

    per_class_precision, per_class_recall, per_class_f1, per_class_support = \
        precision_recall_fscore_support(labels, predictions, average=None, zero_division=0)

    cm = confusion_matrix(labels, predictions)

    try:
        if 'probabilities' in inference_results:
            labels_bin = label_binarize(labels, classes=range(len(CASME2_CLASSES)))
            auc_scores = []

            for i in range(len(CASME2_CLASSES)):
                if labels_bin[:, i].sum() > 0:
                    fpr, tpr, _ = roc_curve(labels_bin[:, i],
                                           np.array(inference_results['probabilities'])[:, i])
                    auc_scores.append(auc(fpr, tpr))
                else:
                    auc_scores.append(0.0)

            macro_auc = np.mean([score for score in auc_scores if score > 0])
        else:
            auc_scores = [0.0] * len(CASME2_CLASSES)
            macro_auc = 0.0
    except:
        auc_scores = [0.0] * len(CASME2_CLASSES)
        macro_auc = 0.0

    unique_labels = set(labels)
    available_classes = [CASME2_CLASSES[i] for i in unique_labels]
    missing_classes = [cls for i, cls in enumerate(CASME2_CLASSES) if i not in unique_labels]

    per_class_performance = {}
    for i, class_name in enumerate(CASME2_CLASSES):
        per_class_performance[class_name] = {
            'precision': float(per_class_precision[i]) if i < len(per_class_precision) else 0.0,
            'recall': float(per_class_recall[i]) if i < len(per_class_recall) else 0.0,
            'f1_score': float(per_class_f1[i]) if i < len(per_class_f1) else 0.0,
            'support': int(per_class_support[i]) if i < len(per_class_support) else 0,
            'auc': float(auc_scores[i]) if i < len(auc_scores) else 0.0,
            'in_test_set': i in unique_labels
        }

    inference_performance = {
        'total_time_seconds': inference_results['inference_time'],
        'average_time_ms_per_sample': (inference_results['inference_time'] / len(predictions)) * 1000
    }

    results = {
        'evaluation_metadata': {
            'dataset': 'CASME_II',
            'model_type': 'MobileNetCASME2Baseline',
            'methodology': 'M2',
            'input_resolution': '224x224 Pure Grayscale (1 channel)',
            'training_strategy': 'from_scratch',
            'evaluation_timestamp': datetime.now().isoformat(),
            'evaluation_mode': inference_results['evaluation_mode'],
            'test_samples': len(predictions),
            'class_names': CASME2_CLASSES,
            'available_classes': available_classes,
            'missing_classes': missing_classes
        },
        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_precision': float(precision),
            'macro_recall': float(recall),
            'macro_f1': float(f1),
            'macro_auc': float(macro_auc)
        },
        'per_class_performance': per_class_performance,
        'confusion_matrix': cm.tolist(),
        'inference_performance': inference_performance
    }

    if 'kfs_late_fusion_info' in inference_results:
        results['kfs_late_fusion_info'] = inference_results['kfs_late_fusion_info']

    return results

def analyze_wrong_predictions(inference_results):
    predictions = np.array(inference_results['predictions'])
    labels = np.array(inference_results['labels'])

    if 'filenames' in inference_results:
        filenames = inference_results['filenames']
    elif 'video_ids' in inference_results:
        filenames = inference_results['video_ids']
    else:
        filenames = [f"sample_{i}" for i in range(len(predictions))]

    wrong_predictions = []
    wrong_by_class = defaultdict(int)
    confusion_patterns = defaultdict(int)

    for i, (pred, true_label) in enumerate(zip(predictions, labels)):
        if pred != true_label:
            true_class = CASME2_CLASSES[true_label]
            pred_class = CASME2_CLASSES[pred]

            wrong_predictions.append({
                'filename': filenames[i],
                'true_label': int(true_label),
                'true_class': true_class,
                'predicted_label': int(pred),
                'predicted_class': pred_class
            })

            wrong_by_class[true_class] += 1
            confusion_patterns[f"{true_class}->{pred_class}"] += 1

    error_summary = {}
    for class_name in CASME2_CLASSES:
        class_idx = CLASS_TO_IDX[class_name]
        class_mask = labels == class_idx
        class_total = class_mask.sum()

        if class_total > 0:
            class_wrong = wrong_by_class.get(class_name, 0)
            error_rate = (class_wrong / class_total) * 100
        else:
            class_wrong = 0
            error_rate = 0.0

        error_summary[class_name] = {
            'total_samples': int(class_total),
            'wrong_predictions': int(class_wrong),
            'error_rate_percent': float(error_rate)
        }

    results = {
        'analysis_metadata': {
            'total_samples': len(predictions),
            'total_wrong_predictions': len(wrong_predictions),
            'overall_error_rate': (len(wrong_predictions) / len(predictions) * 100) if len(predictions) > 0 else 0.0
        },
        'wrong_predictions': wrong_predictions,
        'wrong_predictions_by_class': dict(wrong_by_class),
        'error_summary': error_summary,
        'confusion_patterns': dict(confusion_patterns)
    }

    return results

def save_evaluation_results(evaluation_results, wrong_predictions_results, results_dir, test_version):
    os.makedirs(results_dir, exist_ok=True)

    results_file = f"{results_dir}/casme2_mobilenet_mfs_prep_evaluation_results_{test_version}.json"
    with open(results_file, 'w') as f:
        json.dump(evaluation_results, f, indent=2, default=str)

    wrong_predictions_file = f"{results_dir}/casme2_mobilenet_mfs_prep_wrong_predictions_{test_version}.json"
    with open(wrong_predictions_file, 'w') as f:
        json.dump(wrong_predictions_results, f, indent=2, default=str)

    print(f"Evaluation results saved:")
    print(f"  Main results: {os.path.basename(results_file)}")
    print(f"  Wrong predictions: {os.path.basename(wrong_predictions_file)}")

    return results_file, wrong_predictions_file

all_evaluation_results = {}

checkpoint_path = f"{GLOBAL_CONFIG_CASME2['checkpoint_root']}/{EVALUATION_CONFIG_CASME2['checkpoint_file']}"
casme2_model, training_info = load_trained_model_casme2(checkpoint_path, GLOBAL_CONFIG_CASME2['device'])

results_dir = f"{GLOBAL_CONFIG_CASME2['results_root']}/evaluation_results"

for dataset_version in EVALUATE_DATASETS:
    print("\n" + "=" * 70)
    print(f"EVALUATING DATASET: {dataset_version.upper()}")
    print("=" * 70)

    try:
        test_config = get_test_dataset_config(dataset_version, PROJECT_ROOT)

        print(f"\nTest Dataset Configuration:")
        print(f"  Version: {test_config['version']}")
        print(f"  Variant: {test_config['variant']}")
        print(f"  Description: {test_config['description']}")
        print(f"  Frame strategy: {test_config['frame_strategy']}")
        print(f"  Evaluation mode: {test_config['evaluation_mode']}")
        if 'aggregation' in test_config and test_config['aggregation']:
            print(f"  Aggregation: {test_config['aggregation']}")
        print(f"  Dataset path: {test_config['dataset_path']}")

        print(f"\nCreating CASME II test dataset from {test_config['variant']}...")
        test_dataset = CASME2DatasetEvaluation(
            dataset_root=test_config['dataset_path'],
            split='test',
            transform=GLOBAL_CONFIG_CASME2['transform_val'],
            use_ram_cache=True
        )

        if len(test_dataset) == 0:
            raise ValueError(f"No test samples found for {dataset_version}!")

        test_loader = DataLoader(
            test_dataset,
            batch_size=CASME2_MOBILENET_CONFIG['batch_size'],
            shuffle=False,
            num_workers=CASME2_MOBILENET_CONFIG['num_workers'],
            pin_memory=True
        )

        if test_config['evaluation_mode'] == 'frame_level':
            print(f"\nRunning frame-level evaluation for {test_config['variant']}...")
            inference_results = run_frame_level_inference(casme2_model, test_loader, GLOBAL_CONFIG_CASME2['device'])

        elif test_config['evaluation_mode'] == 'video_level':
            print(f"\nRunning video-level evaluation with late fusion for {test_config['variant']}...")
            inference_results = run_video_level_inference_late_fusion(casme2_model, test_loader, GLOBAL_CONFIG_CASME2['device'])

        else:
            raise ValueError(f"Unknown evaluation mode: {test_config['evaluation_mode']}")

        evaluation_results = calculate_comprehensive_metrics(inference_results)

        wrong_predictions_results = analyze_wrong_predictions(inference_results)

        evaluation_results['training_information'] = training_info
        evaluation_results['test_configuration'] = test_config

        results_file, wrong_file = save_evaluation_results(
            evaluation_results, wrong_predictions_results, results_dir, test_config['version']
        )

        all_evaluation_results[dataset_version] = {
            'evaluation': evaluation_results,
            'wrong_predictions': wrong_predictions_results,
            'config': test_config
        }

        print("\n" + "=" * 60)
        print(f"EVALUATION RESULTS - {test_config['variant']} ({dataset_version})")
        print("=" * 60)

        overall = evaluation_results['overall_performance']
        print(f"\nOverall Performance:")
        print(f"  Accuracy:  {overall['accuracy']:.4f}")
        print(f"  Precision: {overall['macro_precision']:.4f}")
        print(f"  Recall:    {overall['macro_recall']:.4f}")
        print(f"  F1 Score:  {overall['macro_f1']:.4f}")
        print(f"  AUC:       {overall['macro_auc']:.4f}")

        if 'kfs_late_fusion_info' in evaluation_results:
            fusion_info = evaluation_results['kfs_late_fusion_info']
            print(f"\nLate Fusion Info:")
            print(f"  Total frames processed: {fusion_info['total_frames']}")
            print(f"  Video-level predictions: {fusion_info['total_videos']}")
            print(f"  Aggregation method: {fusion_info['aggregation_method']}")

        print(f"\nPer-Class Performance:")
        for class_name, metrics in evaluation_results['per_class_performance'].items():
            in_test = "Present" if metrics['in_test_set'] else "Missing"
            print(f"  {class_name} [{in_test}]: F1={metrics['f1_score']:.4f}, "
                  f"Support={metrics['support']}")

        wrong_meta = wrong_predictions_results['analysis_metadata']
        print(f"\nWrong Predictions Analysis:")
        print(f"  Total errors: {wrong_meta['total_wrong_predictions']} / {wrong_meta['total_samples']}")
        print(f"  Error rate: {wrong_meta['overall_error_rate']:.2f}%")

        print(f"\nInference Performance:")
        print(f"  Total time: {inference_results['inference_time']:.2f}s")
        print(f"  Speed: {evaluation_results['inference_performance']['average_time_ms_per_sample']:.1f} ms/sample")

    except Exception as e:
        print(f"Evaluation failed for {dataset_version}: {e}")
        import traceback
        traceback.print_exc()
        continue

if len(all_evaluation_results) == 2 and 'v7' in all_evaluation_results and 'v8' in all_evaluation_results:
    print("\n" + "=" * 70)
    print("COMPARATIVE ANALYSIS: AF (v7) vs KFS (v8)")
    print("=" * 70)

    v7_results = all_evaluation_results['v7']['evaluation']
    v8_results = all_evaluation_results['v8']['evaluation']

    print("\nOverall Performance Comparison:")
    print(f"{'Metric':<20} {'AF (v7)':<15} {'KFS (v8)':<15} {'Difference':<15}")
    print("-" * 65)

    metrics_to_compare = ['accuracy', 'macro_precision', 'macro_recall', 'macro_f1', 'macro_auc']

    for metric in metrics_to_compare:
        v7_val = v7_results['overall_performance'][metric]
        v8_val = v8_results['overall_performance'][metric]
        diff = v8_val - v7_val

        print(f"{metric:<20} {v7_val:<15.4f} {v8_val:<15.4f} {diff:+.4f}")

    print(f"\nEvaluation Modes:")
    print(f"  AF (v7): {v7_results['evaluation_metadata']['evaluation_mode']}")
    print(f"  KFS (v8): {v8_results['evaluation_metadata']['evaluation_mode']}")

    if 'kfs_late_fusion_info' in v8_results:
        print(f"\nKFS Late Fusion Strategy:")
        print(f"  Frames used: {v8_results['kfs_late_fusion_info']['total_frames']}")
        print(f"  Video predictions: {v8_results['kfs_late_fusion_info']['total_videos']}")
        print(f"  Aggregation: {v8_results['kfs_late_fusion_info']['aggregation_method']}")

if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()

print("\n" + "=" * 70)
print("CASME II MOBILENETV3-SMALL M2 MFS-PREP EVALUATION COMPLETED")
print("=" * 70)
print(f"Evaluated datasets: {EVALUATE_DATASETS}")
print("Next: Cell 4 - Generate confusion matrices and visualization")

CASME II MobileNetV3-Small M2 MFS-PREP Evaluation Framework
Datasets to evaluate: ['v7', 'v8']
Input: 224x224 Pure Grayscale (1 channel)
Model: From Scratch Training

CASME II MobileNetV3 M2 Evaluation Configuration:
  Model: MobileNetV3Small_CASME2_MFS_PREP_Baseline
  Methodology: M2
  Input resolution: 224x224 Pure Grayscale (1 channel)
  Training strategy: from_scratch
Loading trained model from: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/09_01_mobilenet_casme2_mfs_prep/casme2_mobilenet_mfs_prep_best_f1.pth
Checkpoint loaded using: standard
MobileNetV3-Small feature dimension: 1024
Training from scratch with 1-channel input
MobileNet CASME II: 1024 -> 512 -> 128 -> 7
Architecture: Pure grayscale (1ch) from scratch
Model state loaded with strict=True
Model loaded successfully:
  Best validation F1: 0.2209
  Best validation accuracy: 0.3333
  Best epoch: 12

EVALUATING DATASET: V7

Test Dataset Configuration:
  Version: v7
  Variant: AF
  Descr

Loading test to RAM: 100%|██████████| 28/28 [00:00<00:00, 41.39it/s]


RAM caching completed: 28 test images

Running frame-level evaluation for AF...
Running frame-level inference...


Frame-level inference: 100%|██████████| 2/2 [00:13<00:00,  6.76s/it]


Evaluation results saved:
  Main results: casme2_mobilenet_mfs_prep_evaluation_results_v7.json
  Wrong predictions: casme2_mobilenet_mfs_prep_wrong_predictions_v7.json

EVALUATION RESULTS - AF (v7)

Overall Performance:
  Accuracy:  0.3571
  Precision: 0.1981
  Recall:    0.2202
  F1 Score:  0.2037
  AUC:       0.4922

Per-Class Performance:
  others [Present]: F1=0.5556, Support=10
  disgust [Present]: F1=0.4444, Support=7
  happiness [Present]: F1=0.2222, Support=4
  repression [Present]: F1=0.0000, Support=3
  surprise [Present]: F1=0.0000, Support=3
  sadness [Present]: F1=0.0000, Support=1
  fear [Missing]: F1=0.0000, Support=0

Wrong Predictions Analysis:
  Total errors: 18 / 28
  Error rate: 64.29%

Inference Performance:
  Total time: 13.52s
  Speed: 482.9 ms/sample

EVALUATING DATASET: V8

Test Dataset Configuration:
  Version: v8
  Variant: KFS
  Description: Key Frame Sequence with Face-Aware Preprocessing
  Frame strategy: key_frame_sequence
  Evaluation mode: video_level
 

Loading test to RAM: 100%|██████████| 84/84 [00:02<00:00, 39.11it/s]


RAM caching completed: 84 test images

Running video-level evaluation with late fusion for KFS...
Running video-level inference with late fusion...


Frame inference: 100%|██████████| 6/6 [00:13<00:00,  2.23s/it]

Aggregating frame predictions to video level...
Evaluation results saved:
  Main results: casme2_mobilenet_mfs_prep_evaluation_results_v8.json
  Wrong predictions: casme2_mobilenet_mfs_prep_wrong_predictions_v8.json

EVALUATION RESULTS - KFS (v8)

Overall Performance:
  Accuracy:  0.3452
  Precision: 0.2009
  Recall:    0.2229
  F1 Score:  0.2084
  AUC:       0.0000

Late Fusion Info:
  Total frames processed: 84
  Video-level predictions: 84
  Aggregation method: average_probability

Per-Class Performance:
  others [Present]: F1=0.5172, Support=30
  disgust [Present]: F1=0.3774, Support=21
  happiness [Present]: F1=0.2308, Support=12
  repression [Present]: F1=0.1250, Support=9
  surprise [Present]: F1=0.0000, Support=9
  sadness [Present]: F1=0.0000, Support=3
  fear [Missing]: F1=0.0000, Support=0

Wrong Predictions Analysis:
  Total errors: 55 / 84
  Error rate: 65.48%

Inference Performance:
  Total time: 13.40s
  Speed: 159.6 ms/sample

COMPARATIVE ANALYSIS: AF (v7) vs KFS (v8)






In [4]:
# @title Cell 4: CASME II MobileNetV3-Small M2 MFS-PREP Confusion Matrix Generation

# File: 09_01_MobileNet_CASME2_MFS_PREP_Cell4.py
# Location: experiments/09_01_MobileNet_CASME2-MFS-PREP.ipynb
# Purpose: Generate professional confusion matrix visualizations for v7 (AF) and v8 (KFS) test sets

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II MobileNetV3-Small M2 MFS-PREP Confusion Matrix Generation")
print("=" * 60)

PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/09_01_mobilenet_casme2_mfs_prep"

def find_evaluation_json_files_casme2(results_path):
    json_files = {}

    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        for version in ['v7', 'v8']:
            eval_pattern = f"{eval_dir}/casme2_mobilenet_mfs_prep_evaluation_results_{version}.json"
            eval_files = glob.glob(eval_pattern)

            if eval_files:
                json_files[f'main_{version}'] = eval_files[0]
                print(f"Found {version.upper()} evaluation file: {os.path.basename(eval_files[0])}")

        if not json_files:
            print(f"WARNING: No evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (evaluation) has been executed first!")
    else:
        print(f"ERROR: Evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results_casme2(json_path):
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1_casme2(per_class_performance):
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy_casme2(confusion_matrix):
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []

    classes_with_samples = []
    for i in range(n_classes):
        if cm[i, :].sum() > 0:
            classes_with_samples.append(i)

    for i in classes_with_samples:
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - tp - fn - fp

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color_casme2(color_value, threshold=0.5):
    return 'white' if color_value > threshold else 'black'

def create_confusion_matrix_plot_casme2(data, output_path, test_version):
    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']
    test_config = data.get('test_configuration', {})

    test_desc = test_config.get('description', test_version)
    variant = test_config.get('variant', test_version.upper())
    methodology = meta.get('methodology', 'M2')
    input_res = meta.get('input_resolution', '224x224 Pure Grayscale (1 channel)')
    training_strategy = meta.get('training_strategy', 'from_scratch')

    print(f"Processing confusion matrix for {test_version.upper()}")
    print(f"Confusion matrix shape: {cm.shape}")

    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1_casme2(per_class)
    balanced_acc = calculate_balanced_accuracy_casme2(cm)

    print(f"Metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Acc: {accuracy:.4f}, Balanced Acc: {balanced_acc:.4f}")

    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    fig, ax = plt.subplots(figsize=(12, 10))

    cmap = 'Blues'

    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"

            cell_value = cm_pct[i, j]
            text_color = determine_text_color_casme2(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    missing_classes = meta.get('missing_classes', [])
    note_text = f"Test: {test_desc} ({variant})\n{methodology} | {input_res}\nTraining: {training_strategy}"
    if missing_classes:
        note_text += f"\nMissing: {', '.join(missing_classes)}"

    ax.text(0.02, 0.98, note_text, transform=ax.transAxes, fontsize=9,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    title = f"CASME II MobileNetV3-Small M2 MFS-PREP - {variant}\n"
    title += f"Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Acc: {accuracy:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes,
        'test_version': test_version,
        'variant': variant
    }

json_files = find_evaluation_json_files_casme2(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (evaluation) has been executed first!")
else:
    print(f"\nFound {len([k for k in json_files.keys() if k.startswith('main_')])} evaluation result(s)")

output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

results_summary = {}
generated_files = []

for version in ['v7', 'v8']:
    main_key = f'main_{version}'

    if main_key in json_files:
        print(f"\n{'='*60}")
        print(f"Processing {version.upper()} Confusion Matrix")
        print(f"{'='*60}")

        eval_data = load_evaluation_results_casme2(json_files[main_key])

        if eval_data is not None:
            try:
                cm_output_path = os.path.join(output_dir, f"confusion_matrix_CASME2_MobileNet_MFS_PREP_{version.upper()}.png")
                metrics = create_confusion_matrix_plot_casme2(eval_data, cm_output_path, version)
                generated_files.append(cm_output_path)

                results_summary[version] = metrics

                print(f"SUCCESS: {version.upper()} confusion matrix generated")

            except Exception as e:
                print(f"ERROR: Failed to generate {version.upper()} confusion matrix: {str(e)}")
                import traceback
                traceback.print_exc()
        else:
            print(f"ERROR: Could not load {version.upper()} evaluation data")

if generated_files:
    print(f"\n" + "=" * 60)
    print("CASME II MOBILENETV3-SMALL M2 MFS-PREP CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 60)

    print(f"\nGenerated confusion matrix files:")
    for file_path in generated_files:
        filename = os.path.basename(file_path)
        print(f"  {filename}")

    print(f"\nPerformance Summary:")
    for version in ['v7', 'v8']:
        if version in results_summary:
            metrics = results_summary[version]
            variant = metrics.get('variant', version.upper())
            print(f"\n{variant}:")
            print(f"  Macro F1:       {metrics['macro_f1']:.4f}")
            print(f"  Weighted F1:    {metrics['weighted_f1']:.4f}")
            print(f"  Accuracy:       {metrics['accuracy']:.4f}")
            print(f"  Balanced Acc:   {metrics['balanced_accuracy']:.4f}")

            if metrics['missing_classes']:
                print(f"  Missing classes: {len(metrics['missing_classes'])}")

    if len(results_summary) == 2:
        print(f"\nComparative Analysis:")
        v8_f1 = results_summary['v8']['macro_f1']
        v7_f1 = results_summary['v7']['macro_f1']
        delta_f1 = v8_f1 - v7_f1

        v8_variant = results_summary['v8'].get('variant', 'KFS')
        v7_variant = results_summary['v7'].get('variant', 'AF')

        print(f"  {v8_variant} vs {v7_variant} (Macro F1): {v8_f1:.4f} vs {v7_f1:.4f}")
        print(f"  Delta ({v8_variant} - {v7_variant}): {delta_f1:+.4f}")

        if delta_f1 > 0:
            improvement_pct = (delta_f1 / v7_f1) * 100
            print(f"  {v8_variant} improves by {improvement_pct:.1f}% over {v7_variant}")
        else:
            degradation_pct = (abs(delta_f1) / v8_f1) * 100
            print(f"  {v8_variant} degrades by {degradation_pct:.1f}% from {v7_variant}")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No confusion matrices were generated")
    print("Please check:")
    print("1. Cell 3 evaluation results exist")
    print("2. JSON file structure is correct")
    print("3. No file permission issues")

print("\nCell 4 completed - CASME II MobileNetV3-Small M2 MFS-PREP confusion matrix analysis generated")

CASME II MobileNetV3-Small M2 MFS-PREP Confusion Matrix Generation
Found V7 evaluation file: casme2_mobilenet_mfs_prep_evaluation_results_v7.json
Found V8 evaluation file: casme2_mobilenet_mfs_prep_evaluation_results_v8.json

Found 2 evaluation result(s)

Processing V7 Confusion Matrix
Successfully loaded: casme2_mobilenet_mfs_prep_evaluation_results_v7.json
Processing confusion matrix for V7
Confusion matrix shape: (6, 6)
Metrics - Macro F1: 0.2037, Weighted F1: 0.3413, Acc: 0.3571, Balanced Acc: 0.5412
Confusion matrix saved to: confusion_matrix_CASME2_MobileNet_MFS_PREP_V7.png
SUCCESS: V7 confusion matrix generated

Processing V8 Confusion Matrix
Successfully loaded: casme2_mobilenet_mfs_prep_evaluation_results_v8.json
Processing confusion matrix for V8
Confusion matrix shape: (6, 6)
Metrics - Macro F1: 0.2084, Weighted F1: 0.3254, Acc: 0.3452, Balanced Acc: 0.5395
Confusion matrix saved to: confusion_matrix_CASME2_MobileNet_MFS_PREP_V8.png
SUCCESS: V8 confusion matrix generated

CA