In [None]:
epochs = 5
oversampling_strength = 0.5
loss_weight_strength = 0.3
experiment_family = "default"

In [None]:
batch_size = 128
img_size = 48
num_workers = 0
learning_rate = 1e-3
val_ratio = 0.2
seed = 42
EPOCHS = epochs
OVERSAMPLING_STRENGTH = oversampling_strength
LOSS_WEIGHT_STRENGTH = loss_weight_strength
TENSORBOARD_LOG_DIR = "runs"

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'


In [None]:
import json
import os
import pprint
from datetime import datetime
from pathlib import Path
from typing import Literal

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torchvision.ops import sigmoid_focal_loss

# Create dataset


In [None]:
class FERPlusDataset(Dataset):
    def __init__(
        self,
        csv_path,
        img_root: Path,
        split: Literal["Training", "PublicTest", "PrivateTest"],
        transform=None,
    ):
        """
        csv_path: path to fer2013new.csv
        img_root: directory with all images
        split: e.g. 'Training', 'PublicTest', 'PrivateTest'
        transform: torchvision transforms (on PIL image)
        """
        self.transform = transform

        df = pd.read_csv(csv_path)

        # Drop rows without filenames
        df = df[df["Image name"].notna()]

        # filter by split
        df = df[df["Usage"] == split].reset_index(drop=True)

        split_to_path = {
            "Training": "FER2013Train",
            "PublicTest": "FER2013Valid",
            "PrivateTest": "FER2013Test",
        }
        split_path = split_to_path[split]
        self.images = [
            Image.open(img_root / split_path / filename).convert("RGB")
            for filename in df["Image name"].tolist()
        ]

        # Follows order in FER+ csv
        self.classes = [
            "neutral",
            "happiness",
            "surprise",
            "sadness",
            "anger",
            "disgust",
            "fear",
            "contempt",
        ]
        votes = df[self.classes].values.astype("float32")  # shape [N, C]

        # convert votes -> probability distributions (soft targets)
        sums = votes.sum(axis=1, keepdims=True)
        sums[sums == 0.0] = 1.0  # avoid division by zero
        self.targets = torch.from_numpy(votes / sums)  # [N, C], each row sums to 1

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        if self.transform is not None:
            img = self.transform(img)

        target = self.targets[idx]

        return img, target

In [None]:
base_transforms = transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ]
)

train_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(size=(img_size, img_size), scale=(0.9, 1)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomAffine(
            degrees=5, translate=(0.05, 0.05), scale=(0.9, 1.1), shear=(-3, 3)
        ),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        base_transforms,
    ]
)

In [None]:
train_ds = FERPlusDataset(
    "data/fer2013new.csv",
    img_root=Path("data/ferplus_raw"),
    split="Training",
    transform=train_transforms,
)
val_ds = FERPlusDataset(
    "data/fer2013new.csv",
    img_root=Path("data/ferplus_raw"),
    split="PublicTest",
    transform=base_transforms,
)

In [None]:
targets = train_ds.targets.argmax(dim=1).cpu().numpy()
class_sample_count = np.bincount(targets)  # shape: [num_classes]
class_weights = 1.0 / class_sample_count

sample_weights = class_weights[targets] ** OVERSAMPLING_STRENGTH  # shape: [num_samples]
sample_weights = sample_weights / sample_weights.sum()
sample_weights = torch.DoubleTensor(sample_weights)

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True,
)

In [None]:
pd.Series(class_weights**OVERSAMPLING_STRENGTH, index=train_ds.classes).plot.bar(
    figsize=(6, 3), title="Oversampling weight"
)

In [None]:
train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=num_workers,
    pin_memory=True,
    # persistent_workers=True,
    # prefetch_factor=4,
)


val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    # persistent_workers=True,
    # prefetch_factor=4,
)

In [None]:
num_classes = len(train_ds.classes)

f"{len(train_ds)} samples", f"{num_classes} classes: {train_ds.classes}"

# Train model


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        identity = x
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out = F.relu(out + identity)
        return out


class TinyCNN(nn.Module):
    def __init__(self, num_classes: int = num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.res3 = ResidualBlock(128)  # small residual block

        self.fc1 = nn.Sequential(
            nn.Linear(128, 256),  # 128 comes from GAP over 128 channels
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        self.fc2 = nn.Linear(256, num_classes)
        self.float()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)  # 48 -> 24

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)  # 24 -> 12

        x = F.relu(self.conv3(x))
        x = self.res3(x)  # richer features, same spatial size
        x = F.adaptive_avg_pool2d(x, 1)  # GAP: (B,128,H,W) -> (B,128,1,1)
        x = x.view(x.size(0), -1)  # (B,128)

        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [None]:
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)
print(device)

