In [1]:
import torch
print(torch.__version__)
print(torch.version.cuda)

2.9.1+cu126
12.6


In [2]:
"""
Complete Training Script for Facial Emotion Recognition using ResNet-34
5 emotions: angry, disgust, happy, low affect (neutral & sad), arousal (fear & suprise)
"""

# ============================================================================
# CONFIGURATION
# ============================================================================

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision import transforms, datasets
from pathlib import Path
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score  #accuracy score used?
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
#import os
from tqdm import tqdm


# Paths - MODIFY THESE
DATA_DIR = Path(r"C:\Users\avyes\DrAIv2\emotion_pipeline\master_dataset") # Directory with 7 emotion folders


#Ensure path exists
assert DATA_DIR.exists(), f"DATA_DIR not found: {DATA_DIR}"

#Merge neutral/sad, fear/surprise
NEW_CLASSES = ["Angry", "Disgust", "Happy", "LowAffect", "Arousal"]

MERGE_MAP = {
    "Angry": "Angry",
    "Disgust": "Disgust",
    "Happy": "Happy",
    "Neutral": "LowAffect",
    "Sad": "LowAffect",
    "Fear": "Arousal",
    "Surprise": "Arousal",
}

# Hyperparameters
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3       # This name might be misleading bc of variable learning rate (check later)
NUM_WORKERS = 0            #set to 0 if you get issues on Windows
VAL_SPLIT = 0.15
TRAIN_SPLIT = 0.70
PATIENCE = 7  # For early stopping
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# For reproducibility
SEED = 42
import random
random.seed(SEED)

torch.manual_seed(SEED)
np.random.seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
class MergedImageFolder(Dataset):
    def __init__(self, root, merge_map, class_names, transform=None):
        self.base = datasets.ImageFolder(root=root)  # original folder labels
        self.merge_map = merge_map
        self.transform = transform

        # merged class interface (ImageFolder-like)
        self.classes = list(class_names)
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

        # original class index -> name
        idx_to_class = {v: k for k, v in self.base.class_to_idx.items()}

        # remap: old_idx -> new_idx
        self.remap = {}
        for old_idx, old_name in idx_to_class.items():
            if old_name not in self.merge_map:
                raise KeyError(f"Missing '{old_name}' in merge_map keys: {list(self.merge_map.keys())}")
            new_name = self.merge_map[old_name]
            if new_name not in self.class_to_idx:
                raise KeyError(f"merge_map maps '{old_name}' -> '{new_name}', but '{new_name}' not in class_names")
            self.remap[old_idx] = self.class_to_idx[new_name]

        # targets in merged label space (needed for weighting/splitting)
        self.targets = [self.remap[y] for y in self.base.targets]

        # optional but useful: ImageFolder-like samples in merged label space
        self.samples = [(path, self.remap[y]) for (path, y) in self.base.samples]

    def __len__(self):
        return len(self.base)

    def __getitem__(self, i):
        img, old_y = self.base[i]     # img is PIL image (transform not yet applied)
        y = self.remap[old_y]
        if self.transform:
            img = self.transform(img)
        return img, y


In [4]:
class FocalLoss(nn.Module):  #currently not used
    """
    Multi-class Focal Loss for logits.
    - inputs: logits of shape (N, C)
    - targets: int labels of shape (N,)
    Supports optional class weights and label smoothing.
    """
    def __init__(self, gamma=1.5, weight=None, label_smoothing=0.0, reduction="mean"):
        super().__init__()
        self.gamma = gamma
        self.weight = weight  # tensor shape (C,) on same device as inputs
        self.label_smoothing = label_smoothing
        self.reduction = reduction

    def forward(self, inputs, targets):
        # log_probs: (N, C)
        log_probs = F.log_softmax(inputs, dim=1)
        probs = log_probs.exp()

        n_classes = inputs.size(1)

        if self.label_smoothing > 0.0:
            # Smoothed one-hot targets: (N, C)
            with torch.no_grad():
                true_dist = torch.zeros_like(inputs)
                true_dist.fill_(self.label_smoothing / (n_classes - 1))
                true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - self.label_smoothing)

            # CE per sample: -sum(y * logp)
            ce = -(true_dist * log_probs).sum(dim=1)

            # p_t: prob assigned to the (smoothed) target distribution
            # Use expected probability under true_dist
            pt = (true_dist * probs).sum(dim=1).clamp(min=1e-8, max=1.0)

            if self.weight is not None:
                # Expected class weight under smoothed target distribution
                w = (true_dist * self.weight.unsqueeze(0)).sum(dim=1)
                ce = ce * w

        else:
            # Standard CE per sample using the true class index
            ce = F.nll_loss(log_probs, targets, weight=self.weight, reduction="none")
            pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1).clamp(min=1e-8, max=1.0)

        focal_factor = (1.0 - pt) ** self.gamma
        loss = focal_factor * ce

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss


In [5]:
class TransformedSubset(Dataset):   #obsolete, used for 7 emotion calsses version
    """
    A wrapper for a Dataset subset that applies a transform 
    to the samples without modifying the underlying dataset.
    """
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y                     #x is the feature, y is the label

    def __len__(self):
        return len(self.subset)

