<a href="https://colab.research.google.com/github/qmeng222/transformers-for-NLP/blob/main/implement-transformers-from-scratch/decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Challenges:
*   Train from scratch (no fine-tuning)
*   How does training and inference work

In [1]:
import math # Python math module provides mathematical functions

import torch # PyTorch library (a popular DL framework)
import torch.nn as nn # (from PyTorch library) neural network module for building and training neural networks
import torch.nn.functional as F # (within nn module) functional submodule provides functions (such as activation functions, loss functions, and other operations) that are applied element-wise to tensors
from torch.utils.data import Dataset # Dataset class to customize datasets for training

import numpy as np # NumPy library for numerical operations in Python
import matplotlib.pyplot as plt # (from Matplotlib library) pyplot module for data visualization

In [2]:
class CausalSelfAttention(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len):
    # assume d_v = d_k
    # d_k: dimension of the key and value vectors
    # d_model: dimension of the input vectors
    # n_heads: number of attention heads
    # max_len: max sequence length
    super().__init__() # properly initialize the module and set up necessary attributes inherited from nn.Module
    self.d_k = d_k
    self.n_heads = n_heads

    # linear transformations for keys, queries, and values:
    self.key = nn.Linear(d_model, d_k * n_heads)
    self.query = nn.Linear(d_model, d_k * n_heads)
    self.value = nn.Linear(d_model, d_k * n_heads)

    # final linear layer is used for projection:
    self.fc = nn.Linear(d_k * n_heads, d_model)

    # causal mask (cm) as a lower triangular matrix to enforce causality in the self-attention mechanism:
    cm = torch.tril(torch.ones(max_len, max_len)) # tril: "lower triangular" (all elements above the main diagonal are set to 0)
    # register a buffer (tensor) as a persistent buffer of the module
    self.register_buffer(
        "causal_mask", # the name given to the buffer (to access the buffer later)
        cm.view(1, 1, max_len, max_len) # create a view of the lower triangular matrix 'cm' with an additional two dimensions of size 1 at the beginning
    )

  def forward(self, q, k, v, pad_mask=None):
    q = self.query(q) # N x T x (hd_k)
    k = self.key(k)   # N x T x (hd_k)
    v = self.value(v) # N x T x (hd_v)
    N = q.shape[0]
    T = q.shape[1]

    # reshape: (N, T, h, d_k) -> (N, h, T, d_k)
    # in order for matrix multiply to work properly
    q = q.view(N, T, self.n_heads, self.d_k).transpose(1, 2) # (N, T, self.n_heads, self.d_k) -> (N, self.n_heads, T, self.d_k)
    k = k.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
    v = v.view(N, T, self.n_heads, self.d_k).transpose(1, 2)

    # compute attention weights
    # (N, h, T, d_k) x (N, h, d_k, T) --> (N, h, T, T)
    attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)

    # if there's mask: each output can only pay attention to passed tokens
    if pad_mask is not None:
      # replace certain values in the 'attn_scores' tensor with float('-inf') based on the condition specified in the mask (pad_mask[:, None, None, :] == 0)
      attn_scores = attn_scores.masked_fill(
          pad_mask[:, None, None, :] == 0, float('-inf'))

    # else if there's no mask:
    attn_scores = attn_scores.masked_fill(
        self.causal_mask[:, :, :T, :T] == 0, float('-inf')) # up to T (the seq len of the batch)
    attn_weights = F.softmax(attn_scores, dim=-1)

    # compute attention-weighted values
    # (N, h, T, T) x (N, h, T, d_k) --> (N, h, T, d_k)
    A = attn_weights @ v

    # reshape it back before final linear layer
    A = A.transpose(1, 2) # (N, T, h, d_k)
    A = A.contiguous().view(N, T, self.d_k * self.n_heads) # (N, T, h*d_k)

    # projection
    return self.fc(A)

In [3]:
class TransformerBlock(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = CausalSelfAttention(d_k, d_model, n_heads, max_len)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model * 4),
        nn.GELU(),
        nn.Linear(d_model * 4, d_model),
        nn.Dropout(dropout_prob),
    )
    self.dropout = nn.Dropout(p=dropout_prob)

  def forward(self, x, pad_mask=None):
    x = self.ln1(x + self.mha(x, x, x, pad_mask))
    x = self.ln2(x + self.ann(x))
    x = self.dropout(x)
    return x

In [4]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=2048, dropout_prob=0.1):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout_prob)

    position = torch.arange(max_len).unsqueeze(1)
    exp_term = torch.arange(0, d_model, 2)
    div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
    pe = torch.zeros(1, max_len, d_model)
    pe[0, :, 0::2] = torch.sin(position * div_term)
    pe[0, :, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe)

  def forward(self, x):
    # x.shape: N x T x D
    x = x + self.pe[:, :x.size(1), :]
    return self.dropout(x)

In [6]:
class Decoder(nn.Module):
  def __init__(self,
               vocab_size,
               max_len,
               d_k,
               d_model,
               n_heads,
               n_layers,
               dropout_prob):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [
        TransformerBlock(
            d_k,
            d_model,
            n_heads,
            max_len,
            dropout_prob) for _ in range(n_layers)]
    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)
    self.fc = nn.Linear(d_model, vocab_size)

  def forward(self, x, pad_mask=None):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(x, pad_mask)
    x = self.ln(x)
    x = self.fc(x) # many-to-many
    return x

In [7]:
# dummy test the decoder:
model = Decoder(20_000, 1024, 16, 64, 4, 2, 0.1)

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cpu


Decoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05,

In [9]:
x = np.random.randint(0, 20_000, size=(8, 512))
x_t = torch.tensor(x).to(device)

In [14]:
# without mask:
y = model(x_t)
y.shape

torch.Size([8, 512, 20000])

👆 8 samples, with seq len of 512, and each prediction can belong to 20000 possible tokens.

In [12]:
mask = np.ones((8, 512))
mask[:, 256:] = 0
mask_t = torch.tensor(mask).to(device)

In [15]:
# with mask:
y = model(x_t, mask_t)
y.shape

torch.Size([8, 512, 20000])