In [1]:
# @title Cell 1: CASME II ConvNeXT-Tiny M2 MFS-PREP Infrastructure Configuration

# File: 09_03_ConvNeXT_CASME2_MFS_PREP_Cell1.py
# Location: experiments/09_03_ConvNeXT_CASME2-MFS-PREP.ipynb
# Purpose: ConvNeXT-Tiny for CASME II micro-expression recognition with M2 preprocessed methodology

from google.colab import drive
print("=" * 60)
print("CASME II CNN BASELINE - ConvNeXT-Tiny M2 MFS-PREP")
print("=" * 60)
print("\n[1] Mounting Google Drive...")
drive.mount('/content/drive')
print("Google Drive mounted successfully")

print("\n[2] Importing required libraries...")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import timm
import json
import os
import numpy as np
import pandas as pd
from PIL import Image
import time
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
import warnings
warnings.filterwarnings('ignore')

PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
DATASET_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/preprocessed_v9"
CHECKPOINT_ROOT = f"{PROJECT_ROOT}/models/09_03_convnext_casme2_mfs_prep"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/09_03_convnext_casme2_mfs_prep"

PREPROCESSING_SUMMARY = f"{DATASET_ROOT}/preprocessing_summary.json"

print("CASME II ConvNeXT-Tiny M2 MFS-PREP - Infrastructure Configuration")
print("=" * 60)

if not os.path.exists(PREPROCESSING_SUMMARY):
    raise FileNotFoundError(f"v9 preprocessing summary not found: {PREPROCESSING_SUMMARY}")

print("Loading CASME II v9 preprocessing metadata...")
with open(PREPROCESSING_SUMMARY, 'r') as f:
    preprocessing_info = json.load(f)

print(f"Dataset variant: {preprocessing_info['variant']}")
print(f"Processing date: {preprocessing_info['processing_date']}")
print(f"Preprocessing method: {preprocessing_info['preprocessing_method']}")
print(f"Total images processed: {preprocessing_info['total_processed']}")
print(f"Face detection rate: {preprocessing_info['face_detection_stats']['detection_rate']:.2%}")

preproc_params = preprocessing_info['preprocessing_parameters']
print(f"Target size: {preproc_params['target_size']}x{preproc_params['target_size']}px")
print(f"BBox expansion: {preproc_params['bbox_expansion']}px")
print(f"Image format: Grayscale (1 channel)")

print(f"\nDataset split information:")
print(f"  Train: {preprocessing_info['splits']['train']['total_images']} frames")
print(f"  Validation: {preprocessing_info['splits']['val']['total_images']} frames")
print(f"  Test: {preprocessing_info['splits']['test']['total_images']} frames")

USE_FOCAL_LOSS = True
FOCAL_LOSS_GAMMA = 2.5

CROSSENTROPY_CLASS_WEIGHTS = [1.00, 1.25, 1.76, 1.91, 1.99, 3.76, 7.04]
FOCAL_LOSS_ALPHA_WEIGHTS = [0.053, 0.067, 0.094, 0.102, 0.106, 0.201, 0.376]

CONVNEXT_MODEL_NAME = 'convnext_tiny'
USE_PURE_GRAYSCALE = True

print("\n" + "=" * 50)
print("EXPERIMENT CONFIGURATION - CNN M2 MFS-PREP")
print("=" * 50)
print(f"Model: ConvNeXT-Tiny (TIMM)")
print(f"Methodology: M2 (Face-Aware Preprocessing)")
print(f"Input Resolution: 224x224 Pure Grayscale (1 channel)")
print(f"Training Strategy: From Scratch (No Pretrained Weights)")
print(f"Preprocessing: Face detection + crop + grayscale")
print(f"Loss Function: Focal Loss")
print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
print(f"  Alpha Weights: {FOCAL_LOSS_ALPHA_WEIGHTS}")
print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
print("=" * 50)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0

print(f"\nDevice: {device}")
print(f"GPU: {gpu_name} ({gpu_memory:.1f} GB)")

if 'A100' in gpu_name:
    BATCH_SIZE = 16
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("A100: Optimized batch size for ConvNeXT-Tiny (largest model)")
elif 'L4' in gpu_name:
    BATCH_SIZE = 12
    NUM_WORKERS = 8
    torch.backends.cudnn.benchmark = True
    print("L4: Balanced performance configuration")
else:
    BATCH_SIZE = 8
    NUM_WORKERS = 8
    print("Default GPU: Conservative settings for large model")

RAM_PRELOAD_WORKERS = 32
print(f"RAM preload workers: {RAM_PRELOAD_WORKERS}")

CASME2_CLASSES = ['others', 'disgust', 'happiness', 'repression', 'surprise', 'sadness', 'fear']
CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CASME2_CLASSES)}

print("\nLoading v9 class distribution...")
train_dist = preprocessing_info['splits']['train']['emotion_distribution']
val_dist = preprocessing_info['splits']['val']['emotion_distribution']
test_dist = preprocessing_info['splits']['test']['emotion_distribution']

def emotion_dist_to_list(emotion_dict, class_names):
    return [emotion_dict.get(cls, 0) for cls in class_names]

train_dist_list = emotion_dist_to_list(train_dist, CASME2_CLASSES)
val_dist_list = emotion_dist_to_list(val_dist, CASME2_CLASSES)
test_dist_list = emotion_dist_to_list(test_dist, CASME2_CLASSES)

print(f"\nv9 Train distribution: {train_dist_list}")
print(f"v9 Val distribution: {val_dist_list}")
print(f"v9 Test distribution: {test_dist_list}")

