In [148]:
!pip install datasets




[notice] A new release of pip is available: 23.1.2 -> 23.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [149]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from datasets import *

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import os

import argparse
from matplotlib import pyplot as plt

from sklearn.metrics import roc_auc_score

import nltk
nltk.download('punkt')
from nltk import word_tokenize
from collections import Counter
import operator

[nltk_data] Downloading package punkt to C:\Users\Abhishek
[nltk_data]     Goyal\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [150]:
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?

In [151]:
from datasets import load_dataset

dataset=load_dataset("flaviagiammarino/vqa-rad")

train_images, train_ques, train_ans = dataset['train']['image'], dataset['train']['question'], dataset['train']['answer']
val_images, val_ques, val_ans = dataset['test']['image'], dataset['test']['question'], dataset['test']['answer']

In [152]:
answers = train_ans + val_ans

In [153]:
# train_images = [img.resize((224, 224)).convert('RGB') for img in train_images]
# val_images = [img.resize((224, 224)).convert('RGB') for img in val_images]
# (val_images[1]).size

In [154]:
def make_word_map(answers):
    """
    Makes word vocabulary to index words

    Args:
        answers: list of answers

    Returns:
        Dict: word to num map
        Dict: num to word map
    """
    word_to_num = dict()
    num_to_word = dict()
    words = list()

    for ans in answers:
        tokens = word_tokenize(ans)
        for t in tokens:
            words.append(t)

    counter = Counter(words)
    sorted_counter = sorted(counter.items(), key=operator.itemgetter(1), reverse=True)

    for ind, key in enumerate(sorted_counter):
        word_to_num[key[0]] = ind+1
        num_to_word[ind+1] = key[0]

    word_to_num['<start>'] = len(word_to_num)+1
    num_to_word[len(num_to_word)+1] = '<start>'

    word_to_num['<end>'] = len(word_to_num)+1
    num_to_word[len(num_to_word)+1] = '<end>'

    word_to_num['<unk>'] = len(word_to_num)+1
    num_to_word[len(num_to_word)+1] = '<unk>'

    word_to_num['<pad>'] = 0
    num_to_word[0] = '<pad>'

    return word_to_num, num_to_word

word_to_num, num_to_word = make_word_map(answers)

In [155]:
def encode_answers(answers, word_to_num):
    """
    This method is to encode the tokens in an answer by the corresponding number

    Args:
        answers: the list of answers
        word_to_num: the mapping of word to number

    Returns:
        List of nums: encoded answers
        List: of answer lengths
    """
    tokenized_answers = list()
    max_len = 0
    for ans in answers:
        tokens = word_tokenize(ans)
        max_len = max(max_len, len(tokens)+2)
        tokenized_answers.append(tokens)

    encoded_ans = list()
    ans_len = list()
    for ind, tokens in enumerate(tokenized_answers):
        enc_c = [word_to_num['<start>']] + [word_to_num.get(word, word_to_num['<unk>']) for word in tokens] + [
                        word_to_num['<end>']] + [word_to_num['<pad>']] * (max_len - len(tokens))
        encoded_ans.append(enc_c)
        ans_len.append(len(tokens)+2)

    encoded_ans = np.array(encoded_ans)
    ans_len = np.array(ans_len)

    return encoded_ans, ans_len

In [156]:
#image processing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_images = [transform(img.convert('RGB')) for img in train_images]
val_images = [transform(img.convert('RGB')) for img in val_images]


In [157]:
class ResNet152Encoder(nn.Module):
    def __init__(self):
        super(ResNet152Encoder, self).__init__()
        # Load pre-trained ResNet-152
        resnet152 = models.resnet152(pretrained=True)
        self.features = nn.Sequential(*list(resnet152.children())[:-1])

    def forward(self, x):
        x = self.features(x)
        return x
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            # Start with the feature map of size [batch_size, 2048, 1, 1]
            nn.ConvTranspose2d(2048, 1024, kernel_size=7, stride=1, padding=0),  # Output size: [batch_size, 1024, 7, 7]
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),  # Output size: [batch_size, 512, 14, 14]
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # Output size: [batch_size, 256, 28, 28]
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Output size: [batch_size, 128, 56, 56]
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Output size: [batch_size, 64, 112, 112]
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Output size: [batch_size, 32, 224, 224]
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 3, kernel_size=1, stride=1, padding=0),  # Output size: [batch_size, 3, 224, 224]
            nn.Sigmoid()  # Use Sigmoid to scale the output to the range [0, 1]
        )

    def forward(self, x):
        x = self.decoder(x)
        return x


