# **Build your own Large Language Model (LLM) from scratch using PyTorch**
**A Step-by-Step guide to build and train a LLM called MalayGPT. This model translates texts from English to Malay language.**

<a target="_blank" href="https://colab.research.google.com/github/tamangmilan/llm_from_scratch/medium2_build_llm_from_scratch_using_pytorch.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
!pip install datasets
!pip install tokenizers

In [None]:
#Step1: Load the data and separate into train, validation and test data

import os
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from datasets import load_dataset
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

raw_train_dataset = load_dataset("Helsinki-NLP/opus-100", "en-ms", split='train')
raw_validation_dataset = load_dataset("Helsinki-NLP/opus-100", "en-ms", split='validation')
raw_test_dataset = load_dataset("Helsinki-NLP/opus-100", "en-ms", split='test')

raw_train_dataset = raw_train_dataset[:500000]

os.mkdir("./dataset-en")
os.mkdir("./dataset-my")
os.mkdir("./malaygpt")
os.mkdir("./tokenizer_en")
os.mkdir("./tokenizer_my")

dataset_en = []
dataset_my = []
file_count = 1

for data in tqdm(raw_train_dataset["translation"]):
    dataset_en.append(data["en"].replace('\n', " "))
    dataset_my.append(data["ms"].replace('\n', " "))
    if len(dataset_en) == 50000:
        with open(f'./dataset-en/file{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(dataset_en))
            dataset_en = []

        with open(f'./dataset-my/file{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(dataset_my))
            dataset_my = []
        file_count += 1

In [None]:
#Step2: Create tokenizers

from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

path_en = [str(file) for file in Path('./dataset-en').glob("**/*.txt")]
path_my = [str(file) for file in Path('./dataset-my').glob("**/*.txt")]

# Create Source Tokenizer - English
tokenizer_en = Tokenizer(BPE(unk_token="[UNK]"))
trainer_en = BpeTrainer(min_frequency=2, special_tokens=["[PAD]","[UNK]","[CLS]", "[SEP]", "[MASK]"])
# We’ll also need to add a pre-tokenizer to split our input into words as without a pre-tokenizer, we might get tokens that overlap several words: for instance we could get a "there is" token since those two words often appear next to each other.
# Using a pre-tokenizer will ensure no token is bigger than a word returned by the pre-tokenizer.
tokenizer_en.pre_tokenizer = Whitespace()
tokenizer_en.train(files=path_en, trainer=trainer_en)
tokenizer_en.save("./tokenizer_en/tokenizer_en.json")

# Create Target Tokenizer - Malay
tokenizer_my = Tokenizer(BPE(unk_token="[UNK]"))
trainer_my = BpeTrainer(min_frequency=2, special_tokens=["[PAD]","[UNK]","[CLS]", "[SEP]", "[MASK]"])
tokenizer_my.pre_tokenizer = Whitespace()
tokenizer_my.train(files=path_my, trainer=trainer_my)
tokenizer_my.save("./tokenizer_my/tokenizer_my.json")

tokenizer_en = Tokenizer.from_file("./tokenizer_en/tokenizer_en.json")
tokenizer_my = Tokenizer.from_file("./tokenizer_my/tokenizer_my.json")

source_vocab_size = tokenizer_en.get_vocab_size()
target_vocab_size = tokenizer_my.get_vocab_size()

CLS_ID = torch.tensor([tokenizer_my.token_to_id("[CLS]")], dtype=torch.int64).to(device)
SEP_ID = torch.tensor([tokenizer_my.token_to_id("[SEP]")], dtype=torch.int64).to(device)
PAD_ID = torch.tensor([tokenizer_my.token_to_id("[PAD]")], dtype=torch.int64).to(device)

In [None]:
# Step3: Prepare dataset and dataloader

# Transfor raw dataset to the encoded dataset that can be processed by the model
class EncodeDataset(Dataset):
    def __init__(self, raw_dataset, max_seq_len):
        super().__init__()
        self.raw_dataset = raw_dataset
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.raw_dataset)

    def __getitem__(self, index):

        # fetching the single data for the given index value that consist of both english and malay language.
        raw_text = self.raw_dataset[index]

        # separating text by source and target lanaguage which will be later used for encoding.
        source_text = raw_text["en"]
        target_text = raw_text["ms"]

        # Encoding source text with with english tokenizer and target text with malay tokenizer
        source_text_encoded = torch.tensor(tokenizer_en.encode(source_text).ids, dtype = torch.int64).to(device)    # source_text_encode(batch, seq_len)
        target_text_encoded = torch.tensor(tokenizer_my.encode(target_text).ids, dtype = torch.int64).to(device)

        # Convert the CLS, SEP and PAD tokens to their corresponding index id in vocabulary using tokenizer [the id would be same with either tokeniers]
        CLS_ID = torch.tensor([tokenizer_my.token_to_id("[CLS]")], dtype=torch.int64).to(device)
        SEP_ID = torch.tensor([tokenizer_my.token_to_id("[SEP]")], dtype=torch.int64).to(device)
        PAD_ID = torch.tensor([tokenizer_my.token_to_id("[PAD]")], dtype=torch.int64).to(device)

        # To train the model, the sequence lenth of each input should be equal max seq length. Hence additional number of padding will be added to the input sequence if the lenth is not equal to the max seq length.
        num_source_padding = self.max_seq_len - len(source_text_encoded) - 2
        num_target_padding = self.max_seq_len - len(target_text_encoded) - 1

        encoder_padding = torch.tensor([PAD_ID] * num_source_padding, dtype = torch.int64).to(device)
        decoder_padding = torch.tensor([PAD_ID] * num_target_padding, dtype = torch.int64).to(device)

        # encoder_input has the first token as start of senstence - CLS_ID, followed by source encoding which is then followed by the end of sentence token - SEP.
        # To reach the required max_seq_len, addition PAD token will be added at the end.
        encoder_input = torch.cat([CLS_ID, source_text_encoded, SEP_ID, encoder_padding]).to(device)

        # decoder_input has the first token as start of senstence - CLS_ID, followed by target encoding.
        # To reach the required max_seq_len, addition PAD token will be added at the end. There is no end of sentence token - SEP in decoder input.
        decoder_input = torch.cat([CLS_ID, target_text_encoded, decoder_padding ]).to(device)

        # target_label is required for the loss calculation during training to compare between the predicted and target label.
        # target_label has the first token as target encoding followed by actual target encoding. There is no start of sentence token - CLS in target label.
        # To reach the required max_seq_len, addition PAD token will be added at the end.
        target_label = torch.cat([target_text_encoded,SEP_ID,decoder_padding]).to(device)

        # Since we've added extra padding token with input encoding, we don't want this token to be trained by model.
        # So, we'll use encoder mask to nullify the padding value prior to producing output of self attention in encoder block
        encoder_mask = (encoder_input != PAD_ID).unsqueeze(0).unsqueeze(0).int().to(device)

        # We don't want any token to get influence the future token during the decoding stage. Hence, Causal mask is being implemented during masked multihead attention to handle this.
        decoder_mask = (decoder_input != PAD_ID).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)).to(device)

        return {
            'encoder_input': encoder_input,
            'decoder_input': decoder_input,
            'target_label': target_label,
            'encoder_mask': encoder_mask,
            'decoder_mask': decoder_mask,
            'source_text': source_text,
            'target_text': target_text
        }

