In [1]:
# from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

In [3]:
res = models.resnet50(pretrained=True)

In [4]:
blks = list(res.children())

In [5]:
batch_size = 32
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

trainset = torchvision.datasets.ImageFolder(root='../data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [6]:
res = res.eval().cuda()

In [7]:
kernel_size = [56,56,28,14,7]

In [8]:
def linear_pre_hook(out):
    return out.view(out.size(0), -1)

In [9]:
# pre_fc_size = 64+256+512+1024+2048
pre_fc_size = 64+256+512+1024

In [10]:
class ResNetWrap(nn.Module):
    def __init__(self, resnet):
        super(ResNetWrap, self).__init__()
        self.blks = list(res.children())
        self.max_pool_seq = nn.Sequential(*blks[:4])
        self.seq_1 = blks[4]
        self.seq_2 = blks[5]
        self.seq_3 = blks[6]
        self.seq_4 = blks[7]
        self.avg_pool = blks[8]
        self.fc = nn.Linear(pre_fc_size, 144)
    
    def forward(self, inp):
        self.seq_0_out = self.max_pool_seq(inp)
        self.seq_1_out = self.seq_1(self.seq_0_out)
        self.seq_2_out = self.seq_2(self.seq_1_out)
        self.seq_3_out = self.seq_3(self.seq_2_out)
        self.seq_4_out = self.seq_4(self.seq_3_out)
        
        outs = [
            self.seq_0_out,
            self.seq_1_out,
            self.seq_2_out,
            self.seq_3_out,
#             self.seq_4_out,
        ]
        
        avg_pools = [
            torch.nn.functional.avg_pool2d(out, kernel_size[i])
            for i,out in enumerate(outs)
        ]
        self.avg_pools_out = torch.cat(avg_pools, dim=1)
        avg_pools = linear_pre_hook(self.avg_pools_out)
        return self.fc(avg_pools)

In [11]:
agg = ResNetWrap(res).eval().cuda()

In [12]:
def format_keys(network_name):
    # This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
    param_keys = list(eval(network_name).state_dict().keys())
    formatted_keys = []
    for k in param_keys:
        found = re.findall(r'\.[\d]{1,2}\.', k)
        if len(found):
            for f in found:
                k = k.replace(f, '[{}].'.format(f.strip('.')))
        formatted_keys.append(k)
    return formatted_keys

In [13]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(network_name, lst=[], verbose=False):
    formatted_keys = format_keys(network_name)
    for k in formatted_keys:
        obj = eval(f'{network_name}.'+k)
        for kw in lst:
            if not kw in k:
                obj.requires_grad = False
            else:
                if verbose:
                    print(k)
                obj.requires_grad = True

In [14]:
turn_off_grad_except('agg', ['fc'], True)

fc.weight
fc.bias


In [15]:
print_every = 30

In [16]:
total_imgs = len(trainset)

In [17]:
from tqdm import tqdm
def score(net_name, batch_size=batch_size):
    net = eval(net_name)
    trainset = torchvision.datasets.ImageFolder(root='../data/train', transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    valset = torchvision.datasets.ImageFolder(root='../data/val', transform=transform)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    train_correct = 0
    val_correct = 0
    
    for inp, label in tqdm(iter(trainloader)):
        _, idx = net(Variable(inp).cuda()).max(1)
        train_correct += int(sum(idx.cpu().data == label))
    
    for inp, label in tqdm(iter(valloader)):
        _, idx = net(Variable(inp).cuda()).max(1)
        val_correct += int(sum(idx.cpu().data == label))
    
    return {
        'train_accu': train_correct/len(trainset),
        'val_accu': val_correct/len(valset)
    }

def train_k_epoch(net_name, k, score_epoch=False):
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, eval(net_name).parameters()))
    
    for epoch in range(k):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()

            optimizer.zero_grad()
            outputs = eval(net_name)(inputs)
            clf_loss = cls_criterion(outputs, labels)
            
            clf_loss.backward()
            optimizer.step()

            running_loss += clf_loss.data[0]

            if i % print_every == 0:
                print('[%5d] iter, [%2f] epoch, classifer loss: %.3f' %
                      (i + 1, i*batch_size/total_imgs, running_loss/print_every))
                running_loss = 0.0
                running_attn_loss = 0.0
        if score_epoch:
            print(score(net_name, batch_size=32))

In [18]:
train_k_epoch('agg', k=10, score_epoch=True)

