# Projet GPT - Train

This notebook contains the code used to train a small language model using PyTorch from scratch. The model is inspired by the GPT architecture.


#### Hardware
- RTX3060 12GB VRAM
- AMD Ryzen 7 5800X 
- 32GB RAM
- Ubuntu 22.04 LTS

In [26]:
import os
CACHE_DIR = "/media/rob/RobsDisk/cache_data_llm"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, "datasets")
os.environ['HF_METRICS_CACHE'] = os.path.join(CACHE_DIR, "metrics")
os.environ['HF_MODULES_CACHE'] = os.path.join(CACHE_DIR, "modules")


In [27]:
import torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torch.cuda.amp import GradScaler, autocast
import random, math

from datasets import load_dataset,concatenate_datasets
import tiktoken
from tqdm import tqdm
from datetime import datetime

## Datasets

### Common knowledge datasets:

##### English Wikipedia crawled dataset

In [28]:
# English Wikipedia crawled dataset
# path to store the dataset cache: /Volumes/RobertsDisk
wiki_en = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', cache_dir=CACHE_DIR) 
print("English Wikipedia dataset loaded.")
print("dataset size in gb:", wiki_en.dataset_size / (1024**3))
print("Number of entries:", len(wiki_en))
print("-"*50)
print("Example entry:")
print(wiki_en[random.randint(0, len(wiki_en)-1)]['text'])


English Wikipedia dataset loaded.
dataset size in gb: 18.812774107791483
Number of entries: 6407814
--------------------------------------------------
Example entry:
Carss Bush Park is a  nature reserve and urban park located at 74 Carwar Avenue, in the Sydney suburb of Carss Park, Georges River Council, New South Wales, Australia.

History 
Carss Bush Park is situated on a land grant made to Jonathan Croft of  on 28 January 1853. Within ten months Croft sold the land to William Barton on 17 October 1853 for A£352. This land speculation was to continue for another two years with sales in June 1854 to John Chappellow, for A£538 and in September 1855 to Lewis Gordon possibly in default of a mortgage.

Gordon sold the  to William Carss on 7 January 1863 for A£540. Carss was one of fifty tradesmen (stonemasons and carpenters) who had been recruited in Glasgow by Dr John Dunmore Lang. Carss arrived in Sydney in 1831 accompanied by his wife Helen Turnball. A cabinet maker by trade he found w

#### Simple stories dataset

In [29]:
# Simple stories dataset
stories = load_dataset("SimpleStories/SimpleStories", split='train', cache_dir=CACHE_DIR)
print("Simple stories dataset loaded.")
print("dataset size in mb:", stories.dataset_size / (1024**2))
print("Number of entries:", len(stories))
print("-"*50)
print("Example entry:")
print(stories[random.randint(0, len(stories)-1)]['story'])

Simple stories dataset loaded.
dataset size in mb: 3030.012650489807
Number of entries: 2115696
--------------------------------------------------
Example entry:
Sky was bright that morning. A bird flew high above. It was a parrot named Coco. He loved to talk and sing. "Look at me! I can fly!" he shouted. Below, a turtle named Shelly moved slowly. "You may fly, but I can walk the earth," she said. The two animals were very different.

Coco liked to tease Shelly. "You are so slow! I am the best!" he chirped. Shelly laughed softly. "You may be fast, but I can see more," she replied. Coco did not understand. He thought he was better just because he could fly. "What can you see that I cannot?" he asked. 

Then, Shelly said, "I can see the flowers. I can see the big trees. I can see the world from my own way." Coco felt a little sad. He realized he only saw the sky. "Maybe you are right," he said. "I only see clouds." 

Suddenly, a storm came. The wind blew hard, and rain poured down. Coco 

##### FineWeb-Edu dataset

In [30]:
fineweb_edu = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT",  split='train', cache_dir=CACHE_DIR)

print("FineWeb-Edu is ready.")
print("dataset size in gb:", fineweb_edu.dataset_size / (1024**3))
print("Number of entries:", len(fineweb_edu))
print("-"*50)
print("Example entry:")
print(fineweb_edu[random.randint(0, len(fineweb_edu)-1)]['text']) 


FineWeb-Edu is ready.
dataset size in gb: 45.730818568728864
Number of entries: 9672101
--------------------------------------------------
Example entry:
You can easily overcome space constraints and grow vegetables in your home garden through hydroponic gardening — a practice of growing vegetables in water without the use of soil. Roots are submerged in a water-based nutrient solution, while the upper part is supported above water level. Hydroponically grown vegetables are considered healthier because the process eliminates weeds, bacteria and soil-borne pests. You can use hydroponic systems indoors and grow fresh vegetables year-round.
Start the vegetable plant seeds in an inert growing medium such as rock wool cubes. Place the cubes in a small container filled with 1 inch of water so they remain moist and the seeds sprout successfully. When the seedlings reach a height of 2 to 3 inches and roots start showing through the sides, they are ready for transplantation to the hydroponic co

In [31]:
# OpenWebText2 dataset
owt2 = load_dataset("Skylion007/openwebtext", split="train", cache_dir=CACHE_DIR)
print("OpenWebText2 dataset loaded.")
print("Dataset size in GB:", owt2.dataset_size / (1024**3))
print("Number of entries:", len(owt2))
print("-"*50)
print("Example entry:")
print(owt2[random.randint(0, len(owt2)-1)]['text'])

