# Import Libraries

In [None]:
# !pip install -qqq pytorch_lightning torchmetrics wandb tokenizers janome jieba

In [1]:
# Import built-in Python libs
import random
import sys
import heapq
from pathlib import Path
from typing import List
from tqdm import tqdm

# Import data science libs
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

# Import deep learning libs
import pytorch_lightning as pl
import torchmetrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Import weights & bias
import wandb

# Import data preprocessing libs
from tokenizers import Tokenizer
from torch.utils.data import DataLoader

%matplotlib inline

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "False"

In [3]:
utils_path = Path.cwd().parent / "utils"
sys.path.append(str(utils_path))
from custom_tokenizer import load_jieba_tokenizer, load_janome_tokenizer

# Job Selection

In [4]:
job = 1  # 0 - sentencepiece, 1 - language specific

job_name = ["rnn_sentencepiece_ch2jp", "rnn_language_specific_ch2jp"]

tokenizer_job = ["sentencepiece", "language_specific"]
ch_tokenizer_job = ["ch_tokenizer.json", "jieba_tokenizer.json"]
jp_tokenizer_job = ["jp_tokenizer.json", "janome_tokenizer.json"]
embedding_job = ["sentencepiece_embedding", "language_specific_embedding"]

In [5]:
method = 2
method_name = ["semantic", "phonetic", "meta", "concat"]

ch_embedding_method = [
    "ch_embedding.npy",
    "chp_embedding.npy",
    "ch_meta_embedding.npy",
    "ch_concat_embedding.npy",
]

jp_embedding_method = [
    "jp_embedding.npy",
    "jpp_embedding.npy",
    "jp_meta_embedding.npy",
    "jp_concat_embedding.npy",
]


# Config and WandB

In [6]:
config = {
    "enc_emb_dim": (300 if method != 3 else 600),
    "dec_emb_dim": (300 if method != 3 else 600),
    "enc_hid_dim": 512,
    "dec_hid_dim": 512,
    "enc_dropout": 0.1,
    "dec_dropout": 0.1,
    "lr": 7e-4,
    "batch_size": 64,
    "num_workers": 1,
    "precision": 16,
}

In [7]:
run = wandb.init(
    project="phonetic-translation",
    entity="windsuzu",
    group="experiments",
    job_type=job_name[job] + "-" + method_name[method],
    config=config,
    reinit=True,
)