In [None]:
def log_to_tensorboard(
    y_true, y_pred, train_loss, val_loss, writer: SummaryWriter, epoch: int
):
    acc = accuracy_score(y_true, y_pred)

    prec_micro, rec_micro, f1_micro, _ = precision_recall_fscore_support(
        y_true, y_pred, average="micro", zero_division=0
    )
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average="macro", zero_division=0
    )

    writer.add_scalar("Metrics/accuracy", acc, epoch)
    writer.add_scalar("Metrics/micro_precision", prec_micro, epoch)
    writer.add_scalar("Metrics/micro_recall", rec_micro, epoch)
    writer.add_scalar("Metrics/micro_f1", f1_micro, epoch)

    writer.add_scalar("Metrics/macro_precision", prec_macro, epoch)
    writer.add_scalar("Metrics/macro_recall", rec_macro, epoch)
    writer.add_scalar("Metrics/macro_f1", f1_macro, epoch)

    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar("Loss/val", val_loss, epoch)

In [None]:
def focal_with_class_weights_loss(
    logits: torch.Tensor, targets: torch.Tensor, class_weights: torch.Tensor
):
    focal = sigmoid_focal_loss(logits, targets, reduction="none")
    weighted = focal * class_weights
    return weighted.sum(dim=1).mean()

In [None]:
model = TinyCNN(num_classes=num_classes).to(device).to(torch.float32)
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

loss_class_weights = class_weights**LOSS_WEIGHT_STRENGTH
loss_class_weights = loss_class_weights / loss_class_weights.sum()
# Sigmoid focal loss puts more weight in "harder" samples
criterion = lambda logits, targets: focal_with_class_weights_loss(
    logits,
    targets,
    class_weights=torch.tensor(loss_class_weights, dtype=torch.float32, device=device),
)

date_str = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
writer = SummaryWriter(log_dir=f"{TENSORBOARD_LOG_DIR}/{experiment_family}/{date_str}")
train_losses = []
val_losses = []

for epoch in range(EPOCHS):
    model.train()
    tr_loss = 0
    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float32)
        labels = labels.to(device)

        opt.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        opt.step()
        tr_loss += loss.item() * images.size(0)
    tr_loss = tr_loss / len(train_ds)

    model.eval()
    val_loss = 0
    y_true = []
    y_pred = []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            logits = model(images)
            loss = criterion(logits, labels)
            val_loss += loss.item() * images.size(0)

            preds = logits.argmax(1).cpu()
            true_hard = labels.argmax(1).cpu()
            y_true.append(true_hard)
            y_pred.append(preds)

    val_loss = val_loss / len(val_ds)
    y_true = torch.cat(y_true).numpy()
    y_pred = torch.cat(y_pred).numpy()

    train_losses.append(tr_loss)
    val_losses.append(val_loss)
    log_to_tensorboard(y_true, y_pred, tr_loss, val_loss, writer, epoch)

In [None]:
pd.Series(loss_class_weights, index=train_ds.classes).plot.bar(
    figsize=(6, 3), title="Loss weight"
)

## Training metrics


In [None]:
model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        logits = model(images)
        y_true.append(labels.argmax(1).cpu())
        y_pred.append(logits.argmax(1).cpu())

y_true = torch.cat(y_true).numpy()
y_pred = torch.cat(y_pred).numpy()

In [None]:
accuracy = accuracy_score(y_true, y_pred)
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
    y_true, y_pred, average="macro", zero_division=0
)

precision_class, recall_class, f1_class, _ = precision_recall_fscore_support(
    y_true, y_pred, average=None, zero_division=0
)

In [None]:
metrics = {
    "macro": {
        "accuracy": accuracy,
        "precision": precision_macro,
        "recall": recall_macro,
        "f1": f1_macro,
    }
} | {
    class_name: {
        "precision": precision_class[i],
        "recall": recall_class[i],
        "f1": f1_class[i],
    }
    for i, class_name in enumerate(train_ds.classes)
}

pprint.pprint(metrics)

### Export data for DVC

In [None]:
# Export metrics for DVC metrics
with open(f"metrics/train.json", "w") as f:
    json.dump(metrics, f, indent=2)

# Export metrics for DVC plots (yes, different format ðŸ¤¦)
with open(f"plots/train_metrics.json", "w") as f:
    json.dump(
        [{"emotion": emotion} | metrics[emotion] for emotion in metrics.keys()],
        f,
        indent=2,
    )

# Export losses
pd.DataFrame(
    {
        "actual": [train_ds.classes[idx] for idx in y_true],
        "predicted": [train_ds.classes[idx] for idx in y_pred],
    }
).to_csv("plots/train_classes.csv", index=False)
pd.DataFrame({"train": train_losses, "validation": val_losses}).to_csv(
    "plots/losses.csv", index=False
)