if USE_FOCAL_LOSS:
    class_weights = torch.tensor(FOCAL_LOSS_ALPHA_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied Focal Loss alpha weights: {class_weights.cpu().numpy()}")
    print(f"Alpha weights sum: {class_weights.sum().item():.3f}")
else:
    class_weights = torch.tensor(CROSSENTROPY_CLASS_WEIGHTS, dtype=torch.float32).to(device)
    print(f"Applied CrossEntropy class weights: {class_weights.cpu().numpy()}")

CASME2_CONVNEXT_CONFIG = {
    'model_name': CONVNEXT_MODEL_NAME,
    'input_size': (224, 224),
    'num_classes': 7,
    'dropout_rate': 0.3,

    'learning_rate': 5e-5,
    'weight_decay': 1e-5,
    'gradient_clip': 1.0,
    'num_epochs': 50,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'device': device,

    'scheduler_type': 'plateau',
    'scheduler_mode': 'max',
    'scheduler_factor': 0.5,
    'scheduler_patience': 5,
    'scheduler_min_lr': 1e-6,
    'scheduler_monitor': 'val_f1_macro',

    'use_focal_loss': USE_FOCAL_LOSS,
    'focal_loss_gamma': FOCAL_LOSS_GAMMA,
    'focal_loss_alpha_weights': FOCAL_LOSS_ALPHA_WEIGHTS,
    'crossentropy_class_weights': CROSSENTROPY_CLASS_WEIGHTS,
    'class_weights': class_weights,

    'use_macro_avg': True,
    'early_stopping': False,
    'save_best_f1': True,
    'save_strategy': 'best_only',

    'dataset_version': 'v9',
    'methodology': 'M2',
    'preprocessing': 'face_aware_preprocessing',
    'frame_strategy': 'multi_frame_sampling',
    'train_augmentation': 'frame_level_independent',
    'image_format': 'pure_grayscale_1channel',
    'use_pure_grayscale': USE_PURE_GRAYSCALE,
    'use_pretrained_weights': False,
    'training_strategy': 'from_scratch',
    'preprocessing_details': {
        'face_detection': True,
        'crop_method': 'face_bbox_expansion',
        'target_size': preproc_params['target_size'],
        'bbox_expansion': preproc_params['bbox_expansion'],
        'grayscale_conversion': True,
        'input_channels': 1
    }
}

print(f"\nConvNeXT-Tiny Configuration Summary:")
print(f"  Model: {CASME2_CONVNEXT_CONFIG['model_name']}")
print(f"  Input size: {CASME2_CONVNEXT_CONFIG['input_size'][0]}x{CASME2_CONVNEXT_CONFIG['input_size'][1]} Pure Grayscale (1ch)")
print(f"  Methodology: {CASME2_CONVNEXT_CONFIG['methodology']} (face-aware preprocessing)")
print(f"  Training: From Scratch (No Pretrained Weights)")
print(f"  Learning rate: {CASME2_CONVNEXT_CONFIG['learning_rate']}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Dataset version: {CASME2_CONVNEXT_CONFIG['dataset_version']}")
print(f"  Frame strategy: {CASME2_CONVNEXT_CONFIG['frame_strategy']}")

class OptimizedFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(OptimizedFocalLoss, self).__init__()

        if alpha is not None:
            if isinstance(alpha, list):
                self.alpha = torch.tensor(alpha, dtype=torch.float32)
            else:
                self.alpha = alpha

            alpha_sum = self.alpha.sum().item()
            if abs(alpha_sum - 1.0) > 0.01:
                print(f"Warning: Alpha weights sum to {alpha_sum:.3f}, expected 1.0")
        else:
            self.alpha = None

        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)

        if self.alpha is not None:
            if self.alpha.device != targets.device:
                self.alpha = self.alpha.to(targets.device)
            alpha_t = self.alpha.gather(0, targets)
        else:
            alpha_t = 1.0

        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class ConvNeXTCASME2Baseline(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.3, in_channels=1):
        super(ConvNeXTCASME2Baseline, self).__init__()

        self.convnext = timm.create_model(
            CONVNEXT_MODEL_NAME,
            pretrained=False,
            num_classes=0,
            global_pool='avg',
            in_chans=in_channels
        )

        for param in self.convnext.parameters():
            param.requires_grad = True

        with torch.no_grad():
            test_input = torch.randn(1, in_channels, 224, 224)
            test_output = self.convnext(test_input)
            self.convnext_feature_dim = test_output.shape[1]

        print(f"ConvNeXT-Tiny feature dimension: {self.convnext_feature_dim}")
        print(f"Training from scratch with {in_channels}-channel input")

        self.classifier_layers = nn.Sequential(
            nn.Linear(self.convnext_feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
        )

        self.classifier = nn.Linear(128, num_classes)

        print(f"ConvNeXT CASME II: {self.convnext_feature_dim} -> 512 -> 128 -> {num_classes}")
        print(f"Architecture: Pure grayscale (1ch) from scratch")

    def forward(self, x):
        features = self.convnext(x)
        processed_features = self.classifier_layers(features)
        output = self.classifier(processed_features)
        return output

def create_optimizer_scheduler_casme2(model, config):
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay'],
        betas=(0.9, 0.999)
    )

    if config['scheduler_type'] == 'plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=config['scheduler_mode'],
            factor=config['scheduler_factor'],
            patience=config['scheduler_patience'],
            min_lr=config['scheduler_min_lr']
        )
        print(f"Scheduler: ReduceLROnPlateau monitoring {config['scheduler_monitor']}")
    else:
        scheduler = None

    return optimizer, scheduler

print("\nSetting up transforms for M2 methodology (224x224 pure grayscale)...")

convnext_transform_train = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

convnext_transform_val = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

print("M2 transforms configured: 224x224 pure grayscale (1 channel) with standardization")