def get_data_loaders(data_dir, batch_size=32, train_split=TRAIN_SPLIT, val_split=VAL_SPLIT, num_workers=0, seed=SEED):  #used for 7 emotion classes version
    """
    Loads data from a single directory organized by class folders and 
    splits it into Train, Val, and Test sets.
    """
    #Ensure split makes sense
    assert 0 < train_split < 1
    assert 0 <= val_split < 1
    assert train_split + val_split < 1, "train_split + val_split must be < 1"
    
    IMG_SIZE = 224

    #Standard Resnet Normalization
    normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    # 1. Define separate transforms
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10), 
        transforms.ColorJitter(brightness=0.2, contrast=0.2),   #consider changing this for warmup stage
        transforms.ToTensor(),
        normalization
    ])

    val_test_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        normalization
    ])

    # 2. Load the entire dataset using ImageFolder
    # This automatically uses folder names as labels
    full_dataset = datasets.ImageFolder(root=data_dir)

    #check step 2 executed correctly
    print("Classes:", full_dataset.classes)
    print("Counts:", np.bincount(full_dataset.targets))

    # 3. Calculate the lengths for the split
    total_size = len(full_dataset)
    train_size = int(total_size*train_split)
    val_size = int(total_size*val_split)
    test_size = total_size - train_size - val_size

    # 4. Perform the random split
    train_raw, val_raw, test_raw = random_split(
        full_dataset, [train_size, val_size, test_size], 
        generator = torch.Generator().manual_seed(seed)
    )
    
   # 5. Wrap subsets with their respective transforms
    # This prevents the "leakage" where val/test get training augmentations
    train_subset = TransformedSubset(train_raw, transform=train_transform)
    val_subset = TransformedSubset(val_raw, transform=val_test_transform)
    test_subset = TransformedSubset(test_raw, transform=val_test_transform)
    
    # 6. Create Dataloaders
    # Note: pin_memory=True is helpful if you are using your CUDA GPU
    train_loader = DataLoader(
        train_subset, batch_size = batch_size, shuffle=True, 
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )
    val_loader = DataLoader(
        val_subset, batch_size = batch_size, shuffle = False, 
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )
    test_loader = DataLoader(
        test_subset, batch_size = batch_size, shuffle = False, 
        num_workers=num_workers, pin_memory=torch.cuda.is_available()
    )

    return train_loader, val_loader, test_loader, train_raw, val_raw, test_raw, full_dataset

def get_data_loaders_merged(data_dir, batch_size=32, train_split=0.7, val_split=0.15, num_workers=0, seed=SEED):
    IMG_SIZE = 224
    normalization = transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        normalization
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalization
    ])

    full_train_ds = MergedImageFolder(root=data_dir, merge_map=MERGE_MAP, class_names=NEW_CLASSES, transform=train_transform)
    full_val_ds = MergedImageFolder(root=data_dir, merge_map=MERGE_MAP, class_names=NEW_CLASSES, transform=val_transform)

    N = len(full_train_ds)
    train_size = int(N*train_split)
    val_size = int(N*val_split)
    test_size = N - train_size - val_size

    g = torch.Generator().manual_seed(seed)
    perm = torch.randperm(N, generator=g).tolist()

    train_idx = perm[:train_size]
    val_idx = perm[train_size:train_size + val_size]
    test_idx = perm[train_size + val_size:]
    
    train_ds = Subset(full_train_ds, train_idx)
    val_ds   = Subset(full_val_ds, val_idx)
    test_ds  = Subset(full_val_ds, test_idx)

    # split size sanity check
    assert len(train_idx) + len(val_idx) + len(test_idx) == N
    print("Split sizes:", len(train_idx), len(val_idx), len(test_idx))

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, 
                              pin_memory=torch.cuda.is_available())
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers,
                              pin_memory=torch.cuda.is_available())
    test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers,
                              pin_memory=torch.cuda.is_available())

    return train_loader, val_loader, test_loader, train_ds, val_ds, test_ds, full_train_ds, full_val_ds
    
        


In [6]:
from collections import Counter

def make_class_weighted_criterion(train_subset, full_dataset, device, use_focal=True,
                                  gamma=1.5, label_smoothing=0.0):
    """
    train_subset: the Subset you are training on (train_ds)
    full_dataset: the dataset that owns the labels (full_train_ds or full_val_ds)
    """

    #gaurd
    if not hasattr(train_subset, "indices"):
        raise TypeError("train_subset must be a torch.utils.data.Subset (needs .indices)")
    
    # train_subset.indices are indices into full_dataset
    train_labels = [full_dataset.targets[i] for i in train_subset.indices]
    counts = Counter(train_labels)

    num_classes = len(full_dataset.classes)
    class_counts = torch.tensor([counts.get(i, 0) for i in range(num_classes)], dtype=torch.float)
    class_counts = class_counts.clamp_min(1.0)           #prevent accidental division by 0

    weights = 1.0 / class_counts
    weights = weights / weights.mean() # normalize avg weight ~ 1
    weights = weights.to(device)

    #sanity check
    print("Class counts (train only):", class_counts.cpu().numpy())
    print("Weights:", weights.cpu().numpy())
    print("Classes:", full_dataset.classes)

    if use_focal:
        criterion = FocalLoss(gamma=gamma, weight=weights, label_smoothing=label_smoothing)   #use weight = None if using focal to prevent overcorrection
    else:
        # fallback: standard CE
        try:
            criterion = torch.nn.CrossEntropyLoss(weight=weights, label_smoothing=label_smoothing)
        except TypeError:
            criterion = torch.nn.CrossEntropyLoss(weight=weights)

    return criterion, weights

