<a href="https://colab.research.google.com/github/ricefan-tech/Transformer/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
import torch
import torch.nn as nn
import math

import torch.nn.functional as F
import pdb


# Transformer skeleton

In [None]:
class MyEmbedding(nn.Module):
  def __init__(self, d_model, vocab_size):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, input):
    return self.embedding(input)

class MyPositionalEncoding(nn.Module):
  def __init__(self, max_seq_len, d_model):
    super().__init__()
    self.positionalEncoding = nn.Embedding(max_seq_len, d_model)

  def forward(self, input):
    # embeddings look up the positions of the input entries and give the d_model-dimensional embedding
    # pdb.set_trace()
    seq_len = input.size(1)
    positional_encoding = self.positionalEncoding(torch.arange(seq_len, device=input.device)).unsqueeze(0) # need 2 dimension for broadcasting, and size(1) is seq
    return positional_encoding

class MyMultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.Q_proj = nn.Linear(d_model, d_model, bias=False)
    self.K_proj = nn.Linear(d_model, d_model, bias=False)
    self.V_proj = nn.Linear(d_model, d_model, bias=False)
    self.num_heads = num_heads

  def forward(self, q_input, v_input, k_input, padding_mask=None, causal_mask=None, kv_cache=None):

    Q = self.Q_proj(q_input) # now has shape (batch_size, seq_len, d_model)
    K = self.K_proj(k_input)
    V = self.V_proj(v_input)

    # before attn score calculation need to reshape into multi head
    batch_size, decoder_seq_len, _ = q_input.size()
    batch_size, encoder_seq_len, _ = k_input.size()
    Q = Q.reshape(batch_size, decoder_seq_len, self.num_heads, d_model//self.num_heads).transpose(1, 2) #d_model must be divisible by num_heads
    K = K.reshape(batch_size, encoder_seq_len, self.num_heads, d_model//self.num_heads).transpose(1, 2)
    V = V.reshape(batch_size, encoder_seq_len, self.num_heads, d_model//self.num_heads).transpose(1, 2) # batch, n_heads, source_seq_len,

    if kv_cache is not None:
      K = torch.cat([kv_cache["key"], K], dim=1)
      V = torch.cat([kv_cache["value"], V], dim=1)
      kv_cache["key"] = K
      kv_cache["value"] = V

    if padding_mask is not None:
      # padding mask is of shape (batch_size, seq_len), needs to be broadcasted to match Q@K.T which is (batch_size, num_head, seq_len, seq_len)
      padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
      padding_mask = padding_mask==0
      # pdb.set_trace()
      padding_mask = padding_mask.expand(-1, self.num_heads, decoder_seq_len, encoder_seq_len)
      attention_scores = Q @ K.transpose(-2, -1) / math.sqrt(d_model//self.num_heads)
      attention_scores = attention_scores.masked_fill(padding_mask, float("-inf"))
      # print(f"scores after padding: {attention_scores}")
    else:
      attention_scores = Q @ K.transpose(-2, -1) / math.sqrt(d_model//self.num_heads)

    # pdb.set_trace()
    if causal_mask is not None:
      # causal_mask is lower triangular matrix of shape target_seq_len, target_seq_len
      causal_mask = causal_mask.unsqueeze(0).unsqueeze(1).to(q_input.device)
      attention_scores = attention_scores.masked_fill(causal_mask, float("-inf"))
      # print(f"scores after cuasal: {attention_scores}")
    attention_weights = F.softmax(attention_scores, dim=-1) @ V # is shape (batch_size, num_head, seq_len, head_dim)
    return attention_weights.reshape(batch_size, -1, d_model), kv_cache


class MyFeedForwardNetwork(nn.Module):
  def __init__(self, d_model, ff_hidden):
    super().__init__()
    self.layer1 = nn.Linear(d_model, ff_hidden)
    self.layer2 = nn.Linear(ff_hidden, d_model)
    self.relu = nn.ReLU()

  def forward(self, input):
    return self.layer2(self.relu(self.layer1(input)))


class MyLayerNorm(nn.Module):
  def __init__(self, d_model, eps=1e-6):
    super().__init__()
    self.eps = eps #stabuliser for division

    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))

  def forward(self, input):
    mean = input.mean(dim=-1, keepdim=True)
    var = input.var(dim=-1, keepdim=True) #keeps broadcasted shape
    # pdb.set_trace()
    normalised_input = (input - mean)/ torch.sqrt( var+ self.eps)
    return self.gamma * normalised_input + self.beta


class MyEncoderLayer(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, ff_hidden):
    super().__init__()
    self.multiheadattention = MyMultiHeadAttention(d_model, num_heads)
    self.layer_norm = MyLayerNorm(d_model)
    self.layer_norm2 = MyLayerNorm(d_model)
    self.ff_network = MyFeedForwardNetwork(d_model, ff_hidden)

  def forward(self, input, padding_mask):
    attention_weights, kv_cache = self.multiheadattention(input, input, input, padding_mask=padding_mask, causal_mask=None)
    # print(f"encoder attention weights: {attention_weights}")
    resid_conn = input+attention_weights
    noramlised_resid_conn = self.layer_norm(resid_conn)
    ff_network = self.ff_network(noramlised_resid_conn)
    resid_conn2 = noramlised_resid_conn+ff_network
    normalised_resid_conn = self.layer_norm2(resid_conn2)
    return normalised_resid_conn


class MyEncoder(nn.Module):
  def __init__(self, vocab_size, max_seq_len, d_model, num_heads, ff_hidden, num_encoder_layers):
    super().__init__()
    self.embedding_layer = MyEmbedding(d_model, vocab_size)
    self.positional_encoding = MyPositionalEncoding(max_seq_len, d_model)
    self.encoders_layers = nn.ModuleList([MyEncoderLayer(vocab_size, d_model, num_heads, ff_hidden) for _ in range(num_encoder_layers)])

  def forward(self, input, padding_mask):
    input = input.long()
    embed_input = self.embedding_layer(input)
    input = embed_input + self.positional_encoding(input)
    for layer in self.encoders_layers:
      input = layer(input, padding_mask)
    return input


class MyOutputLayer(nn.Module):
  def __init__(self, d_model, vocab_size):
    super().__init__()
    self.W_output = nn.Linear(d_model, vocab_size)

  def forward(self, input):
    return self.W_output(input)

class MyDecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, ff_hidden):
    super().__init__()

    self.multiheadattention = MyMultiHeadAttention(d_model, num_heads)
    self.layer_norm = MyLayerNorm(d_model)
    self.layer_norm2 = MyLayerNorm(d_model)
    self.layer_norm3 = MyLayerNorm(d_model)
    self.layer_norm4 = MyLayerNorm(d_model)
    self.ff_network = MyFeedForwardNetwork(d_model, ff_hidden)
    self.ff_network2 = MyFeedForwardNetwork(d_model, ff_hidden)
    self.cross_multiheadattention = MyMultiHeadAttention(d_model, num_heads)

  def forward(self, input, encoder_output, encoder_mask, padding_mask):
    # self attention first
    seq_len = input.size(1) # input is batch_size, seq_len
    causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    # ocnvert to boolean for the masked_fill() function
    causal_mask = causal_mask == 1 # masked_Fill fills the positions with -inf that are True
    # SELF ATTENTION uses DECODER padding mask
    attention_weights = self.multiheadattention(input, input, input, padding_mask=padding_mask, causal_mask=causal_mask)
    # print(f"decoder SELF attention weights: {attention_weights}")
    resid_conn = input+attention_weights
    noramlised_resid_conn = self.layer_norm(resid_conn)
    ff_network = self.ff_network(noramlised_resid_conn)
    resid_conn2 = noramlised_resid_conn+ff_network
    normalised_resid_conn = self.layer_norm2(resid_conn2)

    # cross attention uses ENCODER padding mask
    crossattention_weights = self.cross_multiheadattention(normalised_resid_conn, encoder_output, encoder_output, padding_mask=encoder_mask)

    # print(f"decoder CROSS attention weights: {crossattention_weights}")
    resid_conn = normalised_resid_conn+crossattention_weights
    noramlised_resid_conn = self.layer_norm3(resid_conn)
    ff_network = self.ff_network2(noramlised_resid_conn)
    resid_conn2 = noramlised_resid_conn+ff_network
    normalised_resid_conn = self.layer_norm4(resid_conn2)
    return normalised_resid_conn


class MyDecoder(nn.Module):
  def __init__(self, vocab_size, max_seq_len, d_model, num_heads, ff_hidden, num_decoder_layers):
    super().__init__()
    self.embedding_layer = MyEmbedding(d_model, vocab_size)
    self.positional_encoding = MyPositionalEncoding(max_seq_len, d_model)
    self.decoder_layers = nn.ModuleList([MyDecoderLayer(d_model, num_heads, ff_hidden) for _ in range(num_decoder_layers)])
    self.output_layer = MyOutputLayer(d_model, vocab_size)

  def forward(self, input, encoder_output, encoder_mask, padding_mask):
    input = input.long()
    embed_input = self.embedding_layer(input)
    input = embed_input + self.positional_encoding(input)
    for layer in self.decoder_layers:
      input = layer(input, encoder_output, encoder_mask,  padding_mask)
    logits = self.output_layer(input)
    return logits

class MyTransformer(nn.Module):
  def __init__(self, vocab_size, max_seq_len, d_model, num_heads, ff_hidden, num_encoder_layers, num_decoder_layers):
    super().__init__()
    self.encoder = MyEncoder(vocab_size, max_seq_len, d_model, num_heads, ff_hidden, num_encoder_layers)
    self.decoder = MyDecoder(vocab_size, max_seq_len, d_model, num_heads, ff_hidden, num_decoder_layers)

  def forward(self, source, target, source_padding_mask, target_padding_mask):
    encoder_output = self.encoder(source, source_padding_mask)
    decoder_output = self.decoder(target, encoder_output, source_padding_mask, target_padding_mask, kv_cache) #raw logits, F.CrossEntropy already does softmax

    return decoder_output

  def generate(self, bos_token, source_inpt,source_padding_mask, max_steps, kv_cache=None):
    with torch.no_grad():
        encoder_outputs = self.encoder(source_inpt, attention_mask=attention_mask)

        # 2. Initialize decoder input and kv cache
        decoder_input_ids = torch.tensor([[bos_token]], device=device)
        kv_cache = [None] * self.decoder.num_layers  # e.g., List[(K, V), ...]

    encoder_output = self.encoder(source_inpt, source_padding_mask)
    for _ in range(max_steps):
      logits, kv_cache = self.decoder(decoder_input_ids, encoder_outputs, source_padding_mask, kv_cache)
      next_token = torch.argmax(logits[:, -1,:], dim=-1).unsqueeze(1) # to be [1,1]

      decoder_input = torch.cat([decoder_input_ids, next_token], dim=-1)s

In [None]:
def loss_function(target, model_output):
  # target is (batch_size, seq_len), no embedding, essnetially  label per sample (sample being one of seq_len*batch), ie all positions
  # model output is (batch_size, seq_len, vocab_size)
  vocab_size = model_output.size(-1)
  return F.cross_entropy(model_output.reshape(-1, vocab_size), target.reshape(-1))


def generate(model, max_seq_len):
  # inference from ready trained model with greedy appraoch
  return

# Training on Huggingface data

## Laoding in WMT16 from Huggingface interface

In [None]:
import torch
torch.cuda.is_available()

True

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

In [None]:
checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# ! pip install -U datasets

In [None]:
dataset = load_dataset('wmt16','de-en', split="train[:1%]")


README.md:   0%|          | 0.00/11.1k [00:00<?, ?B/s]

train-00000-of-00003.parquet:   0%|          | 0.00/282M [00:00<?, ?B/s]

train-00001-of-00003.parquet:   0%|          | 0.00/267M [00:00<?, ?B/s]

train-00002-of-00003.parquet:   0%|          | 0.00/277M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/343k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/475k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4548885 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2169 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2999 [00:00<?, ? examples/s]

In [None]:
dataset

Dataset({
    features: ['translation'],
    num_rows: 45489
})

In [None]:
test_dataset = load_dataset('wmt16', 'de-en', split='test')
test_dataset

Dataset({
    features: ['translation'],
    num_rows: 2999
})

In [None]:

def preprocess_function(dataset_sample, tokenizer=tokenizer):
  source_seq = tokenizer(dataset_sample["translation.de"], max_length=128, truncation=True, padding="max_length")
  target_seq = tokenizer(dataset_sample["translation.en"], max_length=128, truncation=True, padding="max_length")
  source_seq["labels"] = target_seq
  return source_seq
dataset_flat = dataset.flatten()
preprocessed_dataset = dataset_flat.map(preprocess_function, batch_size=1000, num_proc=4, remove_columns=["translation.de", "translation.en"])

In [None]:
preprocessed_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [None]:
test_dataset = test_dataset.flatten()
preprocessed_test_dataset = test_dataset.map(preprocess_function, batch_size=1000, num_proc=4, remove_columns=["translation.de", "translation.en"])
preprocessed_test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(preprocessed_dataset, batch_size=10, shuffle=True)

In [None]:
test_dataloader = DataLoader(preprocessed_test_dataset, batch_size=10, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
from torch.optim import Adam

config_params = {
    "d_model": [20],
    "num_heads": [4],
    "ff_hidden": [10],
    "num_encoder_layers" : [1],
    "num_decoder_layers": [1],
    "learning_rate": [1e-4]
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.vocab_size
max_seq_len = 200 # needs to larger than max length of seq_len
num_epochs = 10
epoch_loss = []
for d_model in config_params["d_model"]:
  for num_heads in config_params["num_heads"]:
    for ff_hidden in config_params["ff_hidden"]:
      for num_encoder_layers in config_params["num_encoder_layers"]:
        for num_decoder_layers in config_params["num_decoder_layers"]:
          for lr in config_params["learning_rate"]:
            transformer_model = MyTransformer(vocab_size, max_seq_len, d_model, num_heads, ff_hidden, num_encoder_layers, num_decoder_layers)
            transformer_model = transformer_model.to(device)
            optim = Adam(transformer_model.parameters(), lr=lr)

            for epoch in range(num_epochs):

              transformer_model.train()
              total_loss = 0

              for idx, batch in enumerate(train_dataloader):

                source_inpt = batch["input_ids"].long()
                source_mask = batch["attention_mask"].long()
                target = batch["labels"]["input_ids"].long()
                target = target.to(device)
                loss_fct_target = target[:, 1:] #targets for loss function start after first token
                decoder_input = target[:, :-1]

                target_mask = batch["labels"]["attention_mask"]
                target_mask = target_mask.to(device)
                target_mask = target_mask[:, :-1] #during training the model should not see the next token
                source_inpt = source_inpt.to(device)
                source_mask = source_mask.to(device)
                decoder_input = decoder_input.to(device)

                optim.zero_grad() # reset the gradients
                logits = transformer_model(source_inpt, decoder_input, source_mask, target_mask)
                loss = loss_function(loss_fct_target, logits)
                loss.backward()
                optim.step()
                total_loss += loss.item()
              epoch_loss.append(total_loss/len(train_dataloader))

In [None]:
loss_fct_target.size()

torch.Size([9, 127])

In [None]:
logits.size()

torch.Size([9, 127, 32100])

In [None]:
for idx, batch in enumerate(test_dataloader):
  source_inpt = batch["input_ids"]
  source_mask = batch["attention_mask"].long()
  target = batch["labels"]["input_ids"].long()
  break

In [None]:
tokenizer.SPECIAL_TOKENS_ATTRIBUTES

['bos_token',
 'eos_token',
 'unk_token',
 'sep_token',
 'pad_token',
 'cls_token',
 'mask_token',
 'additional_special_tokens']

In [None]:
bos = tokenizer.bos_token

In [None]:
# generate
outputs = []
transformer_model.eval()
for i in range(max_seq_len):
  source_inpt = transformer_model(bos)

In [None]:
transformer_model

MyTransformer(
  (encoder): MyEncoder(
    (embedding_layer): MyEmbedding(
      (embedding): Embedding(32100, 20)
    )
    (positional_encoding): MyPositionalEncoding(
      (positionalEncoding): Embedding(200, 20)
    )
    (encoders_layers): ModuleList(
      (0): MyEncoderLayer(
        (multiheadattention): MyMultiHeadAttention(
          (Q_proj): Linear(in_features=20, out_features=20, bias=False)
          (K_proj): Linear(in_features=20, out_features=20, bias=False)
          (V_proj): Linear(in_features=20, out_features=20, bias=False)
        )
        (layer_norm): MyLayerNorm()
        (layer_norm2): MyLayerNorm()
        (ff_network): MyFeedForwardNetwork(
          (layer1): Linear(in_features=20, out_features=10, bias=True)
          (layer2): Linear(in_features=10, out_features=20, bias=True)
          (relu): ReLU()
        )
      )
    )
  )
  (decoder): MyDecoder(
    (embedding_layer): MyEmbedding(
      (embedding): Embedding(32100, 20)
    )
    (positional_enc