OpenWebText2 dataset loaded.
Dataset size in GB: 37.03822539001703
Number of entries: 8013769
--------------------------------------------------
Example entry:
SANTA CLARA, Calif., Sept. 26, 2012 – Today Intel Corporation issued a statement in response to unsubstantiated news reports about comments made by Intel CEO Paul Otellini in a meeting with employees.

Intel has a long and successful heritage working with Microsoft on the release of Windows platforms, delivering devices that provide exciting experiences, stunning performance, and superior compatibility. Intel fully expects this to continue with Windows 8.

Intel, Microsoft and our partners have been working closely together on testing and validation to ensure delivery of a high-quality experience across the nearly 200 Intel-based designs that will start launching in October. Intel CEO Paul Otellini is on record as saying “Windows 8 is one of the best things that ever happened to Intel,” citing the importance of the touch interfa

#### Some Q&A data to improve the model's ability to answer questions:

In [32]:
q_a1 = load_dataset("agentlans/text-sft-questions-answers-only", split='train', cache_dir=CACHE_DIR)
print("Q&A dataset loaded.")
print("dataset size in mb:", q_a1.dataset_size / (1024**2))
print("Number of entries:", len(q_a1))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(q_a1)-1)
print(q_a1[index]['question'][:500], "\n", q_a1[index]['answer'])

Q&A dataset loaded.
dataset size in mb: 46.480509757995605
Number of entries: 120959
--------------------------------------------------
Example entry:
What is the role of national governments in the reintegration process, and why is it crucial to have conversations between them and NGOs? 
 National governments play a crucial role in the reintegration process because they possess the necessary resources and infrastructure to support initiatives. Collaborating with other relevant actors, governments can pool resources, share best practices, and coordinate responses to common challenges faced by returnees.


In [33]:
#euclaise/reddit-instruct
reddit_instruct = load_dataset("euclaise/reddit-instruct", split='train', cache_dir=CACHE_DIR)
# reddit_instruct = load_dataset("Felladrin/ChatML-reddit-instruct-curated", split='train', cache_dir=CACHE_DIR)
print("Reddit Instruct dataset loaded.")
print("dataset size in gb:", reddit_instruct.dataset_size / (1024**3))
print("Number of entries:", len(reddit_instruct))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(reddit_instruct)-1)
print(reddit_instruct[index]['post_title'][:500], reddit_instruct[index]['post_text'][:500]), "\n", reddit_instruct[index]['comment_text'][:500]

Reddit Instruct dataset loaded.
dataset size in gb: 0.09901080373674631
Number of entries: 84784
--------------------------------------------------
Example entry:
[Hitman] Why does anyone try to make an enemy of Agent 47? Seriously, there are lots of people, including (SPOILERS AHEAD) plenty of people in the ICA who have tried and failed to kill Agent 47. Considering that he is not only well-known as the ICA's best asset and top assassin, but is also regarded as a legend across the globe who never fails a contract, why does anyone think that they'll ever succeed against him?


(None,
 '\n',
 "Because at the end of the day he is only human.  An exceptionally skilled, intelligent, and fit human, yes, but still only human.  The ICA, the Franchise, and other underworld elements tend not to be the type to invent bogeymen for themselves to fear.  If Agent 47 represents a thorn in their side, then they're going to eliminate him. The fact that they haven't succeeded so far does not mean, to them, that 47 is some sort of superhuman ubermensch, it just means he's been lucky enough to survive, ")

In [34]:
# tatsu-lab/alpaca ( for Q&A fine-tuning )
alpaca = load_dataset("tatsu-lab/alpaca", split='train')
print("Alpaca dataset loaded.")
print("dataset size in mb:", alpaca.dataset_size / (1024**2))
print("Number of entries:", len(alpaca))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(alpaca)-1)
print(alpaca[index]['instruction'][:500], "\n", alpaca[index]['output'][:500])

Alpaca dataset loaded.
dataset size in mb: 44.06797695159912
Number of entries: 52002
--------------------------------------------------
Example entry:
Rewrite the following sentence to express the same meaning in a negative form: "She likes to play soccer". 
 She does not dislike playing soccer.


## Data Preprocessing

#### Tokenizer setup

For this project i use tiktoken for the tokenizer, as it is the same tokenizer used by OpenAI for their models.

I use the "gpt2" encoding which is a byte pair encoding (BPE) tokenizer.

In [35]:
tokenizer_base = tiktoken.get_encoding("gpt2")

tokenizer = tiktoken.Encoding(
    name="rob-tokenizer",
    pat_str=tokenizer_base._pat_str,
    mergeable_ranks=tokenizer_base._mergeable_ranks,
    special_tokens={
        **tokenizer_base._special_tokens,
        "<|im_start|>": 50257,
        "<|im_end|>": 50258,
        "<|pad|>": 50259,
    }
)

#### Test of the byte pair encoding tokenizer 

In [36]:
# test of tokenizer on reddit_instruct
sample_text = reddit_instruct[0]['post_title'] + " " + reddit_instruct[0]['post_text'] + " " + reddit_instruct[0]['comment_text']
tokens = tokenizer.encode(sample_text)
print(tokens)
print("Decoded text:")
print(tokenizer.decode(tokens)) 
print(f"Sample text length in characters: {len(sample_text)}")
print(f"Sample text length in tokens: {len(tokens)}")   