In [7]:
# ============================================================================
# MODEL CREATION
# ============================================================================
from torchvision.models import resnet34, ResNet34_Weights

# helper for manipulating dropout
def set_head_dropout(model, p: float):
     # Supports model.fc = nn.Sequential(Dropout, Linear, ...)
    if hasattr(model, "fc") and isinstance(model.fc, nn.Sequential):
        for m in model.fc:
            if isinstance(m, nn.Dropout):
                m.p = p

def create_resnet34_model(num_classes=5, pretrained=True, dropout_p=0.4):    #change dropout to 0.5 in case of overfitting
    """Create ResNet-34 for emotion classification"""
    
    # Load pretrained weights or None
    weights = ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
    model = resnet34(weights=weights)
    
    # Replace the classifier head with dropout + linear layer
    model.fc = nn.Sequential(
        nn.Dropout(dropout_p),                                        #randomly deactivates a subset of neurons during training to reduce overfitting (increase to 0.5?)
        nn.Linear(model.fc.in_features, num_classes)
    )
    
    return model
    
# =========================
# Freeze/unfreeze utilities
# =========================
def set_backbone_trainable(model, trainable: bool):
    # Everything except the final fc
    for name, param in model.named_parameters():
        if name.startswith("fc."):
            param.requires_grad = True
        else:
            param.requires_grad = trainable
            
# =======================================
# Build optimizer with discriminative LRs:
# =======================================
def build_optimizer(model, lr_backbone=1e-4, lr_head=5e-4, weight_decay=1e-4):
    backbone_params = []
    head_params = []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("fc."):
            head_params.append(p)
        else:
            backbone_params.append(p)
    
    return optim.AdamW(
        [
            {"params": backbone_params, "lr": lr_backbone},
            {"params": head_params, "lr": lr_head},
        ],
        weight_decay=weight_decay
    )

# =======================================
# Helper functions
# =======================================
def count_trainable_params(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return trainable, total

In [8]:
# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================

#flip_TTA helper function

#@torch.no_grad()
#def forward_tta(model, images, use_flip_tta: bool):
#    if not use_flip_tta:
#        return model(images)
#    
#    logits = model(images)
#    logits_flip = model(torch.flip(images, dims=[3])) # dim=3 is width (horizontal flip)
#    return 0.5 * (logits + logits_flip)

@torch.no_grad()
def forward_tta(model, images, use_flip_tta: bool):
    """
    Enhanced Test-Time Augmentation with multiple crops
    """
    if not use_flip_tta:
        return model(images)
    
    batch_size = images.size(0)
    all_logits = []
    
    # 1. Original image
    all_logits.append(model(images))
    
    # 2. Horizontal flip
    all_logits.append(model(torch.flip(images, dims=[3])))
    
    # 3. Five crops: top-left, top-right, bottom-left, bottom-right, center
    # Each crop is 200x200 from the 224x224 image, then resized back
    h, w = images.size(2), images.size(3)
    crop_size = int(0.9 * min(h, w))  # 90% of image size
    
    # Top-left
    crop_tl = images[:, :, :crop_size, :crop_size]
    crop_tl_resized = F.interpolate(crop_tl, size=(h, w), mode='bilinear', align_corners=False)
    all_logits.append(model(crop_tl_resized))
    
    # Top-right
    crop_tr = images[:, :, :crop_size, -crop_size:]
    crop_tr_resized = F.interpolate(crop_tr, size=(h, w), mode='bilinear', align_corners=False)
    all_logits.append(model(crop_tr_resized))
    
    # Bottom-left
    crop_bl = images[:, :, -crop_size:, :crop_size]
    crop_bl_resized = F.interpolate(crop_bl, size=(h, w), mode='bilinear', align_corners=False)
    all_logits.append(model(crop_bl_resized))
    
    # Bottom-right
    crop_br = images[:, :, -crop_size:, -crop_size:]
    crop_br_resized = F.interpolate(crop_br, size=(h, w), mode='bilinear', align_corners=False)
    all_logits.append(model(crop_br_resized))
    
    # Center crop
    margin = (h - crop_size) // 2
    crop_center = images[:, :, margin:margin+crop_size, margin:margin+crop_size]
    crop_center_resized = F.interpolate(crop_center, size=(h, w), mode='bilinear', align_corners=False)
    all_logits.append(model(crop_center_resized))
    
    # Average all predictions (7 total)
    return torch.stack(all_logits).mean(dim=0)

# Helper function 
def set_backbone_bn_eval(model):
    # put only backbone BN (BatchNorm) layers into eval mode
    for name, m in model.named_modules():
        if not name.startswith("fc.") and isinstance(m, nn.BatchNorm2d):
            m.eval()

def _mixup_data(x, y, alpha=0.2):                #Not currently used
    """
    Returns mixed inputs, paired targets, and lambda.
    If alpha <= 0, returns original inputs/targets.
    """
    if alpha <= 0:
        return x, y, None, 1.0

    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)

    mixed_x = lam * x + (1.0 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


    
def train_one_epoch(model, train_loader, criterion, optimizer,              #set mixup_alpha to 0.0 to disable
                    device, freeze_backbone_bn=False, mixup_alpha=0.0):     #criterion = the loss function (e.g. nn.CrossEntropyLoss)
    
    """Train for one epoch"""                                               #optimizer updates the model's parameters (torch.optim.Adam or SGD)
    model.train()

     # Freeze backbone BN only if requested (usually during head-only warmup)
    if freeze_backbone_bn:
        set_backbone_bn_eval(model)
    
    running_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc='Training')         #pbar = progress bar -- for visualization
    
    for images, labels in pbar:
        images = images.to(device, non_blocking = True) 
        labels = labels.to(device, non_blocking = True)    #Moves tensors to GPU or CPU
        
        optimizer.zero_grad(set_to_none=True)           #Before computing gradients for this batch, we reset previous gradients to zero.

        # Mixup
        images, y_a, y_b, lam = _mixup_data(images, labels, alpha=mixup_alpha)
        
        outputs = model(images)                         #feeds batch of images throught the network. Outputs tensor of shape (batch_size, num_classes),
                                                        #   each row containing raw logits (unnormalized scores) for each class
        # Loss
        if y_b is None:
            loss = criterion(outputs, y_a)              #y_a is original labels
        else:
            loss = lam * criterion(outputs, y_a) + (1.0 - lam) * criterion(outputs, y_b)
            
        #loss = criterion(outputs, labels)               #Computes loss, returns scalar loss (average over the batch)
        loss.backward()                                 #Back propagation
        optimizer.step()                                #Uses the gradients to update the model's parameters (adjusts weights to reduce loss)
        
        # Update loss
        running_loss += loss.item()                    #loss.item() converts the PyTorch scalar tensor to a Python float.
        num_batches += 1
        
        # "Approx" accuracy for progress display
        # Use y_a (orignial labels) even if mixed. This is only a monitor metric
        _, predicted = outputs.max(1)                   #predicted gets the index of the maximum (argmax), i.e., the predicted class ID.
        total += y_a.size(0)                         #We increase total to track how many samples we've seen so far in the epoch.
        correct += predicted.eq(y_a).sum().item()    #correct = total number of correctly classified samples
        
        pbar.set_postfix({
            'loss': running_loss / num_batches,
            'acc': 100. * correct / total
        })
    
    epoch_loss = running_loss / num_batches            #computes mean loss per batch
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc                       #Scalars: epoch_loss: average training loss for this epoch.
                                                       #         epoch_acc: average training accuracy (percent).


