In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

import random as rd
from tqdm import tqdm
from typing import Tuple

In [2]:
class RotatedMNIST(Dataset):
    def __init__(self):
        self.mnist = MNIST('data', transform=transforms.ToTensor(), download=True)
        self.angles = [0, 90, 180, 270]
    
    def __len__(self):
        return len(self.mnist)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        img, label = self.mnist[idx]
        angle_idx = torch.randint(low=0, high=len(self.angles) - 1, size=())
        rotated_img = transforms.functional.rotate(img, self.angles[angle_idx])
        return rotated_img, angle_idx, label

dataset = RotatedMNIST()
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

In [6]:
class RotationPredictor(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=4, out_channels=2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten()
        )
        
        self.rotation_classifier = nn.Sequential(
            nn.Linear(98, 4),
            nn.Softmax(dim=1)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone(x)
        x = self.rotation_classifier(x)
        return x

model = RotationPredictor()

In [7]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [8]:
for epoch in range(2):
    model.train()
    total_loss, total_acc = 0., 0.
    for X, y, _ in tqdm(dataloader):
        
        optimizer.zero_grad()
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        acc = (y_pred.argmax(-1) == y).float().mean()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_acc += acc.item()
    
    print(f"Epoch {epoch} | Loss: {total_loss/len(dataloader):.4f} | Acc: {total_acc/len(dataloader):.4f}")

100%|█████████████████████████████████████████| 600/600 [00:14<00:00, 42.18it/s]


Epoch 0 | Loss: 0.8975 | Acc: 0.8687


100%|█████████████████████████████████████████| 600/600 [00:14<00:00, 42.11it/s]

Epoch 1 | Loss: 0.7777 | Acc: 0.9695





In [13]:
backbone_frozen = model.backbone
# for param in backbone_frozen.parameters():
#     param.requires_grad = False
digit_classifier = nn.Sequential(
    backbone_frozen,
    nn.Linear(98, 10),
    nn.Softmax(dim=1)
)

In [14]:
optimizer = torch.optim.Adam(digit_classifier.parameters())

In [15]:
for epoch in range(2):
    digit_classifier.train()
    total_loss, total_acc = 0., 0.
    for X, _, y in tqdm(dataloader):
        
        optimizer.zero_grad()
        y_pred = digit_classifier(X)
        loss = loss_fn(y_pred, y)
        acc = (y_pred.argmax(-1) == y).float().mean()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_acc += acc.item()
    
    print(f"Epoch {epoch} | Loss: {total_loss/len(dataloader):.4f} | Acc: {total_acc/len(dataloader):.4f}")

100%|█████████████████████████████████████████| 600/600 [00:12<00:00, 47.41it/s]


Epoch 0 | Loss: 2.0346 | Acc: 0.4382


100%|█████████████████████████████████████████| 600/600 [00:13<00:00, 45.92it/s]

Epoch 1 | Loss: 1.9527 | Acc: 0.5213



