In [None]:
import pprint

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Create dataset


In [None]:
TRAIN_PATH = "data/ferplus/Training"
VALIDATION_PATH = "data/ferplus/PublicTest"
batch_size = 64
img_size = 48
num_workers = 4
learning_rate = 1e-3
val_ratio = 0.2
seed = 42
epochs = 10

In [None]:
train_transforms = transforms.Compose(
    [
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5]),
    ]
)

In [None]:
train_ds = datasets.ImageFolder(root=TRAIN_PATH, transform=train_transforms)
val_ds = datasets.ImageFolder(root=VALIDATION_PATH, transform=train_transforms)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
)

val_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

In [None]:
num_classes = len(train_ds.classes)

f"{len(train_ds)} samples", f"{num_classes} classes: {train_ds.classes}"

# Train model


In [None]:
class TinyCNN(nn.Module):
    def __init__(self, num_classes=num_classes):
        super(TinyCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Sequential(
            nn.Linear((img_size // 8) * (img_size // 8) * 128, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        self.fc2 = nn.Linear(256, num_classes)
        self.float()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.flatten(1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [None]:
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)
print(device)

In [None]:
model = TinyCNN(num_classes=num_classes).to(device).to(torch.float32)
opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

for ep in range(epochs):
    model.train()
    tr_loss = 0

    for images, labels in train_loader:
        images = images.to(device, dtype=torch.float32)
        labels = labels.to(device)
        opt.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        opt.step()
        tr_loss += loss.item() * images.size(0)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            logits = model(images)
            loss = criterion(logits, labels)
            val_loss += loss.item() * images.size(0)

    print(
        f"Epoch {ep}/{epochs}"
        f" | Train Loss: {tr_loss/len(train_ds):.4f}"
        f" | Val Loss: {val_loss/len(val_ds):.4f}"
    )

## Training metrics


In [None]:
model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        logits = model(images)
        y_true.append(labels)
        y_pred.append(logits.argmax(1).cpu())

y_true = torch.cat(y_true).numpy()
y_pred = torch.cat(y_pred).numpy()

In [None]:
accuracy = accuracy_score(y_true, y_pred)
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
    y_true, y_pred, average="macro", zero_division=0
)
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
    y_true, y_pred, average="micro", zero_division=0
)

precision_class, recall_class, f1_class, _ = precision_recall_fscore_support(
    y_true, y_pred, average=None, zero_division=0
)

In [None]:
metrics = {
    "global": {
        "accuracy": accuracy,
        "micro": {
            "precision": precision_micro,
            "recall": recall_micro,
            "f1-score": f1_micro,
        },
        "macro": {
            "precision": precision_macro,
            "recall": recall_macro,
            "f1-score": f1_macro,
        },
    },
    "per_class": {
        class_name: {
            "precision": precision_class[i],
            "recall": recall_class[i],
            "f1-score": f1_class[i],
        }
        for i, class_name in enumerate(train_ds.classes)
    },
}

pprint.pprint(metrics)

## Export model


In [None]:
torch.save(
    {"model_state": model.state_dict(), "classes": train_ds.classes},
    "data/model.pt",
)

# Evaluate model


In [None]:
TEST_PATH = "data/ferplus/PrivateTest"

In [None]:
test_ds = datasets.ImageFolder(root=TEST_PATH, transform=train_transforms)
test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

In [None]:
checkpoint = torch.load("data/model.pt")
model = TinyCNN(num_classes=len(checkpoint["classes"]))
model.load_state_dict(checkpoint["model_state"])

In [None]:
model.eval()
model.to(device)
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        logits = model(images)
        y_true.append(labels)
        y_pred.append(logits.argmax(1).cpu())

y_true = torch.cat(y_true).numpy()
y_pred = torch.cat(y_pred).numpy()

In [None]:
accuracy = accuracy_score(y_true, y_pred)
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
    y_true, y_pred, average="macro", zero_division=0
)
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
    y_true, y_pred, average="micro", zero_division=0
)

precision_class, recall_class, f1_class, _ = precision_recall_fscore_support(
    y_true, y_pred, average=None, zero_division=0
)

In [None]:
metrics = {
    "global": {
        "accuracy": accuracy,
        "micro": {
            "precision": precision_micro,
            "recall": recall_micro,
            "f1-score": f1_micro,
        },
        "macro": {
            "precision": precision_macro,
            "recall": recall_macro,
            "f1-score": f1_macro,
        },
    },
    "per_class": {
        class_name: {
            "precision": precision_class[i],
            "recall": recall_class[i],
            "f1-score": f1_class[i],
        }
        for i, class_name in enumerate(train_ds.classes)
    },
}

pprint.pprint(metrics)

# Export model to ONNX


In [None]:
model.eval()

In [None]:
class PreprocessingWrapper(nn.Module):
    def __init__(self, base_model, img_size):
        super(PreprocessingWrapper, self).__init__()
        self.base_model = base_model
        self.img_size = img_size

    def forward(self, x):
        x = x / 255.0
        x = TF.rgb_to_grayscale(x, num_output_channels=1)
        x = TF.resize(x, [self.img_size, self.img_size])
        x = (x - 0.5) / 0.5  # normalize(mean=0.5, std=0.5)
        return self.base_model(x)


wrapped_model = PreprocessingWrapper(model, img_size=img_size)
wrapped_model.eval()

In [None]:
class PreprocessingWrapper___(nn.Module):
    def __init__(self, base_model, img_size):
        super(PreprocessingWrapper, self).__init__()
        self.base_model = base_model
        self.img_size = img_size

    def forward(self, x):
        x = x / 255.0
        x = TF.rgb_to_grayscale(x, num_output_channels=1)
        x = TF.resize(x, [self.img_size, self.img_size])
        x = (x - 0.5) / 0.5  # normalize(mean=0.5, std=0.5)
        return self.base_model(x)

    def logits(self, x):
        return self.forward(x)

    def predict_proba(self, x):
        return torch.softmax(self.logits(x))

    def predict_label(self, x):
        return


wrapped_model = PreprocessingWrapper(model, img_size=img_size)
wrapped_model.eval()

In [None]:
device = next(wrapped_model.parameters()).device
dummy = (torch.randn(1, 3, 64, 64) * 255).to(device)


torch.onnx.export(
    wrapped_model,
    dummy,
    "data/model.onnx",
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"}},
    external_data=False,
)