# Causal mask will make sure any token that comes after the current token will be masked meaning the value will be replaced by -infinity that will be converted to zero or neearly zero after softmax operation. Hence the model will just ignore these value or willn't be able to learn anything.
def causal_mask(size):
        # Creating a square matrix of dimensions 'size x size' filled with ones
        mask = torch.triu(torch.ones(1, size, size), diagonal = 1).type(torch.int)
        return mask == 0

In [None]:
# to calculate the max sequence lenth in the entire training dataset for the source and target dataset
max_seq_len_source = 0
max_seq_len_target = 0

for data in raw_train_dataset["translation"]:
    enc_ids = tokenizer_en.encode(data["en"]).ids
    dec_ids = tokenizer_my.encode(data["ms"]).ids
    max_seq_len_source = max(max_seq_len_source, len(enc_ids))
    max_seq_len_target = max(max_seq_len_target, len(dec_ids))

print(f'max_seqlen_source: {max_seq_len_source}')   #530
print(f'max_seqlen_target: {max_seq_len_target}')   #526

# to make it standard for our training we'll just take max_seq_len_source and add 20 to cover the additional tokens such as PAD, CLS, SEP
max_seq_len = 550

# Instantiate the EncodeRawDataset class and create the encoded train dateset and validation dataset
train_dataset = EncodeDataset(raw_train_dataset["translation"], max_seq_len)
val_dataset = EncodeDataset(raw_validation_dataset["translation"], max_seq_len)