class CASME2Dataset(Dataset):
    def __init__(self, dataset_root, split, transform=None):
        self.dataset_root = dataset_root
        self.split = split
        self.transform = transform
        self.images = []
        self.labels = []
        self.filenames = []

        split_path = os.path.join(dataset_root, split)

        print(f"Loading {split} dataset from {split_path}...")

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Split directory not found: {split_path}")

        all_files = [f for f in os.listdir(split_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        print(f"Found {len(all_files)} image files in directory")

        loaded_count = 0

        for filename in sorted(all_files):
            emotion_found = None
            name_without_ext = filename.rsplit('.', 1)[0]

            for emotion_class in CASME2_CLASSES:
                if emotion_class in name_without_ext.lower():
                    emotion_found = emotion_class
                    break

            if emotion_found and emotion_found in CLASS_TO_IDX:
                image_path = os.path.join(split_path, filename)
                self.images.append(image_path)
                self.labels.append(CLASS_TO_IDX[emotion_found])
                self.filenames.append(filename)
                loaded_count += 1

        print(f"Loaded {len(self.images)} samples for {split} split")
        self._print_distribution()

    def _print_distribution(self):
        if len(self.labels) == 0:
            print("No samples to display distribution")
            return

        label_counts = {}
        for label in self.labels:
            label_counts[label] = label_counts.get(label, 0) + 1

        for label, count in sorted(label_counts.items()):
            class_name = CASME2_CLASSES[label]
            percentage = (count / len(self.labels)) * 100
            print(f"  {class_name}: {count} samples ({percentage:.1f}%)")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path)

        if image.mode != 'L':
            image = image.convert('L')

        if self.transform:
            image = self.transform(image)

        return image, self.labels[idx], self.filenames[idx]

os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/training_logs", exist_ok=True)
os.makedirs(f"{RESULTS_ROOT}/evaluation_results", exist_ok=True)

TRAIN_PATH = DATASET_ROOT
VAL_PATH = DATASET_ROOT
TEST_PATH = DATASET_ROOT

print(f"\nDataset paths:")
print(f"Root: {DATASET_ROOT}")
print(f"  Train: {TRAIN_PATH}/train")
print(f"  Validation: {VAL_PATH}/val")
print(f"  Test: {TEST_PATH}/test")

print("\nConvNeXT-Tiny CASME II architecture validation...")

try:
    test_model = ConvNeXTCASME2Baseline(num_classes=7, dropout_rate=0.3, in_channels=1).to(device)
    test_input = torch.randn(1, 1, 224, 224).to(device)
    test_output = test_model(test_input)

    print(f"Validation successful: Output shape {test_output.shape}")
    print(f"Expected output shape: [1, 7] for CASME II 7 classes")
    print(f"Pure grayscale (1 channel) architecture validated")

    del test_model, test_input, test_output
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Validation failed: {e}")

def create_criterion_casme2(weights, use_focal_loss=False, alpha_weights=None, gamma=2.0):
    if use_focal_loss:
        print(f"Using Optimized Focal Loss with gamma={gamma}")
        if alpha_weights:
            print(f"Per-class alpha weights: {alpha_weights}")
            print(f"Alpha sum: {sum(alpha_weights):.3f}")
        return OptimizedFocalLoss(alpha=alpha_weights, gamma=gamma)
    else:
        print(f"Using CrossEntropy Loss with optimized class weights")
        print(f"Class weights: {weights.cpu().numpy()}")
        return nn.CrossEntropyLoss(weight=weights)

GLOBAL_CONFIG_CASME2 = {
    'device': device,
    'batch_size': BATCH_SIZE,
    'num_workers': NUM_WORKERS,
    'num_classes': 7,
    'class_weights': class_weights,
    'class_names': CASME2_CLASSES,
    'class_to_idx': CLASS_TO_IDX,
    'transform_train': convnext_transform_train,
    'transform_val': convnext_transform_val,
    'convnext_config': CASME2_CONVNEXT_CONFIG,
    'checkpoint_root': CHECKPOINT_ROOT,
    'results_root': RESULTS_ROOT,
    'dataset_root': DATASET_ROOT,
    'train_path': TRAIN_PATH,
    'val_path': VAL_PATH,
    'test_path': TEST_PATH,
    'optimizer_scheduler_factory': create_optimizer_scheduler_casme2,
    'criterion_factory': create_criterion_casme2
}

print("\n" + "=" * 60)
print("CASME II CONVNEXT-TINY M2 MFS-PREP CONFIGURATION COMPLETE")
print("=" * 60)

print(f"Loss Configuration:")
if USE_FOCAL_LOSS:
    print(f"  Function: Optimized Focal Loss")
    print(f"  Gamma: {FOCAL_LOSS_GAMMA}")
    print(f"  Per-class Alpha: {FOCAL_LOSS_ALPHA_WEIGHTS}")
    print(f"  Alpha Sum: {sum(FOCAL_LOSS_ALPHA_WEIGHTS):.3f}")
else:
    print(f"  Function: CrossEntropy with Optimized Weights")
    print(f"  Class Weights: {CROSSENTROPY_CLASS_WEIGHTS}")

print(f"\nModel Configuration:")
print(f"  Architecture: ConvNeXT-Tiny")
print(f"  Parameters: ~28M (largest CNN baseline)")
print(f"  Input Resolution: 224x224 Pure Grayscale (1 channel)")
print(f"  Methodology: M2 (Face-aware preprocessing)")
print(f"  Training Strategy: From Scratch (No Pretrained Weights)")
print(f"  First Conv Layer: Modified to 1-channel input")

print(f"\nDataset Configuration:")
print(f"  Version: {CASME2_CONVNEXT_CONFIG['dataset_version']}")
print(f"  Frame strategy: {CASME2_CONVNEXT_CONFIG['frame_strategy']}")
print(f"  Train augmentation: {CASME2_CONVNEXT_CONFIG['train_augmentation']}")
print(f"  Classes: {len(CASME2_CLASSES)}")
print(f"  Train samples: {preprocessing_info['splits']['train']['total_images']} frames")

print("\nNext: Cell 2 - Dataset Loading and Training Pipeline")

CASME II CNN BASELINE - ConvNeXT-Tiny M2 MFS-PREP

[1] Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully

[2] Importing required libraries...
CASME II ConvNeXT-Tiny M2 MFS-PREP - Infrastructure Configuration
Loading CASME II v9 preprocessing metadata...
Dataset variant: MFS
Processing date: 2025-10-19T08:20:12.098301
Preprocessing method: face_bbox_expansion_all_directions
Total images processed: 2774
Face detection rate: 100.00%
Target size: 224x224px
BBox expansion: 20px
Image format: Grayscale (1 channel)

Dataset split information:
  Train: 2613 frames
  Validation: 78 frames
  Test: 83 frames

EXPERIMENT CONFIGURATION - CNN M2 MFS-PREP
Model: ConvNeXT-Tiny (TIMM)
Methodology: M2 (Face-Aware Preprocessing)
Input Resolution: 224x224 Pure Grayscale (1 channel)
Training Strategy: From Scratch (No Pretrained Weights)
Preprocessing: Face detection + crop + grayscale
Loss Function: Focal Loss
  Gamma: 2.5
  Alpha Weights: [0.053, 0.067, 0.094, 0.102, 0.

In [2]:
# @title Cell 2: CASME II ConvNeXT-Tiny M2 MFS-PREP Training Pipeline

# File: 09_03_ConvNeXT_CASME2_MFS_PREP_Cell2.py
# Location: experiments/09_03_ConvNeXT_CASME2-MFS-PREP.ipynb
# Purpose: Training loop with RAM caching, robust checkpointing, and comprehensive logging

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import f1_score, precision_recall_fscore_support
import time
import json
import os
import shutil
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from PIL import Image

print("CASME II ConvNeXT-Tiny M2 MFS-PREP Training Pipeline")
print("=" * 60)

class RAMCachedDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, preload=True, num_workers=32):
        self.base_dataset = base_dataset
        self.transform = base_dataset.transform
        self.cached_images = {}
        self.labels = base_dataset.labels
        self.filenames = base_dataset.filenames

        if preload:
            print(f"Preloading {len(base_dataset)} images to RAM using {num_workers} workers...")
            start_time = time.time()

            with ThreadPoolExecutor(max_workers=num_workers) as executor:
                futures = []
                for idx in range(len(base_dataset)):
                    future = executor.submit(self._load_image, idx)
                    futures.append((idx, future))

                for idx, future in futures:
                    self.cached_images[idx] = future.result()

            elapsed = time.time() - start_time

            sample_image = list(self.cached_images.values())[0]
            if isinstance(sample_image, Image.Image):
                image_array = np.array(sample_image)
                bytes_per_image = image_array.nbytes
            else:
                bytes_per_image = sample_image.size[0] * sample_image.size[1]

            total_memory_mb = (bytes_per_image * len(self.cached_images)) / (1024 * 1024)

            print(f"Preloading completed in {elapsed:.2f}s")
            print(f"Cached {len(self.cached_images)} images (~{total_memory_mb:.2f} MB)")

    def _load_image(self, idx):
        image_path = self.base_dataset.images[idx]
        image = Image.open(image_path)

        if image.mode != 'L':
            image = image.convert('L')

        return image

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        if idx in self.cached_images:
            image = self.cached_images[idx].copy()
        else:
            image = self._load_image(idx)

        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        filename = self.filenames[idx]

        return image, label, filename

print("\nInitializing datasets...")
train_dataset_base = CASME2Dataset(
    dataset_root=GLOBAL_CONFIG_CASME2['dataset_root'],
    split='train',
    transform=None
)

val_dataset_base = CASME2Dataset(
    dataset_root=GLOBAL_CONFIG_CASME2['dataset_root'],
    split='val',
    transform=None
)

print("\nCreating RAM-cached datasets...")
train_dataset = RAMCachedDataset(
    train_dataset_base,
    preload=True,
    num_workers=RAM_PRELOAD_WORKERS
)
train_dataset.transform = GLOBAL_CONFIG_CASME2['transform_train']

val_dataset = RAMCachedDataset(
    val_dataset_base,
    preload=True,
    num_workers=RAM_PRELOAD_WORKERS
)
val_dataset.transform = GLOBAL_CONFIG_CASME2['transform_val']

print("\nCreating data loaders...")
train_loader = DataLoader(
    train_dataset,
    batch_size=GLOBAL_CONFIG_CASME2['batch_size'],
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=GLOBAL_CONFIG_CASME2['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

print("\nInitializing ConvNeXT-Tiny model...")
model = ConvNeXTCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_CONVNEXT_CONFIG['dropout_rate'],
    in_channels=1
).to(GLOBAL_CONFIG_CASME2['device'])

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params / 1e6:.1f}M parameters")

criterion = GLOBAL_CONFIG_CASME2['criterion_factory'](
    weights=GLOBAL_CONFIG_CASME2['class_weights'],
    use_focal_loss=CASME2_CONVNEXT_CONFIG['use_focal_loss'],
    alpha_weights=CASME2_CONVNEXT_CONFIG['focal_loss_alpha_weights'],
    gamma=CASME2_CONVNEXT_CONFIG['focal_loss_gamma']
)

optimizer, scheduler = GLOBAL_CONFIG_CASME2['optimizer_scheduler_factory'](
    model, CASME2_CONVNEXT_CONFIG
)

print(f"\nOptimizer: AdamW")
print(f"Learning rate: {CASME2_CONVNEXT_CONFIG['learning_rate']}")
print(f"Weight decay: {CASME2_CONVNEXT_CONFIG['weight_decay']}")

def train_one_epoch_casme2(model, train_loader, criterion, optimizer, device, epoch, gradient_clip=1.0):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    for batch_idx, (images, labels, _) in enumerate(train_loader):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()

        if gradient_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)

        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(train_loader)

    accuracy = np.mean(np.array(all_preds) == np.array(all_labels))

    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    return epoch_loss, accuracy, macro_f1