class ResNet152Autoencoder(nn.Module):
    def __init__(self):
        super(ResNet152Autoencoder, self).__init__()
        self.encoder = ResNet152Encoder()

        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        #x = x.view(x.size(0), 2048, 1, 1)
        x = self.decoder(x)
        return x

In [172]:
class WNetDownConvBlock(nn.Module):
    r"""Performs two 3x3 2D convolutions, each followed by a ReLU and batch norm. Ends with a 2D max-pool operation."""

    def __init__(self, in_features: int, out_features: int):
        r"""
        :param in_features: Number of feature channels in the incoming data
        :param out_features: Number of feature channels in the outgoing data
        """
        super(WNetDownConvBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_features, out_features, 3),
            nn.ReLU(),
            nn.BatchNorm2d(out_features),
            nn.ReplicationPad2d(1),
            nn.Conv2d(out_features, out_features, 3),
            nn.ReLU(),
            nn.BatchNorm2d(out_features),
            nn.ReplicationPad2d(1),
        )
        self.pool = nn.MaxPool2d(2)

    def forward(self, x: Tensor) -> Tensor:
        """Pushes a set of inputs (x) through the network.

        :param x: Input values
        :return: Module outputs
        """
        feature_map = self.layers(x)
        return self.pool(feature_map), feature_map


class WNetUpConvBlock(nn.Module):
    r"""Performs two 3x3 2D convolutions, each followed by a ReLU and batch norm. Ends with a transposed convolution with a stride of 2 on the last layer. Halves features at first and third convolutions"""

    def __init__(self, in_features: int, mid_features: int, out_features: int):
        r"""
        :param in_features: Number of feature channels in the incoming data
        :param out_features: Number of feature channels in the outgoing data
        """
        super(WNetUpConvBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_features, mid_features, 3),
            nn.BatchNorm2d(mid_features),
            nn.ReLU(),
            nn.ReplicationPad2d(1),
            nn.Conv2d(mid_features, mid_features, 3),
            nn.BatchNorm2d(mid_features),
            nn.ReLU(),
            nn.ReplicationPad2d(1),
            nn.ConvTranspose2d(mid_features, out_features, 2, stride=2),
        )

    def forward(self, x: Tensor) -> Tensor:
        """Pushes a set of inputs (x) through the network.

        :param x: Input values
        :return: Module outputs
        """
        return self.layers(x)


class WNetOutputBlock(nn.Module):
    r"""Performs two 3x3 2D convolutions, each followed by a ReLU and batch Norm.
    Ending with a 1x1 convolution to map features to classes."""

    def __init__(self, in_features: int, num_classes: int):
        r"""
        :param in_features: Number of feature channels in the incoming data
        :param num_classes: Number of feature channels in the outgoing data
        """
        super(WNetOutputBlock, self).__init__()
        mid_features = int(in_features / 2)
        self.layers = nn.Sequential(
            nn.Conv2d(in_features, mid_features, 3),
            nn.BatchNorm2d(mid_features),
            nn.ReLU(),
            nn.ReplicationPad2d(1),
            nn.Conv2d(mid_features, mid_features, 3),
            nn.BatchNorm2d(mid_features),
            nn.ReLU(),
            nn.ReplicationPad2d(1),

            # 1x1 convolution to map features to classes
            nn.Conv2d(mid_features, num_classes, 1)
        )

    def forward(self, x: Tensor) -> Tensor:
        """Pushes a set of inputs (x) through the network.

        :param x: Input values
        :return: Module outputs
        """
        return self.layers(x)