# creating dataloader wrapper for both training and validation dataset. this dataloader will be used later in model training and validation.
train_dataloader = DataLoader(train_dataset, batch_size = 10, shuffle = True, generator=torch.Generator(device='cuda'))
val_dataloader = DataLoader(val_dataset, batch_size = 1, shuffle = True, generator=torch.Generator(device='cuda'))

In [None]:
# Step 4: Input embedding and positional encoding

class EmbeddingLayer(nn.Module):

    def __init__(self, vocab_size: int, d_model: int):
        super().__init__()
        self.d_model = d_model
        # using pytorch models embedding layer to map token id to embeeding vector which has the shape of (vocab_size, d_model)
        # The vocab_size is the vocabulary size of the training data created by tokenizer in step 2
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, input):
        # In addition of giving input to the embedding, the extra multiplication by square root of d_model is to normalize the embedding layer output
        embedding_output = self.embedding(input) * math.sqrt(self.d_model)
        return embedding_output


class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_len: int, d_model: int, dropout_rate: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        pe = torch.zeros(max_seq_len, d_model)

        pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float()) * (-math.log(10000)/d_model)

        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)

        # since we're expecting the input sentenses in batches so the extra dimension to cater batch number needs to be added in 0 postion
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, input_embdding):
        input_embdding = input_embdding + (self.pe[:, :input_embdding.shape[1], :]).requires_grad_(False)   # to prevent from calculating gradient
        return self.dropout(input_embdding)

In [None]:
# Step 5: Multihead Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout_rate: float):
        super().__init__()
        # Defining dropout to prevent overfitting
        self.dropout = nn.Dropout(dropout_rate)

        # Weight matrix are defined which are all learnable parameters
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.num_heads = num_heads
        assert d_model % num_heads == 0, "d_model must be divisible by number of heads"

        # d_k is the new dimension of each self attention heads
        self.d_k = d_model // num_heads

    def forward(self, q, k, v, encoder_mask=None):

        # Please note that we'll be training our model with not just a single sequence but rather batches of sequence, hence we'll include batch_size in the shape
        # query, Key and value are calculated by matrix multiplication of corresponding weights with the input embeddings
        # Change of shape: q(batch_size, seq_len, d_model) @ W_q(d_model, d_model) => query(batch_size, seq_len, d_model) [same goes to key and value]
        query = self.W_q(q)
        key = self.W_k(k)
        value = self.W_v(v)

        # Dividing query, key and value into number of heads, hence new dimenstion will be d_k.
        # Change of shape: query(batch_size, seq_len, d_model) => query(batch_size, seq_len, num_heads, d_k) -> query(batch_size,num_heads, seq_len,d_k) [same goes to key and value]
        query = query.view(query.shape[0], query.shape[1], self.num_heads ,self.d_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.num_heads ,self.d_k).transpose(1,2)
        value = value.view(value.shape[0], value.shape[1], self.num_heads ,self.d_k).transpose(1,2)

        # :: SELF ATTENTION BLOCK STARTS ::

        # Attention score is calculated to find the similarity or relation of query with key of itself and all other embedding in the sequence
        #  Change of shape: query(batch_size,num_heads, seq_len,d_k) @ key(batch_size,num_heads, seq_len,d_k) => attention_score(batch_size,num_heads, seq_len,seq_len)
        attention_score = (query @ key.transpose(-2,-1))/math.sqrt(self.d_k)

        # If mask is provided the attention score needs to modify as per the mask value. Refer to the details in point no 4.
        if encoder_mask is not None:
            attention_score = attention_score.masked_fill(encoder_mask==0, -1e9)

        # Software operation calculates the probability distribution among all the attention scores. This will determine which embedding is more similar to the given query embedding and assign the attention weight accordingly.
        # Change of shape: same as attention_score
        attention_weight = torch.softmax(attention_score, dim=-1)

        if self.dropout is not None:
            attention_weight = self.dropout(attention_weight)

        # Final step of Self attention block is to matrix multiplication of attention_weight with value embedding.
        # Change of shape: attention_score(batch_size,num_heads, seq_len,seq_len) @  value(batch_size,num_heads, seq_len,d_k) => attention_output(batch_size,num_heads, seq_len,d_k)
        attention_output = attention_score @ value

        # :: SELF ATTENTION BLOCK ENDS ::

        # Now, all the heads will be concated back to for a single head
        # Change of shape:attention_output(batch_size,num_heads, seq_len,d_k) => attention_output(batch_size,seq_len,num_heads,d_k) => attention_output(batch_size,seq_len,d_model)
        attention_output = attention_output.transpose(1,2).contiguous().view(attention_output.shape[0], -1, self.num_heads * self.d_k)

        # Finally attention_output is matrix multiplied with output weight matrix to give the final Multi-Head attention output.
        # The shape of the multihead_output is same as the embedding input
        # Change of shape: attention_output(batch_size,seq_len,d_model) @ W_o(d_model, d_model) => multihead_output(batch_size, seq_len, d_model)
        multihead_output = self.W_o(attention_output)

        return multihead_output

