# KAIST AI605 Assignment 3: Transformer


## Environment
You will only use Python 3.7 and PyTorch 1.9, which is already available on Colab:

In [1]:
import os
import re
import math
from copy import deepcopy
from typing import Iterable, Union
from platform import python_version
from datetime import datetime
from itertools import takewhile
from string import digits, ascii_lowercase, ascii_uppercase

from tqdm.auto import tqdm, trange
from IPython.display import clear_output

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import _LRScheduler

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

device = torch.device("cuda")

print("python", python_version())
print("torch", torch.__version__)

python 3.9.7
torch 1.9.1


In [2]:
try:
    import datasets
except:
    !pip install datasets
    clear_output()
    import datasets

## 1. Attention Layer

We will first start with going over a few concepts that you learned in your high school statistics class.
The variance of a random variable $X$, $\text{Var}(X)$ is defined as $\text{E}[(X-\mu)^2]$ where $\mu$ is the mean of $X$.
Furthermore, given two independent random variables $X$ and $Y$ and a constant $a$,
$$ \text{Var}(X+Y) = \text{Var}(X) + \text{Var}(Y),$$
$$ \text{Var}(aX) = a^2\text{Var}(X),$$
$$ \text{Var}(XY) = \text{E}[X^2]\,\text{E}[Y^2] - (\text{E}[X])^2(\text{E}[Y])^2.$$

> **Problem 1.1** *(3 points)*
  Suppose we are given two sets of $n$ random variables, $X_1 \dots X_n$ and $Y_1 \dots Y_n$,
  where all of these $2n$ variables are mutually independent and have a mean of $0$ and a variance of $1$.
  Prove that
  $$\text{Var}\left(\sum_i^n X_i Y_i\right) = n.$$

> **Solution 1.1**
  Given that $X_i$ and $Y_i$ are mutually independent and have a mean of 0 and a variance of 1, we can write
  $$\text{E}[X_i] = \text{E}[Y_i] = 0$$
  and
  $$\text{Var}(X_i) = \text{E}[(X_i - \mu_i)^2] = \text{E}[X_i^2] = 1,$$
  $$\text{Var}(Y_i) = \text{E}[(Y_i - \mu_i)^2] = \text{E}[Y_i^2] = 1,$$
  for all $i = 1, \dots, n$ where $\mu_i$ is the mean of $X_i$ and $Y_i$ respectively.
  Therefore,
  $$
    \begin{align*}
      \text{Var}\left(\sum_i^n X_i Y_i\right)
      &= \sum_i^n \text{Var}(X_iY_i) \\
      &= \sum_i^n \left(\text{E}[X_i^2]\,\text{E}[Y_i^2] - (\text{E}[X_i])^2(\text{E}[Y_i])^2\right) \\
      &= \sum_i^n (1 \cdot 1 - 0^2 \cdot 0^2) \\
      &= n.
    \end{align*}
  $$

In Lecture 11 and 12, we discussed how the attention is computed in Transformer via the following equation,
  $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V.$$

> **Problem 1.2** *(3 points)* 
  Suppose $Q$ and $K$ are matrices of independent variables each of which has a mean of $0$ and a variance of $1$.
  Using what you learned from Problem 1.1., show that
  $$\text{Var}\left(\frac{QK^\top}{\sqrt{d_k}}\right) = \mathbf{1}.$$

> **Solution 1.2**
  Let the matrix $Q \in \mathbb{R}^{n \times d_k}$ and the matrix $K \in \mathbb{R}^{m \times d_m}$.
  The $Q_i$ and the $K_j$ are the row vectors of $Q$ and $K$ respectively, and each of them contains $d_k$ random variables.
  Total $2d_k$ random variables in $Q_i$ and $K_j$ are mutually independent and have a mean of $0$ and a variance of $1$.
  Then, using the result of Problem 1.1, we can rewrite the equation as
  $$
    \begin{align*}
      \text{Var}\left(\frac{QK^\top}{\sqrt{d_k}}\right)
      &= \frac{1}{d_k} \text{Var}(QK^\top) \\
      &= \frac{1}{d_k} \text{Var}\left(\sum_k^{d_k} Q_{ik} K_{jk}\right) \\
      &= \frac{1}{d_k}
        \begin{bmatrix}
          d_k    & \cdots & d_k \\
          \vdots & \ddots & \vdots \\
          d_k    & \cdots & d_k
        \end{bmatrix} \\
      &= \begin{bmatrix}
          1      & \cdots & 1 \\
          \vdots & \ddots & \vdots \\
          1      & \cdots & 1
        \end{bmatrix} \\
      &= \mathbf{1}.
    \end{align*}
  $$
  Here, $\mathbf{1}$ represents the matrix of ones of size $n \times m$.

