In [129]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from datasets import *


import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import os

import argparse
from matplotlib import pyplot as plt

from sklearn.metrics import roc_auc_score

import nltk
nltk.download('punkt')
from nltk import word_tokenize
from collections import Counter
import operator

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\joyrb\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
attn_reg = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper

In [3]:
from datasets import load_dataset

train_dataset = load_dataset("flaviagiammarino/path-vqa", split='train[:5000]')
val_dataset = load_dataset("flaviagiammarino/path-vqa", split='validation[:1000]')

train_images, train_ques, train_ans = train_dataset['image'], train_dataset['question'], train_dataset['answer']
val_images, val_ques, val_ans = val_dataset['image'], val_dataset['question'], val_dataset['answer']

In [4]:
answers = train_ans + val_ans
train_images = [img.resize((224, 224)).convert('RGB') for img in train_images]
val_images = [img.resize((224, 224)).convert('RGB') for img in val_images]

In [5]:
(val_images[1]).size

(224, 224)

In [6]:
def map_words(answers):
    """
    Makes word vocabulary to index bag_of_words
    
    Args:
        answers: list of answers
    
    Returns:
        Dict: word to num map
        Dict: num to word map
    """
    word_to_int = dict()
    int_to_word = dict()
    bag_of_words = list()
    
    for ans in answers:
        tokens = word_tokenize(ans)
        for t in tokens:
            bag_of_words.append(t)
            
    counter = Counter(bag_of_words)
    sorted_counter = sorted(counter.items(), key=operator.itemgetter(1), reverse=True)
    
    for ind, key in enumerate(sorted_counter):
        word_to_int[key[0]] = ind+1
        int_to_word[ind+1] = key[0]
        
    word_to_int['<start>'] = len(word_to_int)+1
    int_to_word[len(int_to_word)+1] = '<start>'
    
    word_to_int['<end>'] = len(word_to_int)+1
    int_to_word[len(int_to_word)+1] = '<end>'
    
    word_to_int['<unk>'] = len(word_to_int)+1
    int_to_word[len(int_to_word)+1] = '<unk>'
    
    word_to_int['<pad>'] = 0
    int_to_word[0] = '<pad>'
    
    return word_to_int, int_to_word

word_to_int, int_to_word = map_words(answers)

In [7]:
def encode_answers(answers, word_to_int):
    """
    This method is to encode the tokens in an answer by the corresponding number
    
    Args:
        answers: the list of answers
        word_to_int: the mapping of word to number
        
    Returns:
        List of nums: encoded answers
        List: of answer lengths
    """
    tokenized_answers = list()
    max_len = 0
    for answer in answers:
        tokens = word_tokenize(answer)
        max_len = max(max_len, len(tokens)+2)
        tokenized_answers.append(tokens)
    
    encoded_ans = list()
    ans_len = list()
    for ind, tokens in enumerate(tokenized_answers):
        enc_c = [word_to_int['<start>']] + [word_to_int.get(word, word_to_int['<unk>']) for word in tokens] + [
                        word_to_int['<end>']] + [word_to_int['<pad>']] * (max_len - len(tokens))
        encoded_ans.append(enc_c)
        ans_len.append(len(tokens)+2)
        
    encoded_ans = np.array(encoded_ans)
    ans_len = np.array(ans_len)
        
    return encoded_ans, ans_len

In [90]:
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
import numpy as np
import h5py

def load_image_features(images, model_name='resnet152'):
    # Load the pre-trained model
    model = models.resnet152(pretrained=True)
    model = torch.nn.Sequential(*(list(model.children())[:-1])) # Remove the last layer
    model.eval()

    # Image transformations
    preprocess = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    fc7_features = []
    image_id_list = []

    # Iterate over images in the list
    for i, image in enumerate(images):
        input_tensor = preprocess(image)
        input_batch = input_tensor.unsqueeze(0) # Create a mini-batch as expected by the model

        # If you have a GPU, put everything on cuda
        if torch.cuda.is_available():
            input_batch = input_batch.to('cuda')
            model.to('cuda')

        with torch.no_grad():
            output = model(input_batch)
        
        output = output.squeeze(0).permute(1, 2, 0)
        
        # Append features and image id to lists
        fc7_features.append(output.cpu().numpy())
        image_id_list.append(f'image_{i}')
        
        if i%100 == 0:
            print(output.shape) 
        
    # Save features and image id list to h5 files
    with h5py.File(f'{model_name}_fc7.h5', 'w') as hf:
        hf.create_dataset("fc7_features",  data=fc7_features)
    with h5py.File(f'{model_name}_image_id_list.h5', 'w') as hf:
        hf.create_dataset("image_id_list",  data=image_id_list)

    return fc7_features, image_id_list

