In [None]:
from typing import Tuple, List

import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, embed_size=1024, hidden_size=768, dropout_rate=0.2):
        super().__init__()
        self.embedding = nn.Embedding(input_size, embedding_dim=embed_size)
        self.post_embed_cnn = nn.Conv1d(embed_size, embed_size, kernel_size=2, padding="same")
        self.encoder = nn.LSTM(embed_size, hidden_size, bidirectional=True)
        self.h_projection = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.c_projection = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, src: Tensor, source_lengths: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        x = self.embedding(src)  # (batch_size, src_len, embed_size)
        x = self.post_embed_cnn(x.permute(0, 2, 1))  # (batch_size, embed_size, src_len)
        x = pack_padded_sequence(
            x.permute(0, 2, 1), source_lengths, batch_first=True, enforce_sorted=False
        )  # (batch_size, src_len, embed_size)
        enc_hidden, (last_hidden, last_cell) = self.encoder(x)
        # (batch_size, src_len, 2*hidden_size), (batch_size, 2, hidden_size), (batch_size, 2, hidden_size)
        enc_hidden, _ = pad_packed_sequence(enc_hidden, batch_first=True)  # (batch_size, src_len, hidden_size*2)
        init_decoder_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)  # (batch_size, 2*hidden_size)
        init_decoder_hidden = self.h_projection(init_decoder_hidden)  # (batch_size, hidden_size)
        init_decoder_cell = torch.cat((last_cell[0], last_cell[1]), dim=1)  # (batch_size, 2*hidden_size)
        init_decoder_cell = self.c_projection(init_decoder_cell)  # (batch_size, hidden_size)
        dec_init_state = (init_decoder_hidden, init_decoder_cell)
        return enc_hidden, dec_init_state


class Decoder(nn.Module):
    def __init__(self, target_size, embed_size=1024, hidden_size=768, dropout_rate=0.2):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(target_size, embedding_dim=embed_size)
        self.decoder = nn.LSTMCell(embed_size + hidden_size, hidden_size)
        self.att_projection = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.combined_output_projection = nn.Linear(3 * hidden_size, hidden_size, bias=False)
        self.dropout = nn.Dropout(dropout_rate)
        self.target_vocab_projection = nn.Linear(hidden_size, target_size, bias=False)

    def forward(
        self,
        enc_hidden: torch.Tensor,
        enc_masks: torch.Tensor,
        dec_init_state: Tuple[torch.Tensor, torch.Tensor],
        target_padded: torch.Tensor,
    ) -> torch.Tensor:
        # Chop off the <END> token for max length sentences.
        target_padded = target_padded[:, :-1]
        # Initialize the decoder state (hidden and cell)
        dec_state = dec_init_state

        # Initialize previous combined output vector o_{t-1} as zero
        batch_size = enc_hidden.size(0)
        o_prev = torch.zeros(batch_size, self.hidden_size, device=enc_hidden.device)  # (batch_size, hidden_size)

        # Initialize a list we will use to collect the combined output o_t on each step
        combined_outputs = []
        enc_hidden_proj = self.att_projection(enc_hidden)  # (batch_size, src_len, hidden_size)
        y = self.embedding(target_padded)  # (batch_size, tgt_len, embed_size)
        for y_t in torch.split(y, split_size_or_sections=1, dim=1):
            y_t = y_t.squeeze(dim=1)  # (batch_size, embed_size)
            ybar_t = torch.cat((y_t, o_prev), dim=1)  # (batch_size, embed_size+hidden_size)
            o_t, dec_state = self.step(
                ybar_t=ybar_t,
                dec_state=dec_state,
                enc_hidden=enc_hidden,
                enc_hidden_proj=enc_hidden_proj,
                enc_masks=enc_masks,
            )
            combined_outputs.append(o_t)
            o_prev = o_t
        combined_outputs = torch.stack(combined_outputs)  # (tgt_len, batch_size, hidden_size)
        return self.target_vocab_projection(combined_outputs.permute(1, 0, 2))  # (batch_size, tgt_len, vocab_size)

    def step(
        self,
        ybar_t: torch.Tensor,
        dec_state: Tuple[torch.Tensor, torch.Tensor],
        enc_hidden: torch.Tensor,
        enc_hidden_proj: torch.Tensor,
        enc_masks: torch.Tensor,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        dec_state = self.decoder(ybar_t, dec_state)
        dec_hidden, _ = dec_state  # (batch_size, hidden_size)
        # e_t = torch.bmm(enc_hidden_proj, dec_hidden.unsqueeze(dim=2)).squeeze(dim=2)  # (batch_size, src_len)
        e_t = torch.einsum("bsh,bh->bs", enc_hidden_proj, dec_hidden)  # (batch_size, src_len)

        # Set e_t to -inf where enc_masks has 1
        if enc_masks is not None:
            e_t.data.masked_fill_(enc_masks.bool(), -float("inf"))

        alpha_t = torch.softmax(e_t, dim=1)  # (batch_size, src_len)
        # a_t = torch.bmm(alpha_t.unsqueeze(dim=1), enc_hidden).squeeze(dim=1)  # (batch_size, 2*hidden_size)
        a_t = torch.einsum("bs,bsh->bh", alpha_t, enc_hidden)  # (batch_size, 2*hidden_size)
        u_t = torch.cat((dec_hidden, a_t), dim=1)  # (batch_size, 3*hidden_size)
        v_t = self.combined_output_projection(u_t)  # (batch_size, hidden_size)
        o_t = self.dropout(F.tanh(v_t))  # (batch_size, hidden_size)
        return o_t, dec_state


def generate_sent_masks(enc_hidden: torch.Tensor, source_lengths: List[int]) -> torch.Tensor:
    # enc_masks = torch.zeros( enc_hidden.size()[:2], dtype=torch.float, device=enc_hidden.device)
    # for i, l in enumerate(source_lengths):
    #     enc_masks[i, l:] = 1
    # return enc_masks
    return torch.where(torch.arange(enc_hidden.size(1)) < torch.tensor(source_lengths).unsqueeze(1), 0, 1)


class NMT(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src: Tensor, tgt: Tensor, source_lengths: List[int]) -> Tensor:
        encoder_hidden, dec_init_state = self.encoder(src=src, source_lengths=source_lengths)
        enc_masks = generate_sent_masks(encoder_hidden, source_lengths)
        decoder_hidden = self.decoder(
            enc_hidden=encoder_hidden, enc_masks=enc_masks, dec_init_state=dec_init_state, target_padded=tgt
        )
        target_masks = (tgt != 0).float()
        probs = F.log_softmax(decoder_hidden, dim=-1)
        target_gold_words_log_prob = (
            torch.gather(probs, index=tgt[:, 1:].unsqueeze(-1), dim=-1).squeeze(-1) * target_masks[:, 1:]
        )
        return target_gold_words_log_prob.sum(dim=0)

In [None]:
input_size = 100
target_size = 200
batch_size = 32
src_len = 10
src = torch.randint(0, input_size, (batch_size, src_len))
tgt = torch.randint(0, target_size, (batch_size, src_len))
source_lengths = torch.randint(1, src_len, (batch_size,)).tolist()

encoder = Encoder(input_size)
decoder = Decoder(target_size)
nmt = NMT(encoder, decoder)
scores = nmt(src, tgt, source_lengths)