In [7]:
import torch
import torch.nn as nn
import math
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.tensorboard import SummaryWriter
import tqdm
import pandas as pd
import altair as alt
import numpy as np

In [28]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

In [56]:
class PositionalEncodings(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)


In [30]:
class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        mean = x.mean(dim = -1, keepdim=True)
        std = x.std(dim = -1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

In [31]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear_2(self.dropout(nn.ReLU(self.linear_1(x))))

In [59]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float):
        super().__init__()
        self.h = h
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        assert d_model % h == 0, "d_model must be divisible by h"

        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
    
    @staticmethod
    def attention(q, k, v, mask, dropout: nn.Dropout):
        d_k = q.shape[-1]

        att_score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            att_score = att_score.masked_fill_(mask == 0, -1e-9)
        att_score = att_score.softmax(dim = -1)

        if dropout is not None:
            att_score = dropout(att_score)
        return att_score.matmul(v), att_score

    def forward(self, q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)

        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        x, att_score = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        return self.w_o(x)


In [33]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [34]:
class EncoderBlock(nn.Module):
    def __init__(self, self_att_block: MultiHeadAttentionBlock, feed_frwd_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_att_block = self_att_block
        self.feed_frwd_block = feed_frwd_block
        self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connection[0](x, lambda x: self.self_att_block(x, x, x, src_mask))
        x = self.residual_connection[1](x, self.feed_frwd_block)
        return x



In [35]:
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [36]:
class DecoderBlock(nn.Module):
    def __init__(self, 
                 self_att_block: MultiHeadAttentionBlock, 
                 cross_att_block: MultiHeadAttentionBlock, 
                 feed_frwd_block: FeedForwardBlock, 
                 dropout: float) -> None:
        super().__init__()
        self.self_att_block = self_att_block
        self.cross_att_block = cross_att_block
        self.feed_frwd_block = feed_frwd_block
        self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

    def forward(self, x, enc_output, src_mask, tgt_mask):
        x = self.residual_connection[0](x, lambda x: self.self_att_block(x, x, x, tgt_mask))
        x = self.residual_connection[1](x, lambda x: self.cross_att_block(x, enc_output, enc_output, src_mask))
        x = self.residual_connection[2](x, self.feed_frwd_block)
        return x


In [37]:
class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
    
    def forward(self, x, enc_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return self.norm(x)


In [38]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim = -1)

In [39]:
class Transformer(nn.Module):
    def __init__(
            self, 
            encoder: Encoder, 
            decoder: Decoder, 
            src_embed: InputEmbeddings, 
            tgt_embed: InputEmbeddings, 
            src_pos: PositionalEncodings, 
            tgt_pos: PositionalEncodings,
            proj_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.proj_layer = proj_layer
    
    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)

    def decode(self, enc_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, enc_output, src_mask, tgt_mask)
    
    def project(self, x):
        return self.proj_layer(x)

In [40]:
def build_transformer(
        src_vocab_size: int, 
        tgt_vocab_size: int, 
        src_seq_len: int, 
        tgt_seq_len: int,
        d_model: int = 512,
        N: int = 6,
        h: int = 8,
        dropout: float = 0.1,
        d_ff: int = 2048) -> Transformer:
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
    src_pos = PositionalEncodings(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncodings(d_model, tgt_seq_len, dropout)

    encoder_blocks = []
    for _ in range(N):
        encoder_blocks.append(
            EncoderBlock(
                MultiHeadAttentionBlock(d_model, h, dropout), 
                FeedForwardBlock(d_model, d_ff, dropout), 
                dropout
            )
        )
    encoder = Encoder(nn.ModuleList(encoder_blocks))

    decoder_blocks = []
    for _ in range(N):
        decoder_blocks.append(
            DecoderBlock(
                MultiHeadAttentionBlock(d_model, h, dropout), 
                MultiHeadAttentionBlock(d_model, h, dropout), 
                FeedForwardBlock(d_model, d_ff, dropout), 
                dropout
            )
        )
    decoder = Decoder(nn.ModuleList(decoder_blocks))

    proj_layer = ProjectionLayer(d_model, tgt_vocab_size)

    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, proj_layer)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer

In [42]:
def get_all_sentences(dataset, lang):
    for item in dataset:
        yield item["translation"][lang]

In [43]:
def get_or_build_tokenizer(config, dataset, lang):
    tokenizer_path = Path()(config["tokenizer_file"].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(dataset, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

In [54]:
class BilingualDataset(nn.Module):
    def __init__(self, dataset, tokenizer_src, tokenizer_tgt, lang_src, lang_tgt, seq_len):
        super().__init__()
        self.seq_len = seq_len
        self.dataset = dataset
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.lang_src = lang_src
        self.lang_tgt = lang_tgt

        self.sos_token = torch.tensor([tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_src.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        src = torch.tensor(self.tokenizer_src.encode(item["translation"][self.lang_src]).ids, dtype=torch.int64)
        tgt = torch.tensor(self.tokenizer_tgt.encode(item["translation"][self.lang_tgt]).ids, dtype=torch.int64)
        
        enc_num_pad_tokens = self.seq_len - src.shape[0] - 2
        dec_num_pad_tokens = self.seq_len - tgt.shape[0] - 1

        if enc_num_pad_tokens < 0 or dec_num_pad_tokens < 0:
            raise ValueError("Sentence too long")
        
        encoder_input = torch.cat([
            self.sos_token,
            src,
            self.eos_token,
            self.pad_token.repeat(enc_num_pad_tokens)
        ])
        decoder_input = torch.cat([
            self.sos_token,
            tgt,
            self.pad_token.repeat(dec_num_pad_tokens)
        ])
        label = torch.cat([
            tgt,
            self.eos_token,
            self.pad_token.repeat(dec_num_pad_tokens)
        ])

        assert encoder_input.shape[0] == self.seq_len
        assert decoder_input.shape[0] == self.seq_len
        assert label.shape[0] == self.seq_len

        return {
            "encoder_input": encoder_input, 
            "decoder_input": decoder_input,
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & 
            (torch.triu(torch.ones(1, self.seq_len, self.seq_len), diagonal=1) == 0).int(),
            "label": label,
            "src_text": item["translation"][self.lang_src],
            "tgt_text": item["translation"][self.lang_tgt]
            }


In [55]:
def get_dataset(config):
    dataset_raw = load_dataset("Helsinki-NLP/opus-100", f"{config['lang_src']}-{config['lang_tgt']}", split="train")

    tokenizer_src = get_or_build_tokenizer(config, dataset_raw, config["lang_src"])
    tokenizer_tgt = get_or_build_tokenizer(config, dataset_raw, config["lang_tgt"])

    train_ds_size = int(0.9 * len(dataset_raw))
    val_ds_size = len(dataset_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(dataset_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config["lang_src"], config["lang_tgt"], config["seq_len"])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config["lang_src"], config["lang_tgt"], config["seq_len"])

    max_len_src = 0
    max_len_tgt = 0
    for item in dataset_raw:
        max_len_src = max(max_len_src, len(tokenizer_src.encode(item["translation"][config["lang_src"]]).ids))
        max_len_tgt = max(max_len_tgt, len(tokenizer_tgt.encode(item["translation"][config["lang_tgt"]]).ids))

    print(f"Max length src: {max_len_src}")
    print(f"Max length tgt: {max_len_tgt}")

    train_dataloader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=False)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt


In [46]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config["seq_len"])
    return model

In [3]:
def get_config():
    return {
        "lang_src": "en",
        "lang_tgt": "id",
        "batch_size": 8,
        "seq_len": 600,
        "tokenizer_file": "tokenizer_{0}.json",
        "num_epochs": 20,
        "lr": 0.0001,
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": None,
        "experiment_name": "runs/tmodel",
        "d_model": 512
    }

In [49]:
def get_weights_file_path(config, epochs):
    model_folder = config["model_folder"]
    model_basename = config["model_basename"]
    model_filename = f"{model_basename}{epochs}.pt"
    return str(Path(".") / model_folder / model_filename)

In [1]:
def greedy_decode(model, src, src_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id("[SOS]")
    eos_idx = tokenizer_tgt.token_to_id("[EOS]")

    encoder_output = model.encode(src, src_mask)
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(src).to(device)
    while True:
        if decoder_input.shape[1] >= max_len:
            break
        decoder_mask = (torch.triu(torch.ones(1, decoder_input.shape[1], decoder_input.shape[1]), diagonal=1) == 0).type_as(src_mask).to(device)
        decoder_output = model.decode(encoder_output, src_mask, decoder_input, decoder_mask)
        prob = model.project(decoder_output[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(src).fill_(next_word.item()).to(device)], dim=1)

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

In [2]:
def run_validation(
        model, 
        val_dataset, 
        tokenizer_src, 
        tokenizer_tgt, 
        max_len, 
        device, 
        print_msg, 
        global_state,
        writer,
        num_examples = 2):
    model.eval()
    count = 0

    console_width = 80
    with torch.no_grad():
        for item in val_dataset:
            count+=1
            encoder_input = item["encoder_input"].to(device)
            encoder_mask = item["encoder_mask"].to(device)
            
            assert encoder_input.shape[0] == 1, "Batch size must be 1 for validation"
            model_output = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
            
            src_text = item["src_text"][0]
            expected = item["tgt_text"][0]
            predicted = tokenizer_tgt.decode(model_output.detach().cpu().numpy())

            print_msg("-" * console_width)
            print_msg(f"Source: {src_text}")
            print_msg(f"Expected: {expected}")
            print_msg(f"Predicted: {predicted}")

            if count == num_examples:
                break

In [60]:
def train_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device {device}")

    Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_dataset(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=1e-9)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1).to(device)

    writer = SummaryWriter(config["experiment_name"])

    initial_epoch = 0
    global_step = 0
    if config["preload"] is not None:
        model_filename = get_weights_file_path(config, config["preload"])
        print(f"Preloading model from {model_filename}")
        state = torch.load(model_filename)
        initial_epoch = state["epoch"] + 1
        optimizer.load_state_dict(state["optimizer_state_dict"])
        global_step = state["global_step"]

    for epoch in range(initial_epoch, config["num_epochs"]):
        batch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch}")
        for batch in batch_iterator:
            model.train()
            encoder_input = batch["encoder_input"].to(device)
            decoder_input = batch["decoder_input"].to(device)
            encoder_mask = batch["encoder_mask"].to(device)
            decoder_mask = batch["decoder_mask"].to(device)

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            label = batch["label"].to(device)
            loss = criterion(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item(): 6.3f}"})

            writer.add_scalar("Loss/train", loss.item(), global_step)
            writer.flush()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            global_step += 1

        run_validation(
            model, 
            val_dataloader, 
            tokenizer_src, 
            tokenizer_tgt, 
            config["seq_len"], 
            device, 
            lambda msg: batch_iterator.write(msg), 
            global_step, 
            writer)

        model_filename = get_weights_file_path(config, epoch)
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step
        }, model_filename)

