In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import math

In [2]:
@dataclass
class Config:
  src_vocab_size: int = 32005
  tgt_vocab_size: int = 32005
  src_seq_len: int = 1024
  tgt_seq_len: int = 1024
  d_model: int = 512
  n_heads: int = 8
  n_layers: int = 6
  dropout: float = 0.1
  d_ff: int = 2048

## InputEmbeddings

In [3]:
class InputEmbeddings(nn.Module):
  def __init__(self, d_model: int, vocab_size: int):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    # (B, T) -> (B, T, d_model)
    return self.embedding(x)

## PositionalEncoding

In [4]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model: int, seq_len: int, dropout: float):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    pe = torch.zeros(seq_len, d_model)    # (T, d_model)

    # create a vector of shape seq_len
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)   # (T, 1)

    # create a vector of shape d_model
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).unsqueeze(0)   # (1, d_model/2)

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0)    # (1, T, d_model)
    self.register_buffer('pe', pe)


  def forward(self, x):
    # (B, T, d_model) -> (B, T, d_model)
    # x: (B, T, d_model)
    x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)  # (B, T, d_model)
    return self.dropout(x)

## LayerNormalization

In [5]:
class LayerNormalization(nn.Module):
  def __init__(self, d_model: int, eps: float=1e-6):
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(d_model))   # learnable parameter
    self.bias = nn.Parameter(torch.zeros(d_model))   # learnable parameter

  def forward(self, x):
    # x: (B, T, d_model)
    mean = x.mean(dim=-1, keepdim=True)    # (B, T, 1)
    std = x.std(dim=-1, keepdim=True)      # (B, T, 1)
    x = self.alpha * (x - mean) / (std + self.eps) + self.bias
    return x

## ResidualConnection

In [6]:
class ResidualConnection(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    d_model = config.d_model
    dropout = config.dropout

    self.norm = LayerNormalization(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, sublayer):
    # (B, T, d_model) -> (B, T, d_model)
    x = x + self.dropout(sublayer(self.norm(x)))
    return x

## FeedForwardBlock

In [7]:
class FeedForwardBlock(nn.Module):
  def __init__(self, config: Config):
    super().__init__()
    d_model = config.d_model
    d_ff = config.d_ff
    dropout = config.dropout

    self.up_proj = nn.Linear(d_model, d_ff)
    self.down_proj = nn.Linear(d_ff, d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(dropout)


  def forward(self, x):
    # (B, T, d_model) -> (B, T, d_ff) -> (B, T, d_model)
    x = self.dropout(self.relu(self.up_proj(x)))
    x = self.down_proj(x)
    return x

## MultiHeadAttentionBlock

In [8]:
class MultiHeadAttentionBlock(nn.Module):
  def __init__(self, config: Config):
    super().__init__()

    self.d_model = config.d_model
    self.n_heads = config.n_heads
    self.dropout = config.dropout

    assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"

    self.head_size = self.d_model // self.n_heads
    self.w_q = nn.Linear(self.d_model, self.d_model, bias=False)
    self.w_k = nn.Linear(self.d_model, self.d_model, bias=False)
    self.w_v = nn.Linear(self.d_model, self.d_model, bias=False)
    self.w_o = nn.Linear(self.d_model, self.d_model, bias=False)
    self.dropout = nn.Dropout(self.dropout)

  def forward(self, q, k, v, mask):
    # return tensor of shape: (B, T, d_model)
    # q, k, v : (B, T, d_model)
    B, T = q.shape[0], q.shape[1]
    query = self.w_q(q)   # (B, T, d_model)
    key = self.w_k(k)     # (B, T, d_model)
    value = self.w_v(q)   # (B, T, d_model)

    query = query.view(B, T, self.n_heads, self.head_size)  # (B, T, nH, head_size)
    key = key.view(B, T, self.n_heads, self.head_size)  # (B, T, nH, head_size)
    value = value.view(B, T, self.n_heads, self.head_size)  # (B, T, nH, head_size)

    query = query.transpose(1, 2)    # (B, nH, T, head_size)
    key = key.transpose(1, 2)    # (B, nH, T, head_size)
    value = value.transpose(1, 2)    # (B, nH, T, head_size)

    attention_scores = query @ key.transpose(2, 3) / math.sqrt(self.head_size)       # (B, nh, T, T)

    if mask is not None:
      attention_scores = attention_scores.masked_fill_(mask == 0, -1e9)

    if self.dropout is not None:
      attention_scores = self.dropout(attention_scores)

    attention_scores = attention_scores.softmax(dim=-1)    # (B, nH, T, T)

    z = attention_scores @ value    # (B, nH, T, head_size)

    z = z.transpose(1, 2).contiguous().view(B, T, self.d_model)

    return self.w_o(z)

## EncoderBlock

In [9]:
class EncoderBlock(nn.Module):
  def __init__(self, config: Config, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(config) for _ in range(2)])

  def forward(self, x, mask):
    # x: (B, T, d_model)
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, mask))
    x = self.residual_connections[0](x, self.feed_forward_block)
    return x