# TODO: seperable convolutions

class UNetAuto(nn.Module):
    r"""UNet based architecture for image auto encoding"""

    def __init__(self, num_channels: int = 3, num_out_channels: int = 3, max_features: int = 1024):
        r"""
        :param num_channels: Number of channels in the raw image data
        :param num_out_channels: Number of channels in the output data
        """
        super(UNetAuto, self).__init__()
        if max_features not in [2048, 1024, 512, 256]:
            print('Max features restricted to [1024, 512, 256]')
            max_features =  1024
        features_5 = max_features // 2
        features_4 = features_5 // 2
        features_3 = features_4 // 2
        features_2 = features_3 // 2
        features_1 = features_2 // 2

        self.conv_block1 = WNetDownConvBlock(num_channels, features_1)
        self.conv_block2 = WNetDownConvBlock(features_1, features_2)
        self.conv_block3 = WNetDownConvBlock(features_2, features_3)
        self.conv_block4 = WNetDownConvBlock(features_3, features_4)
        self.conv_block5 = WNetDownConvBlock(features_4, features_5)
        self.flatten=nn.Flatten()
        self.down_fc6=nn.Linear(25088,2048)

        self.up_fc6=nn.Linear(2048,25088)
        self.up_unflatten=nn.Unflatten(1,(512,7,7))

        self.deconv_block1 = WNetUpConvBlock(features_5, max_features, features_5)
        self.deconv_block2 = WNetUpConvBlock(max_features, features_5, features_4)
        self.deconv_block3 = WNetUpConvBlock(features_5, features_4, features_3)
        self.deconv_block4 = WNetUpConvBlock(features_4, features_3, features_2)
        self.deconv_block5 = WNetUpConvBlock(features_3, features_2, features_1)

        self.output_block = WNetOutputBlock(features_2, num_out_channels)

    def input_embed(self,x):
        x, c1 = self.conv_block1(x)
        x, c2 = self.conv_block2(x)
        x, c3 = self.conv_block3(x)
        x, c4 = self.conv_block4(x)
        x, c5 = self.conv_block5(x)
        x=self.flatten(x)
        x=self.down_fc6(x)
        x=x.unsqueeze(2)
        x=x.unsqueeze(3)
        return x

    def forward(self, x: Tensor) -> Tensor:
        """Pushes a set of inputs (x) through the network.

        :param x: Input values
        :return: Network output Tensor
        """
        x, c1 = self.conv_block1(x)
        x, c2 = self.conv_block2(x)
        x, c3 = self.conv_block3(x)
        x, c4 = self.conv_block4(x)
        x, c5 = self.conv_block5(x)
        x=self.flatten(x)
        x=self.down_fc6(x)
        x= self.up_fc6(x)
        x=self.up_unflatten(x)
        d1 = self.deconv_block1(x)
        d2 = self.deconv_block2(torch.cat((c5, d1), dim=1))
        d3 = self.deconv_block3(torch.cat((c4, d2), dim=1))
        d4 = self.deconv_block4(torch.cat((c3, d3), dim=1))
        d5 = self.deconv_block5(torch.cat((c2, d4), dim=1))
        out = self.output_block(torch.cat((c1, d5), dim=1))

        return out

In [179]:
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
import numpy as np
import h5py

