In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np

In [3]:
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

In [4]:
resnet_attn = models.resnet50()

In [5]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [6]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(lst=[]):
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        for kw in lst:
            if not kw in k:
                obj.requires_grad = False
            else:
                obj.requires_grad = True

In [7]:
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 144)

Start training

In [8]:
# batch_size = 32
# batch_size = 64
batch_size = 256

In [9]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

trainset = torchvision.datasets.ImageFolder(root='../data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

valset = torchvision.datasets.ImageFolder(root='../data/val', transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [10]:
total_imgs = len(trainset.imgs)

In [11]:
resnet_attn = resnet_attn.cuda()

In [12]:
total_attn_params = 0
for k in formatted_keys:
    obj = eval('resnet_attn.'+k)
    if 'attn_weights' in k:
        total_attn_params += np.prod(obj.shape)
print("Total number of attention parameters", total_attn_params)

Total number of attention parameters 0


We want the attention parameters to diverge from 1, therefore we penalize element-wise square loss as $\lambda (1 \times \text{# params} - (x - 1)^2)$

But this is too big a number,
let's try: 
$- (x - 1)^2$ for now

In [13]:
_lambda = 1 #set default

In [14]:
def get_params_objs(name, net='resnet_attn'):
    res = []
    for k in formatted_keys:
        obj = eval(f'{net}.'+k)
        if name in k:
            res.append(obj)
    return res

In [15]:
def compute_attn_loss(n_params=26560):
    attns = get_params_objs('attn_weights')
#     penalty = sum([torch.abs(t - 1,2).mean() for t in attns])
    penalty = sum([torch.norm(t, p=1) for t in attns])/float(total_attn_params)
    return _lambda*(penalty)

In [16]:
print_every = 50

In [17]:
def score_top3(train=True, val=True, partial=True, frac=4):
    if train:
        correct_count = 0
        num_imgs = len(trainset)
        if partial:
            part = len(trainset)//frac
            total = 0
            num_imgs = part
        
        for inp, label in tqdm(iter(trainloader)):
            _, idx = resnet_attn(Variable(inp).cuda()).topk(3)
            lab = Variable(label).cuda()
            lab_expand = lab.unsqueeze(1).expand_as(idx)
            correct_count += int((idx == lab_expand).sum())
            
            if partial:
                total += batch_size
                if total >= part:
                    break
            
        print({'Train Accuracy': correct_count/num_imgs})
    
    if val:
        correct_count = 0
        for inp, label in tqdm(iter(valloader)):
            _, idx = resnet_attn(Variable(inp).cuda()).topk(3)
            lab = Variable(label).cuda()
            lab_expand = lab.unsqueeze(1).expand_as(idx)
            correct_count += int((idx == lab_expand).sum())
        print({'Val Accuracy': correct_count/len(valset)})

In [18]:
def plot_attn_hist():
    attns = get_params_objs('attn_weights')
    attns = torch.cat([attn.view(-1).squeeze() for attn in attns])
    attns_arr = attns.data.cpu().numpy()
    plt.hist(attns_arr)

In [19]:
def get_loss_opt():
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(resnet_attn.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
    return cls_criterion, optimizer

In [20]:
def train_one():
    trainloader, train_total = get_loader('train')

    running_cls_loss = 0.0
    running_attn_loss = 0.0
    top1_count = 0
    top3_count = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()

        optimizer.zero_grad()
        outputs = resnet_attn(inputs)
        cls_loss = cls_criterion(outputs, labels)
        loss = cls_loss

        loss.backward()
        optimizer.step()

        running_cls_loss += cls_loss.data[0]
        
        top1_count += compute_correct(outputs, labels, 1)
        top3_count += compute_correct(outputs, labels, 3)

        if (i + 1) % print_every == 0:
            print_log(
                '{} iter, {} epoch, cls loss: {}'.format(
                    i + 1,
                    i * batch_size / total_imgs,
                    running_cls_loss / print_every))
            running_cls_loss = 0.0
            running_attn_loss = 0.0

    print_log("Begin Scoring")
    print_log({
        f'train_top1': top1_count / train_total,
        f'train_top3': top3_count / train_total
    })
    score('resnet_attn', batch_size=64, train=False)
    print_log("Done Scoring")

In [21]:
def get_loader(dirname, batch_size=32):
    trainset = torchvision.datasets.ImageFolder(root=f'../data/{dirname}', transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=3)
    return trainloader, len(trainset)

In [22]:
import logging

logging.basicConfig(format='%(asctime)s : %(message)s',
                    filename='{}'.format('from_scratch.log'),
                    level=logging.INFO,
                    filemode='w+'
                    )


def print_log(*string):
    print(*string)
    logging.info(str(string))

In [23]:
print_log('Begin')

Begin


In [24]:
def score_batch(inp, label, top, network):
    _, idx = eval(network)(Variable(inp).cuda()).topk(top)
    lab = Variable(label).cuda()
    lab_expand = lab.unsqueeze(1).expand_as(idx)
    return int((idx == lab_expand).sum())

def compute_correct(out, label, top):
    _, idx = out.topk(top)
    lab_expand = label.unsqueeze(1).expand_as(idx)
    return int((idx == lab_expand).sum())


def score_data(data_dir, network_name):
    trainloader, train_total = get_loader(data_dir, batch_size=64)
    top3_count = 0
    top1_count = 0
    for inp, label in iter(trainloader):
        top1_count += score_batch(inp, label, 1, network_name)
        top3_count += score_batch(inp, label, 3, network_name)
    print_log({
        f'{data_dir}_top1': top1_count / train_total,
        f'{data_dir}_top3': top3_count / train_total
    })


def score(network_name, train=True, val=True, batch_size=32):
    if train:
        score_data('train', network_name)
    if val:
        score_data('val', network_name)

In [25]:
cls_criterion, optimizer = get_loss_opt()
for iteration in range(50):
    train_one()

50 iter, 1.5370665359637299 epoch, cls loss: 9.780503072738647
100 iter, 3.105501776743046 epoch, cls loss: 4.381561708450318
150 iter, 4.673937017522363 epoch, cls loss: 4.3171054458618165
200 iter, 6.242372258301678 epoch, cls loss: 4.236970739364624
250 iter, 7.810807499080995 epoch, cls loss: 4.209580683708191
Begin Scoring
{'train_top1': 0.0654331577012621, 'train_top3': 0.16897439039333415}
{'val_top1': 0.07810718358038768, 'val_top3': 0.17103762827822122}
Done Scoring
50 iter, 1.5370665359637299 epoch, cls loss: 4.23497700214386
100 iter, 3.105501776743046 epoch, cls loss: 4.200983171463013
150 iter, 4.673937017522363 epoch, cls loss: 4.180145978927612
200 iter, 6.242372258301678 epoch, cls loss: 4.219748349189758
250 iter, 7.810807499080995 epoch, cls loss: 4.198048677444458
Begin Scoring
{'train_top1': 0.06972184781276804, 'train_top3': 0.1802475186864355}
{'val_top1': 0.07696693272519954, 'val_top3': 0.1750285062713797}
Done Scoring
50 iter, 1.5370665359637299 epoch, cls loss

Process Process-66:
Process Process-64:
Process Process-65:
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ubuntu/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ubuntu/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/ubuntu/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/home/ubuntu/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ubuntu/miniconda3/lib/python3.6/site-packages/torch/utils/dat

RuntimeError: DataLoader worker (pid 14747) exited unexpectedly with exit code 1.