> **Problem 1.3** *(4 points)*
  What would happen if the assumption that the variance of $Q$ and $K$ is $1$ does not hold?
  Consider each case of it being higher and lower than $1$ and conjecture what it implies, respectively.

> **Solution 1.3**
  If the variance of $Q$ and $K$ is higher than $1$, values go to the regions that the gradients of the softmax function are small.
  It means that the gradients becomes small and it affects negatively while the model is training.
  If the variance of $Q$ and $K$ is lower than $1$, the softmax output variance also becomes lower.
  It means that the attention values are not much different from each other, and the attention is not able to focus on the important parts of $V$.

## 2. Transformer

In this section, you will implement Transformer for a few tasks that are simpler than machine translation.
First, go through [Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html) and make sure you understand every block of the code.
Then, you will reuse these code where appropriate to create models for following three tasks.
Note that we do not provide a separate training or evaluation data, so it is your job to be able to create these in a reasonable manner.

### Model Architecture

In [3]:
class Transformer(nn.Module):
    def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
                 layer_norm_eps: float = 1e-5, **kwargs):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        # Encoder module
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, layer_norm_eps)
        encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
        # Decoder module
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, layer_norm_eps)
        decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
        # Reset parameters
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None):
        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
        return output

    def generate_square_subsequent_mask(self, sz: int):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
        return mask

### Encoder and Decoder Stacks

In [4]:
class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = nn.ModuleList([deepcopy(encoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None):
        output = src
        for layer in self.layers:
            output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
        if self.norm is not None:
            output = self.norm(output)
        return output

In [5]:
class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, norm=None):
        super().__init__()
        self.layers = nn.ModuleList([deepcopy(decoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None):
        output = tgt
        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask)
        if self.norm is not None:
            output = self.norm(output)
        return output

In [6]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-5, **kwargs):
        super().__init__()
        # Attention modules
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        # Feed-forward modules
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Encoder modules
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Self-attention
        src2 = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        # Feed-forward
        src2 = self.linear1(src)
        src2 = F.relu(src2)
        src2 = self.dropout(src2)
        src2 = self.linear2(src2)
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

In [7]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-5, **kwargs):
        super().__init__()
        # Attention modules
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        # Feed-forward modules
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # Decoder modules
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None):
        # Self-attention
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        # Cross-attention
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        # Feed-forward
        tgt2 = self.linear1(tgt)
        tgt2 = F.relu(tgt2)
        tgt2 = self.dropout(tgt2)
        tgt2 = self.linear2(tgt2)
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt

### Attention

In [8]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0., **kwargs):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        self.head_dim_scale = math.sqrt(self.head_dim)
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        # Projection
        self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim)))
        self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        # Reset parameters
        nn.init.xavier_uniform_(self.in_proj_weight)
        nn.init.constant_(self.in_proj_bias, 0.)
        nn.init.constant_(self.out_proj.bias, 0.)

    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None):
        batch_size, tgt_len, _ = query.shape
        _, src_len, _ = key.shape

        w_q, w_k, w_v = self.in_proj_weight.chunk(3)
        b_q, b_k, b_v = self.in_proj_bias.chunk(3)
        query = F.linear(query.transpose(0, 1), w_q, b_q) \
                 .reshape(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
        key   = F.linear(key.transpose(0, 1), w_k, b_k) \
                 .reshape(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
        value = F.linear(value.transpose(0, 1), w_v, b_v) \
                 .reshape(-1, batch_size * self.num_heads, self.head_dim).transpose(0, 1)

        if key_padding_mask is not None:
            key_padding_mask = \
                key_padding_mask.view(batch_size, 1, 1, src_len) \
                                .expand(-1, self.num_heads, -1, -1) \
                                .reshape(batch_size * self.num_heads, 1, src_len)
            if attn_mask is None:
                attn_mask = key_padding_mask
            else:
                if attn_mask.dim() == 2:
                    attn_mask = attn_mask.unsqueeze(0)
                if attn_mask.dtype == torch.bool:
                    attn_mask = attn_mask.logical_or(key_padding_mask)
                else:
                    attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))

        if attn_mask is not None and attn_mask.dtype == torch.bool:
            attn_mask = torch.zeros_like(attn_mask, dtype=torch.float) \
                             .masked_fill_(attn_mask, float("-inf"))

        attn = torch.bmm(query, key.transpose(-2, -1)) / self.head_dim_scale
        attn += attn_mask if attn_mask is not None else 0
        attn = F.softmax(attn, dim=-1)
        attn = F.dropout(attn, p=self.dropout, training=self.training)
        attn = torch.bmm(attn, value)
        attn = attn.transpose(0, 1).reshape(tgt_len, batch_size, self.embed_dim)
        attn = self.out_proj(attn).transpose(1, 0)
        return attn, None

