In [3]:
import torch
from torch import nn

import import_ipynb
from utils.dynamic_rnn import DynamicRNN

In [2]:
class DiscriminativeDecoder(nn.Module):
    def __init__(self, config, vocabulary):
        super().__init__()
        self.config = config
        
        # To embed the answer options
        self.word_embed = nn.Embedding(
            len(vocabulary),
            config["word_embedding_size"],
            padding_idx=vocabulary.PAD_INDEX,
        )

        # Applying LSTM on answer options
        self.option_rnn = nn.LSTM(
            config["word_embedding_size"],
            config["lstm_hidden_size"],
            config["lstm_num_layers"],
            batch_first=True,
            dropout = config["dropout"],
        )

        # Options are variable length padded sequences, use DynamicRNN.
        # This is to make the answer options representation identical cause initially they are variable in length 
        self.options_rnn = DynamicRNN(self.option_rnn)

    def forward(self, encoder_output, batch):
        # Input Given (encoder's output & candidate answer options) 
        # Here we need to predict a score for each answer options

        # Extracting candidate answer options
        options = batch["opt"]
        batch_size, num_rounds, num_options, max_sequence_length = (
            options.size()
        )

        # Converting it to 2-D tensor for better embedding
        options = options.view(
            batch_size * num_rounds * num_options, max_sequence_length
        )

        # To handle sequences of variable length
        options_length = batch["opt_len"]
        # Options length converted to 1-D tensor
        options_length = options_length.view(
            batch_size * num_rounds * num_options
        )

        # 
        nonzero_options_length_indices = options_length.nonzero().squeeze()
        nonzero_options_length = options_length[nonzero_options_length_indices]
        nonzero_options = options[nonzero_options_length_indices]

        nonzero_options_embed = self.word_embed(nonzero_options)

        _, (nonzero_options_embed, _) = self.option_rnn(
            nonzero_options_embed, nonzero_options_length
        )

        options_embed = torch.zeros(
            batch_size * num_rounds * num_options,
            nonzero_options_embed(-1),
            device=nonzero_options_embed.device,
        )

        encoder_output = encoder_output.view(
            batch_size * num_rounds * num_options,
            self.config["lstm_hidden_size"],
        )

        scores = torch.sum(options_embed * encoder_output, 1)
        scores = scores.view(batch_size, num_rounds, num_options)

        return scores


