Code largely adapted from: https://github.com/SaoYan/LearnToPayAttention

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import torchvision.utils as utils
import torchvision.transforms as transforms

from torch.utils.tensorboard import SummaryWriter

In [2]:
def weights_init_xavierUniform(module):
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight, gain=np.sqrt(2))
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.uniform_(m.weight, a=0, b=1)
            nn.init.constant_(m.bias, val=0.)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight, gain=np.sqrt(2))
            if m.bias is not None:
                nn.init.constant_(m.bias, val=0.)

In [3]:
class ConvBlock(nn.Module):
    """Defining a block that is conv-> batch norm -> Relu base on the dimensions passed in"""
    def __init__(self, in_features, out_features, num_conv, pool=False):
        super(ConvBlock, self).__init__()
        features = [in_features] + [out_features for i in range(num_conv)]
        layers = []
        for i in range(len(features)-1):
            layers.append(nn.Conv2d(in_channels=features[i], out_channels=features[i+1], kernel_size=3, padding=1, bias=True))
            layers.append(nn.BatchNorm2d(num_features=features[i+1], affine=True, track_running_stats=True))
            layers.append(nn.ReLU())
            if pool:
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
        self.op = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.op(x)

class ProjectorBlock(nn.Module):
    """Block to project to different dimensions. Essentially just wraps a 2D Conv"""
    def __init__(self, in_features, out_features):
        super(ProjectorBlock, self).__init__()
        self.op = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=1, padding=0, bias=False)
    
    def forward(self, inputs):
        return self.op(inputs)

class LinearAttentionBlock(nn.Module):
    """Creates the 2D matrix that compares local to global features"""
    def __init__(self, in_features, normalize_attn=True):
        super(LinearAttentionBlock, self).__init__()
        self.normalize_attn = normalize_attn
        self.op = nn.Conv2d(in_channels=in_features, out_channels=1, kernel_size=1, padding=0, bias=False)
    
    def forward(self, l, g):
        """
        l: local features
        g:: global features (at end of network)
        """
        N, C, W, H = l.size()
        c = self.op(l+g) # batch_sizex1xWxH
        if self.normalize_attn:
            a = F.softmax(c.view(N,1,-1), dim=2).view(N,1,W,H)
        else:
            a = torch.sigmoid(c)
        g = torch.mul(a.expand_as(l), l)
        if self.normalize_attn:
            g = g.view(N,C,-1).sum(dim=2) # batch_sizexC
        else:
            g = F.adaptive_avg_pool2d(g, (1,1)).view(N,C)
            
        return c.view(N,1,W,H), g


In [4]:
class AttnVGG(nn.Module):
    """Main network"""
    def __init__(self, im_size, num_classes, normalize_attn=True):
        super(AttnVGG, self).__init__()
        # conv blocks
        self.conv_block1 = ConvBlock(3, 64, 2)
        self.conv_block2 = ConvBlock(64, 128, 2)
        self.conv_block3 = ConvBlock(128, 256, 3)
        self.conv_block4 = ConvBlock(256, 512, 3)
        self.conv_block5 = ConvBlock(512, 512, 3)
        self.conv_block6 = ConvBlock(512, 512, 2, pool=True)
        self.dense = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=int(im_size/32), padding=0, bias=True)

        # Projectors & Compatibility functions
        self.projector = ProjectorBlock(256, 512)
        self.attn1 = LinearAttentionBlock(in_features=512, normalize_attn=normalize_attn)
        self.attn2 = LinearAttentionBlock(in_features=512, normalize_attn=normalize_attn)
        self.attn3 = LinearAttentionBlock(in_features=512, normalize_attn=normalize_attn)

        # final classification layer, using the combination of local features and attention map
        self.classify = nn.Linear(in_features=512*3, out_features=num_classes, bias=True)
        
        # initialize
        weights_init_xavierUniform(self)

    def forward(self, x):
        # feed forward
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        l1 = self.conv_block3(x) # /1
        x = F.max_pool2d(l1, kernel_size=2, stride=2, padding=0) # /2
        l2 = self.conv_block4(x) # /2
        x = F.max_pool2d(l2, kernel_size=2, stride=2, padding=0) # /4
        l3 = self.conv_block5(x) # /4
        x = F.max_pool2d(l3, kernel_size=2, stride=2, padding=0) # /8
        x = self.conv_block6(x) # /32
        g = self.dense(x) # batch_sizex512x1x1
        # pay attention
        c1, g1 = self.attn1(self.projector(l1), g)
        c2, g2 = self.attn2(l2, g)
        c3, g3 = self.attn3(l3, g)
        g = torch.cat((g1,g2,g3), dim=1) # batch_sizexC
        # classification layer
        x = self.classify(g) # batch_sizexnum_classes
        
        return [x, c1, c2, c3]

