In [None]:
import math
from pathlib import Path
from typing import Tuple, List
from collections import defaultdict

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import Tensor
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

from utils import read_corpus, Vocab, collate_fn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, embed_size=1024, hidden_size=768, dropout_rate=0.2):
        super().__init__()
        self.embedding = nn.Embedding(input_size, embedding_dim=embed_size)
        self.post_embed_cnn = nn.Conv1d(embed_size, embed_size, kernel_size=2, padding="same")
        self.encoder = nn.LSTM(embed_size, hidden_size, bidirectional=True)
        self.h_projection = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.c_projection = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, src: Tensor, source_lengths: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        x = self.embedding(src)  # (batch_size, src_len, embed_size)
        x = self.post_embed_cnn(x.permute(0, 2, 1))  # (batch_size, embed_size, src_len)
        x = pack_padded_sequence(
            x.permute(0, 2, 1), source_lengths, batch_first=True, enforce_sorted=False
        )  # (batch_size, src_len, embed_size)
        enc_hidden, (last_hidden, last_cell) = self.encoder(x)
        # (batch_size, src_len, 2*hidden_size), (batch_size, 2, hidden_size), (batch_size, 2, hidden_size)
        enc_hidden, _ = pad_packed_sequence(enc_hidden, batch_first=True)  # (batch_size, src_len, hidden_size*2)
        init_decoder_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)  # (batch_size, 2*hidden_size)
        init_decoder_hidden = self.h_projection(init_decoder_hidden)  # (batch_size, hidden_size)
        init_decoder_cell = torch.cat((last_cell[0], last_cell[1]), dim=1)  # (batch_size, 2*hidden_size)
        init_decoder_cell = self.c_projection(init_decoder_cell)  # (batch_size, hidden_size)
        dec_init_state = (init_decoder_hidden, init_decoder_cell)
        return enc_hidden, dec_init_state


