In [281]:
import matplotlib.pyplot as plt
import numpy as np
import copy

import torch
import torchvision

from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
mnist_train = torchvision.datasets.MNIST('/tmp/mnist', download=True, train=True)
mnist_test = torchvision.datasets.MNIST('/tmp/mnist', download=True, train=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw



In [9]:
x_train = mnist_train.data
y_train = mnist_train.targets

x_test = mnist_test.data
y_test = mnist_test.targets

t0_digit = 3
t1_digit = 8

# NOTE: x_train, x_test remains the same for both tasks

y_t0_train = torch.where(y_train == t0_digit, 1, 0)
y_t0_test = torch.where(y_test == t0_digit, 1, 0)
y_t1_train = torch.where(y_train == t1_digit, 1, 0)
y_t1_test = torch.where(y_test == t1_digit, 1, 0)

In [10]:
class MultiTaskDataset:
    def __init__(self, x, t0_y, t1_y):
        self.x = x
        self.t0_y = t0_y
        self.t1_y = t1_y
        
    def __getitem__(self, idx):
        return self.x[idx] / 255., self.t0_y[idx], self.t1_y[idx]
            
    def __len__(self):
        return len(self.x)

In [189]:
class MTL(nn.Module):
    """Generic multi-task learner."""

    def __init__(self, shared=True):
        super().__init__()
        
        self.shared = shared
    
        self.linear1 = nn.Linear(28 * 28, 256)
        self.act1 = nn.ReLU()
        self.linear2 = nn.Linear(256, 32)
        self.act2 = nn.ReLU()
        
#         if not self.shared:
#             self.backbone2 = nn.Sequential(
#                 nn.Linear(28 * 28, 256),
#                 nn.ReLU(),
#                 nn.Linear(256, 32),
#                 nn.ReLU()
#             )
                
        self.t0_head = nn.Linear(32, 1)
        self.t1_head = nn.Linear(32, 1)
    
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        t0_logits = self.t0_head(self.act2(self.linear2(self.act1(self.linear1(x)))))
        t1_logits = self.t1_head(self.act2(self.linear2(self.act1(self.linear1(x)))))
#         if not self.shared:
#             t1_logits = self.t1_head(self.backbone2(x))
#         else:
#             t1_logits = self.t1_head(self.backbone(x))
        
        return t0_logits, t1_logits
    
    def loss(self, x, t0_y, t1_y, t0_lambda=1.0, t1_lambda=1.0):
        t0_logits, t1_logits = self.forward(x)
        t0_logits = t0_logits.view(-1,)
        t1_logits = t1_logits.view(-1,)
        
        t0_loss = nn.functional.binary_cross_entropy_with_logits(t0_logits, t0_y)
        t1_loss = nn.functional.binary_cross_entropy_with_logits(t1_logits, t1_y)
        total_loss = t0_lambda * t0_loss + t1_lambda * t1_loss
        losses = {
            'loss/total': total_loss,
            'loss/t0': t0_loss,
            'loss/t1': t1_loss,
        }
        return losses

In [299]:
def train_mtl(
    *, 
    train_loader, 
    val_loader, 
    lr=1e-3, 
    epochs=5, 
):
    model = MTL()
    t0_model = MTL()
    t1_model = MTL()
    
    linear_lr = []
    shared_lr = []
    for k, v in model.named_parameters():
        if "linear" in k and "weight" in k:
            linear_lr.append([v])
        else:
            shared_lr.append(v)
#     parameters = [{"params": linear_lr}, {"params": shared_lr}]
    parameters = [{"params": linear_lr[0]}, {"params": linear_lr[1]}, {"params": shared_lr}]
    optimizer = torch.optim.Adam(parameters, lr=lr)
    t0_optimizer = torch.optim.Adam(t0_model.param_groups, lr = lr)
    t1_optimizer = torch.optim.Adam(t1_model.param_groups, lr = lr)
    
    train_metrics = {'loss/total': [], 
                     'loss/t0': [], 
                     'loss/t1': [],
                     'acc/t0': [],
                     'acc/t1': [],
                    }
    val_metrics = {'loss/total': [], 
                     'loss/t0': [], 
                     'loss/t1': [],
                     'acc/t0': [],
                     'acc/t1': [],
                    }

    for epoch in range(epochs):
        model.train()
        last_10_it = len(train_loader) - 1
        for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):    
            im, t0_labels, t1_labels = batch
            
            loss = model.loss(im, t0_labels.float(), t1_labels.float())
            total_loss = loss['loss/total']
            t0_loss = loss['loss/t0']
            t1_loss = loss['loss/t1']

            
            total_loss.backward(retain_graph=True)
            rgn = 1