def load_image_features(images, model_name='unet'):
    # Load the pre-trained model
    if model_name=='unet':
        autoencoder = UNetAuto(max_features=1024)
        autoencoder.load_state_dict(torch.load('./model_parameters_UnetV2.pth'))
    
    else:
        autoencoder = ResNet152Encoder()
        autoencoder.load_state_dict(torch.load('./autoencoder_model.pth'))
    
    #autoencoder = autoencoder.to(device)
    
    
    
    autoencoder.eval()

    fc7_features = []
    image_id_list = []

    # Iterate over images in the list
    for i, image in enumerate(images):
        #input_tensor = preprocess(image)
        input_batch = image.unsqueeze(0) # Create a mini-batch as expected by the model

        # If you have a GPU, put everything on cuda
        if torch.cuda.is_available():
            input_batch = input_batch.to('cuda')
            autoencoder.to('cuda')

        with torch.no_grad():
            output = autoencoder.input_embed(input_batch)
        output = output.squeeze(0).permute(1, 2, 0)
        #encoded_images = encoded_images.view(encoded_images.size(0), -1)

        # Append features and image id to lists
        fc7_features.append(output.cpu().numpy())
        image_id_list.append(f'image_{i}')

        if i%100 == 0:
            print(output.shape)



    # Save features and image id list to h5 files
    with h5py.File(f'{model_name}_fc7.h5', 'w') as hf:
        hf.create_dataset("fc7_features",  data=fc7_features)
    with h5py.File(f'{model_name}_image_id_list.h5', 'w') as hf:
        hf.create_dataset("image_id_list",  data=image_id_list)

    return fc7_features, image_id_list

In [180]:
from transformers import AutoTokenizer, AutoModel
import torch

def load_ques_embed(questions, model_name='dmis-lab/biobert-base-cased-v1.1'):
    # Load pre-trained model and tokenizer
    model = AutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    embeddings = []

    # Iterate over questions
    for question in questions:
        # Tokenize the question
        inputs = tokenizer(question, return_tensors='pt')

        # Generate embeddings
        with torch.no_grad():
            outputs = model(**inputs)

        # Get the embeddings of the [CLS] token (first token)
        cls_embedding = outputs.last_hidden_state[:, 0, :].numpy()

        # Append to the list of embeddings
        embeddings.append(cls_embedding)

    return embeddings

In [181]:
def prepare_ques_img_pair(img_id_list, img_feat, ques_id_list, ques_embed):
    data = dict()
    img_id_list_dict = list_to_dict(img_id_list)
    img_feat_data = list()

    for ques_id in ques_id_list:
        img_feat_data.append(img_feat[img_id_list_dict[ques_id]])

    img_feat_data = np.array(img_feat_data)

    data['img_feat'] = img_feat_data
    data['ques_feat'] = ques_embed

    return data

In [182]:
def load_data(images, questions, answers, word_to_num, img_model, ques_model, split):
    encoded_ans, ans_len = encode_answers(answers, word_to_num)
    img_feat, image_id_list = load_image_features(images, 'unet')
    ques_embed, ques_id_list = load_ques_embed(questions, 'dmis-lab/biobert-base-cased-v1.1'), image_id_list

    data = prepare_ques_img_pair(image_id_list, img_feat, ques_id_list, ques_embed)
    data['answer'] = encoded_ans
    data['ans_len'] = ans_len

    print(data['img_feat'].shape)

    return data

In [183]:
def list_to_dict(id_list):
    id_dict = {}
    count = 0
    for obj_id in id_list:
        id_dict[obj_id] = count
        count += 1

    return id_dict

In [185]:
train_data = load_data(train_images, train_ques, train_ans, word_to_num, 'resnet152', 'bert', 'train')
val_data = load_data(val_images, val_ques, val_ans, word_to_num, 'resnet152', 'dmis-lab/biobert-base-cased-v1.1', 'val')

import pickle

# Save dictionary using pickle
with open('train_pipeline_unet.pkl', 'wb') as pickle_file:
    pickle.dump(train_data, pickle_file)

with open('val_pipeline_unet.pkl', 'wb') as pickle_file:
    pickle.dump(val_data, pickle_file)

import pickle
with open('train_pipeline_unet.pkl', 'rb') as pickle_file:
    train_data = pickle.load(pickle_file)

with open('val_pipeline_unet.pkl', 'rb') as pickle_file:
    val_data = pickle.load(pickle_file)

torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
(1793, 1, 1, 2048)
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
torch.Size([1, 1, 2048])
(451, 1, 1, 2048)


In [186]:
(torch.Tensor(train_data['img_feat'][0])).shape

torch.Size([1, 1, 2048])