## Export model


In [None]:
torch.save(
    {"model_state": model.state_dict(), "classes": train_ds.classes},
    "data/model.pt",
)

# Evaluate model


In [None]:
TEST_PATH = "data/ferplus/PrivateTest"

In [None]:
test_ds = FERPlusDataset(
    "data/fer2013new.csv",
    img_root=Path("data/ferplus_raw"),
    split="PrivateTest",
    transform=base_transforms,
)
test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    # num_workers=num_workers,
    pin_memory=True,
)

In [None]:
checkpoint = torch.load("data/model.pt")
model = TinyCNN(num_classes=len(checkpoint["classes"]))
model.load_state_dict(checkpoint["model_state"])

In [None]:
model.eval()
model.to(device)
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits = model(images)
        y_true.append(labels.argmax(1).cpu())
        y_pred.append(logits.argmax(1).cpu())

y_true = torch.cat(y_true).numpy()
y_pred = torch.cat(y_pred).numpy()

In [None]:
accuracy = accuracy_score(y_true, y_pred)
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
    y_true, y_pred, average="macro", zero_division=0
)

precision_class, recall_class, f1_class, _ = precision_recall_fscore_support(
    y_true, y_pred, average=None, zero_division=0
)

In [None]:
metrics = {
    "macro": {
        "accuracy": accuracy,
        "precision": precision_macro,
        "recall": recall_macro,
        "f1": f1_macro,
    }
} | {
    class_name: {
        "precision": precision_class[i],
        "recall": recall_class[i],
        "f1": f1_class[i],
    }
    for i, class_name in enumerate(train_ds.classes)
}

pprint.pprint(metrics)

In [None]:
# Export metrics for DVC metrics
with open(f"metrics/test.json", "w") as f:
    json.dump(metrics, f, indent=2)

# Export metrics for DVC plots (yes, different format ðŸ¤¦)
with open(f"plots/test_metrics.json", "w") as f:
    json.dump(
        [{"emotion": emotion} | metrics[emotion] for emotion in metrics.keys()],
        f,
        indent=2,
    )

pd.DataFrame(
    {
        "actual": [train_ds.classes[idx] for idx in y_true],
        "predicted": [train_ds.classes[idx] for idx in y_pred],
    }
).to_csv("plots/test_classes.csv", index=False)

# Export model to ONNX


In [None]:
model.eval()

In [None]:
class PreprocessingWrapper(nn.Module):
    def __init__(self, base_model, img_size):
        super(PreprocessingWrapper, self).__init__()
        self.base_model = base_model
        self.img_size = img_size

    def forward(self, x):
        x = x / 255.0
        x = TF.rgb_to_grayscale(x, num_output_channels=1)
        x = TF.resize(x, [self.img_size, self.img_size])
        x = (x - 0.5) / 0.5  # normalize(mean=0.5, std=0.5)
        return self.base_model(x)


wrapped_model = PreprocessingWrapper(model, img_size=img_size)
wrapped_model.eval()

In [None]:
device = next(wrapped_model.parameters()).device
dummy = (torch.randn(1, 3, 64, 64) * 255).to(device)

torch.onnx.export(
    wrapped_model,
    dummy,
    "data/model.onnx",
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"}},
    external_data=False,
)

In [None]:
model.eval()
device = next(model.parameters()).device
dummy = (torch.randn(1, 1, img_size, img_size) * 255).to(device)

torch.onnx.export(
    model,
    dummy,
    "data/basemodel.onnx",
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"}},
    external_data=False,
)

# Random data checks


In [None]:
me_ds = datasets.ImageFolder(root="data/samples_raw", transform=base_transforms)

me_loader = DataLoader(
    me_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

In [None]:
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in me_loader:
        images = images.to(device)
        logits = model(images)
        y_true.append(labels)
        y_pred.append(logits.argmax(1).cpu())

In [None]:
[test_ds.classes[y.item()] for y in y_pred[0]], y_true

In [None]:
import torch
from torch.utils.data import DataLoader, Dataset


class NoiseDataset(Dataset):
    def __init__(self, num_samples, img_size):
        self.num_samples = num_samples
        self.img_size = img_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 1-channel noise in [0,1]
        noise = torch.rand(1, self.img_size, self.img_size)
        label = 0  # dummy label
        return noise, label


# usage
num_samples = 20
noise_ds = NoiseDataset(num_samples, img_size)
noise_loader = DataLoader(
    noise_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

In [None]:
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in noise_loader:
        images = images.to(device)
        logits = model(images)
        y_true.append(labels)
        y_pred.append(logits.argmax(1).cpu())

In [None]:
y_pred, y_true