#             for k, v in model.named_parameters():
#                 grad = torch.autograd.grad(t0_loss, v, allow_unused=True, retain_graph=True)
#                 if grad[0] is not None:
#                     t0_rgn[k] = torch.linalg.norm(grad[0]/v)
#             t0_loss.backward(retain_graph=True)
#             t1_loss.backward()
            if i == last_it:
                t0_model.load_state_dict(model.state_dict())
                t0_optimizer.load_state_dict(optimizer.state_dict())
                t1_model.load_state_dict(model.state_dict())
                t1_optimizer.load_state_dict(optimizer.state_dict())
            if i >= last_it:
                for g in t0_optimizer.param_groups:
                    grad = torch.autograd.grad(t0_loss, g["params"][0], allow_unused=True, retain_graph=True)
                    if grad[0] is not None:
                        rgn = torch.linalg.norm(grad[0]/torch.linalg.norm(g["params"][0]))
                    g["lr"] = rgn*g["lr"]
                for g in t1_optimizer.param_groups:
                    grad = torch.autograd.grad(t1_loss, g["params"][0], allow_unused=True, retain_graph=True)
                    if grad[0] is not None:
                        rgn = torch.linalg.norm(grad[0]/torch.linalg.norm(g["params"][0]))
                    g["lr"] = rgn*g["lr"]
                t0_optimizer.step()
                t1_optimizer.step()
                
                t0_optimizer.zero_grad()
                t1_optimizer.zero_grad()
            else:
                optimizer.step()
                optimizer.zero_grad()

            # loss metrics
            for k in train_metrics.keys():
                if k in loss:
                    train_metrics[k].append(loss[k].item())

            # accuracy metrics
            t0_logits, t1_logits = model.forward(im)
            t0_pred = nn.functional.sigmoid(t0_logits) > 0.5
            t1_pred = nn.functional.sigmoid(t1_logits) > 0.5
            train_metrics['acc/t0'].append(accuracy(t0_pred[:, 0], t0_labels.bool()))
            train_metrics['acc/t1'].append(accuracy(t1_pred[:, 0], t1_labels.bool()))


        t0_val_loss = 0
        t1_val_loss = 0
        t0_val_acc = 0
        t1_val_acc = 0
        val_count = 0

        model.eval()
        for batch in val_loader:
            im, t0_labels, t1_labels = batch
            val_loss = model.loss(im, t0_labels.float(), t1_labels.float())

            val_count += len(im)        
            t0_val_loss += val_loss['loss/t0'].item() * len(im)
            t1_val_loss += val_loss['loss/t1'].item() * len(im)

            # accuracy metrics
            t0_logits, t1_logits = model.forward(im)
            t0_pred = torch.sigmoid(t0_logits) > 0.5
            t1_pred = torch.sigmoid(t1_logits) > 0.5
            t0_val_acc += accuracy(t0_pred[:, 0], t0_labels.bool()).item() * len(im)
            t1_val_acc += accuracy(t1_pred[:, 0], t1_labels.bool()).item() * len(im)

        val_metrics['loss/t0'].append(t0_val_loss / val_count)
        val_metrics['loss/t1'].append(t1_val_loss / val_count)
        val_metrics['acc/t0'].append(t0_val_acc / val_count)
        val_metrics['acc/t1'].append(t1_val_acc / val_count)
    
    return train_metrics, val_metrics

In [300]:
def accuracy(y_hat, y):
    assert y_hat.shape == y.shape
    n_classes = y_hat.shape[-1]
    y_hat = y_hat.view(-1, n_classes)
    y = y.view(-1, n_classes)
    n_correct = torch.sum(y_hat == y)
    return n_correct / y_hat.numel()

In [301]:
EPOCHS = 5
LR = 3e-4
BATCH_SIZE = 512

train_ds = MultiTaskDataset(x_train, y_t0_train, y_t1_train)
val_ds = MultiTaskDataset(x_test, y_t0_test, y_t1_test)
train_iter = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_iter = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=False)

In [302]:
_, shared_metrics = train_mtl(
    train_loader=train_iter, 
    val_loader=val_iter,
    lr=LR, 
    epochs=EPOCHS,
)

100%|████████████████████████████████████████| 118/118 [00:00<00:00, 141.35it/s]
100%|████████████████████████████████████████| 118/118 [00:00<00:00, 148.69it/s]
100%|████████████████████████████████████████| 118/118 [00:00<00:00, 138.12it/s]
100%|████████████████████████████████████████| 118/118 [00:00<00:00, 146.88it/s]
100%|████████████████████████████████████████| 118/118 [00:00<00:00, 140.55it/s]


In [303]:
print(shared_metrics)

{'loss/total': [], 'loss/t0': [0.15632205335299174, 0.15356650929450988, 0.15347934050957363, 0.15347577875057855, 0.15347576359510423], 'loss/t1': [0.23573486652374268, 0.23160629379749298, 0.2314805144548416, 0.2314755568186442, 0.23147552610238392], 'acc/t0': [0.9552333333333334, 0.9554, 0.9554166666666667, 0.9554166666666667, 0.9554166666666667], 'acc/t1': [0.9178666666348775, 0.9202666666666667, 0.9204, 0.9204, 0.9204]}