[2061, 318, 24207, 1616, 2587, 30, 314, 2342, 257, 7684, 286, 1097, 5861, 290, 484, 1561, 546, 275, 32512, 7021, 290, 884, 11, 1312, 373, 11263, 644, 275, 32512, 318, 290, 1312, 18548, 1064, 597, 2562, 7468, 284, 644, 340, 318, 24207, 1616, 318, 655, 262, 1438, 329, 257, 16058, 286, 6147, 13, 554, 262, 29393, 995, 340, 338, 1690, 973, 355, 257, 1790, 1021, 329, 3354, 326, 547, 3235, 1389, 503, 286, 257, 1263, 2512, 286, 2587, 11, 355, 6886, 284, 11721, 3350, 654, 810, 44030, 6147, 318, 19036, 656, 257, 15936, 12070, 503, 286, 9629, 6147, 13, 7080, 3191, 318, 517, 5789, 329, 1588, 17794, 475, 340, 460, 779, 1365, 3081, 286, 21782, 290, 318, 4577, 284, 787, 329, 4833, 17794, 588, 3234, 3354, 13]
Decoded text:
What is Billet material? I watch a bunch of car videos and they talk about billet blocks and such, i was wondering what billet is and i cant find any easy explanation to what it is Billet is just the name for a chunk of metal. In the automotive world it's often used as a short hand 

### Formatting datasets functions

#### Merging datasets

In [37]:
from datasets import concatenate_datasets
combined_train_dataset = concatenate_datasets([  
    wiki_en,
    stories,
    fineweb_edu,
    owt2,  
])  

combined_finetune_dataset = concatenate_datasets([
    q_a1,
    reddit_instruct,
    alpaca,
])

# Shuffle the combined dataset
train_dataset = combined_train_dataset.shuffle(seed=42)
finetune_dataset = combined_finetune_dataset.shuffle(seed=42)
print(f"Train dataset size: {len(combined_train_dataset)}")
print(f"Finetune dataset size: {len(combined_finetune_dataset)}")

# Exemple 

print("Example entry from train dataset:")
index = random.randint(0, len(train_dataset)-1)
print(train_dataset[index])   

Train dataset size: 26209380
Finetune dataset size: 257745
Example entry from train dataset:
{'id': '22475636', 'url': 'https://en.wikipedia.org/wiki/List%20of%20dermatologists', 'title': 'List of dermatologists', 'text': 'This is a list of dermatologists who have made notable contributions to the field of dermatology.\n\nDermatologists in popular culture\n Dr. Sandra Lee, presenter of the TLC TV series Dr. Pimple Popper\n\nFictional dermatologists\n Dr. Archibald Newlands (Martin Donovan) in the television series Law & Order: Special Victims Unit\n Dr. Sara Sitarides (Marcia Cross) in the television sitcom Seinfeld\n Dr. Emily Sweeney (Laura Spencer) in the television series The Big Bang Theory\n\nReferences\n\n \nLists of health professionals\nLists of physicians', 'story': None, 'topic': None, 'theme': None, 'style': None, 'feature': None, 'grammar': None, 'persona': None, 'initial_word_type': None, 'initial_letter': None, 'word_count': None, 'character_count': None, 'num_paragraphs

#### Custom Dataset class

Inspired by the dataloader from the "LLMs from scratch" repository. But adapted for multi-row text arrays.

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/dataloader.ipynb

In [38]:
class GPTDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length):
        """
        Args:
            dataset: Dataset of the combined hugginface entries
            tokenizer: the initiatokenizer to process text
            max_length: Context window size
        """
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.input_tokens = []
        self.target_tokens = []

        self.pad_token_id = 50259         # <|pad|>
        self.bos_token_id = 50257    # <|im_start|>
        self.eos_token_id = 50258    # <|im_end|>

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Raw text
        
        # Data format handling
        entry = self.data[idx]
        if 'text' in entry:
            text = entry['text']
        elif 'story' in entry:
            text = entry['story']
        elif 'question' in entry and 'answer' in entry: 
            text = "User: " + entry['question'] + " Assistant:" + entry['answer']
        elif 'post_title' in entry and 'post_text' in entry and 'comment_text' in entry:
            text = "User: " + entry['post_title'] + " Assistant:" + entry['post_text'] + " " + entry['comment_text']
        elif 'instruction' in entry and 'output' in entry:
            text = "User: " + entry['instruction'] + " Assistant:" + entry['output']
        else:
            raise ValueError("Unknown data entry format")
        
        text = str(text) # Ensure text is a string
        #print(text)

        # Adding Start and End tokens
        text = "<|im_start|>" + text + "<|im_end|>" 

        # Tokenization
        tokens = self.tokenizer.encode(text, allowed_special="all")

        # Truncation
        tokens = tokens[:self.max_length] #Data is loost here ( fix later with sliding window )

        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)  # All tokens except last
        labels = torch.tensor(tokens[1:], dtype=torch.long)      # All tokens except first

        


        attention_mask = (input_ids != self.pad_token_id).long() # 1 for real tokens, 0 for padding

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