def validate(model, val_loader, criterion, device, use_flip_tta=False):
    """Validate the model"""
    model.eval()                                       #Puts model in evaluation mode (turns off Dropout and uses running means instaed of batch stats)
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():                                           #Pytorch does not compute gradients (optimizes performance & memory), avoids backdrop storage overhead
        for images, labels in tqdm(val_loader, desc='Validation'):
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            
            outputs = forward_tta(model, images, use_flip_tta)  #modified
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def train_model(model, train_loader, val_loader, criterion, optimizer, 
                scheduler, device, num_epochs=50, patience=7, freeze_epochs=5):
    """Main training loop with early stopping"""
    best_val_acc = 0.0
    patience_counter = 0
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    print("\n" + "=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)

    # Phase A: freeze backbone
    set_backbone_trainable(model, trainable=False)
    set_backbone_bn_eval(model)
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 60)
        
        # Unfreeze after warmup and rebuild optimizer/scheduler once
        if epoch == freeze_epochs:
            print("→ Unfreezing backbone and switching to discriminative learning rates")

            set_backbone_trainable(model, trainable=True)

            # Optional: let BN adapt once backbone is trainable
            # set_backbone_bn_train(model)

            optimizer = build_optimizer(model, lr_backbone=1e-4, lr_head=5e-4, weight_decay=1e-4)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=2
            )
            
        
        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device,
            freeze_backbone_bn=(epoch < freeze_epochs),
            mixup_alpha=(0.0 if epoch >= freeze_epochs else 0.0)
        )
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device, use_flip_tta=False) #set to false to disable flip_tta
        
        # Update learning rate
        if epoch >= freeze_epochs:
            scheduler.step(val_loss)
            
        
        if len(optimizer.param_groups) == 1:
            lr = optimizer.param_groups[0]['lr']
            print(f'Learning Rate (head-only): {lr:.6f}')
        else:
            lr_backbone = optimizer.param_groups[0]['lr']
            lr_head = optimizer.param_groups[1]['lr']
            print(f'Learning Rate (backbone): {lr_backbone:.6f} | (head): {lr_head:.6f}')
        
        
        # Store metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'✓ New best model saved! (Val Acc: {val_acc:.2f}%)')
            patience_counter = 0
        else:
            patience_counter += 1
            print(f'No improvement. Patience: {patience_counter}/{patience}')
        
        # Early stopping
        if patience_counter >= patience:
            print(f'\n⚠ Early stopping triggered at epoch {epoch+1}')
            break
    
    # Plot training history
    plot_training_history(train_losses, val_losses, train_accs, val_accs)      

    
    return model, best_val_acc                                                   #returns model:FIXME current model(wights from last epoch trained, 
                                                                                #   not necessarily the best. (OK?--YES) and returns best validation acc encountered 