In [5]:
num_aug = 3
im_size = 32
transform_train = transforms.Compose([
    transforms.RandomCrop(im_size, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])
train_set = torchvision.datasets.CIFAR100(root='../../../data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=8) #, worker_init_fn=_init_fn)
test_set = torchvision.datasets.CIFAR100(root='../../../data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, num_workers=5)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
model = AttnVGG(im_size=im_size, num_classes=100, normalize_attn=False)
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
lr_lambda = lambda epoch : np.power(0.5, int(epoch/25))
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

In [7]:
step = 0
log_freq = 30
epochs = 100
save_freq = 10
writer = SummaryWriter('./runs/attempt_7')
for epoch in range(epochs):
    # adjust learning rate
    scheduler.step()
    writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], epoch)
    print("\nepoch %d learning rate %f\n" % (epoch, optimizer.param_groups[0]['lr']))
    # run for one epoch
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()

        # warm up
        model.train()
        model.zero_grad()
        optimizer.zero_grad()

        # forward
        pred, _, _, _ = model(inputs)

        # backward
        loss = criterion(pred, labels)
        loss.backward()
        optimizer.step()

        # display results
        if i % log_freq == 0:
            model.eval()
            pred, __, __, __ = model(inputs)
            predict = torch.argmax(pred, 1)
            total = labels.size(0)
            correct = torch.eq(predict, labels).sum().double().item()
            accuracy = correct / total
            writer.add_scalar('train/loss', loss.item(), step)
            writer.add_scalar('train/accuracy', accuracy, step)

            print("[epoch %d][%d/%d] loss %.4f accuracy %.2f%%"
                % (epoch, i, len(train_loader)-1, loss.item(), (100*accuracy)))
        
        step += 1

    if epoch % save_freq == 0:
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, 
            "./checkpoints/attn-net_epoch_%s.pth" % epoch
        )

    print('-'*40)
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        # log scalars
        for i, data in enumerate(test_loader):
            images_test, labels_test = data
            images_test, labels_test = images_test.cuda(), labels_test.cuda()
            pred_test, _, _, _ = model(images_test)
            predict = torch.argmax(pred_test, 1)
            total += labels_test.size(0)
            correct += torch.eq(predict, labels_test).sum().double().item()
        writer.add_scalar('test/accuracy', correct/total, epoch)
        print("\n[epoch %d] accuracy on test data: %.2f%%\n" % (epoch, 100*correct/total))
    
    print('-'*40)

[60/390] loss 0.1156 accuracy 98.44%
[epoch 71][90/390] loss 0.1252 accuracy 100.00%
[epoch 71][120/390] loss 0.1499 accuracy 98.44%
[epoch 71][150/390] loss 0.1691 accuracy 99.22%
[epoch 71][180/390] loss 0.1174 accuracy 99.22%
[epoch 71][210/390] loss 0.1216 accuracy 96.88%
[epoch 71][240/390] loss 0.1416 accuracy 96.09%
[epoch 71][270/390] loss 0.2021 accuracy 97.66%
[epoch 71][300/390] loss 0.1426 accuracy 94.53%
[epoch 71][330/390] loss 0.3314 accuracy 88.28%
[epoch 71][360/390] loss 0.2231 accuracy 97.66%
[epoch 71][390/390] loss 0.1460 accuracy 98.75%
----------------------------------------

[epoch 71] accuracy on test data: 67.49%

----------------------------------------

epoch 72 learning rate 0.025000

[epoch 72][0/390] loss 0.1722 accuracy 96.09%
[epoch 72][30/390] loss 0.1870 accuracy 96.09%
[epoch 72][60/390] loss 0.0925 accuracy 90.62%
[epoch 72][90/390] loss 0.1590 accuracy 97.66%
[epoch 72][120/390] loss 0.2195 accuracy 98.44%
[epoch 72][150/390] loss 0.0955 accuracy 

KeyboardInterrupt: 

In [None]:
# step = 0
# log_freq = 10
# writer = SummaryWriter('./runs/attempt_1')
# for epoch in range(300):
#     # images_disp = []
#     # adjust learning rate
#     scheduler.step()
#     writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], epoch)
#     print("\nepoch %d learning rate %f\n" % (epoch, optimizer.param_groups[0]['lr']))
#     # run for one epoch
#     for i, data in enumerate(train_loader):
#         inputs, labels = data
#         inputs, labels = inputs.cuda(), labels.cuda()

