In [1]:
from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np
import time
from tqdm import tqdm

In [2]:
resnet_pretrained = models.resnet50(pretrained=True)
nn.Conv2d = Conv2d_Attn
resnet_attn = models.resnet50()
    
resnet_attn.load_state_dict(resnet_pretrained.state_dict(), strict=False)
# print(resnet_pretrained.state_dist())

In [3]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [4]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(lst=[]):
    turning_off_keys = [k for k in formatted_keys for l in lst if l in k]
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        if k in turning_off_keys:
            print(k)
            obj.requires_grad = True
        else:
            obj.requires_grad = False
    

def turn_on_layer4_conv_weight():
    for k in range(3):
        for j in range(1, 4):
            eval('resnet_attn.layer4[{}].conv{}.weight'.format(k, j)).requires_gred = True

In [5]:
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 144)

Start training

In [6]:
batch_size = 32

In [7]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

trainset = torchvision.datasets.ImageFolder(root='./data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [8]:
total_imgs = len(trainset.imgs)

In [9]:
resnet_attn = resnet_attn.cuda()

In [10]:
total_attn_params = 0
for k in formatted_keys:
    obj = eval('resnet_attn.'+k)
    if 'attn_weights' in k:
        total_attn_params += np.prod(obj.shape)
print("Total number of attention parameters", total_attn_params)

Total number of attention parameters 26560


We want the attention parameters to diverge from 1, therefore we penalize element-wise square loss as $\lambda (1 \times \text{# params} - (x - 1)^2)$

But this is too big a number,
let's try: 
$- (x - 1)^2$ for now

In [11]:
_lambda = 1e-1 #set default

In [12]:
def get_params_objs(name, net='resnet_attn'):
    res = []
    for k in formatted_keys:
        obj = eval('{}.'.format(net)+k)
        if name in k:
            res.append(obj)
    return res

In [13]:
def compute_attn_loss(n_params=26560):
    attns = get_params_objs('attn_weights')
#     penality = sum([torch.pow(t - 1,2).mean() for t in attns])
    penality = (1000/n_params)*sum([torch.min(torch.pow(t-2, 2), torch.pow(t, 2)).sum() for t in attns])
    return (_lambda)*(penality)

In [14]:
len(get_params_objs('attn_weights'))

53

In [15]:
compute_attn_loss()

tensor(100.0000, device='cuda:0')

In [16]:
print_every = 20

In [26]:
def train_k_epoch(k, add_attn=True, score_epoch=False):
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_attn.parameters()))
    
    start = time.time()
    
    for epoch in range(k):
        running_loss = 0.0
        running_attn_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()

            optimizer.zero_grad()
            outputs = resnet_attn(inputs)
            clf_loss = cls_criterion(outputs, labels)
                        
            attn_loss = compute_attn_loss()
            
            if add_attn:
                loss = clf_loss + attn_loss
            else:
                loss = clf_loss

            loss.backward()
            optimizer.step()

            running_loss += clf_loss.data[0]
            running_attn_loss += attn_loss.data[0]

            if i % print_every == 0:
                print('[%5d] iter, [%2f] epoch, classifer loss: %.3f, attn_loss: %.5f' %
                      (i + 1, i*batch_size/total_imgs, running_loss/print_every, running_attn_loss/print_every))
                running_loss = 0.0
                running_attn_loss = 0.0
        if score_epoch:
            dic = score(batch_size=32)
            e_time = time.time() - start
            dic['time_elapsed'] = '{:.0f}m {:.0f}s'.format(e_time // 60, e_time%60)
            print(dic)

In [24]:

def score(net=resnet_attn, batch_size=batch_size):
    trainset = torchvision.datasets.ImageFolder(root='./data/train', transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    valset = torchvision.datasets.ImageFolder(root='./data/val', transform=transform)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    train_correct = 0
    val_correct_top1 = 0
    val_correct_top3 = 0
    val_correct_top5 = 0
    
    for inp, label in tqdm(iter(valloader)):
        _, idx = net(Variable(inp).cuda()).topk(1)
        val_correct_top1 += int(sum([label[i] in idx.cpu().data[i] for i in range(len(idx))]))
        
        _, idx = net(Variable(inp).cuda()).topk(3)
        val_correct_top3 += int(sum([label[i] in idx.cpu().data[i] for i in range(len(idx))]))
        
        _, idx = net(Variable(inp).cuda()).topk(5)
        val_correct_top5 += int(sum([label[i] in idx.cpu().data[i] for i in range(len(idx))]))
        
    for inp, label in tqdm(iter(trainloader)):
        _, idx = net(Variable(inp).cuda()).topk(1)
        train_correct += int(sum([label[i] in idx.cpu().data[i] for i in range(len(idx))]))
    
    return {
        'train_accu': train_correct/len(trainset),
        'val_accu_top1': val_correct_top1/len(valset),
        'val_accu_top3': val_correct_top3/len(valset),
        'val_accu_top5': val_correct_top5/len(valset)
    }

Train a fresh fc layer. 
`turn_off_grad_except([])` turns off grads for all weights but the fc layer

In [None]:
turn_off_grad_except(['fc'])
turn_on_layer4_conv_weight()
resnet_attn.eval()
resnet_attn.layer4.train()
train_k_epoch(50,score_epoch=True, add_attn=False)

fc.weight
fc.bias




[    1] iter, [0.000000] epoch, classifer loss: 0.147, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 3.365, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 2.824, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 2.821, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 2.784, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 2.722, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 2.623, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 2.691, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 2.727, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 2.566, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 2.613, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 2.621, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 2.507, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.78it/s]

{'val_accu_top5': 0.6180159635119726, 'train_accu': 0.43389290528121555, 'time_elapsed': '2m 59s', 'val_accu_top3': 0.5273660205245154, 'val_accu_top1': 0.3403648802736602}





[    1] iter, [0.000000] epoch, classifer loss: 0.087, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 2.357, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 2.324, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 2.262, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 2.204, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 2.285, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 2.277, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 2.284, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 2.343, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 2.203, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 2.217, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 2.180, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 2.197, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.621436716077537, 'train_accu': 0.4673446881509619, 'time_elapsed': '5m 59s', 'val_accu_top3': 0.5347776510832383, 'val_accu_top1': 0.34207525655644244}





[    1] iter, [0.000000] epoch, classifer loss: 0.092, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 2.040, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 2.013, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 2.069, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.992, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.903, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.936, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.935, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.967, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 2.016, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 2.007, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 2.049, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 2.026, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.6225769669327252, 'train_accu': 0.4755544663644161, 'time_elapsed': '8m 59s', 'val_accu_top3': 0.5256556442417332, 'val_accu_top1': 0.30786773090079816}