[34m[1mwandb[0m: Currently logged in as: [33mwindsuzu[0m (use `wandb login --relogin` to force relogin)


# Download Datasets, Tokenizers, Embedding, DataModule

## Raw Data

In [8]:
train_data_art = run.use_artifact("sampled_train:latest")
train_data_dir = train_data_art.download()

dev_data_art = run.use_artifact("dev:latest")
dev_data_dir = dev_data_art.download()

test_data_art = run.use_artifact("test:latest")
test_data_dir = test_data_art.download()

data_dir = {
    "train": train_data_dir,
    "dev": dev_data_dir,
    "test": test_data_dir,
}

## Tokenizer

In [9]:
tokenizer_art = run.use_artifact(f"{tokenizer_job[job]}:latest")
tokenizer_dir = tokenizer_art.download()

src_tokenizer_dir = Path(tokenizer_dir) / ch_tokenizer_job[job]
trg_tokenizer_dir = Path(tokenizer_dir) / jp_tokenizer_job[job]

## Pretrained Embedding

How to load pretrained embedding ?

> https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding.from_pretrained

In [10]:
embedding_art = run.use_artifact(f"{embedding_job[job]}:latest")
embedding_dir = embedding_art.download()

ch_embedding_dir = Path(embedding_dir) / ch_embedding_method[method]
jp_embedding_dir = Path(embedding_dir) / jp_embedding_method[method]

[34m[1mwandb[0m: Downloading large artifact language_specific_embedding:latest, 732.42MB. 8 files... Done. 0:0:0


In [11]:
src_embedding = np.load(Path(ch_embedding_dir))
trg_embedding = np.load(Path(jp_embedding_dir))

src_embedding = torch.FloatTensor(src_embedding)
trg_embedding = torch.FloatTensor(trg_embedding)

In [12]:
print(src_embedding.shape)
print(trg_embedding.shape)

torch.Size([32000, 300])
torch.Size([32000, 300])


## DataModule

In [13]:
class SentencePieceDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir,
        src_tokenizer_dir,
        trg_tokenizer_dir,
        batch_size=128,
        num_workers=8,
        pin_memory=True,
        job=0,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.src_tokenizer_dir = src_tokenizer_dir
        self.trg_tokenizer_dir = trg_tokenizer_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.job = job

    def setup(self, stage=None):
        self.src_tokenizer = self._load_tokenizer(self.src_tokenizer_dir)
        self.trg_tokenizer = self._load_tokenizer(self.trg_tokenizer_dir)

        if stage == "fit":
            self.train_set = self._data_preprocess(self.data_dir["train"])
            self.val_set = self._data_preprocess(self.data_dir["dev"])

        if stage == "test":
            self.test_set = self._data_preprocess(self.data_dir["test"])

    def train_dataloader(self):
        return DataLoader(
            self.train_set,
            self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self._data_batching_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_set,
            self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self._data_batching_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_set,
            self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=self._data_batching_fn,
        )

    def _read_data_array(self, data_dir):
        with open(data_dir, encoding="utf8") as f:
            arr = f.readlines()
        return arr

    def _load_tokenizer(self, tokenizer_dir, lang="ch"):
        if self.job == 0:
            return Tokenizer.from_file(str(tokenizer_dir))
        else:
            return (
                load_jieba_tokenizer(tokenizer_dir)
                if "jieba" in str(tokenizer_dir)
                else load_janome_tokenizer(tokenizer_dir)
            )

    def _data_preprocess(self, data_dir):
        src_txt = self._read_data_array(Path(data_dir) / "ch.txt")
        trg_txt = self._read_data_array(Path(data_dir) / "jp.txt")
        parallel_txt = np.array(list(zip(src_txt, trg_txt)))
        return parallel_txt

    def _data_batching_fn(self, data_batch):
        data_batch = np.array(data_batch)  # shape=(batch_size, 2=src+trg)

        src_batch = data_batch[:, 0]  # shape=(batch_size, )
        trg_batch = data_batch[:, 1]  # shape=(batch_size, )

        # src_batch=(batch_size, longest_sentence)
        # trg_batch=(batch_size, longest_sentence)
        src_batch = self.src_tokenizer.encode_batch(src_batch)
        trg_batch = self.trg_tokenizer.encode_batch(trg_batch)

        # We have to sort the batch by their non-padded lengths in descending order,
        # because the descending order can help in `nn.utils.rnn.pack_padded_sequence()`,
        # which it will help us ignoring the <pad> in training rnn.
        # https://meetonfriday.com/posts/4d6a906a
        src_batch, trg_batch = zip(
            *sorted(
                zip(src_batch, trg_batch),
                key=lambda x: sum(x[0].attention_mask),
                reverse=True,
            )
        )

        return src_batch, trg_batch

In [14]:
dm = SentencePieceDataModule(
    data_dir,
    src_tokenizer_dir,
    trg_tokenizer_dir,
    config["batch_size"],
    config["num_workers"],
    job=job,
)

### Test DataModule

In [15]:
dm.setup("test")

In [16]:
input_dim = dm.src_tokenizer.get_vocab_size()
output_dim = dm.trg_tokenizer.get_vocab_size()
print(input_dim, output_dim)

src_pad_idx = dm.src_tokenizer.token_to_id("[PAD]")
print(src_pad_idx)

32000 32000
3


In [18]:
for src, trg in dm.test_dataloader():
    print(len(src), src[0], src[0].tokens)
    print(len(trg), trg[0], trg[0].tokens)
    break

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.604 seconds.
Prefix dict has been built successfully.


64 Encoding(num_tokens=105, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]) ['[BOS]', '日本', '从', '世界', '各国', '进口', '酱油', '等', '的', '各种', '调', '料', ',', '但', '公', '知', '的', '是', '在', '制造', '过程', '中', '致癌性', '物质', '的', '氯', '丙', '醇', '类', '作为', '副产品', '生成', '3', '-', '氯', '-', '1', ',', '2', '-', '丙', '二醇', '(', '3', '-', 'MCP', 'D', ')', '、', '1', ',', '3', '-', '二氯', '-', '2', '-', '丙', '醇', '(', '1', ',', '3', '-', 'D', 'CP', ')', '、', '2', '-', '氯', '-', '1', ',', '3', '-', '丙', '二醇', '(', '2', '-', 'MCP', 'D', ')', '以及', '2', ',', '3', '-', '二氯', '-', '1', '丙', '醇', '(', '2', ',', '3', '-', 'D', 'CP', ')', '。', '\n', '[EOS]']
64 Encoding(num_tokens=114, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]) ['[BOS]', '日本', 'へ', 'は', '世界', '各国', 'から', '醤油', 'など', 'の', '種々', 'の', '調味', '料', 'が', '輸入', 'さ', 'れ', 'て', 'いる', 'が', ',', '製造', '過程', '中', 'で', '発癌', '性', '物質', 'の', 'クロロ', 'プロパノール', '類', 'が

# Build Lightning Model

## Encoder

In [19]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(src_embedding)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_len):
        # src     = [batch_size, src_len]
        # src_len = [batch_size]

        embedded = self.dropout(self.embedding(src))
        # embedded = [batch_size, src_len, emb_dim]
        
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len.to("cpu"), batch_first=True)
        packed_outputs, hidden = self.rnn(packed_embedded)
        # packed_outputs is a packed sequence containing all hidden states
        # hidden is now from the final non-padded element in the batch

        enc_outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True)
        # enc_outputs is now a non-packed sequence

        # enc_outputs = [batch_size, src_len, enc_hid_dim*num_directions]
        #             = [forward_n + backward_n]
        #             = [last layer]

        # hidden  = [n_layers*num_directions, batch_size, enc_hid_dim]
        #         = [forward_1, backward_1, forward_2, backword_2, ...]

        # hidden[-2, :, : ] is the last of the forwards RNN
        # hidden[-1, :, : ] is the last of the backwards RNN

        last_hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
        init_dec_hidden = torch.tanh(self.fc(last_hidden))

        # enc_outputs     = [batch_size, src_len, enc_hid_dim*2]  (we only have 1 layer)
        # init_dec_hidden = [batch_size, dec_hid_dim]

        return enc_outputs, init_dec_hidden

## Attention

In [20]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs, mask):

        # hidden = [batch_size, dec_hid_dim]
        # encoder_outputs = [batch_size, src_len, enc_hid_dim * 2]

        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        # hidden = [batch_size, 1, dec_hid_dim]        (unsqueeze 1)
        #        = [batch_size, src_len, dec_hid_dim]  (repeat)
        
        stacked_hidden = torch.cat((hidden, encoder_outputs), dim=2)
        # stacked_hidden = [batch_size, src_len, dec_hid_dim + enc_hid_dim * 2]

        energy = torch.tanh(self.attn(stacked_hidden))
        # energy = [batch_size, src_len, dec_hid_dim]

        attention = self.v(energy).squeeze(2)
        # attention = [batch_size, src_len, 1]   (v)
        #           = [batch_size, src_len]      (squeeze)

        attention = attention.masked_fill(mask == 0, -(2**15))

        return F.softmax(attention, dim=1)

## Decoder

In [21]:
class Decoder(nn.Module):
    def __init__(
        self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention
    ):
        super().__init__()

        self.attention = attention
        self.embedding = nn.Embedding.from_pretrained(trg_embedding)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim, batch_first=True)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inp, hidden, encoder_outputs, mask):
        # encoder_outputs = [batch_size, src_len, enc_hid_dim*2]
        # hidden = [batch_size, dec_hid_dim]

        inp = inp.unsqueeze(1)
        # inp = [batch_size]
        #     = [batch_size, 1]  (unsqueeze 1)

        # embedded = [batch_size, 1, emb_dim]
        embedded = self.dropout(self.embedding(inp))
        
        # a = [batch_size, src_len]
        #   = [batch_size, 1, src_len]  (unsqueeze 1)
        a = self.attention(hidden, encoder_outputs, mask)
        a = a.unsqueeze(1)
        
        # weighted = [batch_size, 1, enc_hid_dim*2]
        weighted = torch.bmm(a, encoder_outputs)

        # rnn_input = [batch_size, 1, emb_dim + enc_hid_dim*2]
        rnn_input = torch.cat((embedded, weighted), dim=2)
        
        # hidden = [1, batch_size, dec_hid_dim]  (unsqueeze 0)
        hidden = hidden.unsqueeze(0)
        
        # output = [batch_size, 1, dec_hid_dim]
        # hidden = [1, batch_size, dec_hid_dim]
        output, hidden = self.rnn(rnn_input, hidden)
        
        # embedded = [batch_size, emb_dim]        (squeeze 0)
        # output   = [batch_size, dec_hid_dim]    (squeeze 0)
        # weighted = [batch_size, enc_hid_dim*2]  (squeeze 0)
        # hidden = [batch_size, dec_hid_dim]      (squeeze 0)
        embedded = embedded.squeeze(1)
        output = output.squeeze(1)
        weighted = weighted.squeeze(1)
        hidden = hidden.squeeze(0)
        
        assert (output == hidden).all()
        
        predict_input = torch.cat((output, weighted, embedded), dim=1)

        # prediction = [batch_size, output_dim]
        prediction = self.fc_out(predict_input)

        # a = [batch_size, src_len]  (squeeze 1)
        a = a.squeeze(1)

        return prediction, hidden, a

## Full Seq2Seq Model

In [22]:
class Seq2SeqModel(pl.LightningModule):
    def __init__(self, input_dim, output_dim, trg_tokenizer, config):
        super().__init__()
        self.trg_tokenizer = trg_tokenizer
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.encoder = Encoder(
            input_dim,
            config["enc_emb_dim"],
            config["enc_hid_dim"],
            config["dec_hid_dim"],
            config["enc_dropout"],
        )

        attn = Attention(config["enc_hid_dim"], config["dec_hid_dim"])

        self.decoder = Decoder(
            output_dim,
            config["dec_emb_dim"],
            config["enc_hid_dim"],
            config["dec_hid_dim"],
            config["dec_dropout"],
            attn,
        )

        self.lr = config["lr"]
        self.apply(self.init_weights)
    
    
    def init_weights(self, m):
        for name, param in m.named_parameters():
            if 'weight' in name:
                nn.init.normal_(param.data, mean=0, std=0.01)
            else:
                nn.init.constant_(param.data, 0)
    
    
    # Training
    # Use only when training and validation
    def _forward(self, src, trg, teacher_forcing_ratio=0.5):
        # teacher_forcing_ratio is probability to use teacher forcing
        # e.g., if teacher_forcing_ratio is 0.5 we use teacher forcing 50% of the time

        # src = list of Encoding([ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
        # trg = list of Encoding([ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

        # src_batch = [batch_size, src_len]
        # src_mask  = [batch_size, src_len]
        # src_len   = [batch_size]
        src_batch = torch.tensor([e.ids for e in src], device=self.device)
        src_mask = torch.tensor([e.attention_mask for e in src], device=self.device)
        src_len = torch.sum(src_mask, axis=1)

        # trg_batch = [batch_size, trg_len]
        trg_batch = torch.tensor([e.ids for e in trg], device=self.device)

        batch_size = src_batch.shape[0]
        trg_len = trg_batch.shape[1]
        trg_vocab_size = self.output_dim

        # create a tensor for storing all decoder outputs
        preds = torch.zeros(batch_size, trg_len, trg_vocab_size, device=self.device)

        # encoder_outputs is all hidden states of the input sequence, back and forwards
        # hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src_batch, src_len)
        
        # first input to the decoder = [BOS] tokens
        # inp = [batch_size]
        inp = trg_batch[:, 0]

        for t in range(1, trg_len):
            # pred   = [batch_size, output_dim]
            # hidden = [batch_size, dec_hid_dim]
            pred, hidden, _ = self.decoder(inp, hidden, encoder_outputs, src_mask)

            # store predictions in a tensor holding predictions for each token
            preds[:, t, :] = pred
            
            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio

            # top1 = [batch_size]
            # get the highest predicted token from our predictions
            top1 = pred.argmax(1)

            # inp = [batch_size]
            # if teacher forcing, use actual next token as next input
            # if not, use predicted token
            inp = trg_batch[:, t] if teacher_force else top1
        
        return preds
    
    
    # Inference
    # * Let you use the pl model as a pytorch model.
    # * 
    # * pl_model.eval()
    # * pl_model(X)
    # *
    def forward(self, src, max_len=200):
        src_batch = torch.tensor([e.ids for e in src], device=self.device)
        src_mask = torch.tensor([e.attention_mask for e in src], device=self.device)
        src_len = torch.sum(src_mask, axis=1)  # actual src_len without [PAD]
        
        batch_size = src_batch.shape[0]
        src_size = src_batch.shape[1]  # src_len with [PAD]
        trg_len = max_len
        trg_vocab_size = self.output_dim
        
        preds = torch.zeros(batch_size, trg_len, trg_vocab_size, device=self.device)
        encoder_outputs, hidden = self.encoder(src_batch, src_len)
        
        # create a tensor for storing all attention matrices
        attns = torch.zeros(batch_size, trg_len, src_size, device=self.device)
        
        # first input to the decoder = [BOS] tokens
        # inp = [batch_size]
        inp = torch.tensor([self.trg_tokenizer.token_to_id("[BOS]")], device=self.device).repeat(batch_size)
        
        for t in range(1, trg_len):
            
            # attn = [batch_size, src_len]
            pred, hidden, attn = self.decoder(inp, hidden, encoder_outputs, src_mask)
            
            preds[:, t, :] = pred
            top1 = pred.argmax(1)
            inp = top1
            
            # store attention sequences in a tensor holding attention value for each token
            attns[:, t, :] = attn
            
        return preds, attns, src_len


    def training_step(self, batch, batch_idx):
        # both are lists of encodings
        src, trg = batch
        
        # y    = [batch_size, trg_len]
        # pred = [batch_size, trg_len, output_dim]
        y = torch.tensor([e.ids for e in trg], device=self.device)
        preds = self._forward(src, trg)
        output_dim = preds.shape[-1]
        
        # y    = [batch_size * (trg_len-1)]
        # pred = [batch_size * (trg_len-1), output_dim]
        y = y[:, 1:].reshape(-1)
        preds = preds[:, 1:, :].reshape(-1, output_dim)
        
        loss = F.cross_entropy(preds, y, ignore_index=self.trg_tokenizer.token_to_id("[PAD]"))
        self.log("train_loss", loss)

        perplexity = torch.exp(loss)
        self.log("train_ppl", perplexity)
        
        if self.global_step % 50 == 0:
            torch.cuda.empty_cache()
            
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        src, trg = batch
        y = torch.tensor([e.ids for e in trg], device=self.device)
        preds = self._forward(src, trg)
        
        output_dim = preds.shape[-1]
        y = y[:, 1:].reshape(-1)
        preds = preds[:, 1:, :].reshape(-1, output_dim)
        
        loss = F.cross_entropy(preds, y, ignore_index=self.trg_tokenizer.token_to_id("[PAD]"))
        self.log("valid_loss", loss, sync_dist=True)
        
        perplexity = torch.exp(loss)
        self.log("valid_ppl", perplexity, sync_dist=True)
        
        
    def test_step(self, batch, batch_idx):
        src, trg = batch
        preds, attn_matrix, real_src_len = self(src)
        
        # attn_matrix = [batch_size, trg_len, src_len]
        # preds       = [batch_size, trg_len, output_dim]
        #             = [batch_size, trg_len]             (argmax 2)
        preds = preds.argmax(2)
        
        # convert `preds` tensor to list of real sentences (tokens)
        # meaning to cut the sentence by [EOS] and remove the [PAD] tokens
        
        # eos_pos = dict(sentence_idx: first_pad_position)
        #
        # e.g., {0: 32, 2: 55} 
        # Meaning that we have 32 tokens (include [EOS]) in the first predicted sentence
        # and `max_len` tokens (no [EOS]) in the second predicted setence
        # and 55 tokens (include [EOS]) in the third predicted sentence
        eos_pos = dict(reversed((preds == self.trg_tokenizer.token_to_id("[EOS]")).nonzero().tolist()))
    
        pred_sentences, attn_matrices = [], []
        for idx, (sentence, attention, src_len) in enumerate(zip(preds, attn_matrix, real_src_len)):
            
            # sentence  = [trg_len_with_pad]
            #           = [real_trg_len]
            pred_sentences.append(sentence[:eos_pos.get(idx)+1 if eos_pos.get(idx) else None])
            
            # attention = [trg_len_with_pad, src_len_with_pad]
            #           = [real_trg_len, real_src_len]
            attn_matrices.append(attention[:eos_pos.get(idx)+1 if eos_pos.get(idx) else None, :src_len])
        
        # source sentences for displaying attention matrix 
        src = [[token for token in e.tokens if token != "[PAD]"] for e in src]
        
        # target sentences for calculating BLEU scores
        trg = [[token for token in e.tokens if token != "[PAD]"] for e in trg]
        
        return pred_sentences, attn_matrices, src, trg
        
    
    def test_epoch_end(self, test_outputs):
        outputs = []
        for (pred_sent_list, attn_list, src_list, trg_list) in test_outputs:
            for pred_sent, attn, src, trg in list(zip(pred_sent_list, attn_list, src_list, trg_list)):
                pred_sent = list(map(self.trg_tokenizer.id_to_token, pred_sent))
                outputs.append((pred_sent, attn, src, trg))
        
        # outputs = list of predictions of testsets, each has a tuple of (pred_sentence, attn_matrix, src_sentence, trg_sentence)
        # pred_sentence = [trg_len]
        # attn_matrix   = [trg_len, src_len]
        # src_sentence  = [src_len]
        # trg_sentence  = [trg_len]
        self.test_outputs = outputs

    def configure_optimizers(self):
        return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr)
    
    
    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
        optimizer.zero_grad(set_to_none=True)
        