In [187]:
torch.cuda.set_device(0)
print(train_data['answer'].shape)

train_data_tensor = TensorDataset(torch.Tensor(train_data['img_feat']),
                                  torch.Tensor(train_data['ques_feat']),
                                  torch.Tensor(train_data['answer']),
                                  torch.Tensor(train_data['ans_len']))
val_data_tensor = TensorDataset(torch.Tensor(val_data['img_feat']),
                                  torch.Tensor(val_data['ques_feat']),
                                  torch.Tensor(val_data['answer']),
                                  torch.Tensor(val_data['ans_len']))

train_loader = DataLoader(train_data_tensor, batch_size = 32, shuffle=True, drop_last=True)
val_loader = DataLoader(val_data_tensor, batch_size = 32, shuffle=True, drop_last=True)

(1793, 24)


In [188]:
for img, q, a, a_len in train_loader:
    print(q.shape)
    break

torch.Size([32, 1, 768])


In [189]:
class Encoder(nn.Module):
    """
    Encoder.
    """

    def __init__(self, embedding_size, LSTM_units, LSTM_layers, feat_size,
                 batch_size, global_avg_pool_size,
                 dropout = 0.3, mfb_output_dim = 5000):
        super(Encoder, self).__init__()
        self.batch_size = batch_size
        self.mfb_output_dim = mfb_output_dim
        self.feat_size = feat_size
        self.mfb_out = 1000
        self.mfb_factor = 5
        self.channel_size = global_avg_pool_size
        self.num_ques_glimpse = 2
        self.num_img_glimpse = 2

        self.LSTM = nn.LSTM(input_size=embedding_size, hidden_size=LSTM_units,
                            num_layers=LSTM_layers, batch_first=False)
        self.Dropout = nn.Dropout(p=dropout, )
        self.Softmax = nn.Softmax()

        self.Linear1_q_proj = nn.Linear(LSTM_units*self.num_ques_glimpse, self.mfb_output_dim)
        self.Conv_i_proj = nn.Conv2d(self.feat_size, self.mfb_output_dim, 1)

        self.Dropout_L = nn.Dropout(p=0.2)
        self.Dropout_M = nn.Dropout(p=0.2)
        self.Conv1_Qatt = nn.Conv2d(LSTM_units, 512, 1)
        self.Conv2_Qatt = nn.Conv2d(512, self.num_ques_glimpse, 1)
        self.Conv1_Iatt = nn.Conv2d(1000, 512, 1)
        self.Conv2_Iatt = nn.Conv2d(512, self.num_img_glimpse, 1)

        self.qatt_maps = None
        self.iatt_maps = None

    def forward(self, ques_embed, img_feat):

        # preparing image features
#         print(img_feat.shape)
        img_feat_resh = img_feat.permute(0, 3, 1, 2).contiguous()         # N x w x w x C -> N x C x w x w
#         print(img_feat_resh.shape)
        #print(img_feat_resh.shape)
        img_feat_resh = img_feat_resh.reshape(img_feat_resh.shape[0], img_feat_resh.shape[1],
                                              self.channel_size*self.channel_size)      # N x C x w*w

        # ques_embed                                         N x T x embedding_size
#         print(ques_embed.shape)
        ques_embed_resh = ques_embed.permute(1, 0, 2).contiguous()        #T x N x embedding_size
#         print(ques_embed_resh.shape)
        lstm_out, (hn, cn) = self.LSTM(ques_embed_resh)
        lstm1_droped = self.Dropout_L(lstm_out)
        lstm1_resh = lstm1_droped.permute(1, 2, 0).contiguous()                   # N x 1024 x T
        lstm1_resh2 = torch.unsqueeze(lstm1_resh, 3)                # N x 1024 x T x 1

        '''
        Question Attention
        '''
        qatt_conv1 = self.Conv1_Qatt(lstm1_resh2)                   # N x 512 x T x 1
        qatt_relu = F.relu(qatt_conv1)
        qatt_conv2 = self.Conv2_Qatt(qatt_relu)                     # N x 2 x T x 1