[    1] iter, [0.000000] epoch, classifer loss: 0.063, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.899, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.821, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.867, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.775, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.804, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.836, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.790, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.864, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.859, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.873, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.680, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.891, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.6510832383124288, 'train_accu': 0.5772576890086999, 'time_elapsed': '11m 59s', 'val_accu_top3': 0.5513112884834663, 'val_accu_top1': 0.3443557582668187}





[    1] iter, [0.000000] epoch, classifer loss: 0.068, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.661, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.560, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.517, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.715, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.616, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.697, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.763, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.778, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.704, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.814, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.694, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.745, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.6459521094640821, 'train_accu': 0.5710084548462199, 'time_elapsed': '14m 58s', 'val_accu_top3': 0.5433295324971493, 'val_accu_top1': 0.3289623717217788}





[    1] iter, [0.000000] epoch, classifer loss: 0.082, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.782, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.503, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.461, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.459, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.624, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.589, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.672, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.550, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.542, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.697, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.609, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.637, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.6436716077537058, 'train_accu': 0.6047053057223379, 'time_elapsed': '17m 58s', 'val_accu_top3': 0.5456100342075256, 'val_accu_top1': 0.3289623717217788}





[    1] iter, [0.000000] epoch, classifer loss: 0.064, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.583, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.468, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.470, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.516, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.472, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.473, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.464, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.505, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.496, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.363, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.525, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.499, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.80it/s]