In [39]:
class GPTstreamingDataset(IterableDataset):
    def __init__(self, dataset, tokenizer, max_length):
        """
        Args:
            dataset: Dataset of the combined hugginface entries
            tokenizer: the initiatokenizer to process text
            max_length: Context window size
        """
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.bos_token_id = 50257    # <|im_start|>
        self.eos_token_id = 50258    # <|im_end|>

    def __iter__(self):
        buffer = []
        for entry in self.data:
            # Data format 
            if entry.get('text') is not None:
                text = entry['text']
            elif entry.get('story') is not None:
                text = entry['story']
            elif entry.get('question') is not None and entry.get('answer') is not None: 
                text = f"User: {entry['question']} Assistant: {entry['answer']}"
            elif entry.get('post_title') is not None:
                title = entry.get('post_title', "")
                post = entry.get('post_text', "")
                comment = entry.get('comment_text', "")
                text = f"User: {title} Assistant: {post} {comment}"
            elif entry.get('instruction') is not None and entry.get('output') is not None:
                text = f"User: {entry['instruction']} Assistant: {entry['output']}"
            if text is None:
                continue
                
            text = str(text) 

            # Start and End tokens
            text = "<|im_start|>" + text + "<|im_end|>" 

            # Tokenization
            tokens = self.tokenizer.encode(text, allowed_special="all")

            buffer.extend(tokens) #pile tokens into buffer


            while len(buffer) >= self.max_length +1:
                chunk = buffer[:self.max_length + 1]
                buffer = buffer[self.max_length+ 1:] # remove used tokens

                input_ids = torch.tensor(chunk[:-1], dtype=torch.long)  
                labels = torch.tensor(chunk[1:], dtype=torch.long)      

                yield {
                    'input_ids': input_ids,
                    'labels': labels
                }

## GPT Model 

### GPT config 

This is the configuration for the GPT model i am going to train. It is a smaller version of the GPT-2 model. 

- Context length: 512 tokens
- Embedding dimension: 512
- Number of attention heads: 8
- Number of layers: 8

In [40]:
GPT_CONFIG = {
    "vocab_size": 50260,
    "context_length": 256, # max i could fit on my gpu
    "emb_dim": 384,
    "number_heads": 6,
    "number_layers": 6,
    "drop_rate": 0.1,
}

##### Test of a entry from dataloader

In [41]:
# Empty cuda cache and memory management
torch.cuda.empty_cache()

train_dataset = GPTstreamingDataset(train_dataset, tokenizer, GPT_CONFIG['context_length'])

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=6, pin_memory=True, prefetch_factor=2, persistent_workers=True)

In [42]:
print("##### Test of a entry from dataloader")
batch = next(iter(train_dataloader))
print(batch)    

##### Test of a entry from dataloader
{'input_ids': tensor([[50257,   818,   257,  ...,  4403,  5866, 10062],
        [  290,  7103, 28717,  ..., 23312, 33408,   341],
        [  262,  4403,  5866,  ..., 24337,  1080,   373],
        ...,
        [  897,   276,  4034,  ...,  7184,   290,   262],
        [  286,  3292,    13,  ...,   991,  9389,  1917],
        [ 4525,  4890,    11,  ...,   284,  2074, 48837]]), 'labels': tensor([[  818,   257,  6016,  ...,  5866, 10062,  2838],
        [ 7103, 28717,  1989,  ..., 33408,   341,   287],
        [ 4403,  5866, 10062,  ...,  1080,   373,   973],
        ...,
        [  276,  4034,    13,  ...,   290,   262,  5236],
        [ 3292,    13,   198,  ...,  9389,  1917,    13],
        [ 4890,    11,  8098,  ...,  2074, 48837, 14317]])}


#### Pytroch model implementation

For this first implementation, i am using the transformer and embedding modules from PyTorch. Later i will try to implement the attention mechanism from scratch for better understanding.

https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [43]:
class GPTModel(nn.Module):
    """
    Gpt model class using transformer library
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Network components 
        ## Embedding layers
        self.embedding = nn.Embedding(config['vocab_size'], config['emb_dim'])
        self.positional_encoding = nn.Embedding(config['context_length'], config['emb_dim'])
        ## Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['emb_dim'],
            nhead=config['number_heads'],
            dim_feedforward=4 * config['emb_dim'],
            dropout=config['drop_rate'],
            activation='gelu',
            batch_first=True,
            norm_first=True # stabilityy 
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['number_layers'])
        ## Output layer
        self.output_layer = nn.Linear(config['emb_dim'], config['vocab_size'], bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)



    def forward(self, input_ids, attention_mask, label_ids=None):   
        batch_size, seq_length = input_ids.shape

        # Embedding
        token_embeddings = self.embedding(input_ids)  
        pos_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.positional_encoding(pos_ids)  # (batch_size, seq_length, emb_dim)

        embeddings = token_embeddings + position_embeddings  # (batch_size, seq_length, emb_dim)

        # Prevent attending to future tokens
        causal_mask = torch.triu(torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device), diagonal=1)

        # voiding to pay attatention padding tokens
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
        
        x = self.transformer(embeddings, mask=causal_mask, src_key_padding_mask=key_padding_mask, is_causal=True)
        logits = self.output_layer(x)

        # Computing loss 
        if label_ids is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.config['vocab_size']), label_ids.view(-1)) # Applies loss to predictions
            return logits, loss

        return logits, None




In [44]:
class GPTModelNoPad(nn.Module):
    """
    Gpt model class using transformer library adapted for streaming dataset without padding mask
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Network components 
        ## Embedding layers
        self.embedding = nn.Embedding(config['vocab_size'], config['emb_dim'])
        self.positional_encoding = nn.Embedding(config['context_length'], config['emb_dim'])
        ## Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['emb_dim'],
            nhead=config['number_heads'],
            dim_feedforward=4 * config['emb_dim'],
            dropout=config['drop_rate'],
            activation='gelu',
            batch_first=True,
            norm_first=True # stabilityy 
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['number_layers'])
        ## Output layer
        self.output_layer = nn.Linear(config['emb_dim'], config['vocab_size'], bias=False)
        # Weight Tying ( input and output embeddings share weights)
        self.output_layer.weight = self.embedding.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, input_ids, label_ids=None):   
        batch_size, seq_length = input_ids.shape

        # Embedding
        token_embeddings = self.embedding(input_ids)  
        pos_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.positional_encoding(pos_ids)  # (batch_size, seq_length, emb_dim)

        embeddings = token_embeddings + position_embeddings  # (batch_size, seq_length, emb_dim)

        # Prevent attending to future tokens
        causal_mask = torch.triu(torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device), diagonal=1)
        
        x = self.transformer(embeddings, mask=causal_mask, is_causal=True)
        logits = self.output_layer(x)

        # Computing loss 
        if label_ids is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config['vocab_size']), label_ids.view(-1)) # Applies loss to predictions
            return logits, loss

        return logits, None 