#        print(qatt_conv2.shape)
        qatt_conv2 = qatt_conv2.reshape(qatt_conv2.shape[0]*self.num_ques_glimpse,-1)
        qatt_softmax = self.Softmax(qatt_conv2)
        qatt_softmax = qatt_softmax.view(qatt_conv1.shape[0], self.num_ques_glimpse, -1, 1)
        self.qatt_maps = qatt_softmax
        qatt_feature_list = []
        for i in range(self.num_ques_glimpse):
            t_qatt_mask = qatt_softmax.narrow(1, i, 1)              # N x 1 x T x 1
#            print(t_qatt_mask.shape, lstm1_resh2.shape)
            t_qatt_mask = t_qatt_mask * lstm1_resh2                 # N x 1024 x T x 1
            t_qatt_mask = torch.sum(t_qatt_mask, 2, keepdim=True)   # N x 1024 x 1 x 1
            qatt_feature_list.append(t_qatt_mask)
        qatt_feature_concat = torch.cat(qatt_feature_list, 1)       # N x 2048 x 1 x 1

        '''
        Image Attention with MFB
        '''
        q_feat_resh = torch.squeeze(qatt_feature_concat)
                                  # N x 2048
        i_feat_resh = torch.unsqueeze(img_feat_resh, 3)                                   # N x 2048 x w*w x 1
        iatt_q_proj = self.Linear1_q_proj(q_feat_resh)
                                   # N x 5000
        iatt_q_resh = iatt_q_proj.view(iatt_q_proj.shape[0], self.mfb_output_dim, 1, 1)      # N x 5000 x 1 x 1

        iatt_i_conv = self.Conv_i_proj(i_feat_resh)                                     # N x 5000 x w*w x 1
        iatt_iq_eltwise = iatt_q_resh * iatt_i_conv
        iatt_iq_droped = self.Dropout_M(iatt_iq_eltwise)                                # N x 5000 x w*w x 1
        iatt_iq_permute1 = iatt_iq_droped.permute(0,2,1,3).contiguous()                              # N x w*w x 5000 x 1
        iatt_iq_resh = iatt_iq_permute1.view(iatt_iq_permute1.shape[0], self.channel_size*self.channel_size,
                                             self.mfb_out, self.mfb_factor)
        iatt_iq_sumpool = torch.sum(iatt_iq_resh, 3, keepdim=True)                      # N x w*w x 1000 x 1
        iatt_iq_permute2 = iatt_iq_sumpool.permute(0,2,1,3).contiguous()                            # N x 1000 x w*w x 1
        iatt_iq_sqrt = torch.sqrt(F.relu(iatt_iq_permute2)) - torch.sqrt(F.relu(-iatt_iq_permute2))
        iatt_iq_sqrt = torch.squeeze(iatt_iq_sqrt)
        iatt_iq_sqrt = iatt_iq_sqrt.reshape(iatt_iq_sqrt.shape[0], -1)                           # N x 1000*w*w
        iatt_iq_l2 = F.normalize(iatt_iq_sqrt)
        iatt_iq_l2 = iatt_iq_l2.view(iatt_iq_l2.shape[0], self.mfb_out, self.channel_size*self.channel_size, 1)  # N x 1000 x w*w x 1

        ## 2 conv layers 1000 -> 512 -> 2
        iatt_conv1 = self.Conv1_Iatt(iatt_iq_l2)                    # N x 512 x w*w x 1
        iatt_relu = F.relu(iatt_conv1)
        iatt_conv2 = self.Conv2_Iatt(iatt_relu)                     # N x 2 x w*w x 1
        iatt_conv2 = iatt_conv2.view(iatt_conv2.shape[0]*self.num_img_glimpse, -1)
        iatt_softmax = self.Softmax(iatt_conv2)
        iatt_softmax = iatt_softmax.view(iatt_conv1.shape[0], self.num_img_glimpse, -1, 1)
        self.iatt_maps = iatt_softmax.view(iatt_conv1.shape[0], self.num_img_glimpse, self.channel_size, self.channel_size)
        iatt_feature_list = []
        for i in range(self.num_img_glimpse):
            t_iatt_mask = iatt_softmax.narrow(1, i, 1)              # N x 1 x w*w x 1
            t_iatt_mask = t_iatt_mask * i_feat_resh                 # N x 2048 x w*w x 1
            iatt_feature_list.append(t_iatt_mask)