class Decoder(nn.Module):
    def __init__(self, target_size, embed_size=1024, hidden_size=768, dropout_rate=0.2):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(target_size, embedding_dim=embed_size)
        self.decoder = nn.LSTMCell(embed_size + hidden_size, hidden_size)
        self.att_projection = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.combined_output_projection = nn.Linear(3 * hidden_size, hidden_size, bias=False)
        self.dropout = nn.Dropout(dropout_rate)
        self.target_vocab_projection = nn.Linear(hidden_size, target_size, bias=False)

    def forward(
        self,
        enc_hidden: torch.Tensor,
        enc_masks: torch.Tensor,
        dec_init_state: Tuple[torch.Tensor, torch.Tensor],
        target_padded: torch.Tensor,
    ) -> torch.Tensor:
        # Chop off the <END> token for max length sentences.
        target_padded = target_padded[:, :-1]
        # Initialize the decoder state (hidden and cell)
        dec_state = dec_init_state

        # Initialize previous combined output vector o_{t-1} as zero
        batch_size = enc_hidden.size(0)
        o_prev = torch.zeros(batch_size, self.hidden_size, device=enc_hidden.device)  # (batch_size, hidden_size)

        # Initialize a list we will use to collect the combined output o_t on each step
        combined_outputs = []
        enc_hidden_proj = self.att_projection(enc_hidden)  # (batch_size, src_len, hidden_size)
        y = self.embedding(target_padded)  # (batch_size, tgt_len, embed_size)
        for y_t in torch.split(y, split_size_or_sections=1, dim=1):
            y_t = y_t.squeeze(dim=1)  # (batch_size, embed_size)
            ybar_t = torch.cat((y_t, o_prev), dim=1)  # (batch_size, embed_size+hidden_size)
            o_t, dec_state = self.step(
                ybar_t=ybar_t,
                dec_state=dec_state,
                enc_hidden=enc_hidden,
                enc_hidden_proj=enc_hidden_proj,
                enc_masks=enc_masks,
            )
            combined_outputs.append(o_t)
            o_prev = o_t
        combined_outputs = torch.stack(combined_outputs)  # (tgt_len, batch_size, hidden_size)
        return self.target_vocab_projection(combined_outputs.permute(1, 0, 2))  # (batch_size, tgt_len, vocab_size)

    def step(
        self,
        ybar_t: torch.Tensor,
        dec_state: Tuple[torch.Tensor, torch.Tensor],
        enc_hidden: torch.Tensor,
        enc_hidden_proj: torch.Tensor,
        enc_masks: torch.Tensor,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        dec_state = self.decoder(ybar_t, dec_state)
        dec_hidden, _ = dec_state  # (batch_size, hidden_size)
        # e_t = torch.bmm(enc_hidden_proj, dec_hidden.unsqueeze(dim=2)).squeeze(dim=2)  # (batch_size, src_len)
        e_t = torch.einsum("bsh,bh->bs", enc_hidden_proj, dec_hidden)  # (batch_size, src_len)

        # Set e_t to -inf where enc_masks has 1
        if enc_masks is not None:
            e_t.data.masked_fill_(enc_masks.bool(), -float("inf"))

        alpha_t = torch.softmax(e_t, dim=1)  # (batch_size, src_len)
        # a_t = torch.bmm(alpha_t.unsqueeze(dim=1), enc_hidden).squeeze(dim=1)  # (batch_size, 2*hidden_size)
        a_t = torch.einsum("bs,bsh->bh", alpha_t, enc_hidden)  # (batch_size, 2*hidden_size)
        u_t = torch.cat((dec_hidden, a_t), dim=1)  # (batch_size, 3*hidden_size)
        v_t = self.combined_output_projection(u_t)  # (batch_size, hidden_size)
        o_t = self.dropout(F.tanh(v_t))  # (batch_size, hidden_size)
        return o_t, dec_state


def generate_sent_masks(enc_hidden: torch.Tensor, source_lengths: List[int]) -> torch.Tensor:
    # enc_masks = torch.zeros( enc_hidden.size()[:2], dtype=torch.float, device=enc_hidden.device)
    # for i, l in enumerate(source_lengths):
    #     enc_masks[i, l:] = 1
    # return enc_masks
    return torch.where(torch.arange(enc_hidden.size(1)) < torch.tensor(source_lengths).unsqueeze(1), 0, 1)


class NMT(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src: Tensor, tgt: Tensor, source_lengths: List[int]) -> Tensor:
        encoder_hidden, dec_init_state = self.encoder(src=src, source_lengths=source_lengths)
        enc_masks = generate_sent_masks(encoder_hidden, source_lengths)
        decoder_hidden = self.decoder(
            enc_hidden=encoder_hidden, enc_masks=enc_masks, dec_init_state=dec_init_state, target_padded=tgt
        )
        target_masks = (tgt != 0).float()
        probs = F.log_softmax(decoder_hidden, dim=-1)
        target_gold_words_log_prob = (
            torch.gather(probs, index=tgt[:, 1:].unsqueeze(-1), dim=-1).squeeze(-1) * target_masks[:, 1:]
        )
        return target_gold_words_log_prob.sum(dim=0)

In [None]:
input_size = 100
target_size = 200
batch_size = 32
src_len = 10
src = torch.randint(0, input_size, (batch_size, src_len))
tgt = torch.randint(0, target_size, (batch_size, src_len))
source_lengths = torch.randint(1, src_len, (batch_size,)).tolist()

encoder = Encoder(input_size)
decoder = Decoder(target_size)
nmt = NMT(encoder, decoder)
scores = nmt(src, tgt, source_lengths)

In [None]:
batch_size = 8
shuffle = True

data_path = Path("../code")

train_data_src = read_corpus(data_path.joinpath("zh_en_data", "train_debug.zh"), data_path / "src.model")
train_data_tgt = read_corpus(data_path.joinpath("zh_en_data", "train_debug.en"), data_path / "tgt.model")

valid_data_src = read_corpus(data_path.joinpath("zh_en_data", "dev.en"), data_path / "src.model")
valid_data_tgt = read_corpus(data_path.joinpath("zh_en_data", "dev.zh"), data_path / "tgt.model")

vocab = Vocab.load(data_path / "vocab.json")
print(vocab)

train_data_indices_src = vocab.src.words2indices(train_data_src)
train_data_indices_tgt = vocab.tgt.words2indices(train_data_tgt)

valid_data_indices_src = vocab.src.words2indices(valid_data_src)
valid_data_indices_tgt = vocab.tgt.words2indices(valid_data_tgt)

In [None]:
train_data = list(zip(train_data_indices_src, train_data_indices_tgt))
valid_data = list(zip(valid_data_indices_src, valid_data_indices_tgt))

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

src_sentences, tgt_sentences, _ = next(iter(train_loader))
print(f"src shape: {src_sentences.shape}, tgt shape: {tgt_sentences.shape}")

print("src and tgt sentences:")
for src_sent, tgt_sent in zip(src_sentences, tgt_sentences):
    src_sent = "".join(vocab.src.indices2words(src_sent.tolist()))
    tgt_sent = "".join(vocab.tgt.indices2words(tgt_sent.tolist()))
    print("-" * 100)
    print(src_sent.replace("<pad>", "").replace("▁", " "))
    print(tgt_sent.replace("<pad>", "").replace("▁", " ").replace("</s>", "").replace("<s>", ""))
    print("-" * 100)
    print()

In [None]:
def train_epoch(dataloader, model, optimizer):
    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor, input_lengths = data
        optimizer.zero_grad()
        loss = -model(input_tensor, target_tensor, input_lengths).sum()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)


def train(train_dataloader, model, n_epochs, learning_rate=0.001):
    num_digits = int(math.log10(n_epochs)) + 1
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    moving_loss = None
    losses = defaultdict(list)
    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, model, optimizer)
        if moving_loss is None:
            moving_loss = loss
        else:
            moving_loss = 0.95 * moving_loss + 0.05 * loss
        losses["loss"].append(loss)
        losses["moving_loss"].append(moving_loss)
        print(f"Epoch {epoch+1:0{num_digits}d}/{n_epochs} Loss: {moving_loss:.4f}")
    return losses

In [None]:
input_size = len(vocab.src)
target_size = len(vocab.tgt)
hidden_size = 128
batch_size = 32

encoder = Encoder(input_size, hidden_size=hidden_size).to(device)
decoder = Decoder(target_size, hidden_size=hidden_size).to(device)
nmt = NMT(encoder, decoder).to(device)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

losses = train(train_loader, nmt, 200)

In [None]:
_, axs = plt.subplots(nrows=len(losses.keys()), figsize=(10, 10), sharex=True)

for ax, (k, v) in zip(axs, losses.items()):
    ax.plot(v)
    ax.set_title(k)
    ax.grid()
plt.tight_layout()
plt.show()