In [9]:
# ============================================================================
# VISUALIZATION FUNCTIONS
# ============================================================================

def plot_training_history(train_losses, val_losses, train_accs, val_accs):
    """Plot training history"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    ax1.plot(train_losses, label='Train Loss', marker='o')
    ax1.plot(val_losses, label='Val Loss', marker='s')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Accuracy plot
    ax2.plot(train_accs, label='Train Acc', marker='o')
    ax2.plot(val_accs, label='Val Acc', marker='s')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
    print("\n✓ Training history plot saved as 'training_history.png'")
    plt.close()


def plot_confusion_matrix(y_true, y_pred, class_names):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.title('Confusion Matrix', fontsize=16, pad=20)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
    print("✓ Confusion matrix saved as 'confusion_matrix.png'")
    plt.close()


In [10]:
# ============================================================================
# Model Evaluation
# ============================================================================

def evaluate_model(model, test_loader, device, class_names, use_flip_tta=False):
    """Evaluate model on test set and plot confusion matrix."""
    
    model.eval()
    all_preds = []
    all_labels = []
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = forward_tta(model, images, use_flip_tta)
            predicted = outputs.argmax(dim=1)
            
            # Accumulate predictions and labels
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Accuracy update
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_acc = 100.0 * correct / total
    print(f"\n✓ Test Accuracy: {test_acc:.2f}%")

    # Sanity checks
    num_classes = len(class_names)
    assert min(all_labels) >= 0 and max(all_labels) < num_classes
    assert min(all_preds)  >= 0 and max(all_preds)  < num_classes

    # Per-class metrics (high value)
    print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))
    
    # Plot confusion matrix
    plot_confusion_matrix(all_labels, all_preds, class_names)
    
    return test_acc


In [11]:
# ============================================================================
# MAIN EXECUTION
# ============================================================================

print("About to define main()")

def main():
    print("\n" + "=" * 60)
    print("FACIAL EMOTION RECOGNITION - ResNet34")
    print("=" * 60)
    print(f"Device: {DEVICE}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Initial Head Learning Rate: {LEARNING_RATE}")
    print(f"Number of Epochs: {NUM_EPOCHS}")
    print(f"Validation Split: {VAL_SPLIT}")
    print("=" * 60)

    #debug
    print("Merged classes:", NEW_CLASSES)
    print("MERGE_MAP keys example:", list(MERGE_MAP.keys())[:5] if isinstance(MERGE_MAP, dict) else MERGE_MAP)


    # Load data
    train_loader, val_loader, test_loader, train_ds, val_ds, test_ds, full_train_ds, full_val_ds = get_data_loaders_merged(
        data_dir = DATA_DIR,
        batch_size=BATCH_SIZE,
        train_split = TRAIN_SPLIT,
        val_split = VAL_SPLIT, 
        num_workers = NUM_WORKERS
    )

    class_names = full_train_ds.classes

     #for debugging
    print("Merged class_to_idx:", full_train_ds.class_to_idx)
    print("Merged classes:", full_train_ds.classes)

    
    print(f"\nData loaders ready:")
    print(f"  Train batches: {len(train_loader)}")
    print(f"  Val batches: {len(val_loader)}")
    print(f"  Test batches: {len(test_loader)}")

    #for debugging
    images, labels = next(iter(train_loader))
    print("One batch:", images.shape, labels.shape)
    
    # Create model
    print("\n" + "=" * 60)
    print("INITIALIZING MODEL")
    print("=" * 60)
    model = create_resnet34_model(num_classes=len(class_names), pretrained=True)
    model = model.to(DEVICE)
    print("✓ ResNet-34 model created (pretrained on ImageNet)")
    print(f"✓ Final layer modified for {len(class_names)} classes")

    #Dropout test
    print("Initial head dropout:", model.fc[0].p)

    
    # Loss and optimizer
    criterion, class_weights = make_class_weighted_criterion(
        train_ds, full_train_ds, DEVICE, 
        use_focal=False,                           
        gamma = 1.5,
        label_smoothing=0.08         #0.05 if not using Focal
    )
    print("Class weights:", class_weights.detach().cpu().numpy())
    print("Criterion:", criterion)
   # print("Loss: FocalLoss(gamma=1.5, label_smoothing=0.0)")
    
    optimizer = optim.AdamW(model.fc.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)  #added weight decay
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )
    
    # Train model
    model, best_val_acc = train_model(
        model, train_loader, val_loader, criterion, 
        optimizer, scheduler, DEVICE, 
        num_epochs=NUM_EPOCHS, patience=PATIENCE,
        freeze_epochs=5
    )

    
    
    # Load best model for testing
    print("\n" + "=" * 60)
    print("LOADING BEST MODEL FOR TESTING")
    print("=" * 60)
    state_dict = torch.load('best_model.pth', map_location=DEVICE)        #ensures always load to correct device
    model.load_state_dict(state_dict)
    model.to(DEVICE)                                                      #not necessary, ensures model is loaded to proper device
    print(f"✓ Loaded best model (Val Acc: {best_val_acc:.2f}%)")
    
    ## Evaluate on test set
    #test_acc = evaluate_model(model, test_loader, DEVICE, class_names)

    # Evaluate on test set WITHOUT TTA (baseline)
    print("\n" + "="*60)
    print("STANDARD EVALUATION (No TTA)")
    print("="*60)
    test_acc_standard = evaluate_model(model, test_loader, DEVICE, class_names, use_flip_tta=False)

    # Evaluate on test set WITH ENHANCED TTA
    print("\n" + "="*60)
    print("ENHANCED TTA EVALUATION (7 augmentations)")
    print("="*60)
    test_acc_tta = evaluate_model(model, test_loader, DEVICE, class_names, use_flip_tta=True)

    print("\n" + "="*60)
    print("FINAL COMPARISON")
    print("="*60)
    print(f"Standard Accuracy: {test_acc_standard:.2f}%")
    print(f"TTA Accuracy:      {test_acc_tta:.2f}%")
    print(f"Improvement:       +{(test_acc_tta - test_acc_standard):.2f}%")
    
    # Final summary
    print("\n" + "=" * 60)
    print("TRAINING COMPLETE - SUMMARY")
    print("=" * 60)
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Test Accuracy: {test_acc_tta:.2f}%")
    print(f"Model saved as: best_model.pth")
    print("=" * 60)

print("About to call main() explicitly")
# if __name__ == "__main__":
#     main()
main()
print("returned from main")

About to define main()
About to call main() explicitly

FACIAL EMOTION RECOGNITION - ResNet34
Device: cuda
Batch Size: 32
Initial Head Learning Rate: 0.001
Number of Epochs: 50
Validation Split: 0.15
Merged classes: ['Angry', 'Disgust', 'Happy', 'LowAffect', 'Arousal']
MERGE_MAP keys example: ['Angry', 'Disgust', 'Happy', 'Neutral', 'Sad']
Split sizes: 36149 7746 7747
Merged class_to_idx: {'Angry': 0, 'Disgust': 1, 'Happy': 2, 'LowAffect': 3, 'Arousal': 4}
Merged classes: ['Angry', 'Disgust', 'Happy', 'LowAffect', 'Arousal']

Data loaders ready:
  Train batches: 1130
  Val batches: 243
  Test batches: 243
One batch: torch.Size([32, 3, 224, 224]) torch.Size([32])

INITIALIZING MODEL
✓ ResNet-34 model created (pretrained on ImageNet)
✓ Final layer modified for 5 classes
Initial head dropout: 0.4
Class counts (train only): [ 4808.  3008.  6472. 13318.  8543.]
Weights: [1.1723021  1.8738128  0.87089443 0.42321885 0.6597716 ]
Classes: ['Angry', 'Disgust', 'Happy', 'LowAffect', 'Arousal']
Cl

Training: 100%|████████████████████████████████████████████████| 1130/1130 [26:36<00:00,  1.41s/it, loss=1.6, acc=31.1]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [05:20<00:00,  1.32s/it]


Learning Rate (head-only): 0.001000
Train Loss: 1.6026, Train Acc: 31.07%
Val Loss: 1.5358, Val Acc: 38.26%
✓ New best model saved! (Val Acc: 38.26%)

Epoch 2/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [17:45<00:00,  1.06it/s, loss=1.57, acc=33.1]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:59<00:00,  1.35it/s]


Learning Rate (head-only): 0.001000
Train Loss: 1.5687, Train Acc: 33.12%
Val Loss: 1.4982, Val Acc: 37.27%
No improvement. Patience: 1/7

Epoch 3/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [16:25<00:00,  1.15it/s, loss=1.57, acc=33.2]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:50<00:00,  1.43it/s]


Learning Rate (head-only): 0.001000
Train Loss: 1.5718, Train Acc: 33.19%
Val Loss: 1.5110, Val Acc: 38.05%
No improvement. Patience: 2/7

Epoch 4/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [15:11<00:00,  1.24it/s, loss=1.57, acc=33.2]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:46<00:00,  1.46it/s]


Learning Rate (head-only): 0.001000
Train Loss: 1.5707, Train Acc: 33.16%
Val Loss: 1.5112, Val Acc: 33.02%
No improvement. Patience: 3/7

Epoch 5/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [14:56<00:00,  1.26it/s, loss=1.56, acc=33.9]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:46<00:00,  1.46it/s]


Learning Rate (head-only): 0.001000
Train Loss: 1.5587, Train Acc: 33.88%
Val Loss: 1.4955, Val Acc: 36.16%
No improvement. Patience: 4/7

Epoch 6/50
------------------------------------------------------------
→ Unfreezing backbone and switching to discriminative learning rates


Training: 100%|█████████████████████████████████████████████████| 1130/1130 [21:08<00:00,  1.12s/it, loss=1.21, acc=59]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:48<00:00,  1.44it/s]


Learning Rate (backbone): 0.000100 | (head): 0.000500
Train Loss: 1.2060, Train Acc: 59.03%
Val Loss: 1.0421, Val Acc: 67.47%
✓ New best model saved! (Val Acc: 67.47%)

Epoch 7/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [21:25<00:00,  1.14s/it, loss=1.04, acc=67.6]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:55<00:00,  1.38it/s]


Learning Rate (backbone): 0.000100 | (head): 0.000500
Train Loss: 1.0377, Train Acc: 67.63%
Val Loss: 1.0253, Val Acc: 69.49%
✓ New best model saved! (Val Acc: 69.49%)

Epoch 8/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [21:18<00:00,  1.13s/it, loss=0.99, acc=70.4]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:55<00:00,  1.38it/s]


Learning Rate (backbone): 0.000100 | (head): 0.000500
Train Loss: 0.9905, Train Acc: 70.38%
Val Loss: 0.9861, Val Acc: 69.80%
✓ New best model saved! (Val Acc: 69.80%)

Epoch 9/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [21:26<00:00,  1.14s/it, loss=0.96, acc=71.9]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:56<00:00,  1.38it/s]


Learning Rate (backbone): 0.000100 | (head): 0.000500
Train Loss: 0.9599, Train Acc: 71.86%
Val Loss: 1.0013, Val Acc: 69.98%
✓ New best model saved! (Val Acc: 69.98%)

Epoch 10/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [21:15<00:00,  1.13s/it, loss=0.93, acc=73.5]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:57<00:00,  1.37it/s]


Learning Rate (backbone): 0.000100 | (head): 0.000500
Train Loss: 0.9305, Train Acc: 73.49%
Val Loss: 0.9798, Val Acc: 69.17%
No improvement. Patience: 1/7

Epoch 11/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [21:28<00:00,  1.14s/it, loss=0.907, acc=74.5]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:57<00:00,  1.37it/s]


Learning Rate (backbone): 0.000100 | (head): 0.000500
Train Loss: 0.9070, Train Acc: 74.48%
Val Loss: 0.9895, Val Acc: 70.68%
✓ New best model saved! (Val Acc: 70.68%)

Epoch 12/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [25:06<00:00,  1.33s/it, loss=0.884, acc=75.6]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [03:49<00:00,  1.06it/s]


Learning Rate (backbone): 0.000100 | (head): 0.000500
Train Loss: 0.8841, Train Acc: 75.60%
Val Loss: 0.9802, Val Acc: 71.48%
✓ New best model saved! (Val Acc: 71.48%)

Epoch 13/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [23:09<00:00,  1.23s/it, loss=0.859, acc=77.1]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [03:07<00:00,  1.30it/s]


Learning Rate (backbone): 0.000050 | (head): 0.000250
Train Loss: 0.8587, Train Acc: 77.07%
Val Loss: 0.9810, Val Acc: 72.46%
✓ New best model saved! (Val Acc: 72.46%)

Epoch 14/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [21:13<00:00,  1.13s/it, loss=0.778, acc=80.9]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [03:01<00:00,  1.34it/s]


Learning Rate (backbone): 0.000050 | (head): 0.000250
Train Loss: 0.7777, Train Acc: 80.93%
Val Loss: 0.9760, Val Acc: 71.43%
No improvement. Patience: 1/7

Epoch 15/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [21:22<00:00,  1.14s/it, loss=0.746, acc=82.4]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [03:05<00:00,  1.31it/s]


Learning Rate (backbone): 0.000050 | (head): 0.000250
Train Loss: 0.7462, Train Acc: 82.39%
Val Loss: 1.0125, Val Acc: 71.44%
No improvement. Patience: 2/7

Epoch 16/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [19:13<00:00,  1.02s/it, loss=0.72, acc=83.7]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:18<00:00,  1.75it/s]


Learning Rate (backbone): 0.000050 | (head): 0.000250
Train Loss: 0.7203, Train Acc: 83.71%
Val Loss: 1.0156, Val Acc: 70.90%
No improvement. Patience: 3/7

Epoch 17/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [20:26<00:00,  1.09s/it, loss=0.696, acc=84.9]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:16<00:00,  1.78it/s]


Learning Rate (backbone): 0.000025 | (head): 0.000125
Train Loss: 0.6963, Train Acc: 84.90%
Val Loss: 1.0450, Val Acc: 72.41%
No improvement. Patience: 4/7

Epoch 18/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [20:20<00:00,  1.08s/it, loss=0.645, acc=87.4]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:25<00:00,  1.67it/s]


Learning Rate (backbone): 0.000025 | (head): 0.000125
Train Loss: 0.6445, Train Acc: 87.40%
Val Loss: 1.0620, Val Acc: 72.67%
✓ New best model saved! (Val Acc: 72.67%)

Epoch 19/50
------------------------------------------------------------


Training: 100%|███████████████████████████████████████████████| 1130/1130 [20:37<00:00,  1.10s/it, loss=0.62, acc=88.7]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:25<00:00,  1.67it/s]


Learning Rate (backbone): 0.000025 | (head): 0.000125
Train Loss: 0.6202, Train Acc: 88.73%
Val Loss: 1.0707, Val Acc: 72.75%
✓ New best model saved! (Val Acc: 72.75%)

Epoch 20/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [20:38<00:00,  1.10s/it, loss=0.607, acc=89.4]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:20<00:00,  1.72it/s]


Learning Rate (backbone): 0.000013 | (head): 0.000063
Train Loss: 0.6072, Train Acc: 89.38%
Val Loss: 1.0780, Val Acc: 72.95%
✓ New best model saved! (Val Acc: 72.95%)

Epoch 21/50
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████| 1130/1130 [21:18<00:00,  1.13s/it, loss=0.576, acc=91]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:36<00:00,  1.56it/s]


Learning Rate (backbone): 0.000013 | (head): 0.000063
Train Loss: 0.5763, Train Acc: 91.01%
Val Loss: 1.1008, Val Acc: 73.04%
✓ New best model saved! (Val Acc: 73.04%)

Epoch 22/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [22:10<00:00,  1.18s/it, loss=0.565, acc=91.7]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [03:13<00:00,  1.26it/s]


Learning Rate (backbone): 0.000013 | (head): 0.000063
Train Loss: 0.5648, Train Acc: 91.65%
Val Loss: 1.1271, Val Acc: 71.88%
No improvement. Patience: 1/7

Epoch 23/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [24:18<00:00,  1.29s/it, loss=0.553, acc=92.1]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [03:11<00:00,  1.27it/s]


Learning Rate (backbone): 0.000006 | (head): 0.000031
Train Loss: 0.5527, Train Acc: 92.11%
Val Loss: 1.1378, Val Acc: 72.27%
No improvement. Patience: 2/7

Epoch 24/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [28:55<00:00,  1.54s/it, loss=0.541, acc=92.8]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [03:34<00:00,  1.13it/s]


Learning Rate (backbone): 0.000006 | (head): 0.000031
Train Loss: 0.5412, Train Acc: 92.77%
Val Loss: 1.1364, Val Acc: 72.82%
No improvement. Patience: 3/7

Epoch 25/50
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████| 1130/1130 [22:22<00:00,  1.19s/it, loss=0.537, acc=93]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:24<00:00,  1.68it/s]


Learning Rate (backbone): 0.000006 | (head): 0.000031
Train Loss: 0.5372, Train Acc: 93.05%
Val Loss: 1.1399, Val Acc: 72.57%
No improvement. Patience: 4/7

Epoch 26/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [20:45<00:00,  1.10s/it, loss=0.531, acc=93.5]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:51<00:00,  1.41it/s]


Learning Rate (backbone): 0.000003 | (head): 0.000016
Train Loss: 0.5315, Train Acc: 93.49%
Val Loss: 1.1472, Val Acc: 72.17%
No improvement. Patience: 5/7

Epoch 27/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [22:13<00:00,  1.18s/it, loss=0.525, acc=93.6]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [02:29<00:00,  1.63it/s]


Learning Rate (backbone): 0.000003 | (head): 0.000016
Train Loss: 0.5246, Train Acc: 93.57%
Val Loss: 1.1469, Val Acc: 72.48%
No improvement. Patience: 6/7

Epoch 28/50
------------------------------------------------------------


Training: 100%|██████████████████████████████████████████████| 1130/1130 [23:34<00:00,  1.25s/it, loss=0.523, acc=93.7]
Validation: 100%|████████████████████████████████████████████████████████████████████| 243/243 [04:51<00:00,  1.20s/it]


Learning Rate (backbone): 0.000003 | (head): 0.000016
Train Loss: 0.5229, Train Acc: 93.71%
Val Loss: 1.1561, Val Acc: 72.82%
No improvement. Patience: 7/7

⚠ Early stopping triggered at epoch 28

✓ Training history plot saved as 'training_history.png'

LOADING BEST MODEL FOR TESTING
✓ Loaded best model (Val Acc: 73.04%)

STANDARD EVALUATION (No TTA)


Testing: 100%|███████████████████████████████████████████████████████████████████████| 243/243 [05:07<00:00,  1.26s/it]



✓ Test Accuracy: 74.65%
              precision    recall  f1-score   support

       Angry     0.5941    0.6337    0.6133       961
     Disgust     0.5260    0.6012    0.5611       657
       Happy     0.8677    0.8751    0.8714      1417
   LowAffect     0.7721    0.7650    0.7685      2808
     Arousal     0.7903    0.7306    0.7593      1904

    accuracy                         0.7465      7747
   macro avg     0.7101    0.7211    0.7147      7747
weighted avg     0.7511    0.7465    0.7482      7747

✓ Confusion matrix saved as 'confusion_matrix.png'

ENHANCED TTA EVALUATION (7 augmentations)


Testing: 100%|███████████████████████████████████████████████████████████████████████| 243/243 [06:03<00:00,  1.50s/it]



✓ Test Accuracy: 75.42%
              precision    recall  f1-score   support

       Angry     0.5967    0.6389    0.6171       961
     Disgust     0.5630    0.5921    0.5772       657
       Happy     0.8696    0.8751    0.8723      1417
   LowAffect     0.7735    0.7796    0.7765      2808
     Arousal     0.7967    0.7411    0.7679      1904

    accuracy                         0.7542      7747
   macro avg     0.7199    0.7253    0.7222      7747
weighted avg     0.7570    0.7542    0.7552      7747

✓ Confusion matrix saved as 'confusion_matrix.png'

FINAL COMPARISON
Standard Accuracy: 74.65%
TTA Accuracy:      75.42%
Improvement:       +0.77%

TRAINING COMPLETE - SUMMARY
Best Validation Accuracy: 73.04%
Test Accuracy: 75.42%
Model saved as: best_model.pth
returned from main