### Model instantiation 

In [45]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModelNoPad(GPT_CONFIG).to(device)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=10000) # Linear learning rate scheduler
scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision



GPTModelNoPad(
  (embedding): Embedding(50260, 384)
  (positional_encoding): Embedding(256, 384)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
        )
        (linear1): Linear(in_features=384, out_features=1536, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1536, out_features=384, bias=True)
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_layer): Linear(in_features=384, out_features=50260, bias=False)
)


  scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision


### Number of parameters calculation

In [46]:
# Calculate number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)    

count_params = count_parameters(model)  
print(f"Number of trainable parameters: {count_params}")

Number of trainable parameters: 30044928


## Training setup

In [47]:
def inference(model, tokenizer, prompt, max_length=256, device='cpu'):
    model.eval()
    input_ids = tokenizer.encode(prompt, allowed_special="all")
    input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

    generated_ids = input_ids
    max_length = max_length - input_ids.shape[1]  # Remaining length for generation
    with torch.no_grad():
        for _ in range(max_length):
            #attention_mask = torch.ones_like(generated_ids)  # All tokens are real (no padding)
            logits, _ = model(generated_ids) # model(generated_ids, attention_mask)
            next_token_logits = logits[:, -1, :]  #next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)  # (1, 1) #
            probs = torch.softmax(next_token_logits, dim=-1)
            #print(next_token_logits)
            next_token_id = torch.multinomial(probs, num_samples=1)  # mult
            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)  # Append to sequence

            if next_token_id.item() == 50258:  # Stop if eot token is generated
                break

    generated_text = tokenizer.decode(generated_ids.squeeze().tolist())
    return generated_text


In [48]:
accumulation_steps= 1  # Number of steps to accumulate gradients

def train_loop(model, dataloader, optimizer, scheduler, device, num_epochs=3, accumulation_steps = 4, question_interval=500, saving_interval=5000):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        step_count = 0

        for i, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            #attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):  # bf16 optimized for ampere archi
                logits, loss = model(input_ids, labels) # model(input_ids, attention_mask, labels)
                loss = loss / accumulation_steps
            scaler.scale(loss).backward()
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            epoch_loss += loss.item()*accumulation_steps
            progress_bar.set_postfix(loss=loss.item() * accumulation_steps)

            # Inference check
            step_count += 1
            if step_count % question_interval == 0:
                model.eval()
                with torch.no_grad():
                    prompt = "AI is technology that enables computers and machines to"
                    generated_text = inference(model, tokenizer, prompt, max_length=50, device=device)
                    print(f"\n[Inference at step {step_count}]")
                    #save generated text to file train_ouput.txt
                    with open("outputs/train_output.txt", "a") as f:
                        if generated_text.strip() != "":
                            f.write(f"\n[Inference at step {step_count}]: {generated_text}\n")  
                            f.write("-"*50 + "\n")
                        else:
                            f.write(f"\n[Inference at step {step_count}]: [No output generated]\n")
                            f.write("-"*50 + "\n")
                model.train()

            if step_count % saving_interval == 0:
                # Save model checkpoint
                date = datetime.now().strftime("%Y%m%d")
                # Create models directory if it doesn't exist
                Path(f"models/{date}").mkdir(parents=True, exist_ok=True)
                checkpoint_path = f"models/{date}/weights_step{step_count}.pt"
                torch.save(model.state_dict(), checkpoint_path)
                print(f"\nModel checkpoint saved at step {step_count} to {checkpoint_path}")

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")


In [49]:
def load_model_weights(model, checkpoint_path, device):
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)
    print(f"Model loaded from {checkpoint_path}")

In [72]:
def save(model, checkpoint_path):
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model saved to {checkpoint_path}")


### First training on combined train dataset

In [50]:
#empty gpu memory 
torch.cuda.empty_cache()
torch.cuda.synchronize()

load_model_weights(model, "models/starter/weights_step10000.pt", device)

train_loop(model, train_dataloader, optimizer, scheduler, device, num_epochs=1)

  model.load_state_dict(torch.load(checkpoint_path, map_location=device))


Model loaded from models/starter/weights_step10000.pt


Epoch 1/1: 501it [01:14,  5.54it/s, loss=5.62]


[Inference at step 500]


Epoch 1/1: 1001it [02:29,  6.10it/s, loss=5.58]


[Inference at step 1000]


Epoch 1/1: 1501it [03:43,  6.07it/s, loss=5.9] 


[Inference at step 1500]


Epoch 1/1: 2001it [04:57,  6.13it/s, loss=5.73]


[Inference at step 2000]


Epoch 1/1: 2501it [06:10,  6.11it/s, loss=5.54]


[Inference at step 2500]


Epoch 1/1: 3001it [07:24,  6.10it/s, loss=5.96]


[Inference at step 3000]


Epoch 1/1: 3501it [08:38,  6.09it/s, loss=5.64]


