In [1]:
from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [2]:
resnet_pretrained = models.resnet50(pretrained=True)
nn.Conv2d = Conv2d_Attn
resnet_attn = models.resnet50()
resnet_attn.load_state_dict(resnet_pretrained.state_dict(), strict=False)

In [3]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
param_keys = list(resnet_attn.state_dict().keys())
formatted_keys = []
for k in param_keys:
    found = re.findall(r'\.[\d]{1,2}\.', k)
    if len(found):
        for f in found:
            k = k.replace(f, '[{}].'.format(f.strip('.')))
    formatted_keys.append(k)

In [4]:
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(lst=[]):
    for k in formatted_keys:
        obj = eval('resnet_attn.'+k)
        for kw in lst:
            if not kw in k:
                obj.requires_grad = False
            else:
                obj.requires_grad = True

In [5]:
resnet_attn.fc = nn.Linear(resnet_attn.fc.in_features, 144)

Start training

In [6]:
batch_size = 32

# count number of instances for each class and use sampler for class imbalance
TRAIN_DIR = '/home/bdrad1/ryan/194/data/train'
VAL_DIR = '/home/bdrad1/ryan/194/data/val'
TEST_DIR = '/home/bdrad1/ryan/194/data/test'

classes = os.listdir(TRAIN_DIR)
classes.remove('.DS_Store')
class_counts = [len(os.listdir(os.path.join(TRAIN_DIR, c))) for c in classes]

c = 0
weights = []
for directory, _, files in os.walk(TRAIN_DIR):
    if not directory.endswith('train'):
        for f in files:
            weights.append(class_counts[c])       

weights = [1.0/i for i in weights]
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, 8161, replacement= True)

In [7]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [
        transforms.RandomRotation(30),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])

trainset = torchvision.datasets.ImageFolder(root=TRAIN_DIR, transform=transform)
valset = torchvision.datasets.ImageFolder(root=VAL_DIR, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=2, sampler = sampler)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [8]:
total_imgs = len(trainset.imgs)
print('number of training images', total_imgs)

number of training images 8161


In [9]:
resnet_attn = resnet_attn.cuda()

In [10]:
total_attn_params = 0
for k in formatted_keys:
    obj = eval('resnet_attn.'+k)
    if 'attn_weights' in k:
        total_attn_params += np.prod(obj.shape)
print("Total number of attention parameters", total_attn_params)

Total number of attention parameters 13385920


We want the attention parameters to diverge from 1, therefore we penalize element-wise square loss as $\lambda (1 \times \text{# params} - (x - 1)^2)$

But this is too big a number,
let's try: 
$- (x - 1)^2$ for now

In [11]:
_lambda = 1e-2 #set default

In [12]:
def get_params_objs(name, net='resnet_attn'):
    res = []
    for k in formatted_keys:
        obj = eval(f'{net}.'+k)
        if name in k:
            res.append(obj)
    return res

In [13]:
def compute_attn_loss(n_params=26560):
    attns = get_params_objs('attn_weights')
    penality = sum([torch.pow(t - 1,2).mean() for t in attns])
    return _lambda*(- penality)

In [15]:
def train_one_epoch(add_attn=True):
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, resnet_attn.parameters()))
    
    running_loss = 0.0
    running_attn_loss = 0.0
    running_corrects = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
        
        optimizer.zero_grad()
        outputs = resnet_attn(inputs)
        

        loss = cls_criterion(outputs, labels)
        attn_loss = compute_attn_loss()
        if add_attn:
            loss += attn_loss
        
        loss.backward()
        optimizer.step()
        
        #calculate training acc
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels)
        
        running_loss += loss.data[0]
        running_attn_loss += attn_loss.data[0]
        
        if i % print_every == 0:
            print('[%5d] iter, [%2f] epoch, avg loss: %.3f, attn_loss: %.5f ' %
                  (i + 1, i*batch_size/total_imgs, running_loss/print_every, running_attn_loss/print_every))
            running_loss = 0.0
            running_attn_loss = 0.0
    training_acc = running_corrects.double() / total_imgs
    return training_acc, optimizer