### Positional Encoding

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int = 512, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, step=2) * -(math.log(10000) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        x = self.dropout(x)
        return x

### Full Model

In [10]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 512, nhead: int = 8,
                 num_encoder_layers: int = 6, num_decoder_layers: int = 6,
                 dim_feedforward: int = 2048, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.d_model = d_model
        self.d_model_scale = math.sqrt(d_model)
        self.embedder = nn.Embedding(vocab_size, d_model)
        self.positional_encoder = PositionalEncoding(d_model, dropout, max_len)
        self.transformer = Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers,
                                       dim_feedforward, dropout, batch_first=True)
        self.out_linear = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None):
        src = self.embedder(src) * self.d_model_scale
        tgt = self.embedder(tgt) * self.d_model_scale
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        out = self.transformer(
            src, tgt, tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
        )
        out = self.out_linear(out)
        return out

### Optimizer Schedule

In [11]:
class NoamLR(_LRScheduler):
    def __init__(self, optimizer, warmup_epochs: int, warmup_factor: float = 1.0):
        self.warmup_epochs = warmup_epochs
        self.warmup_factor = warmup_factor
        super().__init__(optimizer)

    def get_lr(self):
        last_epoch = max(1, self.last_epoch)
        scale = self.warmup_factor \
              * min(last_epoch ** (-0.5), last_epoch * self.warmup_epochs ** (-1.5))
        return [base_lr * scale for base_lr in self.base_lrs]


### Loss

In [12]:
class CrossEntropyLoss(_Loss):
    def __init__(self, ignore_index: int = -100, reduction: str = "mean", label_smoothing: float = 0.0):
        super().__init__(reduction=reduction)
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing

    def forward(self, input, target):
        is_ignore_index_enabled = (0 <= self.ignore_index < input.size(1))
        num_wrong_labels = input.size(1) - (2 if is_ignore_index_enabled else 1)
        smoothed_target = torch.full_like(input, self.label_smoothing / num_wrong_labels)
        smoothed_target.scatter_(1, target.unsqueeze(1), 1 - self.label_smoothing)
        if is_ignore_index_enabled:
            smoothed_target[:, self.ignore_index] = 0
        out = -torch.sum(F.log_softmax(input, dim=1) * smoothed_target, dim=1)
        if self.reduction == "mean":
            out = torch.mean(out)
        elif self.reduction == "sum":
            out = torch.sum(out)
        return out

### Trainer

