In [1]:
from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np
from tqdm import tqdm
from datetime import datetime

In [2]:
resnet_pretrained = models.resnet50(pretrained=True)
nn.Conv2d = Conv2d_Attn
resnet_attn = models.resnet50()
resnet_attn.load_state_dict(resnet_pretrained.state_dict(), strict=False)
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 144)

In [5]:
torch.save(resnet_attn, 'fresh_resnet_attn.pkl')

In [7]:
resnet_attn = torch.load('fresh_resnet_attn.pkl')

In [6]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [7]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(lst=[]):
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        for kw in lst:
            if not kw in k:
                obj.requires_grad = False
            else:
                obj.requires_grad = True

Start training

In [8]:
batch_size = 32

In [9]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

trainset = torchvision.datasets.ImageFolder(root='./data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [10]:
total_imgs = len(trainset.imgs)

In [11]:
resnet_attn = resnet_attn.cuda()

In [12]:
total_attn_params = 0
for k in formatted_keys:
    obj = eval('resnet_attn.'+k)
    if 'attn_weights' in k:
        total_attn_params += np.prod(obj.shape)
print("Total number of attention parameters", total_attn_params)

Total number of attention parameters 26560


We want the attention parameters to diverge from 1, therefore we penalize element-wise square loss as $\lambda (1 \times \text{# params} - (x - 1)^2)$

But this is too big a number,
let's try: 
$- (x - 1)^2$ for now

In [13]:
_lambda = 1e-2 #set default

In [14]:
def get_params_objs(name, net='resnet_attn'):
    res = []
    for k in formatted_keys:
        obj = eval(f'{net}.'+k)
        if name in k:
            res.append(obj)
    return res

In [15]:
def compute_attn_loss(n_params=26560):
    attns = get_params_objs('attn_weights')
    penality = sum([torch.pow(t - 1,2).mean() for t in attns])
    return _lambda*(- penality)

In [16]:
print_every = 5

In [17]:
def train_k_epoch(epoch, add_attn=True):
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_attn.parameters()))
    
    running_loss = 0.0
    running_attn_loss = 0.0
    
    for k in range(epoch):
        for i, data in enumerate(trainloader, 0):
            i += int(len(trainset)*k/batch_size)
            
            inputs, labels = data
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()

            optimizer.zero_grad()
            outputs = resnet_attn(inputs)
            loss = cls_criterion(outputs, labels)
            attn_loss = compute_attn_loss()
            if add_attn:
                loss += attn_loss

            loss.backward()
            optimizer.step()


            running_loss += loss.data[0]
            running_attn_loss += attn_loss.data[0]

            if i % print_every == 0:
                print('[%5d] iter, [%2f] epoch, avg loss: %.3f, attn_loss: %.5f ' %
                      (i + 1, 
                       i*batch_size/total_imgs, 
                       running_loss/print_every, 
                       running_attn_loss/print_every))
                running_loss = 0.0
                running_attn_loss = 0.0
        print(score(resnet_attn))

In [18]:
from tqdm import tqdm
def score(net, folder='val', batch_size=batch_size):
    valset = torchvision.datasets.ImageFolder(root='./data/{}'.format(folder), transform=transform)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    val_correct = 0
    
    for inp, label in tqdm(iter(valloader)):
        _, idx = net(Variable(inp).cuda()).max(1)
        val_correct += int(sum(idx.cpu().data == label))
    
    return {
        'val_accu': val_correct/len(valset)
    }

Now start with a new model
- Trian fc, train attn

In [20]:
resnet_attn = torch.load('fresh_resnet_attn.pkl')

In [19]:
print_every = 10

In [20]:
def train_fc():
    turn_off_grad_except(['fc'])
    resnet_attn.eval() 
    train_k_epoch(1, add_attn=False)

In [21]:
def train_attn():
    turn_off_grad_except(['attn_weights'])
    resnet_attn.eval() 
    _lambda = 1
    train_k_epoch(1, add_attn=True)

In [22]:
train_fc()
train_attn()

[    1] iter, [0.000000] epoch, avg loss: 0.505, attn_loss: 0.00000 
[   11] iter, [0.039211] epoch, avg loss: 4.777, attn_loss: 0.00000 
[   21] iter, [0.078422] epoch, avg loss: 4.251, attn_loss: 0.00000 
[   31] iter, [0.117633] epoch, avg loss: 4.373, attn_loss: 0.00000 
[   41] iter, [0.156844] epoch, avg loss: 4.044, attn_loss: 0.00000 
[   51] iter, [0.196054] epoch, avg loss: 3.723, attn_loss: 0.00000 
[   61] iter, [0.235265] epoch, avg loss: 3.707, attn_loss: 0.00000 
[   71] iter, [0.274476] epoch, avg loss: 3.712, attn_loss: 0.00000 
[   81] iter, [0.313687] epoch, avg loss: 3.551, attn_loss: 0.00000 
[   91] iter, [0.352898] epoch, avg loss: 3.507, attn_loss: 0.00000 
[  101] iter, [0.392109] epoch, avg loss: 3.530, attn_loss: 0.00000 
[  111] iter, [0.431320] epoch, avg loss: 3.140, attn_loss: 0.00000 
[  121] iter, [0.470531] epoch, avg loss: 3.263, attn_loss: 0.00000 
[  131] iter, [0.509741] epoch, avg loss: 3.357, attn_loss: 0.00000 
[  141] iter, [0.548952] epoch, av

100%|██████████| 55/55 [00:13<00:00,  4.01it/s]

{'val_accu': 0.3363740022805017}





[    1] iter, [0.000000] epoch, avg loss: 0.229, attn_loss: 0.00000 
[   11] iter, [0.039211] epoch, avg loss: 2.292, attn_loss: -0.00000 
[   21] iter, [0.078422] epoch, avg loss: 2.514, attn_loss: -0.00002 
[   31] iter, [0.117633] epoch, avg loss: 2.449, attn_loss: -0.00003 
[   41] iter, [0.156844] epoch, avg loss: 2.404, attn_loss: -0.00004 
[   51] iter, [0.196054] epoch, avg loss: 2.201, attn_loss: -0.00005 
[   61] iter, [0.235265] epoch, avg loss: 2.535, attn_loss: -0.00006 
[   71] iter, [0.274476] epoch, avg loss: 2.430, attn_loss: -0.00007 
[   81] iter, [0.313687] epoch, avg loss: 2.199, attn_loss: -0.00009 
[   91] iter, [0.352898] epoch, avg loss: 1.990, attn_loss: -0.00010 
[  101] iter, [0.392109] epoch, avg loss: 2.253, attn_loss: -0.00012 
[  111] iter, [0.431320] epoch, avg loss: 2.047, attn_loss: -0.00014 
[  121] iter, [0.470531] epoch, avg loss: 2.287, attn_loss: -0.00015 
[  131] iter, [0.509741] epoch, avg loss: 2.283, attn_loss: -0.00016 
[  141] iter, [0.5489

100%|██████████| 55/55 [00:13<00:00,  3.97it/s]

{'val_accu': 0.395096921322691}





In [35]:
PRUNE_PORTION = 1/4

In [25]:
attn_weights = get_params_objs('attn_weights')
attn_masks = get_params_objs('attn_mask')

In [29]:
t = attn_weights[0].data

In [32]:
_, rank = torch.sort(t, dim=0)

In [39]:
prune_idxs = rank[:int(len(rank)*PRUNE_PORTION)]

In [42]:
attn_masks[0].data[prune_idxs.squeeze()] = 0

In [43]:
attn_masks[0]

Parameter containing:
(0 ,0 ,.,.) = 
  0

(1 ,0 ,.,.) = 
  0

(2 ,0 ,.,.) = 
  1

(3 ,0 ,.,.) = 
  0

(4 ,0 ,.,.) = 
  0

(5 ,0 ,.,.) = 
  1

(6 ,0 ,.,.) = 
  0

(7 ,0 ,.,.) = 
  1

(8 ,0 ,.,.) = 
  0

(9 ,0 ,.,.) = 
  1

(10,0 ,.,.) = 
  1

(11,0 ,.,.) = 
  1

(12,0 ,.,.) = 
  1

(13,0 ,.,.) = 
  1

(14,0 ,.,.) = 
  1

(15,0 ,.,.) = 
  1

(16,0 ,.,.) = 
  1

(17,0 ,.,.) = 
  0

(18,0 ,.,.) = 
  1

(19,0 ,.,.) = 
  1

(20,0 ,.,.) = 
  1

(21,0 ,.,.) = 
  0

(22,0 ,.,.) = 
  1

(23,0 ,.,.) = 
  0

(24,0 ,.,.) = 
  1

(25,0 ,.,.) = 
  1

(26,0 ,.,.) = 
  1

(27,0 ,.,.) = 
  0

(28,0 ,.,.) = 
  1

(29,0 ,.,.) = 
  0

(30,0 ,.,.) = 
  1

(31,0 ,.,.) = 
  1

(32,0 ,.,.) = 
  1

(33,0 ,.,.) = 
  1

(34,0 ,.,.) = 
  1

(35,0 ,.,.) = 
  1

(36,0 ,.,.) = 
  1

(37,0 ,.,.) = 
  1

(38,0 ,.,.) = 
  1

(39,0 ,.,.) = 
  1

(40,0 ,.,.) = 
  0

(41,0 ,.,.) = 
  1

(42,0 ,.,.) = 
  1

(43,0 ,.,.) = 
  1

(44,0 ,.,.) = 
  1

(45,0 ,.,.) = 
  1

(46,0 ,.,.) = 
  0

(47,0 ,.,.) = 
  1

(48,0 ,.,.) = 
  1