#         print(t_iatt_mask.shape)
        iatt_feature_concat = torch.mean(torch.stack(iatt_feature_list), dim=0)       # N x 2048 x w*w x 1
#         print(iatt_feature_concat.shape)
        iatt_feature_resh = iatt_feature_concat.view(iatt_iq_permute1.shape[0], self.channel_size,
                                                        self.channel_size, 2048)           # N x w x w x 2048

        return iatt_feature_resh

In [190]:
class Attention(nn.Module):
    """
    Attention Network.
    """

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: feature size of encoded images
        :param decoder_dim: size of decoder's RNN
        :param attention_dim: size of the attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward propagation.

        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        :return: attention weighted encoding, weights
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

In [191]:
class DecoderWithAttention(nn.Module):
    """
    Decoder.
    """

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        """
        :param attention_dim: size of attention network
        :param embed_dim: embedding size
        :param decoder_dim: size of decoder's RNN
        :param vocab_size: size of vocabulary
        :param encoder_dim: feature size of encoded images
        :param dropout: dropout
        """
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network

        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution

    def init_weights(self):
        """
        Initializes some parameters with values from the uniform distribution, for easier convergence.
        """
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        """
        Loads embedding layer with pre-trained embeddings.

        :param embeddings: pre-trained embeddings
        """
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        """
        Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).

        :param fine_tune: Allow?
        """
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.

        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :return: hidden state, cell state
        """
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward propagation.

        :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
        :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
        :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """

        device = "cuda" if torch.cuda.is_available() else "cpu"

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths; why? apparent below
        caption_lengths, sort_ind = caption_lengths.sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        # At each time-step, decode by
        # attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [192]:
class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [193]:
def accuracy(scores, targets, k):
    """
    Computes top-k accuracy, from predicted and true labels.
    :param scores: scores from the model
    :param targets: true labels
    :param k: k in top-k accuracy
    :return: top-k accuracy
    """

    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)

In [194]:
encoder = Encoder(768, 1024, 2, 2048, 32, 1)
decoder = DecoderWithAttention(1024, 1024, 1024, len(word_to_num))

In [196]:
lr=0.001
n_epochs = 30
print_every = 100

criterion = nn.CrossEntropyLoss()

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=encoder_lr)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=decoder_lr)
#    train_loss = np.zeros(n_epochs+1)
val_loss = []
val_acc = []
train_loss = []
train_acc = []

losses = AverageMeter()
top5accs = AverageMeter()

max_training_acc = 0
max_val_acc = 0

train_on_gpu = torch.cuda.is_available()
if(train_on_gpu):
    encoder.cuda()
    decoder.cuda()

for e in range(n_epochs):
    print('Epoch - ', e)
    running_acc = 0
    running_loss = 0
    counter = 0
    encoder.train()
    decoder.train()
    for img, ques, ans, ans_len in train_loader:
        counter += 1

        img = img.cuda()
        ques = ques.float().cuda()
        ans = ans.cuda().long()
        ans_len = ans_len.cuda().long()

        encoder_output = encoder( ques, img)

        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(encoder_output, ans, ans_len)

        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this

        scores = (pack_padded_sequence(scores, decode_lengths, batch_first=True)).data
        targets = (pack_padded_sequence(targets, decode_lengths, batch_first=True)).data

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        encoder_optimizer.zero_grad()
        loss.backward()

        # Update weights
        decoder_optimizer.step()
        encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))

        running_loss += loss.item()

        if counter % print_every == 0:
            print('batch no.:', (counter*100)/len(train_loader),
                  ' loss:', loss.item())