[Inference at step 3500]


Epoch 1/1: 4001it [09:51,  6.17it/s, loss=5.93]


[Inference at step 4000]


Epoch 1/1: 4501it [11:03,  6.15it/s, loss=6.06]


[Inference at step 4500]


Epoch 1/1: 5000it [12:16,  4.72it/s, loss=5.27]


[Inference at step 5000]

Model checkpoint saved at step 5000 to models/20260205/weights_step5000.pt


Epoch 1/1: 5501it [13:30,  6.43it/s, loss=5.49]


[Inference at step 5500]


Epoch 1/1: 6001it [14:44,  6.13it/s, loss=5.59]


[Inference at step 6000]


Epoch 1/1: 6501it [15:58,  6.06it/s, loss=5.41]


[Inference at step 6500]


Epoch 1/1: 7001it [17:14,  5.86it/s, loss=5.42]


[Inference at step 7000]


Epoch 1/1: 7501it [18:27,  6.19it/s, loss=5.53]


[Inference at step 7500]


Epoch 1/1: 8001it [19:40,  6.16it/s, loss=5.12]


[Inference at step 8000]


Epoch 1/1: 8501it [20:53,  6.19it/s, loss=5.19]


[Inference at step 8500]


Epoch 1/1: 9001it [22:06,  6.18it/s, loss=6.18]


[Inference at step 9000]


Epoch 1/1: 9501it [23:16,  6.65it/s, loss=5.49]


[Inference at step 9500]


Epoch 1/1: 10000it [24:24,  5.13it/s, loss=5.64]


[Inference at step 10000]

Model checkpoint saved at step 10000 to models/20260205/weights_step10000.pt


Epoch 1/1: 10501it [25:32,  6.66it/s, loss=5.37]


[Inference at step 10500]


Epoch 1/1: 11001it [26:39,  6.64it/s, loss=5.3] 


[Inference at step 11000]


Epoch 1/1: 11501it [27:47,  6.66it/s, loss=5.52]


[Inference at step 11500]


Epoch 1/1: 12001it [28:54,  6.67it/s, loss=5.34]


[Inference at step 12000]


Epoch 1/1: 12501it [30:02,  6.23it/s, loss=5.21]


[Inference at step 12500]


Epoch 1/1: 13001it [31:15,  6.22it/s, loss=5.69]


[Inference at step 13000]


Epoch 1/1: 13501it [32:27,  6.27it/s, loss=5.15]


[Inference at step 13500]


Epoch 1/1: 14001it [33:39,  6.28it/s, loss=5.4] 


[Inference at step 14000]


Epoch 1/1: 14501it [34:52,  6.28it/s, loss=5.65]


[Inference at step 14500]


Epoch 1/1: 15000it [36:04,  4.76it/s, loss=4.94]


[Inference at step 15000]

Model checkpoint saved at step 15000 to models/20260205/weights_step15000.pt


Epoch 1/1: 15501it [37:17,  6.18it/s, loss=4.95]


[Inference at step 15500]


Epoch 1/1: 16001it [38:25,  6.73it/s, loss=5.09]


[Inference at step 16000]


Epoch 1/1: 16501it [39:32,  6.73it/s, loss=5.74]


[Inference at step 16500]


Epoch 1/1: 17001it [40:39,  6.73it/s, loss=5.19]


[Inference at step 17000]


Epoch 1/1: 17501it [41:46,  6.75it/s, loss=5.05]


[Inference at step 17500]


Epoch 1/1: 18001it [42:53,  6.74it/s, loss=5.34]


[Inference at step 18000]


Epoch 1/1: 18501it [44:00,  6.73it/s, loss=5.16]


[Inference at step 18500]


Epoch 1/1: 19001it [45:07,  6.72it/s, loss=5.32]


[Inference at step 19000]


Epoch 1/1: 19501it [46:14,  6.74it/s, loss=5.12]


[Inference at step 19500]


Epoch 1/1: 20000it [47:21,  5.11it/s, loss=5.07]


[Inference at step 20000]

Model checkpoint saved at step 20000 to models/20260205/weights_step20000.pt


Epoch 1/1: 20501it [48:28,  6.73it/s, loss=5.53]


[Inference at step 20500]


Epoch 1/1: 21001it [49:35,  6.72it/s, loss=4.92]


[Inference at step 21000]


Epoch 1/1: 21501it [50:42,  6.72it/s, loss=5.13]


[Inference at step 21500]


Epoch 1/1: 22001it [51:49,  6.71it/s, loss=4.95]


[Inference at step 22000]


Epoch 1/1: 22501it [52:57,  6.67it/s, loss=4.39]


[Inference at step 22500]


Epoch 1/1: 23001it [54:09,  6.29it/s, loss=5.27]


[Inference at step 23000]


Epoch 1/1: 23501it [55:19,  6.20it/s, loss=4.87]


[Inference at step 23500]


Epoch 1/1: 24001it [56:32,  6.13it/s, loss=5.48]


[Inference at step 24000]


Epoch 1/1: 24501it [57:46,  6.16it/s, loss=5.04]


[Inference at step 24500]


Epoch 1/1: 25000it [58:59,  4.66it/s, loss=4.84]


[Inference at step 25000]

Model checkpoint saved at step 25000 to models/20260205/weights_step25000.pt


Epoch 1/1: 25501it [1:00:11,  6.28it/s, loss=5.56]


[Inference at step 25500]


Epoch 1/1: 26001it [1:01:22,  6.61it/s, loss=4.94]