In [None]:
# Step 6: Feedfoward Network, Layer Normalization and AddAndNorm
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
        super().__init__()

        self.layer_1 = nn.Linear(d_model, d_ff)
        self.activation_1 = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_2 = nn.Linear(d_ff, d_model)

    def forward(self, input):
        return self.layer_2(self.dropout(self.activation_1(self.layer_1(input))))

class LayerNorm(nn.Module):
    def __init__(self, eps: float = 1e-5):
        super().__init__()
        # epsilon is a very small value and is plays an important role to avoid division by zero problem
        self.eps = eps

        #Extra learning parameters gamma and beta are introduced to scale and shift the embedding value as the network needed.
        self.gamma = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, input):
        mean = input.mean(dim=-1, keepdim=True)
        std = input.std(dim=-1, keepdim=True)

        return self.gamma * ((input - mean)/(std + self.eps)) + self.beta


class AddAndNorm(nn.Module):
    def __init__(self, dropout_rate: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = LayerNorm()

    def forward(self, input, sub_layer):
        return input + self.dropout(sub_layer(self.layer_norm(input)))

In [None]:
#Step 7: Encoder block and Encoder

class EncoderBlock(nn.Module):
    def __init__(self, multihead_attention: MultiHeadAttention, feed_forward: FeedForward, dropout_rate: float):
        super().__init__()
        self.multihead_attention = multihead_attention
        self.feed_forward = feed_forward
        self.add_and_norm_list = nn.ModuleList([AddAndNorm(dropout_rate) for _ in range(2)])

    def forward(self, encoder_input, encoder_mask):
        # First AddAndNorm unit taking encoder input from skip connection and adding it with the output of MultiHead attention block
        encoder_input = self.add_and_norm_list[0](encoder_input, lambda encoder_input: self.multihead_attention(encoder_input, encoder_input, encoder_input, encoder_mask))

        # Second AddAndNorm unit taking output of MultiHead attention block from skip connection and adding it with the output of Feedforward layer
        encoder_input = self.add_and_norm_list[1](encoder_input, self.feed_forward)

        return encoder_input

class Encoder(nn.Module):
    def __init__(self, encoderblocklist: nn.ModuleList):
        super().__init__()
        # Encoder class initialized by taking encoderblock list
        self.encoderblocklist = encoderblocklist
        self.layer_norm = LayerNorm()

    def forward(self, encoder_input, encoder_mask):
        # Looping through all the encoder block - 6 times
        for encoderblock in self.encoderblocklist:
            encoder_input = encoderblock(encoder_input, encoder_mask)
        # Normalize the final encoder block output and return. This encoder output will be used later on as key and value for the cross attention in decoder block
        encoder_output = self.layer_norm(encoder_input)
        return encoder_output

In [None]:
#Step 8: Decoder block and decoder and the projection

class DecoderBlock(nn.Module):
    def __init__(self, masked_multihead_attention: MultiHeadAttention,multihead_attention: MultiHeadAttention, feed_forward: FeedForward, dropout_rate: float):
        super().__init__()
        self.masked_multihead_attention = masked_multihead_attention
        self.multihead_attention = multihead_attention
        self.feed_forward = feed_forward
        self.add_and_norm_list = nn.ModuleList([AddAndNorm(dropout_rate) for _ in range(3)])

    def forward(self, decoder_input, decoder_mask, encoder_output, encoder_mask):
        # First AddAndNorm unit taking decoder input from skip connection and adding it with the output of Masked Multi-Head attention block
        decoder_input = self.add_and_norm_list[0](decoder_input, lambda decoder_input: self.masked_multihead_attention(decoder_input,decoder_input, decoder_input, decoder_mask))
        # Second AddAndNorm unit taking output of Masked Multi-Head attention block from skip connection and adding it with the output of MultiHead attention block
        decoder_input = self.add_and_norm_list[1](decoder_input, lambda decoder_input: self.multihead_attention(decoder_input,encoder_output, encoder_output, encoder_mask))            # cross attention
        # Third AddAndNorm unit taking output of MultiHead attention block from skip connection and adding it with the output of Feedforward layer
        decoder_input = self.add_and_norm_list[2](decoder_input, self.feed_forward)
        return decoder_input

class Decoder(nn.Module):
    def __init__(self,decoderblocklist: nn.ModuleList):
        super().__init__()
        self.decoderblocklist = decoderblocklist
        self.layer_norm = LayerNorm()

    def forward(self, decoder_input, decoder_mask, encoder_output, encoder_mask):
        for decoderblock in self.decoderblocklist:
            decoder_input = decoderblock(decoder_input, decoder_mask, encoder_output, encoder_mask)

        decoder_output = self.layer_norm(decoder_input)
        return decoder_output

class ProjectionLayer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super().__init__()
        self.projection_layer = nn.Linear(d_model, vocab_size)

    def forward(self, decoder_output):
        # Projection layer first take in decoder output and feed into the linear layer of shape (d_model, vocab_size)
        #Change in shape: decoder_output(batch_size, seq_len, d_model) @ linear_layer(d_model, vocab_size) => output(batch_size, seq_len, vocab_size)
        output = self.projection_layer(decoder_output)

        # software operation to output the probability distribution over the vocabulary
        return torch.log_softmax(output, dim=-1)

In [None]:
#Step 9: Create and build Transfomer

class Transformer(nn.Module):
    def __init__(self, source_embed: EmbeddingLayer, target_embed: EmbeddingLayer, positional_encoding: PositionalEncoding, multihead_attention: MultiHeadAttention, masked_multihead_attention: MultiHeadAttention, feed_forward: FeedForward, encoder: Encoder, decoder: Decoder, projection_layer: ProjectionLayer, dropout_rate: float):
        super().__init__()
        self.source_embed = source_embed
        self.target_embed = target_embed
        self.positional_encoding = positional_encoding
        self.multihead_attention = multihead_attention
        self.masked_multihead_attention = masked_multihead_attention
        self.feed_forward = feed_forward
        self.encoder = encoder
        self.decoder = decoder
        self.projection_layer = projection_layer
        self.dropout = nn.Dropout(dropout_rate)

    def encode(self, encoder_input, encoder_mask):
        encoder_input = self.source_embed(encoder_input)
        encoder_input = self.positional_encoding(encoder_input)
        encoder_output = self.encoder(encoder_input, encoder_mask)
        return encoder_output

    def decode(self, decoder_input, decoder_mask, encoder_output, encoder_mask):
        decoder_input = self.target_embed(decoder_input)
        decoder_input = self.positional_encoding(decoder_input)
        decoder_output = self.decoder(decoder_input, decoder_mask, encoder_output, encoder_mask)
        return decoder_output

    def project(self, decoder_output):
        return self.projection_layer(decoder_output)

In [None]:
def build_model(source_vocab_size, target_vocab_size, max_seq_len=1135, d_model=512, d_ff=2048, num_heads=8, num_blocks=6, dropout_rate=0.1):
    source_embed = EmbeddingLayer(source_vocab_size, d_model)
    target_embed = EmbeddingLayer(target_vocab_size, d_model)
    positional_encoding = PositionalEncoding(max_seq_len, d_model, dropout_rate)
    multihead_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
    masked_multihead_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
    feed_forward = FeedForward(d_model, d_ff, dropout_rate)
    projection_layer = ProjectionLayer(target_vocab_size, d_model)
    encoder_block = EncoderBlock(multihead_attention, feed_forward, dropout_rate)
    decoder_block = DecoderBlock(masked_multihead_attention,multihead_attention, feed_forward, dropout_rate)

    encoderblocklist = []
    decoderblocklist = []

    for _ in range(num_blocks):
        encoderblocklist.append(encoder_block)

    for _ in range(num_blocks):
        decoderblocklist.append(decoder_block)

    encoderblocklist = nn.ModuleList(encoderblocklist)
    decoderblocklist = nn.ModuleList(decoderblocklist)

    encoder = Encoder(encoderblocklist)
    decoder = Decoder(decoderblocklist)

    model = Transformer(source_embed, target_embed, positional_encoding, multihead_attention, masked_multihead_attention,feed_forward, encoder, decoder, projection_layer, dropout_rate)

    for param in model.parameters():
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)

    return model