In [13]:
class Trainer:
    def __init__(self, model, criterion, optimizer, scheduler, device="cuda"):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = torch.device(device)
        self.model.to(device)
        self.model_state_dict = self.model.state_dict
        self.model_load_state_dict = self.model.load_state_dict

    def run_epoch(self, dataloader, train=True):
        total_loss, total_corrects, total_samples = 0, 0, 0
        if train:
            self.model.train()
            torch.set_grad_enabled(True)
        else:
            self.model.eval()
            torch.set_grad_enabled(False)

        for src, tgt, src_mask, tgt_mask in dataloader:
            src, tgt = src.to(self.device), tgt.to(self.device)
            tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
            src_mask = None if src_mask is None else src_mask.to(self.device)
            tgt_mask = None if tgt_mask is None else tgt_mask[:, 1:].to(self.device)

            out = self.model(src, tgt_in, src_mask, tgt_mask)
            loss = self.criterion(out.permute(0, 2, 1), tgt_out)

            if train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            pred = torch.argmax(out, dim=2)
            corrects = torch.sum(torch.all(pred == tgt_out, dim=1)).item()

            total_loss += loss.detach().item() * src.size(0)
            total_corrects += corrects
            total_samples += src.size(0)

        epoch_loss = total_loss / total_samples
        epoch_acc = total_corrects / total_samples
        return epoch_loss, epoch_acc

    @torch.no_grad()
    def predict(self, src, src_mask, start_idx, end_idx):
        self.model.eval()
        pred = torch.full((src.size(0), 1), start_idx, dtype=torch.long, device=self.device)
        for _ in range(src.size(1) + 2):
            out_t = self.model(src, pred, src_mask, None)
            pred_t = torch.argmax(out_t[:, -1, :], dim=1)
            pred = torch.hstack((pred, pred_t.unsqueeze(1)))

        end_lambda = lambda x: x != end_idx
        results = []
        for i in range(src.size(0)):
            src_str = "".join(map(Vocab.to_vocab, takewhile(end_lambda, src[i, 1:].tolist())))
            pred_str = "".join(map(Vocab.to_vocab, takewhile(end_lambda, pred[i, 1:].tolist())))
            results.append((src_str, pred_str))
        return results

    def train(self, train_loader, valid_loader, num_epochs):
        best_loss, best_acc, best_state_dict = float("inf"), 0, None
        for e in trange(num_epochs, desc="Epoch", leave=False):
            train_loss, train_acc = self.run_epoch(train_loader, train=True)
            tqdm.write(f"Epoch {e+1:3d} Train Loss: {train_loss:.5f} Acc: {train_acc * 100:6.2f}")
            test_loss, test_acc = self.run_epoch(valid_loader, train=False)
            tqdm.write(f"          Test  Loss: {test_loss:.5f} Acc: {test_acc * 100:6.2f}")
            if self.scheduler is not None:
                self.scheduler.step()
            if test_acc > best_acc or (test_acc == best_acc and test_loss < best_loss):
                best_loss, best_acc, best_state_dict = test_loss, test_acc, self.model_state_dict()
                tqdm.write(f"          Updated Best Loss: {best_loss:.5f} Best Acc: {best_acc * 100:6.2f}")

        return best_loss, best_acc, best_state_dict

### Dataset

In [14]:
class Vocab:
    valid_vocab_re = re.compile(r"^[0-9a-zA-Z-' ]+$")
    numeric_vocab_re = re.compile(r"^[0-9-]+$")
    vocabs = ["_", "[", "]", *digits, *ascii_lowercase, "-", "'", " "]
    to_idx_dict = {vocab: idx for idx, vocab in enumerate(vocabs)}
    to_vocab_dict = {idx: vocab for idx, vocab in enumerate(vocabs)}
    char_vocabs = vocabs[3:]
    num_vocabs = len(to_idx_dict)
    num_control_vocabs = 3
    num_valid_vocabs = num_vocabs - num_control_vocabs
    num_char_vocabs = len(char_vocabs)
    num_digit_vocabs = 10

    @classmethod
    def to_idx(cls, char: str) -> int:
        return cls.to_idx_dict[char]

    @classmethod
    def to_vocab(cls, idx: int) -> str:
        return cls.to_vocab_dict[idx]

    @classmethod
    def sample_digits(cls, size: int = 1, generator=None) -> str:
        nums = torch.randint(0, 10, size=(size,), generator=generator).tolist()
        tokens = "".join([str(d) for d in nums])
        return tokens

    @classmethod
    def sample_chars(cls, size: int = 1, generator=None) -> str:
        idxs = torch.randint(3, 42, size=(size,), generator=generator).tolist()
        tokens = "".join([cls.to_vocab(i) for i in idxs])
        return tokens

    @classmethod
    def is_valid_vocabs(cls, vocabs: Union[str, Iterable[str]]) -> bool:
        return cls.valid_vocab_re.fullmatch(vocabs) is not None

    @classmethod
    def is_numeric_vocabs(cls, vocabs: Union[str, Iterable[str]]) -> bool:
        return cls.numeric_vocab_re.fullmatch(vocabs) is not None

In [15]:
def get_collate_fn(pad_tensor=None):
    if pad_tensor is None:
        def collate_fn(batch):
            source = torch.vstack([d[0].unsqueeze(dim=0) for d in batch])
            target = torch.vstack([d[1].unsqueeze(dim=0) for d in batch])
            source_mask = target_mask = None
            return source, target, source_mask, target_mask
    else:
        def collate_fn(batch):
            max_len_source = max(len(d[0]) for d in batch)
            max_len_target = max(len(d[1]) for d in batch)
            source = torch.full((len(batch), max_len_source), pad_tensor, dtype=torch.long)
            target = torch.full((len(batch), max_len_target), pad_tensor, dtype=torch.long)
            source_mask = torch.full((len(batch), max_len_source), True, dtype=torch.bool)
            target_mask = torch.full((len(batch), max_len_target), True, dtype=torch.bool)
            for i, d in enumerate(batch):
                source[i, :len(d[0])] = d[0]
                target[i, :len(d[1])] = d[1]
                source_mask[i, :len(d[0])] = False
                target_mask[i, :len(d[1])] = False
            return source, target, source_mask, target_mask
    return collate_fn