[Inference at step 26000]


Epoch 1/1: 26501it [1:02:32,  6.26it/s, loss=4.83]


[Inference at step 26500]


Epoch 1/1: 27001it [1:03:43,  6.20it/s, loss=4.7] 


[Inference at step 27000]


Epoch 1/1: 27501it [1:04:56,  6.33it/s, loss=5.06]


[Inference at step 27500]


Epoch 1/1: 28001it [1:06:09,  5.97it/s, loss=4.91]


[Inference at step 28000]


Epoch 1/1: 28501it [1:07:22,  6.25it/s, loss=4.46]


[Inference at step 28500]


Epoch 1/1: 29001it [1:08:34,  6.26it/s, loss=5.18]


[Inference at step 29000]


Epoch 1/1: 29501it [1:09:47,  6.17it/s, loss=4.99]


[Inference at step 29500]


Epoch 1/1: 30000it [1:10:57,  5.13it/s, loss=5.08]


[Inference at step 30000]

Model checkpoint saved at step 30000 to models/20260205/weights_step30000.pt


Epoch 1/1: 30501it [1:12:04,  6.67it/s, loss=5.1] 


[Inference at step 30500]


Epoch 1/1: 31001it [1:13:12,  6.66it/s, loss=5.16]


[Inference at step 31000]


Epoch 1/1: 31501it [1:14:19,  6.66it/s, loss=4.94]


[Inference at step 31500]


Epoch 1/1: 32001it [1:15:29,  6.29it/s, loss=4.74]


[Inference at step 32000]


Epoch 1/1: 32501it [1:16:41,  6.22it/s, loss=4.79]


[Inference at step 32500]


Epoch 1/1: 33001it [1:17:54,  6.27it/s, loss=5.12]


[Inference at step 33000]


Epoch 1/1: 33501it [1:19:07,  6.30it/s, loss=4.77]


[Inference at step 33500]


Epoch 1/1: 34001it [1:20:17,  6.03it/s, loss=4.76]


[Inference at step 34000]


Epoch 1/1: 34501it [1:21:29,  6.32it/s, loss=4.82]


[Inference at step 34500]


Epoch 1/1: 35000it [1:22:41,  4.68it/s, loss=4.99]


[Inference at step 35000]

Model checkpoint saved at step 35000 to models/20260205/weights_step35000.pt


Epoch 1/1: 35501it [1:23:53,  6.31it/s, loss=4.85]


[Inference at step 35500]


Epoch 1/1: 36001it [1:25:03,  6.15it/s, loss=4.82]


[Inference at step 36000]


Epoch 1/1: 36501it [1:26:15,  6.20it/s, loss=5.08]


[Inference at step 36500]


Epoch 1/1: 37001it [1:27:28,  6.14it/s, loss=4.69]


[Inference at step 37000]


Epoch 1/1: 37501it [1:28:41,  6.13it/s, loss=4.55]


[Inference at step 37500]


Epoch 1/1: 38001it [1:29:55,  6.54it/s, loss=5.14]


[Inference at step 38000]


Epoch 1/1: 38501it [1:31:07,  5.70it/s, loss=4.89]


[Inference at step 38500]


Epoch 1/1: 39001it [1:32:23,  5.58it/s, loss=4.76]


[Inference at step 39000]


Epoch 1/1: 39501it [1:33:37,  6.23it/s, loss=4.84]


[Inference at step 39500]


Epoch 1/1: 40000it [1:34:50,  4.70it/s, loss=5.07]


[Inference at step 40000]

Model checkpoint saved at step 40000 to models/20260205/weights_step40000.pt


Epoch 1/1: 40501it [1:36:06,  6.07it/s, loss=4.88]


[Inference at step 40500]


Epoch 1/1: 41001it [1:37:21,  5.95it/s, loss=4.74]


[Inference at step 41000]


Epoch 1/1: 41501it [1:38:37,  5.96it/s, loss=4.78]


[Inference at step 41500]


Epoch 1/1: 42001it [1:39:53,  5.90it/s, loss=5.03]


[Inference at step 42000]


Epoch 1/1: 42501it [1:41:08,  6.00it/s, loss=4.81]


[Inference at step 42500]


Epoch 1/1: 43001it [1:42:23,  6.18it/s, loss=4.89]


[Inference at step 43000]


Epoch 1/1: 43501it [1:43:38,  6.03it/s, loss=5.07]


[Inference at step 43500]


Epoch 1/1: 44001it [1:44:53,  5.95it/s, loss=4.81]


[Inference at step 44000]


Epoch 1/1: 44501it [1:46:07,  6.11it/s, loss=4.65]


[Inference at step 44500]


Epoch 1/1: 45000it [1:47:22,  4.65it/s, loss=4.68]


[Inference at step 45000]

Model checkpoint saved at step 45000 to models/20260205/weights_step45000.pt


Epoch 1/1: 45501it [1:48:36,  6.15it/s, loss=4.64]


[Inference at step 45500]


Epoch 1/1: 46001it [1:49:49,  6.15it/s, loss=4.38]


[Inference at step 46000]


Epoch 1/1: 46501it [1:51:02,  6.36it/s, loss=4.77]


[Inference at step 46500]


Epoch 1/1: 47001it [1:52:17,  6.03it/s, loss=6.94]


[Inference at step 47000]


Epoch 1/1: 47501it [1:53:29,  6.12it/s, loss=4.92]


[Inference at step 47500]


Epoch 1/1: 48001it [1:54:42,  6.47it/s, loss=4.62]


