In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import import_ipynb
from utils.dynamic_rnn import DynamicRNN

importing Jupyter notebook from /home/abc/Desktop/VisDial3/visdial/utils/dynamic_rnn.ipynb


In [3]:
class LateFusionEncoder(nn.Module):
    def __init__(self, config, vocabulary):
        super().__init__()
        self.config = config

        # To embed the words in question & dialog history
        self.word_embed = nn.Embedding(
            len(vocabulary),
            config["word_embedding_size"],
            padding_idx=vocabulary.PAD_INDEX
        )

        # To apply LSTM on dialog history
        self.hist_rnn = nn.LSTM(
            config["word_embedding_size"],
            config["lstm_hidden_size"],
            config["lstm_num_layers"],
            batch_first=True,
            dropout=config["dropout"],
        )

        # To apply LSTM on question at time t
        self.ques_rnn = nn.LSTM(
            config["word_embedding_size"],
            config["lstm_hidden_size"],
            config["lstm_num_layers"],
            batch_first=True,
            dropout=config["dropout"],
        )

        # To avoid overfitting
        self.dropout = nn.Dropout(p=config["dropout"]) 

        # For reference resolution in question & dialog history
        self.hist_rnn = DynamicRNN(self.hist_rnn)
        self.ques_rnn = DynamicRNN(self.ques_rnn)

        # Project or convert 2048 image-feature vectors to 512 size lstm vectors
        self.image_features_projection = nn.Linear(
            config["img_feature_size"], config["lstm_hidden_size"]
        )

        # Calculating Attention weights
        self.attention_proj = nn.Linear(config["lstm_hidden_size"], 1)

        # Defining a fusion size by combining the image, question & history
        fusion_size = (
            config["img_feature_size"] + config["lstm_hidden_size"] * 2
        )

        # Combinning the img, ques & history
        self.fusion = nn.Linear(fusion_size, config["lstm_hidden_size"])

        # Initilizing Weights
        nn.init.kaiming_uniform_(self.image_features_projection.weight)
        nn.init.constant_(self.image_features_projection.bias, 0)
        nn.init.kaiming_uniform_(self.fusion.weight)
        nn.init.constant_(self.fusion.bias, 0)
    
    def forward(self, batch):
        # Get image features, ques & hist from the batch
        img = batch["img_feat"]
        ques = batch["ques"]
        hist = batch["hist"]

        # Preparing the ques to embed and apply LSTM on it
        # num_rounds = num of dialog rounds
        batch_size, num_rounds, max_sequence_length = ques.size()

        # embed question but we need to convert to 2D shape for better embedding
        ques = ques.view(batch_size*num_rounds, max_sequence_length)
        ques_embed = self.word_embed(ques)

        # Applying LSTM on question
        _, (ques_embed, _) = self.ques_rnn(ques_embed, batch["ques_len"])

        # Applying the Linear Neural Network to get a 512 size vector
        projected_image_features = self.image_features_projection(img)

        # Image features are repeated for every round of dialogue to ensure that each question has access to the same image features.
        # repeat image feature vectors to be provided for every round
        # This block reshapes projected_image_features to have the shape (batch_size * num_rounds, num_proposals, lstm_hidden_size)

        projected_image_features = (
            projected_image_features.view(
                batch_size, 1, -1, self.config["lstm_hidden_size"]
            )
            .repeat(1, num_rounds, 1, 1)
            .view(batch_size*num_rounds, -1, self.config["lstm_hidden_size"])
        )

        # Computing Attention Weights betweeen image features & questions(including history)
        # shape: (batch_size*num_rounds, num_proposals)

        projected_ques_features = ques_embed.unsqueeze(1).repeat(
            1, img.shape[1], 1
        )
        projected_ques_image = (
            projected_ques_features * projected_image_features
        )
        projected_ques_image = self.dropout(projected_ques_image)
        image_attention_weights = self.attention_proj(
            projected_ques_image
        ).squeeze()
        image_attention_weights = F.softmax(image_attention_weights, dim=-1)


        # The image features are reshaped and repeated to match the dimensions for later element-wise multiplication with attention weights.
        # shape: (batch_size * num_rounds, num_proposals, img_features_size)

        img = (
            img.view(batch, 1, -1, self.config["img_feature_size"])
            .repeat(1, num_rounds, 1, 1)
            .view(batch_size * num_rounds, -1, self.config["img_feature_size"])
        )

        # multiply image features with their attention weights
        # shape: (batch_size * num_rounds, num_proposals, img_feature_size)

        image_attention_weights = image_attention_weights.unsqueeze(-1).repeat(
            1, 1, self.config["img_feature_size"]
        )

        attended_image_features = (image_attention_weights * img).sum(1)
        img = attended_image_features

        # embed history
        hist = hist.view(batch_size * num_rounds, max_sequence_length * 20)
        hist_embed = self.word_embed(hist)

        # Applying LSTM on hist
        _, (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"])

        # Output of Encoder Model (Combinning img, ques & hist)
        fused_vector = torch.cat((img, ques_embed, hist_embed), 1)
        fused_vector = self.dropout(fused_vector)

        fused_embedding = torch.tanh(self.fusion(fused_vector))
        fused_embedding = fused_embedding.view(batch_size, num_rounds, -1)

        return fused_embedding