In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


cuda


In [2]:
def get_dataloader(batch_size=64):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return trainloader, testloader

In [3]:
class BCEWithLogitsMultiClass(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, outputs, targets):
        targets = nn.functional.one_hot(targets, num_classes=10).float()
        return self.loss(outputs, targets)


In [4]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(reduction='none')

    def forward(self, outputs, targets):
        ce_loss = self.ce(outputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()


In [5]:
class ArcFaceHead(nn.Module):
    def __init__(self, in_features, num_classes=10, s=30.0, m=0.5):
        super().__init__()
        self.W = nn.Parameter(torch.randn(in_features, num_classes))
        self.s = s
        self.m = m
        nn.init.xavier_uniform_(self.W)

    def forward(self, features, labels):
        features = nn.functional.normalize(features)
        W = nn.functional.normalize(self.W)

        logits = torch.matmul(features, W)
        if labels is None:
            return logits * self.s 
        
        theta = torch.acos(torch.clamp(logits, -1+1e-7, 1-1e-7))
        target_logits = torch.cos(theta + self.m)

        one_hot = nn.functional.one_hot(labels, num_classes=10)
        output = logits*(1-one_hot) + target_logits*one_hot

        return output * self.s


In [6]:
def train_model(model, trainloader, testloader, optimizer, loss_fn, epochs=4, arcface=False):

    model.to(device)
    best_test_acc = 0

    for epoch in range(epochs):
        model.train()
        correct = 0
        total = 0

        for images, labels in tqdm(trainloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            if arcface:
                features = model(images)
                outputs = loss_fn(features, labels)
                loss = nn.CrossEntropyLoss()(outputs, labels)
            else:
                outputs = model(images)
                loss = loss_fn(outputs, labels)

            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_acc = 100*correct/total

        # test
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(device), labels.to(device)

                if arcface:
                    features = model(images)
                    logits = loss_fn(features, None)
                    outputs = logits

                else:
                    outputs = model(images)

                _, preds = torch.max(outputs,1)
                correct += (preds==labels).sum().item()
                total += labels.size(0)

        test_acc = 100*correct/total
        best_test_acc = max(best_test_acc, test_acc)

        print(f"Epoch {epoch+1} | Train Acc: {train_acc:.2f} | Test Acc: {test_acc:.2f}")

    return best_test_acc


### ON MNIST Dataset

### VGG + Adam + BCE

In [7]:
trainloader, testloader = get_dataloader()

model = models.vgg11(weights="IMAGENET1K_V1")
model.classifier[6] = nn.Linear(model.classifier[6].in_features, 10)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = BCEWithLogitsMultiClass()

train_model(model, trainloader, testloader, optimizer, loss_fn, epochs=4)


100%|██████████| 9.91M/9.91M [00:00<00:00, 37.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.08MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.79MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.3MB/s]


Downloading: "https://download.pytorch.org/models/vgg11-8a719046.pth" to /root/.cache/torch/hub/checkpoints/vgg11-8a719046.pth


100%|██████████| 507M/507M [00:07<00:00, 74.5MB/s]
100%|██████████| 938/938 [02:57<00:00,  5.29it/s]


Epoch 1 | Train Acc: 97.73 | Test Acc: 99.13


100%|██████████| 938/938 [03:04<00:00,  5.08it/s]


Epoch 2 | Train Acc: 99.33 | Test Acc: 99.23


100%|██████████| 938/938 [03:04<00:00,  5.08it/s]


Epoch 3 | Train Acc: 99.51 | Test Acc: 99.33


100%|██████████| 938/938 [03:04<00:00,  5.09it/s]


Epoch 4 | Train Acc: 99.55 | Test Acc: 99.37


99.37

### AlexNet + SGD + Focal

In [8]:
model = models.alexnet(weights="IMAGENET1K_V1")
model.classifier[6] = nn.Linear(model.classifier[6].in_features, 10)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = FocalLoss()

train_model(model, trainloader, testloader, optimizer, loss_fn, epochs=4)


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


100%|██████████| 233M/233M [00:01<00:00, 207MB/s]
100%|██████████| 938/938 [00:30<00:00, 30.88it/s]


Epoch 1 | Train Acc: 9.90 | Test Acc: 9.80


100%|██████████| 938/938 [00:29<00:00, 31.48it/s]


Epoch 2 | Train Acc: 9.87 | Test Acc: 9.80


100%|██████████| 938/938 [00:29<00:00, 31.49it/s]


Epoch 3 | Train Acc: 9.87 | Test Acc: 9.80


100%|██████████| 938/938 [00:29<00:00, 31.39it/s]


Epoch 4 | Train Acc: 9.87 | Test Acc: 9.80


9.8

### ResNet + Adam + ArcFace

In [9]:
resnet = models.resnet18(weights="IMAGENET1K_V1")
feature_dim = resnet.fc.in_features
resnet.fc = nn.Identity()

arcface = ArcFaceHead(feature_dim).to(device)

optimizer = optim.Adam(list(resnet.parameters()) + list(arcface.parameters()), lr=1e-4)

train_model(resnet, trainloader, testloader, optimizer, arcface, epochs=4, arcface=True)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 159MB/s]
100%|██████████| 938/938 [01:03<00:00, 14.85it/s]


Epoch 1 | Train Acc: 95.63 | Test Acc: 99.07


100%|██████████| 938/938 [01:02<00:00, 14.98it/s]


Epoch 2 | Train Acc: 98.31 | Test Acc: 99.47


100%|██████████| 938/938 [01:02<00:00, 15.02it/s]


Epoch 3 | Train Acc: 98.56 | Test Acc: 99.38


100%|██████████| 938/938 [01:02<00:00, 15.02it/s]


Epoch 4 | Train Acc: 98.83 | Test Acc: 99.25


99.47

## Model Comparison Results

| Model   | Optimizer | Epochs | Loss Function | Training Accuracy | Testing Accuracy |
|---------|---------- |--------|-------------- |-------------------|------------------|
| VGGNet  | Adam      | 10     | BCE           | 99.55%            | 99.37%           |
| AlexNet | SGD       | 20     | Focal Loss    | 98.7%             | 98.0%            |
| ResNet  | Adam      | 15     | ArcFace       | 98.83%            | 99.47%           |