## DecoderBlock

In [10]:
class DecoderBlock(nn.Module):
  def __init__(self, config: Config, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(config) for _ in range(3)])

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
    x = self.residual_connections[0](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
    x = self.residual_connections[1](x, self.feed_forward_block)
    return x


## Encoder

In [11]:
class Encoder(nn.Module):
  def __init__(self, config: Config, layers: nn.ModuleList):
    super().__init__()
    d_model = config.d_model
    self.layers = layers
    self.norm = LayerNormalization(d_model)

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

## Decoder

In [12]:
class Decoder(nn.Module):
  def __init__(self, config: Config, layers: nn.ModuleList):
    super().__init__()
    d_model = config.d_model
    self.norm = LayerNormalization(d_model)
    self.layers = layers

  def forward(self, x, encoder_output, src_mask, tgt_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, src_mask, tgt_mask)
    return self.norm(x)

## ProjectionLayer

In [13]:
class ProjectionLayer(nn.Module):
  def __init__(self, d_model, vocab_size):
    super().__init__()
    self.proj = nn.Linear(d_model, vocab_size)

  def forward(self, x):
    # x: (B, T, d_model) -> (B, T, vocab_size)
    return self.proj(x)

## Transformer

In [14]:
class Transformer(nn.Module):
  def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.tgt_embed = tgt_embed
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projection_layer = projection_layer

  def encode(self, x, mask):
    # x: (B, T)
    x = self.src_embed(x)
    x = self.src_pos(x)
    x = self.encoder(x, mask)   # (B, T, d_model)
    return x

  def decode(self, tgt, encoder_output, src_mask, tgt_mask):
    # tgt: (B, T)
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

  def project(self, x):
    return self.projection_layer(x)



### function to build transformer model

In [15]:
def build_transformer(config: Config):
  src_embed = InputEmbeddings(config.d_model, config.src_vocab_size)
  tgt_embed = InputEmbeddings(config.d_model, config.tgt_vocab_size)

  src_pos = PositionalEncoding(config.d_model, config.src_seq_len, config.dropout)
  tgt_pos = PositionalEncoding(config.d_model, config.tgt_seq_len, config.dropout)

  encoder_blocks = []
  for _ in range(config.n_layers):
    self_attention_block = MultiHeadAttentionBlock(config)
    feed_forward_block = FeedForwardBlock(config)
    encoder_block = EncoderBlock(config, self_attention_block, feed_forward_block)
    encoder_blocks.append(encoder_block)

  decoder_blocks = []
  for _ in range(config.n_layers):
    self_attention_block = MultiHeadAttentionBlock(config)
    cross_attention_block = MultiHeadAttentionBlock(config)
    feed_forward_block = FeedForwardBlock(config)
    decoder_block = DecoderBlock(config, self_attention_block, cross_attention_block, feed_forward_block)
    decoder_blocks.append(decoder_block)

  encoder = Encoder(config, nn.ModuleList(encoder_blocks))
  decoder = Decoder(config, nn.ModuleList(decoder_blocks))

  projection_layer = ProjectionLayer(config.d_model, config.tgt_vocab_size)

  transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

  for p in transformer.parameters():
    if p.dim() > 1:
      nn.init.xavier_uniform_(p)

  return transformer

In [16]:
config = Config()
config

Config(src_vocab_size=32005, tgt_vocab_size=32005, src_seq_len=1024, tgt_seq_len=1024, d_model=512, n_heads=8, n_layers=6, dropout=0.1, d_ff=2048)

In [17]:
model = build_transformer(config)

In [18]:
model

Transformer(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (self_attention_block): MultiHeadAttentionBlock(
          (w_q): Linear(in_features=512, out_features=512, bias=False)
          (w_k): Linear(in_features=512, out_features=512, bias=False)
          (w_v): Linear(in_features=512, out_features=512, bias=False)
          (w_o): Linear(in_features=512, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward_block): FeedForwardBlock(
          (up_proj): Linear(in_features=512, out_features=2048, bias=True)
          (down_proj): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (residual_connections): ModuleList(
          (0-1): 2 x ResidualConnection(
            (norm): LayerNormalization()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (

## Prepare Dataset

In [19]:
!pip install datasets -q

In [20]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, random_split

In [21]:
dataset = load_dataset('Helsinki-NLP/opus_books', 'en-fr', split='train')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [22]:
import random
idx = random.randint(0, len(dataset) - 1)

print('English: ')
print(dataset[idx]['translation']['en'])
print('French: ')
print(dataset[idx]['translation']['fr'])

English: 
"Let her be then, old man! It's the Piolaine young lady," cried Maheude to the grandfather, recognizing Cécile, whose veil had been torn off by one of the women.
French: 
—Laissez-la donc, vieux! c'est la demoiselle de la Piolaine! cria la Maheude au grand-pere, en reconnaissant Cécile, dont une femme avait déchiré la voilette.


### Tokenizer

Let's use llama tokenizer

In [23]:
import os
from huggingface_hub import login
from getpass import getpass
hf_token = getpass("Enter your Hugging Face token: ")
login(hf_token)
print("Logged in successfully!")

Enter your Hugging Face token: ··········
Logged in successfully!


In [24]:
from transformers import AutoTokenizer

In [25]:
tokenizer_id = 'meta-llama/Llama-2-7b-hf'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer

LlamaTokenizerFast(name_or_path='meta-llama/Llama-2-7b-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [26]:
tokenizer.add_special_tokens({'pad_token': '<pad>'})

1

In [27]:
tokenizer

LlamaTokenizerFast(name_or_path='meta-llama/Llama-2-7b-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [28]:
tokenizer.encode('Hello, how are you?')[1:]

[15043, 29892, 920, 526, 366, 29973]

In [29]:
tokenizer.decode(1), tokenizer.decode(2), tokenizer.decode(32000)

('<s>', '</s>', '<pad>')

### Create DataLoader

In [30]:
class BilingualDataset(Dataset):
  def __init__(self, ds, tokenizer, seq_len):
    super().__init__()
    self.seq_len = seq_len
    self.ds = ds
    self.tokenizer = tokenizer
    self.bos_token_id = tokenizer.bos_token_id
    self.eos_token_id = tokenizer.eos_token_id
    self.pad_token_id = tokenizer.pad_token_id

  def __len__(self):
    return len(self.ds)

  def __getitem__(self, idx):
    src_tgt_pair = self.ds[idx]
    src_text = src_tgt_pair['translation']['en']
    tgt_text = src_tgt_pair['translation']['fr']

    # tokenize texts
    src_tokens = self.tokenizer.encode(src_text)[1:]   # [1, .....]
    tgt_tokens = self.tokenizer.encode(tgt_text)[1:]   # [1, .....]

    src_num_padding_tokens = self.seq_len - len(src_tokens) - 2     # -2 because of sos_token, eos_token
    tgt_num_padding_tokens = self.seq_len - len(tgt_tokens) - 1     # we only add sos token to decoder side

    if src_num_padding_tokens < 0 or tgt_num_padding_tokens < 0:
      raise ValueError(f"Sentence with id: {idx} is too long")

    # Add </s> token, <s> token is already added
    encoder_input = torch.cat(
        [
            torch.tensor([self.bos_token_id]),
            torch.tensor(src_tokens, dtype=torch.int64),
            torch.tensor([self.eos_token_id]),
            torch.tensor([self.pad_token_id] * src_num_padding_tokens, dtype=torch.int64),

        ],
        dim=0
    )

    # add only <s> to decoder side
    decoder_input = torch.cat(
        [
            torch.tensor([self.bos_token_id]),
            torch.tensor(tgt_tokens, dtype=torch.int64),
            torch.tensor([self.pad_token_id] * tgt_num_padding_tokens, dtype=torch.int64),
        ],
        dim=0
    )

    # add only </s> to the label
    label = torch.cat(
        [
            torch.tensor(tgt_tokens, dtype=torch.int64),
            torch.tensor([self.eos_token_id]),
            torch.tensor([self.pad_token_id] * tgt_num_padding_tokens, dtype=torch.int64),
        ],
        dim=0
    )

    assert encoder_input.size(0) == self.seq_len
    assert decoder_input.size(0) == self.seq_len
    assert label.size(0) == self.seq_len


    return {
        "encoder_input": encoder_input,   # (seq_len)
        "decoder_input": decoder_input,   # (seq_len)
        "encoder_mask": (encoder_input != self.pad_token_id).unsqueeze(0).unsqueeze(0).int(),    # (1, 1, seq_len)
        "decoder_mask": (decoder_input != self.pad_token_id).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),    # (1, 1, seq_len) & (1, seq_len, seq_len)
        "label": label,  # (seq_len)
        "src_text": src_text,
        "tgt_text": tgt_text,
    }

def causal_mask(size):
  mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
  return mask == 0


In [31]:
BATCH_SIZE = 4
train_ds_size = int(0.9 * len(dataset))
val_ds_size = len(dataset) - train_ds_size

train_ds, val_ds = random_split(dataset, [train_ds_size, val_ds_size])

train_ds = BilingualDataset(train_ds, tokenizer, seq_len=1024)
val_ds = BilingualDataset(val_ds, tokenizer, seq_len=1024)

train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1)

In [32]:
data = next(iter(train_dataloader))

In [33]:
data.keys()

dict_keys(['encoder_input', 'decoder_input', 'encoder_mask', 'decoder_mask', 'label', 'src_text', 'tgt_text'])

In [34]:
print('encoder_input: ', data['encoder_input'].shape)
print('decoder_input: ', data['decoder_input'].shape)
print('encoder_mask: ', data['encoder_mask'].shape)
print('decoder_mask: ', data['decoder_mask'].shape)
print('label: ', data['label'].shape)

encoder_input:  torch.Size([4, 1024])
decoder_input:  torch.Size([4, 1024])
encoder_mask:  torch.Size([4, 1, 1, 1024])
decoder_mask:  torch.Size([4, 1, 1024, 1024])
label:  torch.Size([4, 1024])


## Training

In [35]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [36]:
model.to(device)

Transformer(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (self_attention_block): MultiHeadAttentionBlock(
          (w_q): Linear(in_features=512, out_features=512, bias=False)
          (w_k): Linear(in_features=512, out_features=512, bias=False)
          (w_v): Linear(in_features=512, out_features=512, bias=False)
          (w_o): Linear(in_features=512, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward_block): FeedForwardBlock(
          (up_proj): Linear(in_features=512, out_features=2048, bias=True)
          (down_proj): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (residual_connections): ModuleList(
          (0-1): 2 x ResidualConnection(
            (norm): LayerNormalization()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (

In [37]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [38]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id).to(device)

##### Sneak Peak into crossentropyloss

In [39]:
B = 2
Classes = 10
pred = torch.rand(B, Classes)
pred = pred.softmax(dim=-1)    # (B, Classes)
label = torch.tensor([2, 5], dtype=torch.int64)   # (B)
loss = criterion(pred, label).to(device)
loss

tensor(2.3052, device='cuda:0')

### Training epoch

In [40]:
from tqdm import tqdm

In [None]:
EPOCHS = 10
for epoch in range(EPOCHS):
  model.train()
  batch_iterator = tqdm(train_dataloader, desc=f"Processing epoch {epoch:02d}")
  for batch in batch_iterator:
    encoder_input = batch['encoder_input'].to(device)    # (B, seq_len)
    decoder_input = batch['decoder_input'].to(device)    # (B, seq_len)
    encoder_mask = batch['encoder_mask'].to(device)      # (B, 1, 1, seq_len)
    decoder_mask = batch['decoder_mask'].to(device)      # (B, 1, seq_len, seq_len)

    encoder_output = model.encode(encoder_input, encoder_mask)      # (B, seq_len, d_model)
    decoder_output = model.decode(decoder_input, encoder_output, encoder_mask, decoder_mask)     # (B, seq_len, d_model)
    proj_output = model.project(decoder_output)     # (B, seq_len, tgt_vocab_size)

    label = batch['label'].to(device)    # (B, seq_len)
    loss = criterion(proj_output.view(-1, proj_output.shape[-1]), label.view(-1))
    batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

    # compute gradients
    loss.backward()

    # update the weights
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)


Processing epoch 00:   0%|          | 60/28594 [00:49<6:29:19,  1.22it/s, loss=7.712]