In [16]:
class NLPDataset(Dataset):
    pad_token = "_"
    start_token = "["
    end_token = "]"
    vocab_size = Vocab.num_vocabs
    collate_fn = get_collate_fn(pad_tensor=Vocab.to_idx(pad_token))

    def __init__(self, seed: int = 109):
        self.seed = seed
        self.rng = torch.Generator().manual_seed(seed)

    def tokens_to_idx(self, tokens: Union[str, Iterable[str]]):
        return torch.tensor([Vocab.to_idx(self.start_token)]
                          + [Vocab.to_idx(v) for v in tokens]
                          + [Vocab.to_idx(self.end_token)], dtype=torch.long)

    def idx_to_tokens(self, idx: torch.Tensor):
        start_idx = Vocab.to_idx(self.start_token)
        end_idx = Vocab.to_idx(self.end_token)
        tokens = []
        for i in (idx[1:] if idx[0] == start_idx else idx).tolist():
            if i == end_idx:
                break
            tokens.append(Vocab.to_vocab(i))
        return "".join(tokens)

    def __len__(self):
        return len(self.source)

    def __getitem__(self, idx: int):
        return self.source[idx], self.target[idx]

> **Problem 2.1** *(4 points)*
  Create a model that takes a random set of input symbols from a vocabulary of digits (i.e. 0, 1, ... , 8, 9) as the input and generate back the same symbols.
  Instead of varying length, we fix the length to 32.
  Make sure to report that your model's accuracy (gives credit only if the entire output sequence is correct) goes above 90%.
  Note that a similar problem is also in Annotated Transformer, and copying code is allowed.

In [17]:
class DigitDataset(NLPDataset):
    pad_token = None
    vocab_size = Vocab.num_control_vocabs + Vocab.num_digit_vocabs
    collate_fn = get_collate_fn(pad_tensor=None)

    def __init__(self, num_data: int, token_len: int, seed: int = 109):
        super().__init__(seed=seed)
        data = [self.tokens_to_idx(Vocab.sample_digits(token_len, generator=self.rng))
                for _ in range(num_data)]
        self.source = self.target = data

In [18]:
torch.manual_seed(19)

train_dataset = DigitDataset(num_data=10000, token_len=32, seed=109)
test_dataset  = DigitDataset(num_data=1000,  token_len=32, seed=10)
collate_fn = DigitDataset.collate_fn
vocab_size = DigitDataset.vocab_size

train_loader = DataLoader(train_dataset, batch_size=100, collate_fn=collate_fn, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=100, collate_fn=collate_fn, shuffle=False)

model = TransformerModel(vocab_size, d_model=64, dim_feedforward=256,
                         num_encoder_layers=2, num_decoder_layers=2)
criterion = CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

trainer1 = Trainer(model, criterion, optimizer, scheduler=None, device=device)
best_loss, best_acc, state_dict = trainer1.train(train_loader, test_loader, num_epochs=70)

os.makedirs("./ckpt/", exist_ok=True)
torch.save(state_dict, "./ckpt/digit.pt")

Epoch:   0%|          | 0/70 [00:00<?, ?it/s]

Epoch   1 Train Loss: 2.30001 Acc:   0.00
          Test  Loss: 2.16563 Acc:   0.00
          Updated Best Loss: 2.16563 Best Acc:   0.00
Epoch   2 Train Loss: 2.15899 Acc:   0.00
          Test  Loss: 2.06764 Acc:   0.00
          Updated Best Loss: 2.06764 Best Acc:   0.00
Epoch   3 Train Loss: 2.09246 Acc:   0.00
          Test  Loss: 1.99632 Acc:   0.00
          Updated Best Loss: 1.99632 Best Acc:   0.00
Epoch   4 Train Loss: 2.04703 Acc:   0.00
          Test  Loss: 1.94798 Acc:   0.00
          Updated Best Loss: 1.94798 Best Acc:   0.00
Epoch   5 Train Loss: 2.01505 Acc:   0.00
          Test  Loss: 1.90460 Acc:   0.00
          Updated Best Loss: 1.90460 Best Acc:   0.00
Epoch   6 Train Loss: 1.99106 Acc:   0.00
          Test  Loss: 1.86787 Acc:   0.00
          Updated Best Loss: 1.86787 Best Acc:   0.00
