Changing the loss

In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np

In [3]:
resnet_pretrained = models.resnet101(pretrained=True)
nn.Conv2d = Conv2d_Attn
resnet_attn = models.resnet101()
resnet_attn.load_state_dict(resnet_pretrained.state_dict(), strict=False)

In [4]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [5]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(lst=[]):
    turning_off_keys = [k for k in formatted_keys for l in lst if l in k]
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        if k in turning_off_keys:
            obj.requires_grad = True
        else:
            obj.requires_grad = False

In [6]:
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 144)

Start training

In [7]:
batch_size = 32

In [8]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

trainset = torchvision.datasets.ImageFolder(root='../data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [9]:
total_imgs = len(trainset.imgs)

In [10]:
resnet_attn = resnet_attn.cuda()

In [11]:
total_attn_params = 0
for k in formatted_keys:
    obj = eval('resnet_attn.'+k)
    if 'attn_weights' in k:
        total_attn_params += np.prod(obj.shape)
print("Total number of attention parameters", total_attn_params)

Total number of attention parameters 52672


We want the attention parameters to diverge from 1, therefore we penalize element-wise square loss as $\lambda (1 \times \text{# params} - (x - 1)^2)$

But this is too big a number,
let's try: 
$- (x - 1)^2$ for now

In [12]:
_lambda = 1e-1 #set default

In [13]:
def get_params_objs(name, net='resnet_attn'):
    res = []
    for k in formatted_keys:
        obj = eval(f'{net}.'+k)
        if name in k:
            res.append(obj)
    return res

In [14]:
def compute_attn_loss(n_params=26560):
    attns = get_params_objs('attn_weights')
    penality = sum([torch.pow(t - 1,2).mean() for t in attns])
#     penality = (1000/n_params)*sum([torch.min(torch.pow(t-2, 2), torch.pow(t, 2)).sum() for t in attns])
    return (_lambda)*(-penality)

In [15]:
len(get_params_objs('attn_weights'))

104

In [16]:
compute_attn_loss()

Variable containing:
-0
[torch.cuda.FloatTensor of size 1 (GPU 0)]

In [17]:
print_every = 5

In [18]:
def train_k_epoch(k, add_attn=True, score_epoch=False):
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_attn.parameters()))
    
    for epoch in range(k):
        running_loss = 0.0
        running_attn_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()

            optimizer.zero_grad()
            outputs = resnet_attn(inputs)
            clf_loss = cls_criterion(outputs, labels)
                        
            attn_loss = compute_attn_loss()
            
            if add_attn:
                loss = clf_loss + attn_loss
            else:
                loss = clf_loss

            loss.backward()
            optimizer.step()

            running_loss += clf_loss.data[0]
            running_attn_loss += attn_loss.data[0]

            if i % print_every == 0:
                print('[%5d] iter, [%2f] epoch, classifer loss: %.3f, attn_loss: %.5f ' %
                      (i + 1, i*batch_size/total_imgs, running_loss/print_every, running_attn_loss/print_every))
                running_loss = 0.0
                running_attn_loss = 0.0
        if score_epoch:
            print(score(batch_size=32))

In [19]:
from tqdm import tqdm
def score(net=resnet_attn, batch_size=batch_size):
    trainset = torchvision.datasets.ImageFolder(root='../data/train', transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    valset = torchvision.datasets.ImageFolder(root='../data/val', transform=transform)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    train_correct = 0
    val_correct = 0
    
    for inp, label in tqdm(iter(trainloader)):
        _, idx = net(Variable(inp).cuda()).max(1)
        train_correct += int(sum(idx.cpu().data == label))
    
    for inp, label in tqdm(iter(valloader)):
        _, idx = net(Variable(inp).cuda()).max(1)
        val_correct += int(sum(idx.cpu().data == label))
    
    return {
        'train_accu': train_correct/len(trainset),
        'val_accu': val_correct/len(valset)
    }

Train a fresh fc layer. 
`turn_off_grad_except([])` turns off grads for all weights but the fc layer

In [20]:
print_every = 50

In [21]:
turn_off_grad_except(['fc'])
resnet_attn.eval()
train_k_epoch(1,score_epoch=True)

[    1] iter, [0.000000] epoch, classifer loss: 0.103, attn_loss: 0.00000 
[   51] iter, [0.196054] epoch, classifer loss: 4.217, attn_loss: 0.00000 
[  101] iter, [0.392109] epoch, classifer loss: 3.421, attn_loss: 0.00000 
[  151] iter, [0.588163] epoch, classifer loss: 3.163, attn_loss: 0.00000 
[  201] iter, [0.784218] epoch, classifer loss: 3.010, attn_loss: 0.00000 
[  251] iter, [0.980272] epoch, classifer loss: 2.986, attn_loss: 0.00000 


100%|██████████| 256/256 [01:51<00:00,  2.29it/s]
100%|██████████| 55/55 [00:24<00:00,  2.28it/s]

{'train_accu': 0.3851243720132337, 'val_accu': 0.31755986316989737}





In [22]:
_lambda = 1
turn_off_grad_except(['attn_weights'])
resnet_attn.eval()
train_k_epoch(2,score_epoch=True)

[    1] iter, [0.000000] epoch, classifer loss: 0.063, attn_loss: 0.00000 
[   51] iter, [0.196054] epoch, classifer loss: 2.338, attn_loss: -0.00941 
[  101] iter, [0.392109] epoch, classifer loss: 2.213, attn_loss: -0.06377 
[  151] iter, [0.588163] epoch, classifer loss: 2.107, attn_loss: -0.20924 
[  251] iter, [0.980272] epoch, classifer loss: 2.027, attn_loss: -0.90647 


100%|██████████| 256/256 [01:51<00:00,  2.29it/s]
100%|██████████| 55/55 [00:24<00:00,  2.28it/s]

{'train_accu': 0.4810684965077809, 'val_accu': 0.3831242873432155}





[    1] iter, [0.000000] epoch, classifer loss: 0.041, attn_loss: -0.02496 
[   51] iter, [0.196054] epoch, classifer loss: 1.988, attn_loss: -1.61095 
[  101] iter, [0.392109] epoch, classifer loss: 1.954, attn_loss: -2.48404 
[  151] iter, [0.588163] epoch, classifer loss: 1.962, attn_loss: -3.62258 
[  201] iter, [0.784218] epoch, classifer loss: 1.968, attn_loss: -5.06184 
[  251] iter, [0.980272] epoch, classifer loss: 1.918, attn_loss: -6.83983 


100%|██████████| 256/256 [01:51<00:00,  2.29it/s]
100%|██████████| 55/55 [00:24<00:00,  2.28it/s]

{'train_accu': 0.5224849895846098, 'val_accu': 0.41163055872291904}





In [23]:
_lambda = 0.5
turn_off_grad_except(['attn_weights'])
resnet_attn.eval()
train_k_epoch(2,score_epoch=True)

[    1] iter, [0.000000] epoch, classifer loss: 0.034, attn_loss: -0.08089 
[   51] iter, [0.196054] epoch, classifer loss: 1.846, attn_loss: -4.36156 
[  101] iter, [0.392109] epoch, classifer loss: 1.807, attn_loss: -5.03179 
[  151] iter, [0.588163] epoch, classifer loss: 1.798, attn_loss: -5.78239 
[  201] iter, [0.784218] epoch, classifer loss: 1.821, attn_loss: -6.62259 
[  251] iter, [0.980272] epoch, classifer loss: 1.829, attn_loss: -7.55782 


100%|██████████| 256/256 [01:51<00:00,  2.29it/s]
100%|██████████| 55/55 [00:24<00:00,  2.28it/s]

{'train_accu': 0.5479720622472736, 'val_accu': 0.4122006841505131}





[    1] iter, [0.000000] epoch, classifer loss: 0.027, attn_loss: -0.16350 
[   51] iter, [0.196054] epoch, classifer loss: 1.675, attn_loss: -8.72544 
[  101] iter, [0.392109] epoch, classifer loss: 1.698, attn_loss: -9.88018 
[  151] iter, [0.588163] epoch, classifer loss: 1.727, attn_loss: -11.15379 
[  201] iter, [0.784218] epoch, classifer loss: 1.791, attn_loss: -12.54730 
[  251] iter, [0.980272] epoch, classifer loss: 1.810, attn_loss: -14.06458 


100%|██████████| 256/256 [01:51<00:00,  2.29it/s]
100%|██████████| 55/55 [00:24<00:00,  2.27it/s]

{'train_accu': 0.5629212106359515, 'val_accu': 0.4161915621436716}





In [24]:
_lambda = 0.1
turn_off_grad_except(['attn_weights'])
resnet_attn.eval()
train_k_epoch(5,score_epoch=True)

[    1] iter, [0.000000] epoch, classifer loss: 0.035, attn_loss: -0.06019 
[   51] iter, [0.196054] epoch, classifer loss: 1.717, attn_loss: -3.10602 
[  101] iter, [0.392109] epoch, classifer loss: 1.728, attn_loss: -3.30089 
[  151] iter, [0.588163] epoch, classifer loss: 1.612, attn_loss: -3.50671 
[  201] iter, [0.784218] epoch, classifer loss: 1.636, attn_loss: -3.72410 
[  251] iter, [0.980272] epoch, classifer loss: 1.690, attn_loss: -3.95235 


100%|██████████| 256/256 [01:52<00:00,  2.28it/s]
100%|██████████| 55/55 [00:24<00:00,  2.27it/s]

{'train_accu': 0.5797083690724176, 'val_accu': 0.4184720638540479}





[    1] iter, [0.000000] epoch, classifer loss: 0.027, attn_loss: -0.08192 
[   51] iter, [0.196054] epoch, classifer loss: 1.557, attn_loss: -4.21855 
[  101] iter, [0.392109] epoch, classifer loss: 1.552, attn_loss: -4.46772 
[  151] iter, [0.588163] epoch, classifer loss: 1.622, attn_loss: -4.72873 
[  201] iter, [0.784218] epoch, classifer loss: 1.627, attn_loss: -5.00201 
[  251] iter, [0.980272] epoch, classifer loss: 1.711, attn_loss: -5.28589 


100%|██████████| 256/256 [01:52<00:00,  2.28it/s]
100%|██████████| 55/55 [00:24<00:00,  2.27it/s]

{'train_accu': 0.5933096434260507, 'val_accu': 0.41733181299885974}





[    1] iter, [0.000000] epoch, classifer loss: 0.033, attn_loss: -0.10929 
[   51] iter, [0.196054] epoch, classifer loss: 1.620, attn_loss: -5.61851 
[  101] iter, [0.392109] epoch, classifer loss: 1.546, attn_loss: -5.92873 
[  151] iter, [0.588163] epoch, classifer loss: 1.496, attn_loss: -6.25151 
[  201] iter, [0.784218] epoch, classifer loss: 1.556, attn_loss: -6.58654 
[  251] iter, [0.980272] epoch, classifer loss: 1.578, attn_loss: -6.93368 


100%|██████████| 256/256 [01:52<00:00,  2.28it/s]
100%|██████████| 55/55 [00:24<00:00,  2.28it/s]

{'train_accu': 0.6085038598211003, 'val_accu': 0.42360319270239455}





[    1] iter, [0.000000] epoch, classifer loss: 0.032, attn_loss: -0.14302 
[   51] iter, [0.196054] epoch, classifer loss: 1.507, attn_loss: -7.33715 
[  101] iter, [0.392109] epoch, classifer loss: 1.533, attn_loss: -7.70959 
[  151] iter, [0.588163] epoch, classifer loss: 1.514, attn_loss: -8.09485 
[  201] iter, [0.784218] epoch, classifer loss: 1.556, attn_loss: -8.49392 
[  251] iter, [0.980272] epoch, classifer loss: 1.458, attn_loss: -8.90481 


100%|██████████| 256/256 [01:52<00:00,  2.28it/s]
100%|██████████| 55/55 [00:24<00:00,  2.27it/s]

{'train_accu': 0.615243229996324, 'val_accu': 0.41961231470923605}





[    1] iter, [0.000000] epoch, classifer loss: 0.028, attn_loss: -0.18323 
[   51] iter, [0.196054] epoch, classifer loss: 1.492, attn_loss: -9.38066 
[  101] iter, [0.392109] epoch, classifer loss: 1.517, attn_loss: -9.81832 
[  151] iter, [0.588163] epoch, classifer loss: 1.463, attn_loss: -10.26987 
[  201] iter, [0.784218] epoch, classifer loss: 1.435, attn_loss: -10.73474 
[  251] iter, [0.980272] epoch, classifer loss: 1.478, attn_loss: -11.21244 


100%|██████████| 256/256 [01:52<00:00,  2.28it/s]
100%|██████████| 55/55 [00:24<00:00,  2.27it/s]

{'train_accu': 0.6328881264550913, 'val_accu': 0.411060433295325}





Looks promising, let me train a baseline first

In [25]:
turn_off_grad_except(['bn'])
resnet_attn.eval()
train_k_epoch(1,score_epoch=True,add_attn=False)

[    1] iter, [0.000000] epoch, classifer loss: 0.033, attn_loss: -0.23019 
[   51] iter, [0.196054] epoch, classifer loss: 2.036, attn_loss: -11.50928 
[  101] iter, [0.392109] epoch, classifer loss: 1.875, attn_loss: -11.50928 
[  151] iter, [0.588163] epoch, classifer loss: 1.925, attn_loss: -11.50928 
[  201] iter, [0.784218] epoch, classifer loss: 1.932, attn_loss: -11.50928 
[  251] iter, [0.980272] epoch, classifer loss: 1.892, attn_loss: -11.50928 


100%|██████████| 256/256 [01:51<00:00,  2.29it/s]
100%|██████████| 55/55 [00:24<00:00,  2.28it/s]

{'train_accu': 0.508516113221419, 'val_accu': 0.37058152793614596}





In [26]:
turn_off_grad_except(['fc'])
resnet_attn.eval()
train_k_epoch(1,score_epoch=True,add_attn=False)

[    1] iter, [0.000000] epoch, classifer loss: 0.030, attn_loss: -0.23019 
[   51] iter, [0.196054] epoch, classifer loss: 2.406, attn_loss: -11.50928 
[  101] iter, [0.392109] epoch, classifer loss: 2.001, attn_loss: -11.50928 
[  151] iter, [0.588163] epoch, classifer loss: 1.896, attn_loss: -11.50928 
[  201] iter, [0.784218] epoch, classifer loss: 1.982, attn_loss: -11.50928 
[  251] iter, [0.980272] epoch, classifer loss: 1.890, attn_loss: -11.50928 


100%|██████████| 256/256 [01:51<00:00,  2.29it/s]
100%|██████████| 55/55 [00:24<00:00,  2.28it/s]

{'train_accu': 0.5950251194706531, 'val_accu': 0.4070695553021665}