{'val_accu_top5': 0.6493728620296465, 'train_accu': 0.6339909324837643, 'time_elapsed': '20m 57s', 'val_accu_top3': 0.556442417331813, 'val_accu_top1': 0.3563283922462942}





[    1] iter, [0.000000] epoch, classifer loss: 0.063, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.444, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.400, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.262, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.227, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.289, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.385, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.449, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.471, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.410, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.405, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.536, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.363, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.79it/s]

{'val_accu_top5': 0.6499429874572406, 'train_accu': 0.6553118490381081, 'time_elapsed': '23m 57s', 'val_accu_top3': 0.5615735461801596, 'val_accu_top1': 0.34321550741163054}





[    1] iter, [0.000000] epoch, classifer loss: 0.061, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.411, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.302, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.244, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.324, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.249, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.258, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.309, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.387, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.306, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.412, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.387, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.388, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.28it/s]
100%|██████████| 256/256 [01:07<00:00,  3.79it/s]

{'val_accu_top5': 0.636259977194983, 'train_accu': 0.6456316627864233, 'time_elapsed': '26m 55s', 'val_accu_top3': 0.5530216647662486, 'val_accu_top1': 0.3443557582668187}





[    1] iter, [0.000000] epoch, classifer loss: 0.053, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.210, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.250, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.250, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.261, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.287, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.211, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.238, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.221, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.175, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.393, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.372, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.348, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.79it/s]

{'val_accu_top5': 0.644811858608894, 'train_accu': 0.6717314054650165, 'time_elapsed': '29m 54s', 'val_accu_top3': 0.556442417331813, 'val_accu_top1': 0.34378563283922464}





[    1] iter, [0.000000] epoch, classifer loss: 0.061, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.348, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.270, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.183, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.213, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.261, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.218, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.161, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.208, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.231, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.243, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.281, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.388, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.79it/s]

{'val_accu_top5': 0.644811858608894, 'train_accu': 0.6740595515255483, 'time_elapsed': '32m 54s', 'val_accu_top3': 0.556442417331813, 'val_accu_top1': 0.34321550741163054}





[    1] iter, [0.000000] epoch, classifer loss: 0.058, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.246, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.058, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.152, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.140, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.159, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.179, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.223, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.185, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.053, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.258, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.209, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.267, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.78it/s]

{'val_accu_top5': 0.6488027366020525, 'train_accu': 0.6885185639014827, 'time_elapsed': '35m 53s', 'val_accu_top3': 0.5450399087799316, 'val_accu_top1': 0.330672748004561}





[    1] iter, [0.000000] epoch, classifer loss: 0.051, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.238, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.149, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.003, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.097, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.074, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.094, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.126, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.171, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.160, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.146, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.223, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.132, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.78it/s]

{'val_accu_top5': 0.6545039908779932, 'train_accu': 0.6931748560225462, 'time_elapsed': '38m 52s', 'val_accu_top3': 0.563854047890536, 'val_accu_top1': 0.3580387685290764}





[    1] iter, [0.000000] epoch, classifer loss: 0.080, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.219, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 1.147, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.029, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.011, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.054, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 0.997, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.116, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.112, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.130, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.031, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.150, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.170, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.78it/s]

{'val_accu_top5': 0.6482326111744584, 'train_accu': 0.6988114201690969, 'time_elapsed': '41m 51s', 'val_accu_top3': 0.5627137970353477, 'val_accu_top1': 0.34150513112884834}





[    1] iter, [0.000000] epoch, classifer loss: 0.047, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.048, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 0.991, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 0.921, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.013, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.095, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 1.022, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 0.953, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.065, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.045, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 1.076, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.114, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 1.112, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.78it/s]

{'val_accu_top5': 0.6470923603192702, 'train_accu': 0.6968508761181228, 'time_elapsed': '44m 51s', 'val_accu_top3': 0.5513112884834663, 'val_accu_top1': 0.33808437856328394}





