In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np

In [3]:
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

In [4]:
# resnet_attn = torch.load('new_attn.pkl')

In [5]:
resnet_pretrained = models.resnet50(pretrained=True)
nn.Conv2d = Conv2d_Attn
resnet_attn = models.resnet50()
resnet_attn.load_state_dict(resnet_pretrained.state_dict(), strict=False)

In [6]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [7]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(lst=[]):
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        for kw in lst:
            if not kw in k:
                obj.requires_grad = False
            else:
                obj.requires_grad = True

In [8]:
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 144)

Start training

In [9]:
batch_size = 32
# batch_size = 64

In [10]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

trainset = torchvision.datasets.ImageFolder(root='../data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

valset = torchvision.datasets.ImageFolder(root='../data/val', transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [11]:
total_imgs = len(trainset.imgs)

In [12]:
resnet_attn = resnet_attn.cuda()

In [13]:
total_attn_params = 0
for k in formatted_keys:
    obj = eval('resnet_attn.'+k)
    if 'attn_weights' in k:
        total_attn_params += np.prod(obj.shape)
print("Total number of attention parameters", total_attn_params)

Total number of attention parameters 13385920


We want the attention parameters to diverge from 1, therefore we penalize element-wise square loss as $\lambda (1 \times \text{# params} - (x - 1)^2)$

But this is too big a number,
let's try: 
$- (x - 1)^2$ for now

In [14]:
_lambda = 1e-2 #set default

In [15]:
def get_params_objs(name, net='resnet_attn'):
    res = []
    for k in formatted_keys:
        obj = eval(f'{net}.'+k)
        if name in k:
            res.append(obj)
    return res

In [16]:
def compute_attn_loss(n_params=26560):
    attns = get_params_objs('attn_weights')
    penality = sum([torch.pow(t - 1,2).mean() for t in attns])
    return _lambda*(- penality)

In [17]:
print_every = 50

In [18]:
def score_top3(train=True, val=True):
    if train:
        correct_count = 0
        for inp, label in tqdm(iter(trainloader)):
            _, idx = resnet_attn(Variable(inp).cuda()).topk(3)
            lab = Variable(label).cuda()
            lab_expand = lab.unsqueeze(1).expand_as(idx)
            correct_count += int((idx == lab_expand).sum())
        print(correct_count/len(trainset))
    
    if val:
        correct_count = 0
        for inp, label in tqdm(iter(valloader)):
            _, idx = resnet_attn(Variable(inp).cuda()).topk(3)
            lab = Variable(label).cuda()
            lab_expand = lab.unsqueeze(1).expand_as(idx)
            correct_count += int((idx == lab_expand).sum())
        print(correct_count/len(valset))

In [19]:
def plot_attn_hist():
    attns = get_params_objs('attn_weights')
    attns = torch.cat([attn.squeeze() for attn in attns])
    attns_arr = attns.data.cpu().numpy()
    plt.hist(attns_arr)

In [20]:
def train(k=1, add_attn=True, score=True, plot_hist=False):
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_attn.parameters()))
    
    for j in range(k):
        running_loss = 0.0
        running_attn_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()

            optimizer.zero_grad()
            outputs = resnet_attn(inputs)
            loss = cls_criterion(outputs, labels)
            attn_loss = compute_attn_loss()
            if add_attn:
                loss += attn_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.data[0]
            running_attn_loss += attn_loss.data[0]

            if i % print_every == 0:
                print('[%5d] iter, [%2f] epoch, avg loss: %.3f, attn_loss: %.5f ' %
                      (i + 1, i*batch_size/total_imgs, running_loss/print_every, running_attn_loss/print_every))
                running_loss = 0.0
                running_attn_loss = 0.0
        if score:
            score_top3()
            print(score_top1(batch_size=32))
        if plot_hist:
            plot_attn_hist()
            plt.show()

In [21]:
from tqdm import tqdm
def score_top1(net=resnet_attn, batch_size=batch_size):    
    train_correct = 0
    val_correct = 0
    
#     for inp, label in tqdm(iter(trainloader)):
#         _, idx = net(Variable(inp).cuda()).max(1)
#         train_correct += int(sum(idx.cpu().data == label))
    
    for inp, label in tqdm(iter(valloader)):
        _, idx = net(Variable(inp).cuda()).max(1)
        val_correct += int(sum(idx.cpu().data == label))
    
    return {
#         'train_accu': train_correct/len(trainset),
        'val_accu': val_correct/len(valset)
    }

In [22]:
turn_off_grad_except(['fc'])
resnet_attn.eval() # Turn on batchnorm
train(1, add_attn=False, plot_hist=False)

[    1] iter, [0.000000] epoch, avg loss: 0.099, attn_loss: 0.00000 
[   51] iter, [0.196054] epoch, avg loss: 4.192, attn_loss: 0.00000 
[  101] iter, [0.392109] epoch, avg loss: 3.507, attn_loss: 0.00000 
[  151] iter, [0.588163] epoch, avg loss: 3.218, attn_loss: 0.00000 
[  201] iter, [0.784218] epoch, avg loss: 3.111, attn_loss: 0.00000 
[  251] iter, [0.980272] epoch, avg loss: 3.036, attn_loss: 0.00000 


100%|██████████| 256/256 [01:04<00:00,  3.99it/s]

0.5542212964097537



100%|██████████| 55/55 [00:13<00:00,  3.96it/s]

0.4925883694412771



100%|██████████| 55/55 [00:13<00:00,  3.94it/s]

{'val_accu': 0.30444697833523376}





In [23]:
turn_off_grad_except(['attn_weights'])
resnet_attn.eval() # Turn on batchnorm
_lambda=1
train(5, add_attn=False, plot_hist=False)

[    1] iter, [0.000000] epoch, avg loss: 0.040, attn_loss: 0.00000 
[   51] iter, [0.196054] epoch, avg loss: 2.321, attn_loss: -0.00144 
[  101] iter, [0.392109] epoch, avg loss: 2.125, attn_loss: -0.00378 
[  151] iter, [0.588163] epoch, avg loss: 1.980, attn_loss: -0.00601 
[  201] iter, [0.784218] epoch, avg loss: 1.983, attn_loss: -0.00821 
[  251] iter, [0.980272] epoch, avg loss: 2.009, attn_loss: -0.01034 


100%|██████████| 256/256 [01:04<00:00,  3.97it/s]

0.7838500183801005



100%|██████████| 55/55 [00:14<00:00,  3.92it/s]

0.6299885974914481



100%|██████████| 55/55 [00:13<00:00,  3.95it/s]

{'val_accu': 0.4241733181299886}





[    1] iter, [0.000000] epoch, avg loss: 0.029, attn_loss: -0.00023 
[   51] iter, [0.196054] epoch, avg loss: 1.585, attn_loss: -0.01341 
[  101] iter, [0.392109] epoch, avg loss: 1.552, attn_loss: -0.01680 
[  151] iter, [0.588163] epoch, avg loss: 1.575, attn_loss: -0.02022 
[  201] iter, [0.784218] epoch, avg loss: 1.435, attn_loss: -0.02373 
[  251] iter, [0.980272] epoch, avg loss: 1.491, attn_loss: -0.02709 


100%|██████████| 256/256 [01:04<00:00,  3.97it/s]

0.8566352162725156



100%|██████████| 55/55 [00:13<00:00,  3.94it/s]

0.6459521094640821



100%|██████████| 55/55 [00:13<00:00,  3.93it/s]

{'val_accu': 0.4395667046750285}





[    1] iter, [0.000000] epoch, avg loss: 0.026, attn_loss: -0.00058 
[   51] iter, [0.196054] epoch, avg loss: 1.170, attn_loss: -0.03128 
[  101] iter, [0.392109] epoch, avg loss: 1.107, attn_loss: -0.03567 
[  151] iter, [0.588163] epoch, avg loss: 1.086, attn_loss: -0.04002 
[  201] iter, [0.784218] epoch, avg loss: 1.121, attn_loss: -0.04420 
[  251] iter, [0.980272] epoch, avg loss: 1.151, attn_loss: -0.04847 


100%|██████████| 256/256 [01:04<00:00,  3.97it/s]

0.9225585099865212



100%|██████████| 55/55 [00:13<00:00,  3.94it/s]

0.661345496009122



100%|██████████| 55/55 [00:13<00:00,  3.95it/s]

{'val_accu': 0.43671607753705816}





[    1] iter, [0.000000] epoch, avg loss: 0.016, attn_loss: -0.00102 
[   51] iter, [0.196054] epoch, avg loss: 0.925, attn_loss: -0.05448 
[  101] iter, [0.392109] epoch, avg loss: 0.790, attn_loss: -0.05919 
[  151] iter, [0.588163] epoch, avg loss: 0.737, attn_loss: -0.06345 
[  201] iter, [0.784218] epoch, avg loss: 0.755, attn_loss: -0.06767 
[  251] iter, [0.980272] epoch, avg loss: 0.780, attn_loss: -0.07209 


100%|██████████| 256/256 [01:04<00:00,  3.96it/s]

0.9596863129518441



100%|██████████| 55/55 [00:13<00:00,  3.94it/s]

0.6556442417331813



100%|██████████| 55/55 [00:13<00:00,  3.94it/s]

{'val_accu': 0.4395667046750285}





[    1] iter, [0.000000] epoch, avg loss: 0.009, attn_loss: -0.00150 
[   51] iter, [0.196054] epoch, avg loss: 0.526, attn_loss: -0.07761 
[  101] iter, [0.392109] epoch, avg loss: 0.496, attn_loss: -0.08194 
[  151] iter, [0.588163] epoch, avg loss: 0.510, attn_loss: -0.08611 
[  201] iter, [0.784218] epoch, avg loss: 0.460, attn_loss: -0.09033 
[  251] iter, [0.980272] epoch, avg loss: 0.491, attn_loss: -0.09470 


100%|██████████| 256/256 [01:04<00:00,  3.97it/s]

0.9779438794265408



100%|██████████| 55/55 [00:13<00:00,  3.93it/s]

0.6795895096921323



100%|██████████| 55/55 [00:14<00:00,  3.93it/s]

{'val_accu': 0.45210946408209807}





In [24]:
score(batch_size=32)

NameError: name 'score' is not defined

In [None]:
turn_off_grad_except(['attn_weights'])
resnet_attn.eval() # Turn on batchnorm
_lambda=1
train(2, add_attn=False, plot_hist=False)

In [None]:
torch.save(resnet_attn, 'new_attn.pkl')

In [None]:
resnet_attn