In [69]:
from transformers import AutoTokenizer, AutoModel
import torch

def load_ques_embed(questions, model_name='bert-base-cased'):
    # Load pre-trained model and tokenizer
    model = AutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    embeddings = []

    # Iterate over questions
    for question in questions:
        # Tokenize the question
        inputs = tokenizer(question, return_tensors='pt')

        # Generate embeddings
        with torch.no_grad():
            outputs = model(**inputs)

        # Get the embeddings of the [CLS] token (first token)
        cls_embedding = outputs.last_hidden_state[:, 0, :].numpy()

        # Append to the list of embeddings
        embeddings.append(cls_embedding)

    return embeddings

In [70]:
def make_ques_img_pair(img_id_list, img_feat, ques_id_list, ques_embed):
    data = dict()
    img_id_list_dict = list_to_dict(img_id_list)
    img_feat_data = list()
    
    for ques_id in ques_id_list:
        img_feat_data.append(img_feat[img_id_list_dict[ques_id]])
        
    img_feat_data = np.array(img_feat_data)
    
    data['img_feat'] = img_feat_data
    data['ques_feat'] = ques_embed
    
    return data

In [71]:
def load_data(images, questions, answers, word_to_int, img_model, ques_model, split):
    encoded_ans, ans_len = encode_answers(answers, word_to_int)
    img_feat, image_id_list = load_image_features(images, 'resnet152')
    ques_embed, ques_id_list = load_ques_embed(questions, 'bert-base-cased'), image_id_list
    
    data = make_ques_img_pair(image_id_list, img_feat, ques_id_list, ques_embed)
    data['answer'] = encoded_ans
    data['ans_len'] = ans_len
    
    print(data['img_feat'].shape)
    
    return data

In [72]:
def list_to_dict(id_list):
    id_dict = {}
    count = 0
    for obj_id in id_list:
        id_dict[obj_id] = count
        count += 1
    
    return id_dict

In [91]:
train_data = load_data(train_images, train_ques, train_ans, word_to_int, 'resnet152', 'bert', 'train')
val_data = load_data(val_images, val_ques, val_ans, word_to_int, 'resnet152', 'bert-base-cased', 'val')

torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])


In [92]:
val_data.keys()

dict_keys(['img_feat', 'ques_feat', 'answer', 'ans_len'])

In [93]:
(torch.Tensor(train_data['img_feat'][0])).shape

torch.Size([1, 1, 2048])

In [94]:
torch.cuda.set_device(0)
print(train_data['answer'].shape)
    
train_data_tensor = TensorDataset(torch.Tensor(train_data['img_feat']),
                                  torch.Tensor(train_data['ques_feat']),
                                  torch.Tensor(train_data['answer']),
                                  torch.Tensor(train_data['ans_len']))
val_data_tensor = TensorDataset(torch.Tensor(val_data['img_feat']),
                                  torch.Tensor(val_data['ques_feat']),
                                  torch.Tensor(val_data['answer']),
                                  torch.Tensor(val_data['ans_len']))

train_loader = DataLoader(train_data_tensor, batch_size = 32, shuffle=True)
val_loader = DataLoader(val_data_tensor, batch_size = 32, shuffle=True)

(5000, 38)


In [126]:
for img, q, a, a_len in train_loader:
    print(q.shape)
    break

torch.Size([32, 1, 768])


