# Task 1: Rotated MNIST

In this task you will use MNIST dataset. The images are 28x28 and they are **rotated** by an angle from the range (-100, 100).  
You are given a pipeline that trains a multi-head convolutional neural network on this dataset. The first head of the model performs a digit classification and the second head tries to predict the angle that the digit was rotated by.

Your task is to:
1. **(5 pts)** Implement CNN with classification and regression heads.
2. **(3 pts)** Implement the model's loss - select appropriate loss functions for both heads and combine them into final loss of the model.
3. **(3 pts)** Check the model's predictions on the test data and find for which classes the model achieves the best/worst performance both for classification and regression. Then, write a short explanation for the observed model behavior (why does the model have a problem with particular classes?).

Hints:
- You don't need to create a very sophisticated model - a few convolutions for CNN and a few linear layers for heads should be enough. After a few epochs the model should achieve classification accuracy above 90% and regression MAE below 18 on a test dataset.
- When training multi-head models it is usually good to scale losses from particular heads so that they have similar contribution towards the final loss.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from PIL import Image

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_blocks = torch.nn.Sequential(
            ################## TODO ####################
            # Subtask 1: Implement CNN architecture
            nn.Flatten(),
        )
        self.classification_head = torch.nn.Sequential(
            ################## TODO ####################
            # Subtask 1: Implement classification head
            nn.LogSoftmax(dim=1),
        )
        self.regression_head = torch.nn.Sequential(
            ################## TODO ####################
            # Subtask 1: Implement regression head
        )

    def forward(self, x):
        ################## TODO ####################
        # Subtask 1: Implement forward
        return log_probs, angle


def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    for batch_idx, (data, (target_digit, target_angle)) in enumerate(train_loader):
        data = data.to(device)
        target_digit, target_angle = target_digit.to(device), target_angle.to(device)
        optimizer.zero_grad()
        log_probs, angle = model(data)
        ################## TODO ####################
        # Subtask 2: Implement classification and regression loss, then combine them into
        # final model loss
        classification_loss = # TODO
        regression_loss = # TODO
        loss = # TODO
        loss.backward()
        optimizer.step()
        if (batch_idx + 1) % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                100. * (batch_idx + 1) / len(train_loader), loss.item()
            ))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    abs_error = 0
    with torch.no_grad():
        for data, (target_digit, target_angle) in test_loader:
            data = data.to(device)
            target_digit, target_angle = target_digit.to(device), target_angle.to(device)
            log_probs, angle = model(data)
            ################## TODO ####################
            # Subtask 2: Implement classification and regression loss, then combine them into
            # final model loss (use the same loss as for training)
            # Hint: pass reduction='sum' to loss functions to output loss correctly when logging
            classification_loss = # TODO
            regression_loss = # TODO
            test_loss += # TODO
            pred_digit = log_probs.argmax(dim=1, keepdim=True)
            correct += pred_digit.eq(target_digit.view_as(pred_digit)).sum().item()
            abs_error += (target_angle - angle).abs().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Classification accuracy: {}/{} ({:.0f}%), Regression MAE: {:.2f}\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset),
        abs_error / len(test_loader.dataset)
    ))

In [None]:
batch_size = 256
test_batch_size = 1000
epochs = 5
lr = 3e-3
seed = 1
log_interval = 50
use_cuda = torch.cuda.is_available()

In [None]:
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}
if use_cuda:
    cuda_kwargs = {
        'num_workers': 1,
        'pin_memory': True,
        'shuffle': True
    }
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [None]:
class MNISTWithRotations(datasets.MNIST):
    def __init__(self, *args, transform=None, target_transform=None, **kwargs):
        super(MNISTWithRotations, self).__init__(*args, **kwargs)
        self.rotation_angles = (torch.rand(len(self.data)) - 0.5) * 2 * 100
        self.is_img_transforemed = [False] * len(self.data)
        self.transformed_data = torch.zeros(*self.data.shape)

    def __getitem__(self, idx):
        if not self.is_img_transforemed[idx]:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(
                    lambda x: transforms.functional.rotate(x, self.rotation_angles[idx].item())
                ),
                transforms.Normalize((0.1307,), (0.3081,)),
            ])
            img = Image.fromarray(self.data[idx].numpy(), mode="L")
            self.transformed_data[idx] = transform(img)
            self.is_img_transforemed[idx] = True

        img = self.transformed_data[idx].unsqueeze(0)
        target_digit = int(self.targets[idx])
        target_angle = self.rotation_angles[idx]
        return img, (target_digit, target_angle)

In [None]:
train_dataset = MNISTWithRotations('../data', train=True, download=True)
test_dataset = MNISTWithRotations('../data', train=False)

train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs)

In [None]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval)
    test(model, device, test_loader)

# Analysis of model performance for particular classes

In [None]:
############## TODO ################
# Subtask 3: Check the model's predictions on the test data and find for which classes the
# model achieves the best/worst performance both for classification and regression.
# Write short explanation for observed behavior