<a href="https://colab.research.google.com/github/ricefan-tech/Transformer/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

# Transformer skeleton

In [34]:
class MyEmbedding(nn.Module):
  def __init__(self, d_model, vocab_size):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, input):
    return self.embedding(input)

class MyPositionalEncoding(nn.Module):
  def __init__(self, seq_len, d_model):
    super().__init__()
    self.positionalEncoding = nn.Linear(seq_len, d_model, bias=False)

  def forward(self, input):
    return input + self.positionalEncoding

class MyMultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.Q_proj = nn.Linear(d_model, d_model, bias=False)
    self.K_proj = nn.Linear(d_model, d_model, bias=False)
    self.V_proj = nn.Linear(d_model, d_model, bias=False)
    self.num_heads = num_heads

  def forward(self, q_input, v_input, k_input, padding_mask=None, causal_mask=None):
    Q = self.Q_matrix(q_input) # now has shape (batch_size, seq_len, d_model)
    K = self.K_proj(k_input)
    V = self.V_matrix(v_input)

    # before attn score calculation need to reshape into multi head
    batch_size, seq_len, d_model = Q.size()
    Q = Q.reshape(batch_size, seq_len, self.num_heads, d_model//self.num_heads).transpose(1, 2) #d_model must be divisible by num_heads
    K = K.reshape(batch_size, seq_len, self.num_heads, d_model//self.num_heads).transpose(1, 2)
    V = V.reshape(batch_size, seq_len, self.num_heads, d_model//self.num_heads).transpose(1, 2)

    if padding_mask:
      # padding mask is of shape (batch_size, seq_len), needs to be broadcasted to match Q@K.T which is (batch_size, num_head, seq_len, seq_len)
      padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
      attention_scores = Q @ K.transpose(-2, -1) / math.sqrt(d_model//self.num_heads)
      attention_scores = attention_scores.masked_fill(padding_mask, float("-inf"))

    else:
      attention_scors = Q @ K.transpose(-2, -1) / math.sqrt(d_model//self.num_heads)

    if causal_mask:
      # causal_mask is lower triangular matrix of shape target_seq_len, target_seq_len
      causal_mask = causal_mask.unsqueeze(1).unsqueeze(2)
      attention_scores = attention_scores.masked_fill(causal_mask, float("-inf"))

    attention_weights = torch.nn.softmax(attention_scores, dim=-1) @ V # is shape (batch_size, num_head, seq_len, head_dim)
    return attention_weights.reshape(batch_size, seq_len, d_model)


class MyFeedForwardNetwork(nn.Module):
  def __init__(self, d_model, ff_hidden):
    super().__init__()
    self.layer1 = nn.Linear(d_model, ff_hidden)
    self.layer2 = nn.Linear(ff_hidden, d_model)

  def forward(self, input):
    return self.layer2(nn.ReLU(self.layer1(input)))


class MyLayerNorm(nn.Module):
  def __init__(self, d_model, eps=1e-6):
    super().__init__()
    self.eps = eps #stabuliser for division

    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))

  def forward(self, input):
    mean = input.mean(dim=-1, keepdim=True)
    var = input.var(dim=-1, keepdim=True) #keeps broadcasted shape
    normalised_input = (input - mean)/ math.sqrt(var)
    return self.gamma * normalised_input + self.beta


class MyEncoderLayer(nn.Module):
  def __init__(self, vocab_size, seq_len, d_model, num_heads, ff_hidden):
    super().__init__()
    self.embedding_layer = MyEmbedding(d_model, vocab_size)
    self.positional_encoding = MyPositionalEncoding(seq_len, d_model)
    self.multiheadattention = MyMultiHeadAttention(d_model, num_heads)
    self.layer_norm = MyLayerNorm(d_model)
    self.layer_norm2 = MyLayerNorm(d_model)
    self.ff_network = MyFeedForwardNetwork(d_model, ff_hidden)

  def forward(self, input, padding_mask):
    embedded_input = self.embedding_layer(input)
    total_input = self.positional_encoding(embedded_input)
    attention_weights = self.multiheadattention(total_input, total_input, total_input, padding_mask=padding_mask, causal_mask=None)
    resid_conn = total_input+attention_weights
    noramlised_resid_conn = self.layer_norm(resid_conn)
    ff_network = self.ff_network(noramlised_resid_conn)
    resid_conn2 = noramlised_resid_conn+ff_network
    normalised_resid_conn = self.layer_norm2(resid_conn2)
    return normalised_resid_conn


class MyEncoder(nn.Module):
  def __init__(self, vocab_size, seq_len, d_model, num_heads, ff_hidden, num_encoder_layers):
    super().__init__()
    self.encoders = nn.ModuleList([MyEncoderLayer(vocab_size, seq_len, d_model, num_heads, ff_hidden) for _ in range(num_encoder_layers)])

  def forward(self, input, padding_mask):
    return self.encoders(input, padding_mask)


class MyOutputLayer(nn.Module):
  def __init__(self, d_model, vocab_size):
    super().__init__()
    self.W_output = nn.Linear(d_model, vocab_size)

  def forward(self, input):
    return self.W_output(input)

class MyDecoderLayer(nn.Module):
  def __init__(self, vocab_size, seq_len, d_model, num_heads, ff_hidden):
    super().__init__()
    self.embedding_layer = MyEmbedding(d_model, vocab_size)
    self.positional_encoding = MyPositionalEncoding(seq_len, d_model)
    self.multiheadattention = MyMultiHeadAttention(d_model, num_heads)
    self.layer_norm = MyLayerNorm(d_model)
    self.layer_norm2 = MyLayerNorm(d_model)
    self.layer_norm3 = MyLayerNorm(d_model)
    self.layer_norm4 = MyLayerNorm(d_model)
    self.ff_network = MyFeedForwardNetwork(d_model, ff_hidden)
    self.ff_network2 = MyFeedForwardNetwork(d_model, ff_hidden)
    self.cross_multiheadattention = MyMultiHeadAttention(d_model, num_heads)
    self.output_layer = MyOutputLayer(d_model, vocab_size)

  def forward(self, input, encoder_output, padding_mask):
    embedded_input = self.embedding_layer(input)
    total_input = self.positional_encoding(embedded_input)
    # self attention first
    seq_len = input.size(0)
    causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    # ocnvert to boolean for the masked_fill() function
    causal_mask = causal_mask == 0 # masked_Fill fills the positions with -inf that are True
    attention_weights = self.multiheadattention(total_input, total_input, total_input, padding_mask=self.padding_mask, causal_mask=causal_mask)
    resid_conn = total_input+attention_weights
    noramlised_resid_conn = self.layer_norm(resid_conn)
    ff_network = self.ff_network(noramlised_resid_conn)
    resid_conn2 = noramlised_resid_conn+ff_network
    normalised_resid_conn = self.layer_norm2(resid_conn2)

    crossattention_weights = self.cross_multiheadattention(normalised_resid_conn, encoder_output, encoder_output, padding_mask=self.padding_mask)
    resid_conn = normalised_resid_conn+crossattention_weights
    noramlised_resid_conn = self.layer_norm3(resid_conn)
    ff_network = self.ff_network2(noramlised_resid_conn)
    resid_conn2 = noramlised_resid_conn+ff_network
    normalised_resid_conn = self.layer_norm4(resid_conn2)
    return self.output_layer(normalised_resid_conn)


class MyDecoder(nn.Module):
  def __init__(self, vocab_size, seq_len, d_model, num_heads, ff_hidden, num_decoder_layers):
    super().__init__()
    self.decoder_layers = nn.ModuleList([MyDecoderLayer(vocab_size, seq_len, d_model, num_heads, ff_hidden) for _ in range(num_decoder_layers)])

  def forward(self, input, padding_mask):
    return self.decoder_layers(input, padding_mask)


class MyTransformer(nn.Module):
  def __init__(self, vocab_size, seq_len, d_model, num_heads, ff_hidden, num_encoder_layers, num_decoder_layers):
    super().__init__()
    self.encoder = MyEncoder(vocab_size, seq_len, d_model, num_heads, ff_hidden, num_encoder_layers)
    self.decoder = MyDecoder(vocab_size, seq_len, d_model, num_heads, ff_hidden, num_decoder_layers)

  def forward(self, source, target, padding_mask, max_seq_len):
    encoder_output = self.encoder(source, padding_mask)
    decoder_output = self.decoder(target, encoder_output, padding_mask) #raw logits, F.CrossEntropy already does softmax

    return decoder_output

In [5]:
def loss_function(target, model_output):
  # target is (batch_size, seq_len), no embedding, essnetially  label per sample (sample being one of seq_len*batch), ie all positions
  # model output is (batch_size, seq_len, vocab_size)
  vocab_size = model_output.size(-1)
  return F.cross_entropy(target.reshape(-1,), model_output.reshape(-1, vocab_size))


def generate(model, max_seq_len):
  # inference from ready trained model with greedy appraoch
  return

# Training on Huggingface data

## Laoding in WMT16 from Huggingface interface

In [6]:
import torch
torch.cuda.is_available()

True

In [11]:
from datasets import load_dataset
from transformers import AutoTokenizer

In [None]:
checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
dataset = load_dataset('wmt16','de-en', split="train")


In [None]:
def preprocess_function(dataset_sample, tokenizer=tokenizer):
  source_seq = [sample["de"] for sample in dataset_sample["translation"]]
  target_seq = [sample["en"] for sample in dataset_sample["translation"]]
  output_tokens = tokenizer(source_seq, target_seq, max_length=128, truncation=True, padding=True)
preprocessed_dataset = dataset.map(preprocess_function, batched=True)

In [35]:
from torch.optim import Adam
from torch.utils.data import DataLoader

train_dataloader = DataLoader(preprocessed_dataset["train"], batch_size = 14)
config_params = {
    "d_model": [512],
    "num_heads": [8],
    "ff_hidden": [1024],
    "num_encoder_layers" : [2, 4],
    "num_decoder_layers": [2,4],
    "learning_rate": [1e-4]
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vocab_size = tokenizer.vocab_size
seq_len = 128 # as set in the preprocess fct
num_epochs = 10

for d_model in config_params["d_model"]:
  for num_heads in config_params["num_heads"]:
    for ff_hidden in config_params["ff_hidden"]:
      for num_encoder_layers in config_params["num_encoder_layers"]:
        for num_decoder_layers in config_params["num_decoder_layers"]:
          for lr in config_params["learning_rate"]:
            transformer_model = MyTransformer(vocab_size, seq_len, d_model, num_heads, ff_hidden, num_encoder_layers, num_decoder_layers)
            transformer_model = transformer_model.to(device)
            optim = Adam(transformer_model.parameters(), lr=lr)

            for epoch in range(num_epochs):
              transformer_model.train()
              total_loss = 0
              for batch in dataset:
                source_inpt, source_padding, target, target_padding = batch
                source_inpt = source_inpt.to(device)
                source_padding = source_padding.to(device)

                optim.zero_grad() # reset the gradients
                logits = transformer_model(source_inpt, source_padding, seq_len)
                loss = loss_function(target, logits)
                loss.backward()
                optim.step()
                total_loss += loss.item()
              if epoch % 2 ==0:
                print(total_loss//dataset.size(0)) # loss per sample in batch

ValueError: not enough values to unpack (expected 4, got 1)