In [1]:
import sys
sys.path.insert(0, '..')

from AttentionModule import Conv2d_Attn

import torch
from torch import nn
from torchvision import models, datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

import re
import numpy as np

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

In [2]:
incep = models.resnet101(pretrained=True)

In [3]:
incep.fc = nn.Linear(incep.fc.in_features, 144)

In [4]:
# This block turns 'layer1.0.downsample.0.weight' to 'layer1[0].downsample[0].weight'
def get_formatted_keys(network_name):
    param_keys = list(eval(network_name).state_dict().keys())
    formatted_keys = []
    for k in param_keys:
        found = re.findall(r'\.[\d]{1,2}\.', k)
        if len(found):
            for f in found:
                k = k.replace(f, '[{}].'.format(f.strip('.')))
        formatted_keys.append(k)
    return formatted_keys
    
# This block turn off gradient up for all params except attn_weights
def turn_off_grad_except(network_name, lst=[]):
    formatted_keys = get_formatted_keys(network_name)
    for k in formatted_keys:
        obj = eval(f'{network_name}.'+k)
        for kw in lst:
            if not kw in k:
                obj.requires_grad = False
            else:
                obj.requires_grad = True

In [5]:
turn_off_grad_except('incep', ['fc'])

In [6]:
batch_size = 32

In [7]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        normalize])

trainset = torchvision.datasets.ImageFolder(root='../data/train', transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

valset = torchvision.datasets.ImageFolder(root='../data/val', transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [8]:
def get_loader(dirname):
    trainset = torchvision.datasets.ImageFolder(root=f'../data/{dirname}', transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    return trainloader, len(trainset)

In [9]:
def score_batch(inp, label, top, network):
    _, idx = eval(network)(Variable(inp).cuda()).topk(top)
    lab = Variable(label).cuda()
    lab_expand = lab.unsqueeze(1).expand_as(idx)
    return int((idx == lab_expand).sum())

In [10]:
def score_data(data_dir, network_name):
    trainloader, train_total = get_loader(data_dir)
    top3_count = 0
    top1_count = 0
    for inp, label in iter(trainloader):
        top1_count += score_batch(inp, label, 1, network_name)
        top3_count += score_batch(inp, label, 3, network_name)
    logging.info({
        f'{data_dir}_top1': top1_count/train_total,
        f'{data_dir}_top3': top3_count/train_total
    })

In [11]:
def score(network_name, train=True, val=True, batch_size=32):    
    if train:
        score_data('train', network_name)
    if val:
        score_data('val', network_name)

In [12]:
cls_criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, incep.parameters()))

In [13]:
incep = incep.eval().cuda()

In [14]:
print_every = 10
total_imgs = len(trainset)

In [15]:
import logging
logging.basicConfig(format='%(asctime)s : %(message)s', 
                    filename='../logs/resnet101-fc.log',
                    level=logging.INFO,
                    filemode='w'
                   )

In [16]:
print_every = 30

In [17]:
logging.info("Retraining Resnet101 FC layer only")

In [18]:
num_iter = 10

In [19]:
for j in range(1, num_iter+1):
    logging.info(f"Iteration {j}/{num_iter}")
    
    running_loss = 0.0
    running_attn_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()

        optimizer.zero_grad()
        outputs = incep(inputs)
        loss = cls_criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.data[0]

        if (i+1) % print_every == 0:
            logging.info(
                '{} iter, {} epoch, avg loss: {} '.format(
                    i + 1, 
                    i*batch_size/total_imgs, 
                    running_loss/print_every))
            running_loss = 0.0
            running_attn_loss = 0.0
        
    logging.info("Begin Scoring")
    score('incep', batch_size=64)
    logging.info("Done Scoring")

In [20]:
torch.save(incep, 'resnet101-fc.pth')