In [None]:
model = build_model(source_vocab_size, target_vocab_size)

In [None]:
#Step 10: Training and Validation of malayGPT

def training_model(preload_epoch=None):

    # The entire training, validation cycle will run for 20 times
    EPOCHS = 20
    initial_epoch = 0
    global_step = 0

    # Adam is one of the most commonly used optimization algorithms that hold the current state and will update the parameters based on the computed gradients.
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # If the preload_epoch is not none, that means the training will start with the weights, optimizer that has been last saved and start with preload epoch + 1
    if preload_epoch is not None:
        model_filename = f"./malaygpt/model_{preload_epoch}.pt"
        state = torch.load(model_filename)
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    # The CrossEntropyLoss loss function computes the difference between the projection output and target label.
    loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_en.token_to_id("[PAD]"), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, EPOCHS):

        # Training block
        model.train()

        # training with the training dataloder prepared in step 3
        for batch in tqdm(train_dataloader):
            encoder_input = batch['encoder_input'].to(device)   # (batch_size, seq_len)
            decoder_input = batch['decoder_input'].to(device)    # (batch_size, seq_len)
            target_label = batch['target_label'].to(device)      # (batch_size, seq_len)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(decoder_input, decoder_mask, encoder_output, encoder_mask)
            projection_output = model.project(decoder_output)

            # projection_output(batch_size, seq_len, vocab_size)
            loss = loss_fn(projection_output.view(-1, projection_output.shape[-1]), target_label.view(-1))

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            global_step += 1

        print(f'Epoch [{epoch+1}/{EPOCHS}]: Train Loss: {loss.item():.2f}')

        # save the state of the model after every epoch
        model_filename = f"./malaygpt/model_{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

        # Validation block
        model.eval()
        with torch.inference_mode():
            for batch in tqdm(val_dataloader):
                encoder_input = batch['encoder_input'].to(device)   # (batch_size, seq_len)
                encoder_mask = batch['encoder_mask'].to(device)
                source_text = batch['source_text']
                target_text = batch['target_text']

                # Computing the output of the encoder for the source sequence
                encoder_output = model.encode(encoder_input, encoder_mask)
                # for prediction task, the first token that goes in decoder input is the [CLS] token
                decoder_input = torch.empty(1,1).fill_(tokenizer_my.token_to_id('[CLS]')).type_as(encoder_input).to(device)

                # since we need to keep adding the output back to the input until the [SEP] - end token is received.
                while True:
                    # check if the max length is received
                    if decoder_input.size(1) == max_seq_len:
                        break
                    # recreate mask each time the new output is added the decoder input for next token prediction
                    decoder_mask = causal_mask(decoder_input.size(1)).type_as(encoder_mask).to(device)

                    decoder_output = model.decode(decoder_input,decoder_mask,encoder_output,encoder_mask)

                    # apply projection only to the next token
                    projection = model.project(decoder_output[:, -1])

                    # select the token with highest probablity which is a greedy search implementation
                    _, new_token = torch.max(projection, dim=1)
                    new_token = torch.empty(1,1). type_as(encoder_input).fill_(new_token.item()).to(device)

                    # add the new token back to the decoder input
                    decoder_input = torch.cat([decoder_input, new_token], dim=1)

                    # check if the new token is the end of token
                    if new_token == tokenizer_my.token_to_id('[SEP]'):
                        break
                # final decoder out is the concatinated decoder input till the end token
                decoder_output = decoder_input.sequeeze(0)
                model_predicted_text = tokenizer_my.decode(decoder_output.detach().cpu.numpy())

                print(f'SOURCE TEXT": {source_text}')
                print(f'TARGET TEXT": {target_text}')
                print(f'PREDICTED TEXT": {model_predicted_text}')