In [228]:
class Encoder(nn.Module):
    """
    Encoder.
    """

    def __init__(self, embedding_size, LSTM_units, LSTM_layers, feat_size,
                 batch_size, global_avg_pool_size, 
                 dropout = 0.3, mfb_output_dim = 5000):
        super(Encoder, self).__init__()
        self.batch_size = batch_size
        self.mfb_output_dim = mfb_output_dim
        self.feat_size = feat_size
        self.mfb_out = 1000
        self.mfb_factor = 5
        self.channel_size = global_avg_pool_size
        self.ques_glimse_num = 2
        self.img_glimse_num = 2
        
        self.LSTM = nn.LSTM(input_size=embedding_size, hidden_size=LSTM_units, 
                            num_layers=LSTM_layers, batch_first=False)
        self.Dropout = nn.Dropout(p=dropout, )
        self.Softmax = nn.Softmax()
        
        self.fc1_q_proj = nn.Linear(LSTM_units*self.ques_glimse_num, self.mfb_output_dim)
        self.Conv_i_proj = nn.Conv2d(self.feat_size, self.mfb_output_dim, 1)

        self.Dropout_L = nn.Dropout(p=0.2)
        self.Dropout_M = nn.Dropout(p=0.2)
        self.conv1_q_attn = nn.Conv2d(LSTM_units, 512, 1)
        self.conv2_q_attn = nn.Conv2d(512, self.ques_glimse_num, 1)
        self.conv1_i_attn = nn.Conv2d(1000, 512, 1)
        self.conv2_i_attn = nn.Conv2d(512, self.img_glimse_num, 1)
        
        self.q_attn_maps = None
        self.i_attn_maps = None
        
    def forward(self, ques_embed, img_feat):
        
        reshaped_img_feat = img_feat.permute(0, 3, 1, 2).contiguous()         # N x w x w x C -> N x C x w x w
        reshaped_img_feat = reshaped_img_feat.reshape(reshaped_img_feat.shape[0], reshaped_img_feat.shape[1], 
                                              self.channel_size*self.channel_size)      # N x C x w*w
        
        # ques_embed                                         N x T x embedding_size
        reshaped_ques_feat = ques_embed.permute(1, 0, 2).contiguous()        #T x N x embedding_size
        lstm_out, (hn, cn) = self.LSTM(reshaped_ques_feat)
        lstm1_droped = self.Dropout_L(lstm_out)
        lstm1_resh = lstm1_droped.permute(1, 2, 0).contiguous()                   # N x 1024 x T
        lstm1_resh2 = torch.unsqueeze(lstm1_resh, 3)                # N x 1024 x T x 1
        
        '''
        Question Attention
        '''        
        q_attn_conv1 = self.conv1_q_attn(lstm1_resh2)                   # N x 512 x T x 1
        q_attn_relu = F.relu(qatt_conv1)
        q_attn_conv2 = self.conv2_q_attn(q_attn_relu)                     # N x 2 x T x 1
        q_attn_conv2 = q_attn_conv2.reshape(q_attn_conv2.shape[0]*self.ques_glimse_num,-1)
        q_attn_softmax = self.Softmax(q_attn_conv2)
        q_attn_softmax = q_attn_softmax.view(qatt_conv1.shape[0], self.ques_glimse_num, -1, 1)
        self.q_attn_maps = q_attn_softmax
        q_attn_feat_list = []
        for i in range(self.ques_glimse_num):
            t_qatt_mask = q_attn_softmax.narrow(1, i, 1)              # N x 1 x T x 1
            t_qatt_mask = t_qatt_mask * lstm1_resh2                 # N x 1024 x T x 1
            t_qatt_mask = torch.sum(t_qatt_mask, 2, keepdim=True)   # N x 1024 x 1 x 1
            q_attn_feat_list.append(t_qatt_mask)
        qatt_feature_concat = torch.cat(q_attn_feat_list, 1)       # N x 2048 x 1 x 1
        
        '''
        Image Attention with MFB
        '''
        q_feat_resh = torch.squeeze(qatt_feature_concat)                                # N x 2048
        i_feat_resh = torch.unsqueeze(reshaped_img_feat, 3)                                   # N x 2048 x w*w x 1
        i_attn_q_proj = self.fc1_q_proj(q_feat_resh)                                  # N x 5000
        i_attn_q_reshape = i_attn_q_proj.view(i_attn_q_proj.shape[0], self.mfb_output_dim, 1, 1)      # N x 5000 x 1 x 1
        i_attn_i_conv = self.Conv_i_proj(i_feat_resh)                                     # N x 5000 x w*w x 1
        i_attn_q_conv = i_attn_q_reshape * i_attn_i_conv
        i_attn_q_conv_dropped = self.Dropout_M(i_attn_q_conv)                                # N x 5000 x w*w x 1
        i_attn_q_conv_permute1 = i_attn_q_conv_dropped.permute(0,2,1,3).contiguous()                              # N x w*w x 5000 x 1
        i_attn_q_conv_reshape = i_attn_q_conv_permute1.view(i_attn_q_conv_permute1.shape[0], self.channel_size*self.channel_size, 
                                             self.mfb_out, self.mfb_factor)
        i_attn_q_conv_sums = torch.sum(i_attn_q_conv_reshape, 3, keepdim=True)                      # N x w*w x 1000 x 1 
        i_attn_q_conv_permute2 = i_attn_q_conv_sums.permute(0,2,1,3).contiguous()                            # N x 1000 x w*w x 1
        i_attn_q_conv_sqrt = torch.sqrt(F.relu(i_attn_q_conv_permute2)) - torch.sqrt(F.relu(-i_attn_q_conv_permute2))
        i_attn_q_conv_sqrt = torch.squeeze(i_attn_q_conv_sqrt)
        i_attn_q_conv_sqrt = i_attn_q_conv_sqrt.reshape(i_attn_q_conv_sqrt.shape[0], -1)                           # N x 1000*w*w
        i_attn_q_conv_l2 = F.normalize(i_attn_q_conv_sqrt)
        i_attn_q_conv_l2 = i_attn_q_conv_l2.view(i_attn_q_conv_l2.shape[0], self.mfb_out, self.channel_size*self.channel_size, 1)  # N x 1000 x w*w x 1
        
        ## 2 conv layers 1000 -> 512 -> 2
        i_attn_conv1 = self.conv1_i_attn(i_attn_q_conv_l2)                    # N x 512 x w*w x 1
        i_attn_relu = F.relu(i_attn_conv1)
        i_attn_conv2 = self.conv2_i_attn(i_attn_relu)                     # N x 2 x w*w x 1
        i_attn_conv2 = i_attn_conv2.view(i_attn_conv2.shape[0]*self.img_glimse_num, -1)
        i_attn_softmax = self.Softmax(i_attn_conv2)
        i_attn_softmax = i_attn_softmax.view(i_attn_conv1.shape[0], self.img_glimse_num, -1, 1)
        self.i_attn_maps = i_attn_softmax.view(i_attn_conv1.shape[0], self.img_glimse_num, self.channel_size, self.channel_size)
        iatt_feature_list = []
        for i in range(self.img_glimse_num):
            t_iatt_mask = i_attn_softmax.narrow(1, i, 1)              # N x 1 x w*w x 1
            t_iatt_mask = t_iatt_mask * i_feat_resh                 # N x 2048 x w*w x 1
            iatt_feature_list.append(t_iatt_mask)
        iatt_feature_concat = torch.mean(torch.stack(iatt_feature_list), dim=0)       # N x 2048 x w*w x 1
        iatt_feature_resh = iatt_feature_concat.view(i_attn_q_conv_permute1.shape[0], self.channel_size, 
                                                        self.channel_size, 2048)           # N x w x w x 2048
        
        return iatt_feature_resh