[    1] iter, [0.000000] epoch, classifer loss: 0.165
[   31] iter, [0.117633] epoch, classifer loss: 4.395
[   61] iter, [0.235265] epoch, classifer loss: 4.284
[   91] iter, [0.352898] epoch, classifer loss: 4.213
[  121] iter, [0.470531] epoch, classifer loss: 4.137
[  151] iter, [0.588163] epoch, classifer loss: 4.144
[  181] iter, [0.705796] epoch, classifer loss: 4.146
[  211] iter, [0.823429] epoch, classifer loss: 4.021
[  241] iter, [0.941061] epoch, classifer loss: 4.063


100%|██████████| 256/256 [01:08<00:00,  3.73it/s]
100%|██████████| 55/55 [00:14<00:00,  3.70it/s]

{'train_accu': 0.09422864844994486, 'val_accu': 0.09407069555302167}





[    1] iter, [0.000000] epoch, classifer loss: 0.130
[   31] iter, [0.117633] epoch, classifer loss: 4.058
[   61] iter, [0.235265] epoch, classifer loss: 3.977
[   91] iter, [0.352898] epoch, classifer loss: 4.003
[  121] iter, [0.470531] epoch, classifer loss: 3.933
[  151] iter, [0.588163] epoch, classifer loss: 3.989
[  181] iter, [0.705796] epoch, classifer loss: 3.955
[  211] iter, [0.823429] epoch, classifer loss: 3.994
[  241] iter, [0.941061] epoch, classifer loss: 3.920


100%|██████████| 256/256 [01:08<00:00,  3.72it/s]
100%|██████████| 55/55 [00:14<00:00,  3.68it/s]

{'train_accu': 0.11751010905526284, 'val_accu': 0.11858608893956671}





[    1] iter, [0.000000] epoch, classifer loss: 0.112
[   31] iter, [0.117633] epoch, classifer loss: 3.892
[   61] iter, [0.235265] epoch, classifer loss: 3.886
[   91] iter, [0.352898] epoch, classifer loss: 3.867
[  121] iter, [0.470531] epoch, classifer loss: 3.891
[  151] iter, [0.588163] epoch, classifer loss: 3.793
[  181] iter, [0.705796] epoch, classifer loss: 3.813
[  211] iter, [0.823429] epoch, classifer loss: 3.805
[  241] iter, [0.941061] epoch, classifer loss: 3.803


100%|██████████| 256/256 [01:08<00:00,  3.72it/s]
100%|██████████| 55/55 [00:14<00:00,  3.68it/s]

{'train_accu': 0.1539027080014704, 'val_accu': 0.15450399087799316}





[    1] iter, [0.000000] epoch, classifer loss: 0.134
[   31] iter, [0.117633] epoch, classifer loss: 3.737
[   61] iter, [0.235265] epoch, classifer loss: 3.726
[   91] iter, [0.352898] epoch, classifer loss: 3.757
[  121] iter, [0.470531] epoch, classifer loss: 3.674
[  151] iter, [0.588163] epoch, classifer loss: 3.686
[  181] iter, [0.705796] epoch, classifer loss: 3.790
[  211] iter, [0.823429] epoch, classifer loss: 3.759
[  241] iter, [0.941061] epoch, classifer loss: 3.729


100%|██████████| 256/256 [01:08<00:00,  3.72it/s]
100%|██████████| 55/55 [00:14<00:00,  3.68it/s]

{'train_accu': 0.1784095086386472, 'val_accu': 0.16875712656784492}





[    1] iter, [0.000000] epoch, classifer loss: 0.119
[   31] iter, [0.117633] epoch, classifer loss: 3.655
[   61] iter, [0.235265] epoch, classifer loss: 3.626
[   91] iter, [0.352898] epoch, classifer loss: 3.707
[  121] iter, [0.470531] epoch, classifer loss: 3.621
[  151] iter, [0.588163] epoch, classifer loss: 3.661
[  181] iter, [0.705796] epoch, classifer loss: 3.615
[  211] iter, [0.823429] epoch, classifer loss: 3.609
[  241] iter, [0.941061] epoch, classifer loss: 3.707


100%|██████████| 256/256 [01:08<00:00,  3.71it/s]
100%|██████████| 55/55 [00:14<00:00,  3.69it/s]

{'train_accu': 0.1632152922435976, 'val_accu': 0.16647662485746864}