Epoch   7 Train Loss: 1.96579 Acc:   0.00
          Test  Loss: 1.83490 Acc:   0.00
          Updated Best Loss: 1.83490 Best Acc:   0.00
Epoch   8 Train Loss: 1.94658 Acc:

> **Result 2.1**

In [19]:
ex_batch = next(iter(test_loader))
ex_src, ex_src_mask = ex_batch[0][:5].to(device), None
start_idx = Vocab.to_idx(test_loader.dataset.start_token)
end_idx   = Vocab.to_idx(test_loader.dataset.end_token)
trainer1.model_load_state_dict(torch.load("./ckpt/digit.pt"))
results = trainer1.predict(ex_src, ex_src_mask, start_idx, end_idx)
print(f"Best Test Loss: {best_loss:.5f} Test Acc: {best_acc * 100:6.2f}%")
for i, (src_str, pred_str) in enumerate(results):
    print(f"Ex {i+1} {src_str}\n  => {pred_str}")

Best Test Loss: 0.00013 Test Acc: 100.00%
Ex 1 75272572156310634062892099449445
  => 75272572156310634062892099449445
Ex 2 40893093795296628969192129729449
  => 40893093795296628969192129729449
Ex 3 44831826149189749644665636026822
  => 44831826149189749644665636026822
Ex 4 95075203081688426351440359569079
  => 95075203081688426351440359569079
Ex 5 38679382662739589748458598384869
  => 38679382662739589748458598384869



> **Problem 2.2** *(6 points)*
  Now, we will implement a bit more useful function, so-called spelling error correction.
  Your job is to create a model whose input is a word with spelling errors, and the output is the spelling-corrected word.
  Here, your vocabulary will be character instead of word.
  You can create your own training data by using an existing text corpus as the target and inject noise into it to use it as the input.
  You are free to use whichever text corpus you like.
  If you can't think of one, please use context data in SQuAD Dataset (see Assignment 2).
  Report accuracy in your own evaluation data (you will receive full credit as long as both the evaluation data and the accuracy are reasonable),
  and also show 5 examples where it succeeds at correcting spelling.

In [20]:
class TypoDataset(NLPDataset):
    def __init__(self, num_data: int, split: str = "train", level: str = "word",
                 noise_rate: float = 0.9, seed: int = 109):
        super().__init__(seed=seed)
        self.split = split
        self.level = level
        self.noise_rate = noise_rate
        squad = datasets.load_dataset("squad", split=("validation" if split == "valid" else split))
        sentences = []
        for context in sorted(set(squad["context"])):
            sentences.extend([s.lower() for s in (context + " ").split(". ")
                              if Vocab.is_valid_vocabs(s) and len(s) > 10])
        if level == "word":
            words = []
            for sentence in sentences:
                words.extend([w for w in sentence.split(" ")
                              if not Vocab.is_numeric_vocabs(w) and len(w) > 1])
            words = sorted(set(words))
            idxs = torch.randint(len(words), size=(num_data,), generator=self.rng)
            self.raw_data = [words[i] for i in idxs]
            should_perterb = (torch.rand((num_data,), generator=self.rng) < noise_rate).tolist()
            self.source = [self.tokens_to_idx(self._perturb(s) if p else s)
                           for s, p in zip(self.raw_data, should_perterb)]
            self.target = [self.tokens_to_idx(s) for s in self.raw_data]
        elif level == "sentence":
            sentences = [s for s in sentences if len(s) <= 48]
            idxs = torch.randint(len(sentences), size=(num_data,), generator=self.rng)
            self.raw_data = [sentences[i] for i in idxs]
            should_perterb = (torch.rand((num_data,), generator=self.rng) < noise_rate).tolist()
            self.source = [self.tokens_to_idx(self._perturb(s) if p else s)
                           for s, p in zip(self.raw_data, should_perterb)]
            self.target = [self.tokens_to_idx(s) for s in self.raw_data]
        else:
            raise ValueError(f"Invalid level: {level}")

    def _perturb(self, tokens: str):
        token_len = len(tokens)
        while True:
            t = torch.randint(12, size=(1,), generator=self.rng).item()
            i = torch.randint(token_len - 1, size=(1,), generator=self.rng).item()
            if t < 3:  # Swap
                return tokens[:i] + tokens[i + 1] + tokens[i] + tokens[i + 2:]
            elif t < 6:  # Insert
                v = Vocab.sample_chars(generator=self.rng)
                return tokens[:i] + v + tokens[i:]
            elif not tokens[i].isdigit():
                if t < 9:  # Delete
                    return tokens[:i] + tokens[i + 1:]
                else:  # Replace
                    v = Vocab.sample_chars(generator=self.rng)
                    return tokens[:i] + v + tokens[i + 1:]

