In [1]:
from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models

In [2]:
import re

In [3]:
import numpy as np

In [4]:
from torchvision import datasets, transforms

In [5]:
resnet_pretrained = models.resnet50(pretrained=True)

In [6]:
nn.Conv2d = Conv2d_Attn
resnet_attn = models.resnet50()

In [7]:
resnet_attn.load_state_dict(resnet_pretrained.state_dict(), strict=False)

In [8]:
# Change batchnorm behavior
resnet_attn = resnet_attn.eval() 

In [9]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [10]:
# This block turn off gradient up for all params except attn_weights
for k in formatted_keys:
    obj = eval('resnet_attn.'+k)
    if not 'attn_weights' in k:
        obj.requires_grad = False

In [11]:
out_class = 10
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 10)

Start training

In [12]:
import torch
import torchvision
import torchvision.transforms as transforms

In [13]:
transform = transforms.Compose(
    [transforms.Scale(224),
        transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

  "please use transforms.Resize instead.")


Files already downloaded and verified


In [14]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

In [15]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, resnet_attn.parameters()), lr=0.001, momentum=0.9)

In [16]:
from torch.autograd import Variable

In [None]:
resnet_attn = resnet_attn.cuda()

In [None]:
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
    # get the inputs
    inputs, labels = data

    # wrap them in Variable
    inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = resnet_attn(inputs)
    loss = criterion(outputs, labels)
    
    attn_summary = []
    
    for k in formatted_keys:
        if 'attn_weights' in k:
            obj = eval('resnet_attn.'+k)
#             loss += torch.mean(torch.pow(obj - 0.5, 2))
            attn_summary.append(obj.data.mean())
    
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.data[0]
    if i % 200 == 199:    # print every 2000 mini-batches
        print('[%5d] loss: %.3f, attn_mean: %.3f, attn_std: %.3f' %
              (i + 1, running_loss / 2000, np.mean(attn_summary), np.std(attn_summary)))
        running_loss = 0.0


[  200] loss: 0.126, attn_mean: 1.001, attn_std: 0.001
[  400] loss: 0.060, attn_mean: 1.001, attn_std: 0.002
[  600] loss: 0.046, attn_mean: 1.002, attn_std: 0.003
[  800] loss: 0.045, attn_mean: 1.002, attn_std: 0.004
[ 1000] loss: 0.033, attn_mean: 1.002, attn_std: 0.004
[ 1200] loss: 0.035, attn_mean: 1.002, attn_std: 0.004
[ 1400] loss: 0.035, attn_mean: 1.002, attn_std: 0.005
[ 1600] loss: 0.037, attn_mean: 1.003, attn_std: 0.005
[ 1800] loss: 0.029, attn_mean: 1.003, attn_std: 0.006
[ 2000] loss: 0.033, attn_mean: 1.003, attn_std: 0.006
[ 2200] loss: 0.029, attn_mean: 1.003, attn_std: 0.006
[ 2400] loss: 0.035, attn_mean: 1.003, attn_std: 0.006
[ 2600] loss: 0.035, attn_mean: 1.003, attn_std: 0.006
[ 2800] loss: 0.032, attn_mean: 1.003, attn_std: 0.007
[ 3000] loss: 0.026, attn_mean: 1.003, attn_std: 0.007
[ 3200] loss: 0.026, attn_mean: 1.003, attn_std: 0.007
[ 3400] loss: 0.032, attn_mean: 1.003, attn_std: 0.007
[ 3600] loss: 0.026, attn_mean: 1.003, attn_std: 0.007
[ 3800] lo

Training run @ April 10, 2018 – 21:27:
- We need to add regularization to move the attention around. 
- Otherwise, the attention will just stay where they are and be happy with it. 
- Todo:
    - Figure out the correct attn size for a conv tensor of size (channel_out, channel_in, H, W)
    - Figure out the correct initialization + penalization scheme. It's 1 for now. 