In [229]:
class Attention(nn.Module):
    """
    Attention Network.
    """

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        encoder_dim: feature size of encoded images
        decoder_dim: size of decoder's RNN
        attention_dim: size of the attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward propagation.

        encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        attention weighted encoding, weights
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

In [230]:
class DecoderWithAttention(nn.Module):
    """
    Decoder.
    """
    
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        """
        :param attention_dim: size of attention network
        :param embed_dim: embedding size
        :param decoder_dim: size of decoder's RNN
        :param vocab_size: size of vocabulary
        :param encoder_dim: feature size of encoded images
        :param dropout: dropout
        """
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network

        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution

    def init_weights(self):
        """
        Initializes some parameters with values from the uniform distribution, for easier convergence.
        """
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        """
        Loads embedding layer with pre-trained embeddings.

        embeddings: pre-trained embeddings
        """
        self.embedding.weight = nn.Parameter(embeddings)

    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.

        encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)

        return: hidden state, cell state
        """
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward propagation.

        encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
        encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
        caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
        
        return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths
        caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [231]:
class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [232]:
def accuracy(scores, targets, k):
    """
    Computes top-k accuracy, from predicted and true labels.
    scores: scores from the model
    targets: true labels
    k: k in top-k accuracy
    
    return: top-k accuracy
    """

    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)

In [233]:
encoder = Encoder(768, 1024, 2, 2048, 32, 1)
decoder = DecoderWithAttention(1024, 1024, 1024, len(word_to_int))

In [234]:
lr=0.001
n_epochs = 100
print_every = 100

criterion = nn.CrossEntropyLoss()   

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=encoder_lr)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=decoder_lr)

val_loss = []
val_acc = []
train_loss = []
train_acc = []

losses = AverageMeter()
top5accs = AverageMeter()

max_training_acc = 0
max_val_acc = 0

train_on_gpu = torch.cuda.is_available()
if(train_on_gpu):
    encoder.cuda()
    decoder.cuda()

for e in range(n_epochs):
    print('Epoch - ', e)
    running_acc = 0
    running_loss = 0
    counter = 0
    encoder.train()
    decoder.train()
    for img, ques, ans, ans_len in train_loader:
        counter += 1

        img = img.cuda()
        ques = ques.float().cuda()
        ans = ans.cuda().long()
        ans_len = ans_len.cuda().long()

        encoder_output = encoder( ques, img)

        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(encoder_output, ans, ans_len)

        targets = caps_sorted[:, 1:]
        
        scores = (pack_padded_sequence(scores, decode_lengths, batch_first=True)).data
        targets = (pack_padded_sequence(targets, decode_lengths, batch_first=True)).data

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += attn_reg * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        encoder_optimizer.zero_grad()
        loss.backward()

        # Update weights
        decoder_optimizer.step()
        encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))

        running_loss += loss.item()

        if counter % print_every == 0:
            print('batch no.:', (counter*100)/len(train_loader), 
                  ' loss:', loss.item())

Epoch -  0
batch no.: 63.69426751592356  loss: 17.62575340270996
Epoch -  1
batch no.: 63.69426751592356  loss: 10.851142883300781
Epoch -  2


KeyboardInterrupt: 

In [237]:
encoder.eval()
decoder.eval()

val_running_loss = 0
val_running_acc = 0
for img, ques, ans, ans_len in val_loader:
    counter += 1

    img = img.cuda()
    ques = ques.float().cuda()
    ans = ans.cuda().long()
    ans_len = ans_len.cuda().long()

    encoder_output = encoder(ques, img)

    scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(encoder_output, ans, ans_len)

    targets = caps_sorted[:, 1:]

    scores = (pack_padded_sequence(scores, decode_lengths, batch_first=True)).data
    targets = (pack_padded_sequence(targets, decode_lengths, batch_first=True)).data

    # Calculate loss
    loss = criterion(scores, targets)

    loss += attn_reg * ((1. - alphas.sum(dim=1)) ** 2).mean()

In [None]:
def answer_generation_output(encoder,decoder,img,ques,max_answer_length):
    img = img.cuda()
    ques = ques.float().cuda()
    encoder_output = encoder(ques, img)

    start_token=word_to_int['<start>']
    pad_token=word_to_int['<pad>']
    answer_begining=torch.tensor(np.array([pad_token,start_token]))
    decoder_input = torch.tile(answer_begining, (img.shape[0], 1)).cuda()
    decoder_input_intial=decoder_input
    max_length = max_answer_length  # Setting desired maximum length

    # Decode sequentially
    for _ in range(max_length):

        decoder_length=torch.tensor([decoder_input.shape[1],decoder_input.shape[1]])
        predictions, caps_sorted, decode_lengths, alphas, sort_ind = decoder(encoder_output, decoder_input, decoder_length.cuda().long())  # [1] is a length of 1
        
        predicted_index = torch.argmax(predictions, dim=2)
        
        decoder_input = torch.cat([decoder_input_intial, predicted_index], dim=1)
     
    generated_answer=[]
    for answer in predicted_index:
            a=''
            for i in answer:
                if int_to_word[i.item()]!='<end>':
                    a+=int_to_word[i.item()]+" "
            generated_answer.append(a)


    return generated_answer #answer_result

In [None]:
val_loader1 = DataLoader(val_data_tensor, batch_size = 32, shuffle=True)
for img, ques, ans, ans_len in val_loader1:
  generated_answer=answer_generation_output(encoder,decoder,img,ques,max_answer_length=18)
  ground_truth=[]
  for answer in ans:
    a=''
    for i in answer:
        if int_to_word[i.item()]!='<end>' and int_to_word[i.item()]!='<pad>' and int_to_word[i.item()]!='<start>':
            a+=int_to_word[i.item()]+" "
    ground_truth.append(a)
  break

for i in range(20):

  print("Generated",generated_answer[i])
  print("True", ground_truth[i])