In [21]:
torch.manual_seed(19)

train_dataset = TypoDataset(num_data=51200, level="word", split="train", seed=109)
test_dataset  = TypoDataset(num_data=10240, level="word", split="valid", seed=10)
collate_fn = TypoDataset.collate_fn
vocab_size = TypoDataset.vocab_size
clear_output()

train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=collate_fn,
                          num_workers=2, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=256, collate_fn=collate_fn,
                          num_workers=2, shuffle=False)

model = TransformerModel(vocab_size, d_model=256, dim_feedforward=1024,
                         num_encoder_layers=2, num_decoder_layers=2)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-5)
scheduler = NoamLR(optimizer, warmup_epochs=30)

trainer2 = Trainer(model, criterion, optimizer, scheduler, device)
best_loss, best_acc, state_dict = trainer2.train(train_loader, test_loader, num_epochs=150)

os.makedirs("./ckpt/", exist_ok=True)
torch.save(state_dict, "./ckpt/typo_word.pt")

Epoch:   0%|          | 0/150 [00:00<?, ?it/s]

Epoch   1 Train Loss: 1.69818 Acc:   0.06
          Test  Loss: 1.38475 Acc:   0.13
          Updated Best Loss: 1.38475 Best Acc:   0.13
Epoch   2 Train Loss: 1.35688 Acc:   0.54
          Test  Loss: 1.24622 Acc:   1.62
          Updated Best Loss: 1.24622 Best Acc:   1.62
Epoch   3 Train Loss: 1.25990 Acc:   1.21
          Test  Loss: 1.17731 Acc:   2.81
          Updated Best Loss: 1.17731 Best Acc:   2.81
Epoch   4 Train Loss: 1.20120 Acc:   1.87
          Test  Loss: 1.12798 Acc:   4.38
          Updated Best Loss: 1.12798 Best Acc:   4.38
Epoch   5 Train Loss: 1.16701 Acc:   2.52
          Test  Loss: 1.10524 Acc:   5.24
          Updated Best Loss: 1.10524 Best Acc:   5.24
Epoch   6 Train Loss: 1.14137 Acc:   3.03
          Test  Loss: 1.07963 Acc:   6.48
          Updated Best Loss: 1.07963 Best Acc:   6.48
Epoch   7 Train Loss: 1.12128 Acc:   3.51
          Test  Loss: 1.05379 Acc:   7.56
          Updated Best Loss: 1.05379 Best Acc:   7.56
Epoch   8 Train Loss: 1.10912 Acc:

> **Result 2.2**

In [22]:
ex_batch = next(iter(test_loader))
ex_idx = [9, 35, 47, 69, 70, 76, 78, 95]
ex_src = ex_batch[0][ex_idx].to(device)
ex_src_mask = ex_batch[2][ex_idx].to(device)
start_idx = Vocab.to_idx(test_loader.dataset.start_token)
end_idx   = Vocab.to_idx(test_loader.dataset.end_token)
trainer2.model_load_state_dict(torch.load("./ckpt/typo_word.pt"))
results = trainer2.predict(ex_src, ex_src_mask, start_idx, end_idx)
print(f"Best Test Loss: {best_loss:.5f} Test Acc: {best_acc * 100:6.2f}%")
for i, (src_str, pred_str) in enumerate(results):
    print(f"Ex {i+1:2d}: {src_str} => {pred_str}")

Best Test Loss: 0.80783 Test Acc:  50.74%
Ex  1: takigg => taking
Ex  2: injruy => injury
Ex  3: offic5ial => official
Ex  4: suqare => square
Ex  5: actdor => actor
Ex  6: prkpagation => propagation
Ex  7: obstacel => obstacle
Ex  8: autsralia => australia


> **Problem 2.3 (bonus)** *(3 points)*
  Extend this word-level spelling correction model to sentence-level.
  You do not have to report accuracy, but find 3 examples where the word-level model fails and sentence-level model correctly predicts.

In [23]:
torch.manual_seed(19)

train_dataset = TypoDataset(num_data=51200, level="sentence", split="train", seed=109)
test_dataset  = TypoDataset(num_data=10240, level="sentence", split="valid", seed=10)
collate_fn = TypoDataset.collate_fn
vocab_size = TypoDataset.vocab_size
clear_output()

train_loader = DataLoader(train_dataset, batch_size=256, collate_fn=collate_fn,
                          num_workers=2, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=256, collate_fn=collate_fn,
                          num_workers=2, shuffle=False)