Epoch -  0
torch.Size([32, 2048, 1, 1])


  qatt_softmax = self.Softmax(qatt_conv2)
  iatt_softmax = self.Softmax(iatt_conv2)


torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32

In [198]:
encoder.eval()
decoder.eval()

val_running_loss = 0
val_running_acc = 0
for img, ques, ans, ans_len in val_loader:
    counter += 1

    img = img.cuda()
    ques = ques.float().cuda()
    ans = ans.cuda().long()
    ans_len = ans_len.cuda().long()

    encoder_output = encoder(ques, img)

    scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(encoder_output, ans, ans_len)

    targets = caps_sorted[:, 1:]

    # Remove timesteps that we didn't decode at, or are pads
    # pack_padded_sequence is an easy trick to do this
    scores = (pack_padded_sequence(scores, decode_lengths, batch_first=True)).data
    targets = (pack_padded_sequence(targets, decode_lengths, batch_first=True)).data

    # Calculate loss
    loss = criterion(scores, targets)

    # Add doubly stochastic attention regularization
    loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

torch.Size([32, 2048, 1, 1])


  qatt_softmax = self.Softmax(qatt_conv2)
  iatt_softmax = self.Softmax(iatt_conv2)


torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])
torch.Size([32, 2048, 1, 1])


In [199]:
torch.save(encoder.state_dict(),"Pipeline_Unet_Encoder.pth")
torch.save(decoder.state_dict(),"Pipeline_Unet_Decoder.pth")


In [208]:
def answer_generation_output(encoder,decoder,img,ques,max_answer_length):
    img = img.cuda()
    ques = ques.float().cuda()
    encoder_output = encoder(ques, img)

    start_token=word_to_num['<start>']
    pad_token=word_to_num['<pad>']
    answer_begining=torch.tensor(np.array([pad_token,start_token]))
    decoder_input = torch.tile(answer_begining, (img.shape[0], 1)).cuda()
    decoder_input_intial=decoder_input
    max_length = max_answer_length  # Setting desired maximum length

    # Decode sequentially
    for _ in range(max_length):

        decoder_length=torch.tensor([decoder_input.shape[1],decoder_input.shape[1]])
        predictions, caps_sorted, decode_lengths, alphas, sort_ind = decoder(encoder_output, decoder_input, decoder_length.cuda().long())  # [1] is a length of 1
        
        predicted_index = torch.argmax(predictions, dim=2)
        #final_pred = (pack_padded_sequence(predictions, decode_lengths, batch_first=True)).data.unsqueeze(1)
        #final_pred=torch.argmax(final_pred, dim=2)
        
        decoder_input = torch.cat([decoder_input_intial, predicted_index], dim=1)
     
    generated_answer=[]
    for answer in predicted_index:
            a=''
            for i in answer:
                if num_to_word[i.item()]!='<end>':
                    a+=num_to_word[i.item()]+" "
            generated_answer.append(a)


    return generated_answer #answer_result

In [209]:
val_loader1 = DataLoader(val_data_tensor, batch_size = 32, shuffle=True)
for img, ques, ans, ans_len in val_loader1:
  generated_answer=answer_generation_output(encoder,decoder,img,ques,max_answer_length=18)
  ground_truth=[]
  for answer in ans:
    a=''
    for i in answer:
        if num_to_word[i.item()]!='<end>' and num_to_word[i.item()]!='<pad>' and num_to_word[i.item()]!='<start>':
            a+=num_to_word[i.item()]+" "
    ground_truth.append(a)
  break

for i in range(20):

  print("Generated",generated_answer[i])
  print("True", ground_truth[i])

torch.Size([32, 2048, 1, 1])


  qatt_softmax = self.Softmax(qatt_conv2)
  iatt_softmax = self.Softmax(iatt_conv2)


Generated yes 
True yes 
Generated yes 
True lateral and third ventricular hydrocephalus 
Generated <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 
True no 
Generated <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 
True right hemisphere 
Generated <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 
True yes 
Generated <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 
True yes 
Generated <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 
True fat 
Generated <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 
True right lung 
Generated <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>