In [1]:
cd ../

/Users/gautam/Desktop/workbench/cs330-final


In [25]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
import itertools

import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import datasets
from model import Shareable

In [111]:
bs = 512
lr = 3e-4

is_two_train = datasets.IsNumber(n=2, train=True)
is_five_train = datasets.IsNumber(n=5, train=True)

is_two_val = datasets.IsNumber(n=2, train=False)
is_five_val = datasets.IsNumber(n=5, train=False)

In [112]:
is_two_train_loader = DataLoader(is_two_train, batch_size=bs, shuffle=True, drop_last=True)
is_five_train_loader = DataLoader(is_five_train, batch_size=bs, shuffle=True, drop_last=True)

is2_iter = itertools.cycle(is_two_train_loader)
is5_iter = itertools.cycle(is_five_train_loader)

In [152]:
def grad_dict(loss, model):
    names, params = zip(*model.named_parameters())
    grads = torch.autograd.grad(loss, params, allow_unused=True, retain_graph=True)
    zipped_grads = list(zip(names, grads))
    return dict(zipped_grads)


def clone_grads(model):
    names, params = zip(*model.named_parameters())
    grads = [p.grad if p.grad is None else p.grad.clone() for p in params]
    return dict(zip(names, grads))


def sub_state_dicts(a, b):
    assert a.keys() == b.keys()
    a_vals = a.values()
    b_vals = b.values()
    return {k: (v1 if v1 is not None else 0) - (v2 if v2 is not None else 0) 
            for k, v1, v2 in zip(a.keys(), a_vals, b_vals)}


def stack_grad(grad_dict_list, task_name, param_name, flatten=True):
    grads = [g[task_name][param_name] for g in grad_dict_list]
    return torch.stack(grads).view(len(grads), -1)


def low_pass_filter(x, filter_size=25):
    x = torch.tensor(x) if not isinstance(x, torch.Tensor) else x
    x_smooth = F.conv1d(x[None], torch.ones(1, 1, filter_size) / filter_size)
    return x_smooth.numpy()

In [180]:
class ConvBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=(3, 3)),
            nn.ReLU(),
            nn.AvgPool2d(2, 2),
        )
    
    def forward(self, x):
        # assume x.shape == (B, 1, 28, 28)
        return self.net(x).view(x.shape[0], -1)
        
        
class LinearBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
        )
    
    def forward(self, x):
        # assume x.shape == (B, 1, 28, 28)
        x = x.view(x.shape[0], -1)
        return self.net(x).view(x.shape[0], -1)
    

class SharedMTL(nn.Module):
    def __init__(self, task_keys):
        super().__init__()
        self.backbone = ConvBackbone()
        self.heads = nn.ModuleDict({
            task: nn.Linear(256, 1)
            for task in task_keys
        })
    
    def forward(self, x, task):
        return self.heads[task](self.backbone(x))

In [181]:
steps = 200

tasks = {
    't0': {
        'data': is2_iter,
        'loss': lambda yh, y: F.binary_cross_entropy_with_logits(yh, y[..., None]),
    },
    't1': {
        'data': is5_iter,
        'loss': lambda yh, y: F.binary_cross_entropy_with_logits(yh, y[..., None]),
    },
}
mtl = SharedMTL(tasks.keys())
opt = torch.optim.Adam(mtl.parameters(), lr=lr)

step = 0
losses = []
grads = []
while step < steps:
    task_losses = {}
    task_grads = {}
    
    opt.zero_grad()
    for task_name, task_iter in tasks.items():
        batch = next(task_iter)
        x, y = batch
        pred = mtl(x[:, None, ...], task_name)
        loss = tasks[task_name][
        
        F.binary_cross_entropy_with_logits(pred, y[..., None])
        
        # avoid gradient accumulation bugs
        running_grads = clone_grads(mtl)
        loss.backward()
        task_grads[task_name] = sub_state_dicts(clone_grads(mtl), running_grads)
        task_losses[task_name] = loss.item()
    
    opt.step()

    losses.append(task_losses)
    grads.append(task_grads)
    
    if step % 10 == 0:
        print('step', step, task_losses)
    
    step += 1

step 0 {'t0': 0.7179786562919617, 't1': 0.6881766319274902}
step 10 {'t0': 0.31588083505630493, 't1': 0.31994783878326416}
step 20 {'t0': 0.28396108746528625, 't1': 0.43509432673454285}
step 30 {'t0': 0.3228101432323456, 't1': 0.2658112347126007}
step 40 {'t0': 0.3149857223033905, 't1': 0.2961196005344391}
step 50 {'t0': 0.2839863896369934, 't1': 0.3070862293243408}
step 60 {'t0': 0.298997700214386, 't1': 0.29681938886642456}
step 70 {'t0': 0.27741920948028564, 't1': 0.28228747844696045}
step 80 {'t0': 0.2513234615325928, 't1': 0.2829279899597168}
step 90 {'t0': 0.22837910056114197, 't1': 0.25727617740631104}
step 100 {'t0': 0.23665127158164978, 't1': 0.25955528020858765}
step 110 {'t0': 0.2073015570640564, 't1': 0.21563036739826202}
step 120 {'t0': 0.16826166212558746, 't1': 0.26743876934051514}
step 130 {'t0': 0.1875883936882019, 't1': 0.24544239044189453}
step 140 {'t0': 0.12857620418071747, 't1': 0.23376533389091492}
step 150 {'t0': 0.11570951342582703, 't1': 0.1712069809436798}
st

In [184]:
param_keys = ['backbone.' + k for k in list(mtl.backbone.state_dict().keys())]

print('param: \t\t\tdir_s\tvar')

values = []

for key in param_keys:
    g0 = stack_grad(grads, 't0', key)
    g1 = stack_grad(grads, 't1', key)
    
    cosine = torch.sum(F.normalize(g0, dim=-1) * F.normalize(g1, dim=-1), dim=-1)
    smooth_cos = low_pass_filter(cosine, filter_size=10)[0]
    
    print(f'{key}: \t{smooth_cos.mean():.4f}\t{cosine.var():.4f}')
    values.append(smooth_cos.mean())
    
    # plt.title(key)
    # plt.ylim([-1.1, 1.1])
    # plt.plot(cosine)
    # plt.plot(smooth_cos)
    # plt.show()

param: 			dir_s	var
backbone.net.0.weight: 	0.4217	0.1997
backbone.net.0.bias: 	0.2994	0.3110
backbone.net.3.weight: 	0.1899	0.1822
backbone.net.3.bias: 	0.1528	0.2950
backbone.net.6.weight: 	0.0307	0.0240
backbone.net.6.bias: 	0.0387	0.0388


In [185]:
import numpy as np
v = sorted(values, reverse=True)
v / np.sum(v)

array([0.37213403, 0.26421025, 0.16755864, 0.13485323, 0.03414306,
       0.02710086], dtype=float32)