# This function runs the training and validation for 20 epochs
training_model(preload_epoch=None)

In [None]:
#Step 11: Finally testing our malayGPT on new Test Dataset

def malaygpt(user_input_text):
  model.eval()
  with torch.inference_mode():
    user_input_text = user_input_text.strip()
    user_input_text_encoded = torch.tensor(tokenizer_en.encode(user_input_text).ids, dtype = torch.int64).to(device)

    num_source_padding = max_seq_len - len(user_input_text_encoded) - 2
    encoder_padding = torch.tensor([PAD_ID] * num_source_padding, dtype = torch.int64).to(device)
    encoder_input = torch.cat([CLS_ID, user_input_text_encoded, SEP_ID, encoder_padding]).to(device)
    encoder_mask = (encoder_input != PAD_ID).unsqueeze(0).unsqueeze(0).int().to(device)

    # Computing the output of the encoder for the source sequence
    encoder_output = model.encode(encoder_input, encoder_mask)
    # for prediction task, the first token that goes in decoder input is the [CLS] token
    decoder_input = torch.empty(1,1).fill_(tokenizer_my.token_to_id('[CLS]')).type_as(encoder_input).to(device)

    # since we need to keep adding the output back to the input until the [SEP] - end token is received.
    while True:
        # check if the max length is received
        if decoder_input.size(1) == max_seq_len:
            break
        # recreate mask each time the new output is added the decoder input for next token prediction
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(encoder_mask).to(device)
        decoder_output = model.decode(decoder_input,decoder_mask,encoder_output,encoder_mask)

        # apply projection only to the next token
        projection = model.project(decoder_output[:, -1])

        # select the token with highest probablity which is a greedy search implementation
        _, new_token = torch.max(projection, dim=1)
        new_token = torch.empty(1,1). type_as(encoder_input).fill_(new_token.item()).to(device)

        # add the new token back to the decoder input
        decoder_input = torch.cat([decoder_input, new_token], dim=1)

        # check if the new token is the end of token
        if new_token == tokenizer_my.token_to_id('[SEP]'):
            break
    # final decoder out is the concatinated decoder input till the end token
    decoder_output = decoder_input.sequeeze(0)
    model_predicted_text = tokenizer_my.decode(decoder_output.detach().cpu.numpy())

    return model_predicted_text

In [None]:
# Test 1: Translation using MalayGPT
user_input_text = "Good Morning"
transalated_text = malaygpt(user_input_text)

print(f"User input (in English): {user_input_text}")
print(f"Translation (in Malay): {transalated_text}")

User input (in English): Good Morning
Translation (in Malay): Selamat Pagi


In [None]:
# Test 2: Translation using MalayGPT
user_input_text = "What are you talking about?"
transalated_text = malaygpt(user_input_text)

print(f"User input (in English): {user_input_text}")
print(f"Translation (in Malay): {transalated_text}")

User input (in English): What are you talking about?
Translation (in Malay): Apa yang kamu merepek ni?