In [16]:
from tqdm import tqdm
def score(net=resnet_attn, batch_size=batch_size):
    trainset = torchvision.datasets.ImageFolder(root=TRAIN_DIR, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    valset = torchvision.datasets.ImageFolder(root=VAL_DIR, transform=transform)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    
    train_correct = 0
    val_correct = 0
    
    for inp, label in tqdm(iter(trainloader)):
        _, idx = net(Variable(inp).cuda()).topk(3)
        train_correct += int(sum(idx.cpu().data == label))
    
    for inp, label in tqdm(iter(valloader)):
        _, idx = net(Variable(inp).cuda()).topk(3)
        val_correct += int(sum(idx.cpu().data == label))
    
    return {
        'train_accu': train_correct/len(trainset),
        'val_accu': val_correct/len(valset)
    }

Train a fresh fc layer. 
`turn_off_grad_except([])` turns off grads for all weights but the fc layer

In [17]:
def score_top3():
    correct_count = 0
    for inp, label in tqdm(iter(valloader)):
        _, idx = resnet_attn(Variable(inp).cuda()).topk(3)
        lab = Variable(label).cuda()
        lab_expand = lab.unsqueeze(1).expand_as(idx)
        correct_count += int((idx == lab_expand).sum())
    print(correct_count/len(valset))

In [36]:
#Training scheme
print_every = 50
val_accs = []
def train(seq = ['fc', 'att', 'fc', 'bn', 'att', 'att', 'bn', 'att', 'fc','att']):
    dir_name = ''
    for x in seq:
        dir_name += x[0]
    if not os.direxists('checkpoints/'+dir_name):
        os.mkdir('checkpoints/'+dir_name)
    highest_acc1 = 0
    for idx, s in enumerate(seq):
        print('======================= epoch:', idx,'layer:',s,"=========================")
        if s == 'fc':
            turn_off_grad_except(['fc'])
        elif s == 'att':
            turn_off_grad_except(['attn_weights'])
        elif s == 'bn':
            turn_off_grad_except(['bn'])
        training_acc,optimizer = train_one_epoch()
        
        correct_count1, correct_count3 = 0,0
        
        for inp, label in iter(valloader):
            _, idx1 = resnet_attn(Variable(inp).cuda()).topk(1)
            _, idx3 = resnet_attn(Variable(inp).cuda()).topk(3)
            
            lab = Variable(label).cuda()
            lab_expand1 = lab.unsqueeze(1).expand_as(idx1)
            lab_expand3 = lab.unsqueeze(1).expand_as(idx3)
            correct_count1 += int((idx1 == lab_expand1).sum())
            correct_count3 += int((idx3 == lab_expand3).sum())
            
        val_acc_1 = correct_count1/len(valset)
        val_acc_3 = correct_count3/len(valset)
        
        if val_acc_1 > highest_acc1:
            highest_acc1 = val_acc_1
            #Save best acc model state dict
            state = {
            'epoch': idx,
            'arch': seq,
            'state_dict': resnet_attn.state_dict(),
            'val_acc1': val_acc_1,
            'val_acc3': val_acc_3,
            'optimizer' : optimizer.state_dict()
            }
            save_checkpoint(state, 'epoch{}_val1_{:.3f}_val3_{:.3f}.pth'.format(idx, val_acc_1, val_acc_3), 'checkpoints/'+dir_name)
        print("top 1 val_acc: {} top 3 val_acc: {}".format(val_acc_1, val_acc_3))
        val_accs.append(val_acc_1)
    return highest_acc1, val_accs
    
def save_checkpoint(state, filename, cp_path):
    torch.save(state, os.path.join(cp_path, filename))
    print("saved model to {}".format(filename))

In [37]:
acc_t1, val_accs = train()

AttributeError: module 'os' has no attribute 'direxists'

In [10]:
checkpoint = torch.load('checkpoints/5-8/epoch9_val1_0.47263397947548463_val3_0.6835803876852907.pth')

In [13]:
resnet_attn.load_state_dict(checkpoint['state_dict'])
resnet_attn.cuda()
#test on best model
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        normalize
    ])


testset = torchvision.datasets.ImageFolder(root=TEST_DIR, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)
def test(model):
    model.eval()
    correct_count1, correct_count3 = 0, 0

    for inp, label in iter(testloader):
        _, idx1 = model(Variable(inp).cuda()).topk(1)
        _, idx3 = model(Variable(inp).cuda()).topk(3)

        lab = Variable(label).cuda()
        lab_expand1 = lab.unsqueeze(1).expand_as(idx1)
        lab_expand3 = lab.unsqueeze(1).expand_as(idx3)
        correct_count1 += int((idx1 == lab_expand1).sum())
        correct_count3 += int((idx3 == lab_expand3).sum())

    test_acc_1 = correct_count1/len(testset)
    test_acc_3 = correct_count3/len(testset)
    print("top 1 test acc:{}, top 3 test acc:{}".format(test_acc_1, test_acc_3))

In [14]:
test(resnet_attn)

top 1 test acc:0.4734010759115362, top 3 test acc:0.6826060968320382


In [24]:
def create_dir(x):
    


In [26]:
's'+'v'

'sv'