In [1]:
from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np

In [2]:
resnet_pretrained = models.resnet50(pretrained=True)

In [3]:
nn.Conv2d = Conv2d_Attn
resnet_attn = models.resnet50()

In [4]:
resnet_attn.load_state_dict(resnet_pretrained.state_dict(), strict=False)

In [5]:
# Change batchnorm behavior
# resnet_attn = resnet_attn.eval() 
# Don't want to do that because bn needs to be re-trained as well

In [6]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [7]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(lst=[]):
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        for kw in lst:
            if not kw in k:
                obj.requires_grad = False
            else:
                obj.requires_grad = True

In [8]:
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 144)

Start training

In [9]:
batch_size = 32

In [10]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

trainset = torchvision.datasets.ImageFolder(root='./data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [11]:
total_imgs = len(trainset.imgs)

In [12]:
resnet_attn = resnet_attn.cuda()

In [13]:
total_attn_params = 0
for k in formatted_keys:
    obj = eval('resnet_attn.'+k)
    if 'attn_weights' in k:
        total_attn_params += np.prod(obj.shape)
print("Total number of attention parameters", total_attn_params)

Total number of attention parameters 26560


We want the attention parameters to diverge from 1, therefore we penalize element-wise square loss as $\lambda (1 \times \text{# params} - (x - 1)^2)$

In [14]:
_lambda = 1e-2

In [15]:
def get_params_objs(name, net='resnet_attn'):
    res = []
    for k in formatted_keys:
        obj = eval(f'{net}.'+k)
        if name in k:
            res.append(obj)
    return res

In [16]:
def compute_attn_loss(n_params=26560):
    attns = get_params_objs('attn_weights')
    penality = sum([torch.pow(t - 1,2).mean() for t in attns])
    return _lambda*(n_params - penality)

In [18]:
print_every = 5

In [19]:
def train_one_epoch(add_attn=True):
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_attn.parameters()))
    
    running_loss = 0.0
    running_attn_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()

        optimizer.zero_grad()
        outputs = resnet_attn(inputs)
        loss = cls_criterion(outputs, labels)
        attn_loss = compute_attn_loss()
        if add_attn:
            loss += attn_loss

        loss.backward()
        optimizer.step()


        running_loss += loss.data[0]
        running_attn_loss += attn_loss.data[0]

        if i % print_every == 0:
            print('[%5d] iter, [%2f] epoch, avg loss: %.3f, attn_loss: %.5f ' %
                  (i + 1, i*batch_size/total_imgs, running_loss/print_every, running_attn_loss/print_every))
            running_loss = 0.0
            running_attn_loss = 0.0

Train a fresh fc layer. 
`turn_off_grad_except([])` turns off grads for all weights but the fc layer

In [20]:
turn_off_grad_except(['fc'])
resnet_attn.eval() # Turn off batchnorm
train_one_epoch(add_attn=False)

[    1] iter, [0.000000] epoch, avg loss: 0.995, attn_loss: 53.12000 
[    6] iter, [0.019605] epoch, avg loss: 4.829, attn_loss: 265.60001 
[   11] iter, [0.039211] epoch, avg loss: 4.714, attn_loss: 265.60001 
[   16] iter, [0.058816] epoch, avg loss: 4.777, attn_loss: 265.60001 
[   21] iter, [0.078422] epoch, avg loss: 4.605, attn_loss: 265.60001 
[   26] iter, [0.098027] epoch, avg loss: 4.198, attn_loss: 265.60001 
[   31] iter, [0.117633] epoch, avg loss: 3.953, attn_loss: 265.60001 
[   36] iter, [0.137238] epoch, avg loss: 4.067, attn_loss: 265.60001 
[   41] iter, [0.156844] epoch, avg loss: 3.808, attn_loss: 265.60001 
[   46] iter, [0.176449] epoch, avg loss: 3.970, attn_loss: 265.60001 
[   51] iter, [0.196054] epoch, avg loss: 3.776, attn_loss: 265.60001 
[   56] iter, [0.215660] epoch, avg loss: 3.797, attn_loss: 265.60001 
[   61] iter, [0.235265] epoch, avg loss: 3.523, attn_loss: 265.60001 
[   66] iter, [0.254871] epoch, avg loss: 3.797, attn_loss: 265.60001 
[   71]

In [21]:
from tqdm import tqdm

In [27]:
def score(net=resnet_attn, batch_size=batch_size):
    trainset = torchvision.datasets.ImageFolder(root='./data/train', transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    valset = torchvision.datasets.ImageFolder(root='./data/val', transform=transform)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    train_correct = 0
    val_correct = 0
    
    for inp, label in tqdm(iter(trainloader)):
        _, idx = net(Variable(inp).cuda()).max(1)
        train_correct += int(sum(idx.cpu().data == label))
    
    for inp, label in tqdm(iter(valloader)):
        _, idx = net(Variable(inp).cuda()).max(1)
        val_correct += int(sum(idx.cpu().data == label))
    
    return {
        'train_accu': train_correct/len(trainset),
        'val_accu': val_correct/len(valset)
    }

In [28]:
score(batch_size=64)


  0%|          | 0/128 [00:00<?, ?it/s][A
Exception ignored in: <bound method DataLoaderIter.__del__ of <torch.utils.data.dataloader.DataLoaderIter object at 0x7f728e4da6d8>>
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 333, in __del__
    self._shutdown_workers()
  File "/home/ubuntu/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 319, in _shutdown_workers
    self.data_queue.get()
  File "/home/ubuntu/miniconda3/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/home/ubuntu/miniconda3/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/ubuntu/miniconda3/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/ubuntu/miniconda3/lib/python3.6/multipro

{'train_accu': 0.34603602499693664, 'val_accu': 0.2913340935005701}

In [29]:
turn_off_grad_except(['attn_weights','bn'])
resnet_attn = resnet_attn.train()
_lambda = 1e-4
train_one_epoch(add_attn=True)

[    1] iter, [0.000000] epoch, avg loss: 1.202, attn_loss: 0.53120 
[    6] iter, [0.019605] epoch, avg loss: 5.620, attn_loss: 2.65600 
[   11] iter, [0.039211] epoch, avg loss: 5.551, attn_loss: 2.65600 
[   16] iter, [0.058816] epoch, avg loss: 5.662, attn_loss: 2.65600 
[   21] iter, [0.078422] epoch, avg loss: 5.460, attn_loss: 2.65600 
[   26] iter, [0.098027] epoch, avg loss: 5.253, attn_loss: 2.65600 
[   31] iter, [0.117633] epoch, avg loss: 5.532, attn_loss: 2.65600 
[   36] iter, [0.137238] epoch, avg loss: 5.282, attn_loss: 2.65600 
[   41] iter, [0.156844] epoch, avg loss: 5.439, attn_loss: 2.65600 
[   46] iter, [0.176449] epoch, avg loss: 5.473, attn_loss: 2.65600 
[   51] iter, [0.196054] epoch, avg loss: 5.185, attn_loss: 2.65600 
[   56] iter, [0.215660] epoch, avg loss: 5.185, attn_loss: 2.65600 
[   61] iter, [0.235265] epoch, avg loss: 5.482, attn_loss: 2.65600 
[   66] iter, [0.254871] epoch, avg loss: 5.303, attn_loss: 2.65600 
[   71] iter, [0.274476] epoch, av

In [30]:
score(batch_size=64)

100%|██████████| 128/128 [01:05<00:00,  1.97it/s]
100%|██████████| 28/28 [00:14<00:00,  1.96it/s]


{'train_accu': 0.47714740840583264, 'val_accu': 0.40364880273660203}