In [23]:
class EmbeddingFineTuning(pl.callbacks.BaseFinetuning):
    def __init__(self, unfreeze_at_epoch=2):
        super().__init__()
        self._unfreeze_at_epoch = unfreeze_at_epoch

    def freeze_before_training(self, pl_module):
        # freeze any module you want
        self.freeze(pl_module.encoder.embedding)
        self.freeze(pl_module.decoder.embedding)

    def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
        # When `current_epoch` is hit, embedding will start training.
        if current_epoch == self._unfreeze_at_epoch:
            self.unfreeze_and_add_param_group(
                modules=[
                    pl_module.encoder.embedding,
                    pl_module.decoder.embedding,
                ],
                optimizer=optimizer,
            )
            
embedding_finetune = EmbeddingFineTuning(unfreeze_at_epoch=1)

In [24]:
wandb_logger = pl.loggers.WandbLogger()

model = Seq2SeqModel(
    input_dim,
    output_dim,
    dm.trg_tokenizer,
    config,
)

In [25]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')
model

The model has 65,420,032 trainable parameters


Seq2SeqModel(
  (encoder): Encoder(
    (embedding): Embedding(32000, 300)
    (rnn): GRU(300, 512, batch_first=True, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(32000, 300)
    (rnn): GRU(1324, 512, batch_first=True)
    (fc_out): Linear(in_features=1836, out_features=32000, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

# Training

In [26]:
ckpt_dir = Path("checkpoints")
checkpoint = pl.callbacks.ModelCheckpoint(dirpath=ckpt_dir,  # path for saving checkpoints
                                          filename=f"{job_name[job]}-{method_name[method]}-" + "{epoch}-{valid_loss:.3f}",
                                          monitor="valid_loss",
                                          mode="min",
                                          save_weights_only=True,
                                          save_top_k=20,
                                         )



In [27]:
trainer = pl.Trainer(
    fast_dev_run=False,
    logger=wandb_logger,
    gpus=1,
    max_epochs=20,
    gradient_clip_val=1,
    precision=config["precision"],
    callbacks=[checkpoint, embedding_finetune],
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [None]:
trainer.fit(model, datamodule=dm)

# Testing (BLEU Scores)

In [28]:
# model.load_state_dict(torch.load(ckpt_dir/"transformer_sentencepiece_ch2jp-phonetic-epoch=5-valid_loss=5.579.ckpt")["state_dict"])

In [29]:
trainer.test(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 1.002 seconds.
Prefix dict has been built successfully.



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


[{}]

In [30]:
def calculate_corpus_bleu(preds: List[str], refs: List[List[str]], n_gram=4):
    # arg example:
    # preds: ["机器人行业在环境问题上的措施", "松下生产科技公司也以环境先进企业为目标"]
    # refs: [["机器人在环境上的改变", "對於机器人在环境上的措施"],  ["松下科技公司的首要目标是解决环境问题"]]
    preds = list(map(list, preds))
    refs = [[list(sen) for sen in ref] for ref in refs]
    return torchmetrics.functional.nlp.bleu_score(preds, refs, n_gram=n_gram)

In [31]:
preds = [dm.trg_tokenizer.decode(list(map(dm.trg_tokenizer.token_to_id, output[0]))) for output in model.test_outputs]
refs = [[dm.trg_tokenizer.decode(list(map(dm.trg_tokenizer.token_to_id, output[3])))] for output in model.test_outputs]

bleu_score = calculate_corpus_bleu(preds, refs, n_gram=4)
print(bleu_score)

tensor(0.2318)


# Case Study and Attention Matrix

In [38]:
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK TC']
plt.rcParams['axes.unicode_minus'] = False

def sentence_bleu(pred_token, trg_token):
    trg  = dm.trg_tokenizer.decode(list(map(dm.trg_tokenizer.token_to_id, trg_token)))
    pred = dm.trg_tokenizer.decode(list(map(dm.trg_tokenizer.token_to_id, pred_token)))
    return calculate_corpus_bleu([trg], [[pred]])


def case_study(pred_token, src_token, trg_token, attn_matrix):
    src  = dm.src_tokenizer.decode(list(map(dm.src_tokenizer.token_to_id, src_token)))
    trg  = dm.trg_tokenizer.decode(list(map(dm.trg_tokenizer.token_to_id, trg_token)))
    pred = dm.trg_tokenizer.decode(list(map(dm.trg_tokenizer.token_to_id, pred_token)))
    bleu = calculate_corpus_bleu([trg], [[pred]])
    
    print(f"SOURCE: \n{src}\n {'-'*100}")
    print(f"TARGET: \n{trg}\n {'-'*100}")
    print(f"PREDICTION: \n{pred}\n {'-'*100}")
    print(f"BLEU SCORE: {bleu}")
    
    plt.figure(figsize=(30, 30))
    for i in range(6):
        plt.subplot(3, 2, i+1)
        ax = sns.heatmap(attn_matrix[i], xticklabels=src_token, yticklabels=pred_token)
        ax.xaxis.set_ticks_position('top')

In [None]:
scores = []

for i in tqdm(range(len(model.test_outputs))):
    bleu = sentence_bleu(
        model.test_outputs[i][0],
        model.test_outputs[i][3],
    )
    scores.append(bleu)

In [None]:
heapq.nlargest(50, range(len(scores)), np.array(scores).take)[10:20]

In [None]:
i=1975
case_study(model.test_outputs[i][0],
           model.test_outputs[i][2],
           model.test_outputs[i][3],
           model.test_outputs[i][1].cpu().numpy())