#         # warm up
#         model.train()
#         model.zero_grad()
#         optimizer.zero_grad()

#         # if (aug == 0) and (i == 0): # archive images in order to save to logs
#         #     images_disp.append(inputs[0:36,:,:,:])

#         # forward
#         pred, _, _, _ = model(inputs)

#         # backward
#         loss = criterion(pred, labels)
#         loss.backward()
#         optimizer.step()

#         # display results
#         if i % log_freq == 0:
#             model.eval()
#             pred, __, __, __ = model(inputs)
#             predict = torch.argmax(pred, 1)
#             total = labels.size(0)
#             correct = torch.eq(predict, labels).sum().double().item()
#             accuracy = correct / total
#             writer.add_scalar('train/loss', loss.item(), step)
#             writer.add_scalar('train/accuracy', accuracy, step)

#             print("[epoch %d][aug %d/%d][%d/%d] loss %.4f accuracy %.2f%% running avg accuracy"
#                 % (epoch, aug, num_aug-1, i, len(trainloader)-1, loss.item(), (100*accuracy)))
        
#         step += 1

#     # the end of each epoch: test & log
#     print('-'*20)
#     # torch.save(model.state_dict(), os.path.join(opt.outf, 'net.pth'))
#     # if epoch == opt.epochs / 2:
#     #     torch.save(model.state_dict(), os.path.join(opt.outf, 'net%d.pth' % epoch))
    
#     model.eval()
#     total = 0
#     correct = 0
#     with torch.no_grad():
#         # log scalars
#         for i, data in enumerate(test_loader):
#             images_test, labels_test = data
#             images_test, labels_test = images_test.cuda(), labels_test.cuda()
#             # if i == 0: # archive images in order to save to logs
#             #     images_disp.append(inputs[0:36,:,:,:])
#             pred_test, _, _, _ = model(images_test)
#             predict = torch.argmax(pred_test, 1)
#             total += labels_test.size(0)
#             correct += torch.eq(predict, labels_test).sum().double().item()
#         writer.add_scalar('test/accuracy', correct/total, epoch)
#         print("\n[epoch %d] accuracy on test data: %.2f%%\n" % (epoch, 100*correct/total))

#         # log images
#         # if opt.log_images:
#         #     print('\nlog images ...\n')
#         #     I_train = utils.make_grid(images_disp[0], nrow=6, normalize=True, scale_each=True)
#         #     writer.add_image('train/image', I_train, epoch)
#         #     if epoch == 0:
#         #         I_test = utils.make_grid(images_disp[1], nrow=6, normalize=True, scale_each=True)
#         #         writer.add_image('test/image', I_test, epoch)
#         # if opt.log_images and (not opt.no_attention):
#         #     print('\nlog attention maps ...\n')
#         #     # base factor
#         #     if opt.attn_mode == 'before':
#         #         min_up_factor = 1
#         #     else:
#         #         min_up_factor = 2
#         #     # sigmoid or softmax
#         #     if opt.normalize_attn:
#         #         vis_fun = visualize_attn_softmax
#         #     else:
#         #         vis_fun = visualize_attn_sigmoid
#         #     # training data
#         #     __, c1, c2, c3 = model(images_disp[0])
#         #     if c1 is not None:
#         #         attn1 = vis_fun(I_train, c1, up_factor=min_up_factor, nrow=6)
#         #         writer.add_image('train/attention_map_1', attn1, epoch)
#         #     if c2 is not None:
#         #         attn2 = vis_fun(I_train, c2, up_factor=min_up_factor*2, nrow=6)
#         #         writer.add_image('train/attention_map_2', attn2, epoch)
#         #     if c3 is not None:
#         #         attn3 = vis_fun(I_train, c3, up_factor=min_up_factor*4, nrow=6)
#         #         writer.add_image('train/attention_map_3', attn3, epoch)
#         #     # test data
#         #     __, c1, c2, c3 = model(images_disp[1])
#         #     if c1 is not None:
#         #         attn1 = vis_fun(I_test, c1, up_factor=min_up_factor, nrow=6)
#         #         writer.add_image('test/attention_map_1', attn1, epoch)
#         #     if c2 is not None:
#         #         attn2 = vis_fun(I_test, c2, up_factor=min_up_factor*2, nrow=6)
#         #         writer.add_image('test/attention_map_2', attn2, epoch)
#         #     if c3 is not None:
#         #         attn3 = vis_fun(I_test, c3, up_factor=min_up_factor*4, nrow=6)
#         #         writer.add_image('test/attention_map_3', attn3, epoch)