def validate_casme2(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels, _ in val_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss = running_loss / len(val_loader)

    accuracy = np.mean(np.array(all_preds) == np.array(all_labels))

    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    precision, recall, f1_per_class, support = precision_recall_fscore_support(
        all_labels, all_preds, average=None, zero_division=0
    )

    return val_loss, accuracy, macro_f1, precision, recall, f1_per_class, support

def save_checkpoint_atomic_casme2(state, filepath, max_retries=3):
    temp_filepath = filepath + '.tmp'

    for attempt in range(max_retries):
        try:
            torch.save(state, temp_filepath)

            if os.path.exists(temp_filepath):
                file_size = os.path.getsize(temp_filepath)
                if file_size > 1000:
                    shutil.move(temp_filepath, filepath)
                    return True
                else:
                    print(f"Warning: Checkpoint file too small ({file_size} bytes), retrying...")

        except Exception as e:
            print(f"Save attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)
            else:
                print(f"ERROR: Failed to save checkpoint after {max_retries} attempts")
                if os.path.exists(temp_filepath):
                    os.remove(temp_filepath)
                return False

    return False

training_history = {
    'train_loss': [],
    'train_accuracy': [],
    'train_f1_macro': [],
    'val_loss': [],
    'val_accuracy': [],
    'val_f1_macro': [],
    'val_f1_per_class': [],
    'learning_rates': [],
    'epoch_times': []
}

best_f1 = 0.0
best_loss = float('inf')
best_accuracy = 0.0

checkpoint_filename = 'casme2_convnext_mfs_prep_best_f1.pth'
checkpoint_path = os.path.join(GLOBAL_CONFIG_CASME2['checkpoint_root'], checkpoint_filename)

print("\n" + "=" * 60)
print("STARTING TRAINING - CONVNEXT-TINY M2 MFS-PREP")
print("=" * 60)
print(f"Total epochs: {CASME2_CONVNEXT_CONFIG['num_epochs']}")
print(f"Batch size: {CASME2_CONVNEXT_CONFIG['batch_size']}")
print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Training strategy: From Scratch (No Pretrained Weights)")
print(f"Input: 224x224 Pure Grayscale (1 channel)")
print(f"Checkpoint: {checkpoint_filename}")
print("=" * 60)

for epoch in range(CASME2_CONVNEXT_CONFIG['num_epochs']):
    epoch_start_time = time.time()

    train_loss, train_acc, train_f1 = train_one_epoch_casme2(
        model, train_loader, criterion, optimizer,
        GLOBAL_CONFIG_CASME2['device'], epoch,
        gradient_clip=CASME2_CONVNEXT_CONFIG['gradient_clip']
    )

    val_loss, val_acc, val_f1, val_precision, val_recall, val_f1_per_class, val_support = validate_casme2(
        model, val_loader, criterion, GLOBAL_CONFIG_CASME2['device']
    )

    current_lr = optimizer.param_groups[0]['lr']

    if scheduler is not None:
        if CASME2_CONVNEXT_CONFIG['scheduler_type'] == 'plateau':
            scheduler.step(val_f1)

    epoch_time = time.time() - epoch_start_time

    training_history['train_loss'].append(float(train_loss))
    training_history['train_accuracy'].append(float(train_acc))
    training_history['train_f1_macro'].append(float(train_f1))
    training_history['val_loss'].append(float(val_loss))
    training_history['val_accuracy'].append(float(val_acc))
    training_history['val_f1_macro'].append(float(val_f1))
    training_history['val_f1_per_class'].append([float(f1) for f1 in val_f1_per_class])
    training_history['learning_rates'].append(float(current_lr))
    training_history['epoch_times'].append(float(epoch_time))

    print(f"\nEpoch [{epoch+1}/{CASME2_CONVNEXT_CONFIG['num_epochs']}] - {epoch_time:.2f}s")
    print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
    print(f"LR: {current_lr:.2e}")

    should_save = False
    save_reason = ""

    if val_f1 > best_f1:
        should_save = True
        save_reason = f"F1 improved: {best_f1:.4f} -> {val_f1:.4f}"
        best_f1 = val_f1
    elif val_f1 == best_f1 and val_loss < best_loss:
        should_save = True
        save_reason = f"F1 tied, Loss improved: {best_loss:.4f} -> {val_loss:.4f}"
        best_loss = val_loss
    elif val_f1 == best_f1 and val_loss == best_loss and val_acc > best_accuracy:
        should_save = True
        save_reason = f"F1 & Loss tied, Acc improved: {best_accuracy:.4f} -> {val_acc:.4f}"
        best_accuracy = val_acc

    if should_save:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'train_loss': train_loss,
            'train_accuracy': train_acc,
            'train_f1_macro': train_f1,
            'val_loss': val_loss,
            'val_accuracy': val_acc,
            'val_f1_macro': val_f1,
            'val_f1_per_class': val_f1_per_class.tolist(),
            'val_precision': val_precision.tolist(),
            'val_recall': val_recall.tolist(),
            'val_support': val_support.tolist(),
            'best_f1': best_f1,
            'best_loss': best_loss,
            'best_accuracy': best_accuracy,
            'config': CASME2_CONVNEXT_CONFIG,
            'class_names': GLOBAL_CONFIG_CASME2['class_names']
        }

        success = save_checkpoint_atomic_casme2(checkpoint, checkpoint_path)

        if success:
            print(f"✓ Checkpoint saved: {save_reason}")
        else:
            print(f"✗ Failed to save checkpoint")

print("\n" + "=" * 60)
print("TRAINING COMPLETED - CONVNEXT-TINY M2 MFS-PREP")
print("=" * 60)
print(f"Best validation F1: {best_f1:.4f}")
print(f"Best validation Loss: {best_loss:.4f}")
print(f"Best validation Accuracy: {best_accuracy:.4f}")
print(f"Total training time: {sum(training_history['epoch_times']):.2f}s")
print(f"Average epoch time: {np.mean(training_history['epoch_times']):.2f}s")
print(f"Checkpoint saved: {checkpoint_path}")

training_log_path = os.path.join(
    GLOBAL_CONFIG_CASME2['results_root'],
    'training_logs',
    'casme2_convnext_mfs_prep_training_history.json'
)

training_log_data = {
    'model_name': 'ConvNeXT-Tiny',
    'methodology': 'M2',
    'preprocessing': 'face_aware_preprocessing',
    'input_resolution': '224x224_pure_grayscale_1ch',
    'training_strategy': 'from_scratch',
    'dataset_version': 'v9',
    'training_config': {
        'num_epochs': CASME2_CONVNEXT_CONFIG['num_epochs'],
        'batch_size': CASME2_CONVNEXT_CONFIG['batch_size'],
        'learning_rate': CASME2_CONVNEXT_CONFIG['learning_rate'],
        'weight_decay': CASME2_CONVNEXT_CONFIG['weight_decay'],
        'gradient_clip': CASME2_CONVNEXT_CONFIG['gradient_clip'],
        'dropout_rate': CASME2_CONVNEXT_CONFIG['dropout_rate'],
        'use_focal_loss': CASME2_CONVNEXT_CONFIG['use_focal_loss'],
        'focal_loss_gamma': CASME2_CONVNEXT_CONFIG['focal_loss_gamma'],
        'focal_loss_alpha_weights': CASME2_CONVNEXT_CONFIG['focal_loss_alpha_weights']
    },
    'best_metrics': {
        'best_f1_macro': float(best_f1),
        'best_loss': float(best_loss),
        'best_accuracy': float(best_accuracy)
    },
    'training_history': training_history,
    'training_completed_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}

with open(training_log_path, 'w') as f:
    json.dump(training_log_data, f, indent=2)

print(f"\nTraining history saved: {training_log_path}")
print("\nNext: Cell 3 - Dual Test Evaluation (v7 AF + v8 KFS)")

CASME II ConvNeXT-Tiny M2 MFS-PREP Training Pipeline

Initializing datasets...
Loading train dataset from /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v9/train...
Found 2613 image files in directory
Loaded 2613 samples for train split
  others: 1027 samples (39.3%)
  disgust: 650 samples (24.9%)
  happiness: 325 samples (12.4%)
  repression: 273 samples (10.4%)
  surprise: 260 samples (10.0%)
  sadness: 65 samples (2.5%)
  fear: 13 samples (0.5%)
Loading val dataset from /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v9/val...
Found 78 image files in directory
Loaded 78 samples for val split
  others: 30 samples (38.5%)
  disgust: 18 samples (23.1%)
  happiness: 9 samples (11.5%)
  repression: 9 samples (11.5%)
  surprise: 6 samples (7.7%)
  sadness: 3 samples (3.8%)
  fear: 3 samples (3.8%)

Creating RAM-cached datasets...
Preloading 2613 images t

In [5]:
# @title Cell 3: CASME II ConvNeXT-Tiny M2 MFS-PREP Dual Test Evaluation

# File: 09_03_ConvNeXT_CASME2_MFS_PREP_Cell3.py
# Location: experiments/09_03_ConvNeXT_CASME2-MFS-PREP.ipynb
# Purpose: Comprehensive evaluation on v7 (AF) and v8 (KFS) test sets with late fusion

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import (
    f1_score, precision_recall_fscore_support, accuracy_score,
    confusion_matrix, roc_auc_score
)
import json
import os
from collections import defaultdict
from datetime import datetime

print("CASME II ConvNeXT-Tiny M2 MFS-PREP Dual Test Evaluation")
print("=" * 60)

TEST_V7_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/preprocessed_v7"
TEST_V8_ROOT = f"{PROJECT_ROOT}/datasets/processed_casme2/preprocessed_v8"

print(f"\nTest dataset paths:")
print(f"v7 (AF): {TEST_V7_ROOT}")
print(f"v8 (KFS): {TEST_V8_ROOT}")

if not os.path.exists(TEST_V7_ROOT):
    raise FileNotFoundError(f"v7 test dataset not found: {TEST_V7_ROOT}")
if not os.path.exists(TEST_V8_ROOT):
    raise FileNotFoundError(f"v8 test dataset not found: {TEST_V8_ROOT}")

checkpoint_path = os.path.join(GLOBAL_CONFIG_CASME2['checkpoint_root'], 'casme2_convnext_mfs_prep_best_f1.pth')

if not os.path.exists(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

print(f"\nLoading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=GLOBAL_CONFIG_CASME2['device'], weights_only=False)

print(f"Checkpoint from epoch: {checkpoint['epoch']}")
print(f"Validation F1: {checkpoint['val_f1_macro']:.4f}")
print(f"Validation Accuracy: {checkpoint['val_accuracy']:.4f}")

model = ConvNeXTCASME2Baseline(
    num_classes=GLOBAL_CONFIG_CASME2['num_classes'],
    dropout_rate=CASME2_CONVNEXT_CONFIG['dropout_rate'],
    in_channels=1
).to(GLOBAL_CONFIG_CASME2['device'])

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print("Model loaded successfully and set to evaluation mode")

def extract_video_id(filename):
    parts = filename.rsplit('_', 1)
    if len(parts) == 2:
        video_id = parts[0]
        return video_id
    return filename.rsplit('.', 1)[0]

def evaluate_test_set_casme2(model, test_dataset, test_loader, device, class_names, test_version='v7'):
    model.eval()

    all_preds = []
    all_labels = []
    all_probs = []
    all_filenames = []

    print(f"\nEvaluating {test_version.upper()} test set...")

    with torch.no_grad():
        for images, labels, filenames in test_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_filenames.extend(filenames)

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    accuracy = accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    precision, recall, f1_per_class, support = precision_recall_fscore_support(
        all_labels, all_preds, average=None, zero_division=0
    )

    cm = confusion_matrix(all_labels, all_preds)

    present_classes = np.unique(all_labels)

    try:
        if len(present_classes) > 1:
            auc_scores = []
            for i in range(len(class_names)):
                if i in present_classes:
                    binary_labels = (all_labels == i).astype(int)
                    binary_probs = all_probs[:, i]

                    if len(np.unique(binary_labels)) > 1:
                        auc = roc_auc_score(binary_labels, binary_probs)
                        auc_scores.append(auc)
                    else:
                        auc_scores.append(0.0)
                else:
                    auc_scores.append(0.0)
        else:
            auc_scores = [0.0] * len(class_names)
    except Exception as e:
        print(f"Warning: Could not calculate AUC scores: {e}")
        auc_scores = [0.0] * len(class_names)

    print(f"\n{test_version.upper()} Results:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Macro F1: {macro_f1:.4f}")
    print(f"  Weighted F1: {weighted_f1:.4f}")

    results = {
        'test_configuration': {
            'version': test_version,
            'variant': 'AF' if test_version == 'v7' else 'KFS',
            'description': 'Apex Frame' if test_version == 'v7' else 'Key Frame Sequence',
            'total_samples': len(all_labels)
        },
        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_f1': float(macro_f1),
            'weighted_f1': float(weighted_f1)
        },
        'per_class_performance': {},
        'confusion_matrix': cm.tolist(),
        'predictions': {
            'filenames': all_filenames,
            'true_labels': all_labels.tolist(),
            'predicted_labels': all_preds.tolist(),
            'prediction_probabilities': all_probs.tolist()
        },
        'evaluation_metadata': {
            'model': 'ConvNeXT-Tiny',
            'methodology': 'M2',
            'preprocessing': 'face_aware_preprocessing',
            'input_resolution': '224x224 Pure Grayscale (1 channel)',
            'training_strategy': 'from_scratch',
            'class_names': class_names,
            'checkpoint_epoch': checkpoint['epoch'],
            'evaluated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        }
    }

    missing_classes = []
    for i, class_name in enumerate(class_names):
        if i not in present_classes:
            missing_classes.append(class_name)
            results['per_class_performance'][class_name] = {
                'precision': 0.0,
                'recall': 0.0,
                'f1_score': 0.0,
                'support': 0,
                'auc': 0.0,
                'missing': True
            }
        else:
            results['per_class_performance'][class_name] = {
                'precision': float(precision[i]),
                'recall': float(recall[i]),
                'f1_score': float(f1_per_class[i]),
                'support': int(support[i]),
                'auc': float(auc_scores[i]),
                'missing': False
            }

    if missing_classes:
        print(f"  Missing classes: {missing_classes}")
        results['evaluation_metadata']['missing_classes'] = missing_classes
    else:
        results['evaluation_metadata']['missing_classes'] = []

    wrong_predictions = []
    for i, (true_label, pred_label, filename) in enumerate(zip(all_labels, all_preds, all_filenames)):
        if true_label != pred_label:
            wrong_predictions.append({
                'filename': filename,
                'true_class': class_names[true_label],
                'predicted_class': class_names[pred_label],
                'true_label_idx': int(true_label),
                'predicted_label_idx': int(pred_label),
                'prediction_confidence': float(all_probs[i][pred_label]),
                'true_class_probability': float(all_probs[i][true_label])
            })

    results['wrong_predictions'] = {
        'count': len(wrong_predictions),
        'percentage': (len(wrong_predictions) / len(all_labels)) * 100,
        'details': wrong_predictions
    }

    print(f"  Wrong predictions: {len(wrong_predictions)}/{len(all_labels)} ({results['wrong_predictions']['percentage']:.1f}%)")

    return results

def evaluate_with_late_fusion_casme2(model, test_dataset, test_loader, device, class_names, test_version='v8'):
    model.eval()

    video_data = defaultdict(lambda: {'probs': [], 'true_label': None, 'filenames': []})

    print(f"\nCollecting frame predictions for {test_version.upper()} late fusion...")

    with torch.no_grad():
        for images, labels, filenames in test_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)

            probs_np = probs.cpu().numpy()
            labels_np = labels.cpu().numpy()

            for prob, label, filename in zip(probs_np, labels_np, filenames):
                video_id = extract_video_id(filename)
                video_data[video_id]['probs'].append(prob)
                video_data[video_id]['true_label'] = label
                video_data[video_id]['filenames'].append(filename)

    print(f"Collected {sum(len(v['probs']) for v in video_data.values())} frames from {len(video_data)} videos")

    all_preds = []
    all_labels = []
    all_probs = []
    all_video_ids = []

    for video_id, data in sorted(video_data.items()):
        avg_probs = np.mean(data['probs'], axis=0)
        pred_label = np.argmax(avg_probs)
        true_label = data['true_label']

        all_preds.append(pred_label)
        all_labels.append(true_label)
        all_probs.append(avg_probs)
        all_video_ids.append(video_id)

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    accuracy = accuracy_score(all_labels, all_preds)
    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    precision, recall, f1_per_class, support = precision_recall_fscore_support(
        all_labels, all_preds, average=None, zero_division=0
    )

    cm = confusion_matrix(all_labels, all_preds)

    present_classes = np.unique(all_labels)

    try:
        if len(present_classes) > 1:
            auc_scores = []
            for i in range(len(class_names)):
                if i in present_classes:
                    binary_labels = (all_labels == i).astype(int)
                    binary_probs = all_probs[:, i]

                    if len(np.unique(binary_labels)) > 1:
                        auc = roc_auc_score(binary_labels, binary_probs)
                        auc_scores.append(auc)
                    else:
                        auc_scores.append(0.0)
                else:
                    auc_scores.append(0.0)
        else:
            auc_scores = [0.0] * len(class_names)
    except Exception as e:
        print(f"Warning: Could not calculate AUC scores: {e}")
        auc_scores = [0.0] * len(class_names)

    print(f"\n{test_version.upper()} Late Fusion Results:")
    print(f"  Videos evaluated: {len(all_video_ids)}")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Macro F1: {macro_f1:.4f}")
    print(f"  Weighted F1: {weighted_f1:.4f}")

    results = {
        'test_configuration': {
            'version': test_version,
            'variant': 'KFS',
            'description': 'Key Frame Sequence with Late Fusion',
            'fusion_method': 'average_probability',
            'total_videos': len(all_video_ids),
            'total_frames': sum(len(v['probs']) for v in video_data.values())
        },
        'overall_performance': {
            'accuracy': float(accuracy),
            'macro_f1': float(macro_f1),
            'weighted_f1': float(weighted_f1)
        },
        'per_class_performance': {},
        'confusion_matrix': cm.tolist(),
        'video_predictions': {
            'video_ids': all_video_ids,
            'true_labels': all_labels.tolist(),
            'predicted_labels': all_preds.tolist(),
            'prediction_probabilities': all_probs.tolist()
        },
        'evaluation_metadata': {
            'model': 'ConvNeXT-Tiny',
            'methodology': 'M2',
            'preprocessing': 'face_aware_preprocessing',
            'input_resolution': '224x224 Pure Grayscale (1 channel)',
            'training_strategy': 'from_scratch',
            'class_names': class_names,
            'checkpoint_epoch': checkpoint['epoch'],
            'evaluated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        }
    }

    missing_classes = []
    for i, class_name in enumerate(class_names):
        if i not in present_classes:
            missing_classes.append(class_name)
            results['per_class_performance'][class_name] = {
                'precision': 0.0,
                'recall': 0.0,
                'f1_score': 0.0,
                'support': 0,
                'auc': 0.0,
                'missing': True
            }
        else:
            results['per_class_performance'][class_name] = {
                'precision': float(precision[i]),
                'recall': float(recall[i]),
                'f1_score': float(f1_per_class[i]),
                'support': int(support[i]),
                'auc': float(auc_scores[i]),
                'missing': False
            }

    if missing_classes:
        print(f"  Missing classes: {missing_classes}")
        results['evaluation_metadata']['missing_classes'] = missing_classes
    else:
        results['evaluation_metadata']['missing_classes'] = []

    wrong_predictions = []
    for video_id, true_label, pred_label, probs in zip(all_video_ids, all_labels, all_preds, all_probs):
        if true_label != pred_label:
            wrong_predictions.append({
                'video_id': video_id,
                'true_class': class_names[true_label],
                'predicted_class': class_names[pred_label],
                'true_label_idx': int(true_label),
                'predicted_label_idx': int(pred_label),
                'prediction_confidence': float(probs[pred_label]),
                'true_class_probability': float(probs[true_label]),
                'num_frames': len(video_data[video_id]['probs'])
            })

    results['wrong_predictions'] = {
        'count': len(wrong_predictions),
        'percentage': (len(wrong_predictions) / len(all_video_ids)) * 100,
        'details': wrong_predictions
    }

    print(f"  Wrong predictions: {len(wrong_predictions)}/{len(all_video_ids)} ({results['wrong_predictions']['percentage']:.1f}%)")

    return results

print("\n" + "=" * 60)
print("EVALUATING V7 TEST SET (APEX FRAME)")
print("=" * 60)

test_v7_dataset = CASME2Dataset(
    dataset_root=TEST_V7_ROOT,
    split='test',
    transform=GLOBAL_CONFIG_CASME2['transform_val']
)

test_v7_loader = DataLoader(
    test_v7_dataset,
    batch_size=GLOBAL_CONFIG_CASME2['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

v7_results = evaluate_test_set_casme2(
    model, test_v7_dataset, test_v7_loader,
    GLOBAL_CONFIG_CASME2['device'],
    GLOBAL_CONFIG_CASME2['class_names'],
    test_version='v7'
)

print("\n" + "=" * 60)
print("EVALUATING V8 TEST SET (KEY FRAME SEQUENCE WITH LATE FUSION)")
print("=" * 60)

test_v8_dataset = CASME2Dataset(
    dataset_root=TEST_V8_ROOT,
    split='test',
    transform=GLOBAL_CONFIG_CASME2['transform_val']
)

test_v8_loader = DataLoader(
    test_v8_dataset,
    batch_size=GLOBAL_CONFIG_CASME2['batch_size'],
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

v8_results = evaluate_with_late_fusion_casme2(
    model, test_v8_dataset, test_v8_loader,
    GLOBAL_CONFIG_CASME2['device'],
    GLOBAL_CONFIG_CASME2['class_names'],
    test_version='v8'
)

print("\n" + "=" * 60)
print("DUAL TEST EVALUATION COMPLETED")
print("=" * 60)

eval_results_dir = os.path.join(GLOBAL_CONFIG_CASME2['results_root'], 'evaluation_results')
os.makedirs(eval_results_dir, exist_ok=True)

v7_output_path = os.path.join(eval_results_dir, 'casme2_convnext_mfs_prep_evaluation_results_v7.json')
with open(v7_output_path, 'w') as f:
    json.dump(v7_results, f, indent=2)
print(f"\nv7 results saved: {v7_output_path}")

v8_output_path = os.path.join(eval_results_dir, 'casme2_convnext_mfs_prep_evaluation_results_v8.json')
with open(v8_output_path, 'w') as f:
    json.dump(v8_results, f, indent=2)
print(f"v8 results saved: {v8_output_path}")

print("\nComparative Summary:")
print(f"v7 (AF)  - Macro F1: {v7_results['overall_performance']['macro_f1']:.4f}, Acc: {v7_results['overall_performance']['accuracy']:.4f}")
print(f"v8 (KFS) - Macro F1: {v8_results['overall_performance']['macro_f1']:.4f}, Acc: {v8_results['overall_performance']['accuracy']:.4f}")

v8_v7_delta = v8_results['overall_performance']['macro_f1'] - v7_results['overall_performance']['macro_f1']
print(f"Delta (v8 - v7): {v8_v7_delta:+.4f}")

if v8_v7_delta > 0:
    improvement_pct = (v8_v7_delta / v7_results['overall_performance']['macro_f1']) * 100
    print(f"KFS improves by {improvement_pct:.1f}% over AF")
else:
    degradation_pct = (abs(v8_v7_delta) / v8_results['overall_performance']['macro_f1']) * 100
    print(f"KFS degrades by {degradation_pct:.1f}% from AF")

print("\nNext: Cell 4 - Confusion Matrix Generation")

CASME II ConvNeXT-Tiny M2 MFS-PREP Dual Test Evaluation

Test dataset paths:
v7 (AF): /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v7
v8 (KFS): /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v8

Loading checkpoint: /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/models/09_03_convnext_casme2_mfs_prep/casme2_convnext_mfs_prep_best_f1.pth
Checkpoint from epoch: 28
Validation F1: 0.2473
Validation Accuracy: 0.3333
ConvNeXT-Tiny feature dimension: 768
Training from scratch with 1-channel input
ConvNeXT CASME II: 768 -> 512 -> 128 -> 7
Architecture: Pure grayscale (1ch) from scratch
Model loaded successfully and set to evaluation mode

EVALUATING V7 TEST SET (APEX FRAME)
Loading test dataset from /content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project/datasets/processed_casme2/preprocessed_v7/test...
Foun

In [6]:
# @title Cell 4: CASME II ConvNeXT-Tiny M2 MFS-PREP Confusion Matrix Generation

# File: 09_03_ConvNeXT_CASME2_MFS_PREP_Cell4.py
# Location: experiments/09_03_ConvNeXT_CASME2-MFS-PREP.ipynb
# Purpose: Generate professional confusion matrix visualizations for v7 (AF) and v8 (KFS) test sets

import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from datetime import datetime

print("CASME II ConvNeXT-Tiny M2 MFS-PREP Confusion Matrix Generation")
print("=" * 60)

PROJECT_ROOT = "/content/drive/MyDrive/RESEARCH-WORKSPACE/ACTIVE-PROJECTS/Thesis_MER_Project"
RESULTS_ROOT = f"{PROJECT_ROOT}/results/09_03_convnext_casme2_mfs_prep"

def find_evaluation_json_files_casme2(results_path):
    json_files = {}

    eval_dir = f"{results_path}/evaluation_results"

    if os.path.exists(eval_dir):
        for version in ['v7', 'v8']:
            eval_pattern = f"{eval_dir}/casme2_convnext_mfs_prep_evaluation_results_{version}.json"
            eval_files = glob.glob(eval_pattern)

            if eval_files:
                json_files[f'main_{version}'] = eval_files[0]
                print(f"Found {version.upper()} evaluation file: {os.path.basename(eval_files[0])}")

        if not json_files:
            print(f"WARNING: No evaluation results found in {eval_dir}")
            print("Make sure Cell 3 (evaluation) has been executed first!")
    else:
        print(f"ERROR: Evaluation directory not found: {eval_dir}")

    return json_files

def load_evaluation_results_casme2(json_path):
    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
        print(f"Successfully loaded: {os.path.basename(json_path)}")
        return data
    except Exception as e:
        print(f"ERROR loading {json_path}: {str(e)}")
        return None

def calculate_weighted_f1_casme2(per_class_performance):
    total_support = sum([class_data['support'] for class_data in per_class_performance.values()
                        if class_data['support'] > 0])

    if total_support == 0:
        return 0.0

    weighted_f1 = 0.0

    for class_name, class_data in per_class_performance.items():
        if class_data['support'] > 0:
            weight = class_data['support'] / total_support
            weighted_f1 += class_data['f1_score'] * weight

    return weighted_f1

def calculate_balanced_accuracy_casme2(confusion_matrix):
    cm = np.array(confusion_matrix)
    n_classes = cm.shape[0]

    per_class_balanced_acc = []

    classes_with_samples = []
    for i in range(n_classes):
        if cm[i, :].sum() > 0:
            classes_with_samples.append(i)

    for i in classes_with_samples:
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        tn = cm.sum() - tp - fn - fp

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

        class_balanced_acc = (sensitivity + specificity) / 2
        per_class_balanced_acc.append(class_balanced_acc)

    balanced_acc = np.mean(per_class_balanced_acc) if per_class_balanced_acc else 0.0

    return balanced_acc

def determine_text_color_casme2(color_value, threshold=0.5):
    return 'white' if color_value > threshold else 'black'

def create_confusion_matrix_plot_casme2(data, output_path, test_version):
    meta = data['evaluation_metadata']
    class_names = meta['class_names']
    cm = np.array(data['confusion_matrix'], dtype=int)
    overall = data['overall_performance']
    per_class = data['per_class_performance']
    test_config = data.get('test_configuration', {})

    test_desc = test_config.get('description', test_version)
    variant = test_config.get('variant', test_version.upper())
    methodology = meta.get('methodology', 'M2')
    input_res = meta.get('input_resolution', '224x224 Pure Grayscale (1 channel)')
    training_strategy = meta.get('training_strategy', 'from_scratch')

    print(f"Processing confusion matrix for {test_version.upper()}")
    print(f"Confusion matrix shape: {cm.shape}")

    macro_f1 = overall.get('macro_f1', 0.0)
    accuracy = overall.get('accuracy', 0.0)
    weighted_f1 = calculate_weighted_f1_casme2(per_class)
    balanced_acc = calculate_balanced_accuracy_casme2(cm)

    print(f"Metrics - Macro F1: {macro_f1:.4f}, Weighted F1: {weighted_f1:.4f}, "
          f"Acc: {accuracy:.4f}, Balanced Acc: {balanced_acc:.4f}")

    row_sums = cm.sum(axis=1, keepdims=True)
    with np.errstate(divide='ignore', invalid='ignore'):
        cm_pct = np.divide(cm, row_sums, where=(row_sums!=0))
        cm_pct = np.nan_to_num(cm_pct)

    fig, ax = plt.subplots(figsize=(12, 10))

    cmap = 'Blues'

    im = ax.imshow(cm_pct, interpolation='nearest', cmap=cmap, vmin=0.0, vmax=0.8)

    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('True Class Percentage', rotation=270, labelpad=15, fontsize=11)

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            count = cm[i, j]

            if row_sums[i, 0] > 0:
                percentage = cm_pct[i, j] * 100
                text = f"{count}\n{percentage:.1f}%"
            else:
                text = f"{count}\n(N/A)"

            cell_value = cm_pct[i, j]
            text_color = determine_text_color_casme2(cell_value, threshold=0.4)

            ax.text(j, i, text, ha="center", va="center",
                   color=text_color, fontsize=9, fontweight='bold')

    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right', fontsize=10)
    ax.set_yticklabels(class_names, fontsize=10)
    ax.set_xlabel("Predicted Label", fontsize=12, fontweight='bold')
    ax.set_ylabel("True Label", fontsize=12, fontweight='bold')

    missing_classes = meta.get('missing_classes', [])
    note_text = f"Test: {test_desc} ({variant})\n{methodology} | {input_res}\nTraining: {training_strategy}"
    if missing_classes:
        note_text += f"\nMissing: {', '.join(missing_classes)}"

    ax.text(0.02, 0.98, note_text, transform=ax.transAxes, fontsize=9,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

    title = f"CASME II ConvNeXT-Tiny M2 MFS-PREP - {variant}\n"
    title += f"Macro F1: {macro_f1:.4f}  |  Weighted F1: {weighted_f1:.4f}  |  Acc: {accuracy:.4f}  |  Balanced Acc: {balanced_acc:.4f}"
    ax.set_title(title, fontsize=12, pad=25, fontweight='bold')

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    print(f"Confusion matrix saved to: {os.path.basename(output_path)}")

    return {
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'missing_classes': missing_classes,
        'test_version': test_version,
        'variant': variant
    }

json_files = find_evaluation_json_files_casme2(RESULTS_ROOT)

if not json_files:
    print(f"ERROR: No evaluation JSON files found in {RESULTS_ROOT}")
    print("Make sure Cell 3 (evaluation) has been executed first!")
else:
    print(f"\nFound {len([k for k in json_files.keys() if k.startswith('main_')])} evaluation result(s)")

output_dir = f"{RESULTS_ROOT}/confusion_matrix_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)

results_summary = {}
generated_files = []

for version in ['v7', 'v8']:
    main_key = f'main_{version}'

    if main_key in json_files:
        print(f"\n{'='*60}")
        print(f"Processing {version.upper()} Confusion Matrix")
        print(f"{'='*60}")

        eval_data = load_evaluation_results_casme2(json_files[main_key])

        if eval_data is not None:
            try:
                cm_output_path = os.path.join(output_dir, f"confusion_matrix_CASME2_ConvNeXT_MFS_PREP_{version.upper()}.png")
                metrics = create_confusion_matrix_plot_casme2(eval_data, cm_output_path, version)
                generated_files.append(cm_output_path)

                results_summary[version] = metrics

                print(f"SUCCESS: {version.upper()} confusion matrix generated")

            except Exception as e:
                print(f"ERROR: Failed to generate {version.upper()} confusion matrix: {str(e)}")
                import traceback
                traceback.print_exc()
        else:
            print(f"ERROR: Could not load {version.upper()} evaluation data")

if generated_files:
    print(f"\n" + "=" * 60)
    print("CASME II CONVNEXT-TINY M2 MFS-PREP CONFUSION MATRIX GENERATION COMPLETED")
    print("=" * 60)

    print(f"\nGenerated confusion matrix files:")
    for file_path in generated_files:
        filename = os.path.basename(file_path)
        print(f"  {filename}")

    print(f"\nPerformance Summary:")
    for version in ['v7', 'v8']:
        if version in results_summary:
            metrics = results_summary[version]
            variant = metrics.get('variant', version.upper())
            print(f"\n{variant}:")
            print(f"  Macro F1:       {metrics['macro_f1']:.4f}")
            print(f"  Weighted F1:    {metrics['weighted_f1']:.4f}")
            print(f"  Accuracy:       {metrics['accuracy']:.4f}")
            print(f"  Balanced Acc:   {metrics['balanced_accuracy']:.4f}")

            if metrics['missing_classes']:
                print(f"  Missing classes: {len(metrics['missing_classes'])}")

    if len(results_summary) == 2:
        print(f"\nComparative Analysis:")
        v8_f1 = results_summary['v8']['macro_f1']
        v7_f1 = results_summary['v7']['macro_f1']
        delta_f1 = v8_f1 - v7_f1

        v8_variant = results_summary['v8'].get('variant', 'KFS')
        v7_variant = results_summary['v7'].get('variant', 'AF')

        print(f"  {v8_variant} vs {v7_variant} (Macro F1): {v8_f1:.4f} vs {v7_f1:.4f}")
        print(f"  Delta ({v8_variant} - {v7_variant}): {delta_f1:+.4f}")

        if delta_f1 > 0:
            improvement_pct = (delta_f1 / v7_f1) * 100
            print(f"  {v8_variant} improves by {improvement_pct:.1f}% over {v7_variant}")
        else:
            degradation_pct = (abs(delta_f1) / v8_f1) * 100
            print(f"  {v8_variant} degrades by {degradation_pct:.1f}% from {v7_variant}")

    print(f"\nFiles saved in: {output_dir}")
    print(f"Analysis completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

else:
    print(f"\nERROR: No confusion matrices were generated")
    print("Please check:")
    print("1. Cell 3 evaluation results exist")
    print("2. JSON file structure is correct")
    print("3. No file permission issues")

print("\n" + "=" * 60)
print("CNN BASELINE M2 EXPERIMENTS COMPLETED")
print("=" * 60)
print("\nAll 3 CNN baseline models with M2 preprocessing have been successfully implemented:")
print("  1. MobileNetV3-Small M2 MFS-PREP (~2.5M params)")
print("  2. EfficientNet-B0 M2 MFS-PREP (~5.3M params)")
print("  3. ConvNeXT-Tiny M2 MFS-PREP (~28M params)")
print("\nEach model:")
print("  - Trained from scratch with pure grayscale (1-channel) input")
print("  - Evaluated on dual test sets (v7 AF + v8 KFS)")
print("  - Generated confusion matrices and comprehensive metrics")
print("\nNext steps:")
print("  - Comparative analysis across all 3 CNN models")
print("  - Cross-comparison with ViT baseline experiments (M1 vs M2)")
print("  - Preprocessing paradox validation on CNN architectures")

print("\nCell 4 completed - CASME II ConvNeXT-Tiny M2 MFS-PREP confusion matrix analysis generated")

CASME II ConvNeXT-Tiny M2 MFS-PREP Confusion Matrix Generation
Found V7 evaluation file: casme2_convnext_mfs_prep_evaluation_results_v7.json
Found V8 evaluation file: casme2_convnext_mfs_prep_evaluation_results_v8.json

Found 2 evaluation result(s)

Processing V7 Confusion Matrix
Successfully loaded: casme2_convnext_mfs_prep_evaluation_results_v7.json
Processing confusion matrix for V7
Confusion matrix shape: (6, 6)
Metrics - Macro F1: 0.3439, Weighted F1: 0.4229, Acc: 0.4286, Balanced Acc: 0.6182
Confusion matrix saved to: confusion_matrix_CASME2_ConvNeXT_MFS_PREP_V7.png
SUCCESS: V7 confusion matrix generated

Processing V8 Confusion Matrix
Successfully loaded: casme2_convnext_mfs_prep_evaluation_results_v8.json
Processing confusion matrix for V8
Confusion matrix shape: (6, 6)
Metrics - Macro F1: 0.3317, Weighted F1: 0.4172, Acc: 0.4286, Balanced Acc: 0.6117
Confusion matrix saved to: confusion_matrix_CASME2_ConvNeXT_MFS_PREP_V8.png
SUCCESS: V8 confusion matrix generated

CASME II CON