[    1] iter, [0.000000] epoch, classifer loss: 0.051, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 1.253, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 0.998, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 1.064, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 1.004, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 1.053, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 0.936, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 1.017, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 1.061, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 1.106, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 0.996, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 1.039, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.28it/s]
100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.644811858608894, 'train_accu': 0.8207327533390516, 'time_elapsed': '116m 23s', 'val_accu_top3': 0.5627137970353477, 'val_accu_top1': 0.34378563283922464}





[    1] iter, [0.000000] epoch, classifer loss: 0.037, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 0.699, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 0.625, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 0.546, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 0.610, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 0.681, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 0.534, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 0.610, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 0.613, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 0.703, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 0.589, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 0.748, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 0.634, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.6488027366020525, 'train_accu': 0.8385001838010048, 'time_elapsed': '119m 22s', 'val_accu_top3': 0.5627137970353477, 'val_accu_top1': 0.3563283922462942}





[    1] iter, [0.000000] epoch, classifer loss: 0.030, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 0.710, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 0.637, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 0.553, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 0.622, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 0.571, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 0.576, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 0.616, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 0.670, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 0.569, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 0.690, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 0.680, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 0.620, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.637400228050171, 'train_accu': 0.7918147285871829, 'time_elapsed': '122m 21s', 'val_accu_top3': 0.5507411630558723, 'val_accu_top1': 0.3255416191562144}





[    1] iter, [0.000000] epoch, classifer loss: 0.033, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 0.791, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 0.700, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 0.588, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 0.604, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 0.616, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 0.637, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 0.559, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 0.568, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 0.620, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 0.618, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 0.740, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 0.622, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.645381984036488, 'train_accu': 0.8273495895110893, 'time_elapsed': '125m 21s', 'val_accu_top3': 0.5672748004561003, 'val_accu_top1': 0.35290763968072975}





[    1] iter, [0.000000] epoch, classifer loss: 0.027, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 0.830, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 0.703, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 0.539, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 0.603, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 0.676, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 0.600, attn_loss: 100.00001
[  141] iter, [0.548952] epoch, classifer loss: 0.575, attn_loss: 100.00001
[  161] iter, [0.627374] epoch, classifer loss: 0.627, attn_loss: 100.00001
[  181] iter, [0.705796] epoch, classifer loss: 0.595, attn_loss: 100.00001
[  201] iter, [0.784218] epoch, classifer loss: 0.573, attn_loss: 100.00001
[  221] iter, [0.862639] epoch, classifer loss: 0.598, attn_loss: 100.00001
[  241] iter, [0.941061] epoch, classifer loss: 0.660, attn_loss: 100.00001


100%|██████████| 55/55 [00:43<00:00,  1.27it/s]
100%|██████████| 256/256 [01:07<00:00,  3.77it/s]

{'val_accu_top5': 0.6619156214367161, 'train_accu': 0.847812768043132, 'time_elapsed': '128m 20s', 'val_accu_top3': 0.5689851767388826, 'val_accu_top1': 0.34207525655644244}





[    1] iter, [0.000000] epoch, classifer loss: 0.023, attn_loss: 5.00000
[   21] iter, [0.078422] epoch, classifer loss: 0.756, attn_loss: 100.00001
[   41] iter, [0.156844] epoch, classifer loss: 0.559, attn_loss: 100.00001
[   61] iter, [0.235265] epoch, classifer loss: 0.640, attn_loss: 100.00001
[   81] iter, [0.313687] epoch, classifer loss: 0.537, attn_loss: 100.00001
[  101] iter, [0.392109] epoch, classifer loss: 0.615, attn_loss: 100.00001
[  121] iter, [0.470531] epoch, classifer loss: 0.558, attn_loss: 100.00001


100%|██████████| 55/55 [00:14<00:00,  3.79it/s]


{'train_accu': 0.0, 'val_accu': 0.5872291904218928}

In [None]:
compute_attn_loss()

In [None]:
_lambda = 0.01

In [None]:
turn_off_grad_except(['attn_weights','fc', 'layer4'])
resnet_attn.eval()
train_k_epoch(5,score_epoch=True)

In [28]:
3

3