model = TransformerModel(vocab_size, d_model=256, dim_feedforward=1024,
                         num_encoder_layers=6, num_decoder_layers=6)
criterion = CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-4)
scheduler = NoamLR(optimizer, warmup_epochs=50)

trainer3 = Trainer(model, criterion, optimizer, scheduler, device)
best_loss, best_acc, state_dict = trainer3.train(train_loader, test_loader, num_epochs=170)

os.makedirs("./ckpt/", exist_ok=True)
torch.save(state_dict, "./ckpt/typo_sentence.pt")

Epoch:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch   1 Train Loss: 2.42568 Acc:   0.00
          Test  Loss: 2.09924 Acc:   0.00
          Updated Best Loss: 2.09924 Best Acc:   0.00
Epoch   2 Train Loss: 2.08601 Acc:   0.00
          Test  Loss: 1.97111 Acc:   0.00
          Updated Best Loss: 1.97111 Best Acc:   0.00
Epoch   3 Train Loss: 1.95894 Acc:   0.00
          Test  Loss: 1.84959 Acc:   0.00
          Updated Best Loss: 1.84959 Best Acc:   0.00
Epoch   4 Train Loss: 1.83685 Acc:   0.00
          Test  Loss: 1.75349 Acc:   0.00
          Updated Best Loss: 1.75349 Best Acc:   0.00
Epoch   5 Train Loss: 1.73483 Acc:   0.00
          Test  Loss: 1.68858 Acc:   0.00
          Updated Best Loss: 1.68858 Best Acc:   0.00
Epoch   6 Train Loss: 1.65211 Acc:   0.01
          Test  Loss: 1.63354 Acc:   0.00
          Updated Best Loss: 1.63354 Best Acc:   0.00
Epoch   7 Train Loss: 1.58170 Acc:   0.04
          Test  Loss: 1.61131 Acc:   0.00
          Updated Best Loss: 1.61131 Best Acc:   0.00
Epoch   8 Train Loss: 1.51686 Acc:

> **Result 2.3**

In [24]:
ex_batch = next(iter(test_loader))
ex_idx = [7, 29, 31, 47, 54, 88, 98, 121]
ex_src = ex_batch[0][ex_idx].to(device)
ex_src_mask = ex_batch[2][ex_idx].to(device)
start_idx = Vocab.to_idx(test_loader.dataset.start_token)
end_idx   = Vocab.to_idx(test_loader.dataset.end_token)
trainer3.model_load_state_dict(torch.load("./ckpt/typo_sentence.pt"))
results = trainer3.predict(ex_src, ex_src_mask, start_idx, end_idx)
print(f"Best Test Loss: {best_loss:.5f} Test Acc: {best_acc * 100:6.2f}%")
for i, (src_str, pred_str) in enumerate(results):
    print(f"Ex {i+1:2d}: {src_str}\n    => {pred_str}")

Best Test Loss: 0.75597 Test Acc:  42.94%
Ex  1: he died inn 1259 without a successor
    => he died in 1259 without a successor
Ex  2: mongol rule wtas cosmopolitan under kublai khan
    => mongol rule was cosmopolitan under kublai khan
Ex  3: regional variations and dishes alsu exist
    => regional variations and dishes also exist
Ex  4: he came back5 to lahore in 1908
    => he came back to lahore in 1908
Ex  5: johns river divides the czty
    => johns river divides the city
Ex  6: theey claimed the law infringed article 34
    => they claimed the law infringed article 34
Ex  7: sweden rationed gasoline and hetaing oil
    => sweden rationed gasoline and heating oil
Ex  8: thits was derogatory in the consumers' eyes
    => this was derogatory in the consumers' eyes


In [25]:
src = [test_dataset.tokens_to_idx(s) for s in [
    "inn", "wtas", "alsu", "back5", "czty", "theey", "hetaing", "thits",
]]
ex_src, _, ex_src_mask, _ = collate_fn(list(zip(src, src)))
results = trainer2.predict(ex_src.to(device), ex_src_mask.to(device), start_idx, end_idx)
for i, (src_str, pred_str) in enumerate(results):
    print(f"Ex {i+1:2d}: {src_str} => {pred_str}")

Ex  1: inn => inn
Ex  2: wtas => wats
Ex  3: alsu => alus
Ex  4: back5 => backl
Ex  5: czty => coty
Ex  6: theey => they
Ex  7: hetaing => hettaing
Ex  8: thits => thirts