In [None]:
config = get_config()
train_model(config)

In [None]:
config = get_config()
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_dataset(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
model_filename = get_weights_file_path(config, 19)
state = torch.load(model_filename)
model.load_state_dict(state["model_state_dict"])

In [12]:
from json import decoder


def load_nex_batch():
    batch = next(iter(val_dataloader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    encoder_input_token = [tokenizer_src.id_to_token(x) for x in encoder_input[0].cpu().numpy()]
    decoder_input_token = [tokenizer_tgt.id_to_token(x) for x in decoder_input[0].cpu().numpy()]

    model_output = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config["seq_len"], device)

    return batch, encoder_input_token, decoder_input_token

In [13]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

In [14]:
def get_attn_map(attn_type: str, layer: int, head: int):
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention_block.attention_scores
    return attn[0, head].data

In [15]:
def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )

In [16]:
def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [None]:
batch, encoder_input_token, decoder_input_token = load_nex_batch()
print(f"Source: {batch['src_text'][0]}")
print(f"Expected: {batch['tgt_text'][0]}")
sentence_len = encoder_input_token.index("[PAD]")

In [None]:
layers = [0, 1, 2]
heads = [0, 1, 2, 3, 4, 5, 6, 7]

get_all_attention_maps("encoder", layers, heads, encoder_input_token, encoder_input_token, min(20, sentence_len))

In [None]:
get_all_attention_maps("decoder", layers, heads, decoder_input_token, decoder_input_token, min(20, sentence_len))

In [None]:
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_token, decoder_input_token, min(20, sentence_len))