[Inference at step 48000]


Epoch 1/1: 48501it [1:55:53,  6.23it/s, loss=4.97]


[Inference at step 48500]


Epoch 1/1: 49001it [1:57:04,  6.19it/s, loss=4.85]


[Inference at step 49000]


Epoch 1/1: 49501it [1:58:16,  6.26it/s, loss=5.23]


[Inference at step 49500]


Epoch 1/1: 50000it [1:59:31,  4.58it/s, loss=5.07]


[Inference at step 50000]

Model checkpoint saved at step 50000 to models/20260205/weights_step50000.pt


Epoch 1/1: 50501it [2:00:47,  5.97it/s, loss=4.97]


[Inference at step 50500]


Epoch 1/1: 51001it [2:02:03,  5.97it/s, loss=4.65]


[Inference at step 51000]


Epoch 1/1: 51501it [2:03:14,  6.00it/s, loss=4.59]


[Inference at step 51500]


Epoch 1/1: 52001it [2:04:25,  5.95it/s, loss=4.67]


[Inference at step 52000]


Epoch 1/1: 52501it [2:05:41,  5.83it/s, loss=4.76]


[Inference at step 52500]


Epoch 1/1: 53001it [2:06:55,  5.94it/s, loss=4.74]


[Inference at step 53000]


Epoch 1/1: 53501it [2:08:09,  6.06it/s, loss=4.74]


[Inference at step 53500]


Epoch 1/1: 54001it [2:09:24,  6.28it/s, loss=4.63]


[Inference at step 54000]


Epoch 1/1: 54501it [2:10:38,  5.99it/s, loss=4.69]


[Inference at step 54500]


Epoch 1/1: 55000it [2:11:54,  4.55it/s, loss=4.79]


[Inference at step 55000]

Model checkpoint saved at step 55000 to models/20260205/weights_step55000.pt


Epoch 1/1: 55501it [2:13:09,  6.47it/s, loss=5.02]


[Inference at step 55500]


Epoch 1/1: 56001it [2:14:23,  6.01it/s, loss=4.57]


[Inference at step 56000]


Epoch 1/1: 56501it [2:15:38,  6.13it/s, loss=4.72]


[Inference at step 56500]


Epoch 1/1: 57001it [2:16:49,  6.10it/s, loss=4.69]


[Inference at step 57000]


Epoch 1/1: 57501it [2:18:04,  6.09it/s, loss=4.86]


[Inference at step 57500]


Epoch 1/1: 58001it [2:19:18,  6.30it/s, loss=4.77]


[Inference at step 58000]


Epoch 1/1: 58501it [2:20:29,  6.57it/s, loss=4.6] 


[Inference at step 58500]


Epoch 1/1: 59001it [2:21:40,  6.05it/s, loss=4.8] 


[Inference at step 59000]


Epoch 1/1: 59501it [2:22:52,  6.33it/s, loss=4.96]


[Inference at step 59500]


Epoch 1/1: 60000it [2:24:04,  4.69it/s, loss=4.77]


[Inference at step 60000]

Model checkpoint saved at step 60000 to models/20260205/weights_step60000.pt


Epoch 1/1: 60501it [2:25:19,  6.04it/s, loss=4.67]


[Inference at step 60500]


Epoch 1/1: 61001it [2:26:33,  6.05it/s, loss=4.73]


[Inference at step 61000]


Epoch 1/1: 61501it [2:27:48,  5.98it/s, loss=4.91]


[Inference at step 61500]


Epoch 1/1: 62001it [2:29:03,  6.06it/s, loss=4.78]


[Inference at step 62000]


Epoch 1/1: 62501it [2:30:16,  6.28it/s, loss=4.64]


[Inference at step 62500]


Epoch 1/1: 63001it [2:31:28,  6.48it/s, loss=4.45]


[Inference at step 63000]


Epoch 1/1: 63501it [2:32:38,  6.21it/s, loss=4.79]


[Inference at step 63500]


Epoch 1/1: 64001it [2:33:50,  6.24it/s, loss=4.96]


[Inference at step 64000]


Epoch 1/1: 64225it [2:34:21,  6.93it/s, loss=4.86]


KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()
torch.cuda.synchronize()

## Test the trained model

In [73]:
save(model, "models/final_model.pt")

Model saved to models/final_model.pt


In [130]:
print("Testing the trained model:")
prompt = " AI is technology that enables computers and machines to"
generated_text = inference(model, tokenizer, prompt, max_length=250, device=device)
print("Generated text:")
print(generated_text)


Testing the trained model:
Generated text:
 AI is technology that enables computers and machines to convert data into all devices. Searchive causes and imagery can help the networks. Users could ignore the isolation, review, breaking down detailing data and hit down the security with a wireless technology such as commands and legitimate transfers.
A new model by Google technology software team refers accordingly to the issues: Open up 30 people, transmitting data on systems outside the symmetry and using anonymous communication systems, including the “obooters,” cloudnote, “space, storing data”, and x DNA, or the internet that radiatin content inside the camera of an tram, like cryptography and other detectors; here’s the question.
Once we begin looking at exposures, we can maintain the industrial model, and to accurately examine accurate data, and otherwise earned the maximum relative recurrence and 3.3 signal is divided into middle or middle-aged subjects responsible for smaller embe

## Sources 

### Principal references: 
- https://arxiv.org/abs/2005.14165 (GPT-3 paper)
- https://arxiv.org/abs/2002.05709 (Attention is all you need paper)
- Build a Large Language Model (from scratch) by Sebastian Raschka