[    1] iter, [0.000000] epoch, classifer loss: 0.121
[   31] iter, [0.117633] epoch, classifer loss: 3.593
[   61] iter, [0.235265] epoch, classifer loss: 3.554
[   91] iter, [0.352898] epoch, classifer loss: 3.607
[  121] iter, [0.470531] epoch, classifer loss: 3.609
[  151] iter, [0.588163] epoch, classifer loss: 3.648
[  181] iter, [0.705796] epoch, classifer loss: 3.510
[  211] iter, [0.823429] epoch, classifer loss: 3.603
[  241] iter, [0.941061] epoch, classifer loss: 3.490


100%|██████████| 256/256 [01:08<00:00,  3.72it/s]
100%|██████████| 55/55 [00:14<00:00,  3.68it/s]

{'train_accu': 0.1608871461830658, 'val_accu': 0.161345496009122}





[    1] iter, [0.000000] epoch, classifer loss: 0.116
[   31] iter, [0.117633] epoch, classifer loss: 3.465
[   61] iter, [0.235265] epoch, classifer loss: 3.535
[   91] iter, [0.352898] epoch, classifer loss: 3.520
[  121] iter, [0.470531] epoch, classifer loss: 3.485
[  151] iter, [0.588163] epoch, classifer loss: 3.508
[  181] iter, [0.705796] epoch, classifer loss: 3.495
[  211] iter, [0.823429] epoch, classifer loss: 3.532
[  241] iter, [0.941061] epoch, classifer loss: 3.454


100%|██████████| 256/256 [01:08<00:00,  3.72it/s]
100%|██████████| 55/55 [00:14<00:00,  3.69it/s]

{'train_accu': 0.21443450557529714, 'val_accu': 0.19384264538198404}





[    1] iter, [0.000000] epoch, classifer loss: 0.132
[   31] iter, [0.117633] epoch, classifer loss: 3.501
[   61] iter, [0.235265] epoch, classifer loss: 3.469
[   91] iter, [0.352898] epoch, classifer loss: 3.489
[  121] iter, [0.470531] epoch, classifer loss: 3.327
[  151] iter, [0.588163] epoch, classifer loss: 3.397
[  181] iter, [0.705796] epoch, classifer loss: 3.376
[  211] iter, [0.823429] epoch, classifer loss: 3.519
[  241] iter, [0.941061] epoch, classifer loss: 3.442


100%|██████████| 256/256 [01:08<00:00,  3.72it/s]
100%|██████████| 55/55 [00:14<00:00,  3.68it/s]

{'train_accu': 0.20058816321529224, 'val_accu': 0.16875712656784492}





[    1] iter, [0.000000] epoch, classifer loss: 0.099
[   31] iter, [0.117633] epoch, classifer loss: 3.446
[   61] iter, [0.235265] epoch, classifer loss: 3.396
[   91] iter, [0.352898] epoch, classifer loss: 3.334
[  121] iter, [0.470531] epoch, classifer loss: 3.412
[  151] iter, [0.588163] epoch, classifer loss: 3.367
[  181] iter, [0.705796] epoch, classifer loss: 3.372
[  211] iter, [0.823429] epoch, classifer loss: 3.427
[  241] iter, [0.941061] epoch, classifer loss: 3.393


100%|██████████| 256/256 [01:08<00:00,  3.72it/s]
100%|██████████| 55/55 [00:14<00:00,  3.70it/s]

{'train_accu': 0.21994853571866194, 'val_accu': 0.19954389965792474}





[    1] iter, [0.000000] epoch, classifer loss: 0.110
[   31] iter, [0.117633] epoch, classifer loss: 3.400
[   61] iter, [0.235265] epoch, classifer loss: 3.319
[   91] iter, [0.352898] epoch, classifer loss: 3.336
[  121] iter, [0.470531] epoch, classifer loss: 3.412
[  151] iter, [0.588163] epoch, classifer loss: 3.320
[  181] iter, [0.705796] epoch, classifer loss: 3.309
[  211] iter, [0.823429] epoch, classifer loss: 3.320
[  241] iter, [0.941061] epoch, classifer loss: 3.320


100%|██████████| 256/256 [01:08<00:00,  3.72it/s]
100%|██████████| 55/55 [00:14<00:00,  3.70it/s]

{'train_accu': 0.20904300943511825, 'val_accu': 0.18700114025085518}



