In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
train_dataset = torchvision.datasets.MNIST('~/.torch/models/mnist', train=True, download=True, 
                                           transform=transform)
test_dataset = torchvision.datasets.MNIST('~/.torch/models/mnist', train=False, download=True, 
                                           transform=transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=True)

In [3]:
def assert_shape(x, shape):
    assert tuple(x.shape[-2:]) == tuple(shape), f'Expected shape ending {shape}, got {x.shape}'

In [13]:
class Encoder(nn.Module):
    def __init__(self, size: (int, int)):
        super().__init__()
        self.size = size
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=25, kernel_size=12, padding=0, stride=2)
        self.conv2 = nn.Conv2d(in_channels=25, out_channels=64, kernel_size=5, padding=2)

    
    def forward(self, x):
        assert_shape(x, self.size)
        
        x = F.relu(self.conv1(x))
        assert_shape(x, (9, 9)) # Should this be the same size?

        x = F.relu(self.conv2(x))
        assert_shape(x, (9, 9))
        
        x = F.max_pool2d(x, 2)
        assert_shape(x, (4, 4))
        
        return x
    
class Decoder(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.fc1 = nn.Linear(in_features=4*4*64, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=num_classes)
        
    def forward(self, x):
        assert_shape(x, (4, 4))
        
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class Model(nn.Module):
    def __init__(self, size: (int, int), num_classes: int):
        super().__init__()
        self.size = size
        self.encoder = Encoder(size)
        self.decoder1 = Decoder(num_classes=3)
        self.decoder2 = Decoder(num_classes=10)
        
        self.weight1 = nn.Parameter(torch.tensor([1.0]))
        self.weight2 = nn.Parameter(torch.tensor([1.0]))
        
    def forward(self, x):
        assert_shape(x, self.size)
        
        x = self.encoder(x)
        x1 = self.decoder1(x)
        x2 = self.decoder2(x)
        return x1, x2
        
model = Model((28,28), 10)
model = model.cuda()
for image, labels in train_dataloader:
    image = image.cuda()
    image /= 255
    model(image)
    break

In [14]:
def labels_to_1(labels):
    """Task 1 is 3, 7 or other"""
    # 2 is class other.
    converted = torch.full_like(labels, 2)
    converted[labels == 3] = 0
    converted[labels == 7] = 1
    return converted

labels_to_1(torch.tensor([1,2,3,4,5,6,7,8,9,0]))

tensor([2, 2, 0, 2, 2, 2, 1, 2, 2, 2])

In [15]:
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

enable1 = True
enable2 = True
learn_weights = True

In [18]:
for epoch in range(10):
    epoch_loss = 0
    for i, data in enumerate(train_dataloader):
        image, labels = data
        
        image = image.cuda()
        labels = labels.cuda()
        
        image /= 255
        
        optimizer.zero_grad()
        
        output1, output2 = model(image)
        
        if enable1:
            loss1 = criterion1(output1, labels_to_1(labels))
        else:
            loss1 = 0
            
        if enable2:
            loss2 = criterion2(output2, labels)
        else:
            loss2 = 0
        
        if learn_weights:
            loss = (torch.exp(-model.weight1) * loss1 + 0.5 * model.weight1 
                    + torch.exp(-model.weight2) * loss2 + 0.5 * model.weight2)
        else:
            loss = loss1 + loss2
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
    print(f'Epoch {epoch}: {epoch_loss/i:.3} ({model.weight1.item()}, {model.weight2.item()})')

Epoch 0: -1.44 (-2.5295610427856445, -1.9533275365829468)
Epoch 1: -1.86 (-3.035012722015381, -2.3334403038024902)
Epoch 2: -2.14 (-3.371398687362671, -2.6007096767425537)
Epoch 3: -2.4 (-3.6361169815063477, -2.827561855316162)
Epoch 4: -2.61 (-3.863328456878662, -3.0023446083068848)
Epoch 5: -2.86 (-4.093160629272461, -3.2284069061279297)
Epoch 6: -3.0 (-4.2603936195373535, -3.388504981994629)
Epoch 7: -3.18 (-4.3957905769348145, -3.574328660964966)
Epoch 8: -3.32 (-4.539775371551514, -3.7055885791778564)
Epoch 9: -3.57 (-4.725478649139404, -3.9001665115356445)


In [19]:
with torch.no_grad():
    correct1 = 0
    correct2 = 0
    total = 0
    
    for i, data in enumerate(test_dataloader):
        image, labels = data
        
        image = image.cuda()
        labels = labels.cuda()
        
        output1, output2 = model(image)
        preds1 = output1.argmax(dim=1)
        preds2 = output2.argmax(dim=1)
        
        correct1 += (labels_to_1(labels) == preds1).sum().item()
        correct2 += (labels == preds2).sum().item()
        total += preds1.shape[0]
    
    print(f'Accuracy 1: {correct1}/{total} ({100 * correct1/total:.2f}%)')
    print(f'Accuracy 2: {correct2}/{total} ({100 * correct2/total:.2f}%)')

Accuracy 1: 9947/10000 (99.47%)
Accuracy 2: 9853/10000 (98.53%)


Just 1:
Accuracy 1: 9868/10000 (98.68%)
Accuracy 2: 1352/10000 (13.52%)

Just 2:
Accuracy 1: 1270/10000 (12.70%)
Accuracy 2: 9434/10000 (94.34%)

Both (equal weights):
Accuracy 1: 9840/10000 (98.40%)
Accuracy 2: 9457/10000 (94.57%)

Both (learned weights):
Accuracy 1: 9897/10000 (98.97%)
Accuracy 2: 9735/10000 (97.35%)