In [1]:
from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np

In [2]:
resnet_pretrained = models.resnet50(pretrained=True)

In [3]:
nn.Conv2d = Conv2d_Attn
resnet_attn = models.resnet50()

In [4]:
resnet_attn.load_state_dict(resnet_pretrained.state_dict(), strict=False)

In [5]:
# Change batchnorm behavior
resnet_attn = resnet_attn.eval() 

In [6]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [18]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except_attn():
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        if not 'attn_weights' in k:
            obj.requires_grad = False
        else:
            obj.requires_grad = True

In [8]:
def turn_off_al_grad():
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        obj.requires_grad = False

In [9]:
turn_off_al_grad()
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 10)

Start training

In [10]:
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_attn.parameters()))

In [12]:
resnet_attn = resnet_attn.cuda()

In [13]:
# reg_lambda = 1e-4

Just train the FC layer

In [14]:
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
    # get the inputs
    inputs, labels = data

    # wrap them in Variable
    inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = resnet_attn(inputs)
    loss = criterion(outputs, labels)
    
#     for k in formatted_keys:
#         if 'attn_weights' in k:
#             obj = eval('resnet_attn.'+k)
#             loss += reg_lambda*torch.norm(obj, p=1)
    
    loss.backward()
    optimizer.step()
    
    
    # print statistics
    running_loss += loss.data[0]
    if i % 200 == 199:    # print every 2000 mini-batches
        attn_mean = []
        attn_std = []

        for k in formatted_keys:
            if 'attn_weights' in k:
                obj = eval('resnet_attn.'+k)
                attn_mean.append(obj.data.mean())
                attn_std.append(obj.data.std())
            
        print('[%5d] loss: %.3f, attn_mean: %.3f, attn_std: %.3f' %
              (i + 1, running_loss / 200, np.mean(attn_mean), np.std(attn_std)))
        running_loss = 0.0

[  200] loss: 1.303, attn_mean: 1.000, attn_std: 0.000
[  400] loss: 0.769, attn_mean: 1.000, attn_std: 0.000
[  600] loss: 0.752, attn_mean: 1.000, attn_std: 0.000
[  800] loss: 0.723, attn_mean: 1.000, attn_std: 0.000
[ 1000] loss: 0.741, attn_mean: 1.000, attn_std: 0.000
[ 1200] loss: 0.652, attn_mean: 1.000, attn_std: 0.000
[ 1400] loss: 0.626, attn_mean: 1.000, attn_std: 0.000
[ 1600] loss: 0.695, attn_mean: 1.000, attn_std: 0.000
[ 1800] loss: 0.588, attn_mean: 1.000, attn_std: 0.000
[ 2000] loss: 0.611, attn_mean: 1.000, attn_std: 0.000
[ 2200] loss: 0.592, attn_mean: 1.000, attn_std: 0.000
[ 2400] loss: 0.681, attn_mean: 1.000, attn_std: 0.000
[ 2600] loss: 0.609, attn_mean: 1.000, attn_std: 0.000
[ 2800] loss: 0.639, attn_mean: 1.000, attn_std: 0.000
[ 3000] loss: 0.599, attn_mean: 1.000, attn_std: 0.000
[ 3200] loss: 0.632, attn_mean: 1.000, attn_std: 0.000
[ 3400] loss: 0.600, attn_mean: 1.000, attn_std: 0.000
[ 3600] loss: 0.608, attn_mean: 1.000, attn_std: 0.000
[ 3800] lo

Test

In [15]:
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

In [16]:
correct = 0
total = 0
for data in testloader:
    images, labels = data
    outputs = resnet_attn(Variable(images).cuda())
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted.cpu() == labels).sum()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 82 %


In [19]:
turn_off_grad_except_attn()

In [22]:
running_loss = 0.0
reg_lambda = 5e-7

optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_attn.parameters()))

for i, data in enumerate(trainloader, 0):
    # get the inputs
    inputs, labels = data

    # wrap them in Variable
    inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = resnet_attn(inputs)
    loss = criterion(outputs, labels)
    
    for k in formatted_keys:
        if 'attn_weights' in k:
            obj = eval('resnet_attn.'+k)
            loss += reg_lambda*torch.norm(obj, p=1)
    
    loss.backward()
    optimizer.step()
    
    
    # print statistics
    running_loss += loss.data[0]
    if i % 200 == 199:    # print every 2000 mini-batches
        attn_mean = []
        attn_std = []

        for k in formatted_keys:
            if 'attn_weights' in k:
                obj = eval('resnet_attn.'+k)
                attn_mean.append(obj.data.mean())
                attn_std.append(obj.data.std())
            
        print('[%5d] loss: %.3f, attn_mean: %.3f, attn_std: %.3f' %
              (i + 1, running_loss / 200, np.mean(attn_mean), np.std(attn_std)))
        running_loss = 0.0

[  200] loss: 6.002, attn_mean: 0.871, attn_std: 0.092
[  400] loss: 5.954, attn_mean: 0.860, attn_std: 0.109
[  600] loss: 5.870, attn_mean: 0.856, attn_std: 0.110
[  800] loss: 5.811, attn_mean: 0.853, attn_std: 0.110
[ 1000] loss: 5.707, attn_mean: 0.850, attn_std: 0.111
[ 1200] loss: 5.668, attn_mean: 0.847, attn_std: 0.111
[ 1400] loss: 5.646, attn_mean: 0.844, attn_std: 0.111
[ 1600] loss: 5.603, attn_mean: 0.841, attn_std: 0.111
[ 1800] loss: 5.598, attn_mean: 0.838, attn_std: 0.111
[ 2000] loss: 5.523, attn_mean: 0.836, attn_std: 0.110
[ 2200] loss: 5.496, attn_mean: 0.833, attn_std: 0.110
[ 2400] loss: 5.467, attn_mean: 0.831, attn_std: 0.110
[ 2600] loss: 5.444, attn_mean: 0.828, attn_std: 0.110
[ 2800] loss: 5.393, attn_mean: 0.826, attn_std: 0.110
[ 3000] loss: 5.386, attn_mean: 0.823, attn_std: 0.110


Process Process-10:
  File "/home/seg-image/anaconda3/envs/simon/lib/python3.6/multiprocessing/queues.py", line 341, in get
    with self._rlock:
Process Process-9:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/seg-image/anaconda3/envs/simon/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/seg-image/anaconda3/envs/simon/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/seg-image/anaconda3/envs/simon/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/seg-image/anaconda3/envs/simon/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/seg-image/anaconda3/envs/simon/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/home/seg-image/anaconda3/envs/simon/lib/pytho

KeyboardInterrupt: 

In [23]:
correct = 0
total = 0
for data in testloader:
    images, labels = data
    outputs = resnet_attn(Variable(images).cuda())
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted.cpu() == labels).sum()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 84 %


Training run @ April 10, 2018 – 21:27:
- We need to add regularization to move the attention around. 
- Otherwise, the attention will just stay where they are and be happy with it. 
- Todo:
    - Figure out the correct attn size for a conv tensor of size (channel_out, channel_in, H, W)
    - Figure out the correct initialization + penalization scheme. It's 1 for now. 