# Projet GPT - Train

This notebook contains the code used to train a small language model using PyTorch from scratch. The model is inspired by the GPT architecture.


#### Hardware
- RTX3060 12GB VRAM
- AMD Ryzen 7 5800X 8-Core
- 32GB RAM
- Ubuntu 22.04 LTS

In [1]:
import os
CACHE_DIR = "/media/rob/RobsDisk/cache_data_llm"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, "datasets")
os.environ['HF_METRICS_CACHE'] = os.path.join(CACHE_DIR, "metrics")
os.environ['HF_MODULES_CACHE'] = os.path.join(CACHE_DIR, "modules")


In [2]:
import torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torch.cuda.amp import GradScaler, autocast
import random, math

from datasets import load_dataset,concatenate_datasets
import tiktoken
from tqdm import tqdm
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


## Datasets

### Common knowledge datasets:

##### English Wikipedia crawled dataset

In [3]:
# English Wikipedia crawled dataset
# path to store the dataset cache: /Volumes/RobertsDisk
wiki_en = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', cache_dir=CACHE_DIR) 
print("English Wikipedia dataset loaded.")
print("dataset size in gb:", wiki_en.dataset_size / (1024**3))
print("Number of entries:", len(wiki_en))
print("-"*50)
print("Example entry:")
print(wiki_en[random.randint(0, len(wiki_en)-1)]['text'])


English Wikipedia dataset loaded.
dataset size in gb: 18.812774107791483
Number of entries: 6407814
--------------------------------------------------
Example entry:
Kamel Habri (; born March 5, 1976, in Tlemcen) is a retired Algerian international football player. He spent the majority of his career with his hometown club of WA Tlemcen. He also had 7 caps for the Algeria National Team and was a member of the team at the 1998 African Cup of Nations in Burkina Faso.

Club career
 1994-2000 WA Tlemcen 
 2000-2003 JSM Béjaïa 
 2003-2006 JS Kabylie 
 2006-2008 JSM Béjaïa 
 2008-2011 WA Tlemcen

Honours
 Won the Arab Champions League once with WA Tlemcen in 1998
 Won the Algerian Cup once with WA Tlemcen in 1998
 Won the Algerian League once with JS Kabylie in 2004
 Played in the 1998 African Cup of Nations in Burkina Faso
 Has 7 caps for the Algerian National Team

References

1976 births
1998 African Cup of Nations players
Algerian men's footballers
Algeria men's international footballers

#### Simple stories dataset

In [4]:
# Simple stories dataset
stories = load_dataset("SimpleStories/SimpleStories", split='train', cache_dir=CACHE_DIR)
print("Simple stories dataset loaded.")
print("dataset size in mb:", stories.dataset_size / (1024**2))
print("Number of entries:", len(stories))
print("-"*50)
print("Example entry:")
print(stories[random.randint(0, len(stories)-1)]['story'])



Simple stories dataset loaded.
dataset size in mb: 3030.012650489807
Number of entries: 2115696
--------------------------------------------------
Example entry:
Beneath the waves, a girl named Lily swam happily. She loved the ocean and all the colorful fish. One day, she found an old treasure chest on the sea floor. Curious, she opened it. Inside was a beautiful necklace that glowed with a bright light. "What is this?" she wondered.

Suddenly, a figure burst from the water. It was Aqua Knight, a hero of the sea! "That necklace holds the power of the ocean," he explained. "But beware, for the Dark Tide seeks it for evil." Just as he spoke, a shadow appeared. It was the anti-hero, the Deep Whisper. He was known for stealing treasures and causing chaos.

"Hand over the necklace!" Deep Whisper shouted. Aqua Knight stood firm. "You will not take it!" he declared. Lily felt a mix of fear and bravery. She wanted to protect the treasure and help Aqua Knight. "We must outsmart him!" she said, 

##### FineWeb-Edu dataset

In [5]:
fineweb_edu = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT",  split='train', cache_dir=CACHE_DIR)

print("FineWeb-Edu is ready.")
print("dataset size in gb:", fineweb_edu.dataset_size / (1024**3))
print("Number of entries:", len(fineweb_edu))
print("-"*50)
print("Example entry:")
print(fineweb_edu[random.randint(0, len(fineweb_edu)-1)]['text']) 


FineWeb-Edu is ready.
dataset size in gb: 45.730818568728864
Number of entries: 9672101
--------------------------------------------------
Example entry:
Deb thru-hiked the Appalachian Trail and is a Search & Rescue volunteer and writer living in Flagstaff, AZ.
See Geology at its Best in the Paria Canyon-Vermilion Cliffs Wilderness
Until a friend invited me on a hiking trip in northern Arizona and southern Utah, I'd not heard of the Wave. But I realized when we arrived that I had seen photos of this amazing geological formation, now an internationally known and very popular wilderness destination. Which is why daily visitation to the Wave has been limited by the Bureau of Land Management, to prevent excessive damage, crowding, and overuse of the area.
If you're willing to take your chances with the permit process, which I'll describe below, and can hike six, mostly easy to moderately difficult (round-trip) miles in desert terrain, I highly recommend a visit to the Wave.
Here, I'll show

In [6]:
# OpenWebText2 dataset
owt2 = load_dataset("Skylion007/openwebtext", split="train", cache_dir=CACHE_DIR)
print("OpenWebText2 dataset loaded.")
print("Dataset size in GB:", owt2.dataset_size / (1024**3))
print("Number of entries:", len(owt2))
print("-"*50)
print("Example entry:")
print(owt2[random.randint(0, len(owt2)-1)]['text'])

OpenWebText2 dataset loaded.
Dataset size in GB: 37.03822539001703
Number of entries: 8013769
--------------------------------------------------
Example entry:
Mission Not Accomplished

The war in Libya is a good war — or at least, it should and could be. But it is certainly not a smart war and may well turn into a debacle. Bringing down Col. Muammar al-Qaddafi’s tyranny would be a major strategic and humanitarian victory in the Middle East. That achievement would be even more stunning if a democratic government, brought to power by Libyans themselves, replaced Qaddafi. Although the Libyan rebels will undoubtedly need Western help — and are rightly receiving it — the credit will be theirs: The American Revolutionaries needed French arms to defeat the British, but French help did not tarnish their victory.


A protracted civil war in Libya could have effects beyond its borders. It could lead competing outside powers — France, Turkey, or even China — to back different Libyan factions. U.

#### Some Q&A data to improve the model's ability to answer questions:

In [7]:
q_a1 = load_dataset("agentlans/text-sft-questions-answers-only", split='train', cache_dir=CACHE_DIR)
print("Q&A dataset loaded.")
print("dataset size in mb:", q_a1.dataset_size / (1024**2))
print("Number of entries:", len(q_a1))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(q_a1)-1)
print(q_a1[index]['question'][:500], "\n", q_a1[index]['answer'])

Q&A dataset loaded.
dataset size in mb: 46.480509757995605
Number of entries: 120959
--------------------------------------------------
Example entry:
Why is proper curing important in automotive painting? 
 Proper curing is crucial for allowing paint to fully bond with the underlying surface, preventing premature cracking, peeling, or issues that impact the final product's longevity and overall quality.


In [8]:
#euclaise/reddit-instruct
reddit_instruct = load_dataset("euclaise/reddit-instruct", split='train', cache_dir=CACHE_DIR)
# reddit_instruct = load_dataset("Felladrin/ChatML-reddit-instruct-curated", split='train', cache_dir=CACHE_DIR)
print("Reddit Instruct dataset loaded.")
print("dataset size in gb:", reddit_instruct.dataset_size / (1024**3))
print("Number of entries:", len(reddit_instruct))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(reddit_instruct)-1)
print(reddit_instruct[index]['post_title'][:500], reddit_instruct[index]['post_text'][:500]), "\n", reddit_instruct[index]['comment_text'][:500]

Reddit Instruct dataset loaded.
dataset size in gb: 0.09901080373674631
Number of entries: 84784
--------------------------------------------------
Example entry:
ELI5: Howcome when I drink salty sea-water I vomit but drinking water while eating heavily salted popcorn is fine? 


(None,
 '\n',
 "There is much much much more salt in salt water.  When you eat popcorn then drink water you are essentially diluting the salt content.  When drinking salt water you aren't diluting anything,  but you are increasing your body's salt content. (edit to fix auto correct - lol) ")

In [9]:
# tatsu-lab/alpaca ( for Q&A fine-tuning )
alpaca = load_dataset("tatsu-lab/alpaca", split='train')
print("Alpaca dataset loaded.")
print("dataset size in mb:", alpaca.dataset_size / (1024**2))
print("Number of entries:", len(alpaca))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(alpaca)-1)
print(alpaca[index]['instruction'][:500], "\n", alpaca[index]['output'][:500])

Alpaca dataset loaded.
dataset size in mb: 44.06797695159912
Number of entries: 52002
--------------------------------------------------
Example entry:
What is the impact of deforestation on wildlife? 
 Deforestation has a range of consequences for wildlife, including habitat loss, disruption of ecosystems and food webs, displacement of species, increased competition, and increased vulnerability to predation. Additionally, deforestation can increase the risk of climate change, resulting in altered temperatures and unpredictable weather patterns that further disrupt wildlife habitats.


## Data Preprocessing

#### Tokenizer setup

For this project i use tiktoken for the tokenizer, as it is the same tokenizer used by OpenAI for their models.

I use the "gpt2" encoding which is a byte pair encoding (BPE) tokenizer.

In [10]:
tokenizer_base = tiktoken.get_encoding("gpt2")

tokenizer = tiktoken.Encoding(
    name="rob-tokenizer",
    pat_str=tokenizer_base._pat_str,
    mergeable_ranks=tokenizer_base._mergeable_ranks,
    special_tokens={
        **tokenizer_base._special_tokens,
        "<|im_start|>": 50257,
        "<|im_end|>": 50258,
        "<|pad|>": 50259,
    }
)

#### Test of the byte pair encoding tokenizer 

In [11]:
# test of tokenizer on reddit_instruct
sample_text = reddit_instruct[0]['post_title'] + " " + reddit_instruct[0]['post_text'] + " " + reddit_instruct[0]['comment_text']
tokens = tokenizer.encode(sample_text)
print(tokens)
print("Decoded text:")
print(tokenizer.decode(tokens)) 
print(f"Sample text length in characters: {len(sample_text)}")
print(f"Sample text length in tokens: {len(tokens)}")   

[2061, 318, 24207, 1616, 2587, 30, 314, 2342, 257, 7684, 286, 1097, 5861, 290, 484, 1561, 546, 275, 32512, 7021, 290, 884, 11, 1312, 373, 11263, 644, 275, 32512, 318, 290, 1312, 18548, 1064, 597, 2562, 7468, 284, 644, 340, 318, 24207, 1616, 318, 655, 262, 1438, 329, 257, 16058, 286, 6147, 13, 554, 262, 29393, 995, 340, 338, 1690, 973, 355, 257, 1790, 1021, 329, 3354, 326, 547, 3235, 1389, 503, 286, 257, 1263, 2512, 286, 2587, 11, 355, 6886, 284, 11721, 3350, 654, 810, 44030, 6147, 318, 19036, 656, 257, 15936, 12070, 503, 286, 9629, 6147, 13, 7080, 3191, 318, 517, 5789, 329, 1588, 17794, 475, 340, 460, 779, 1365, 3081, 286, 21782, 290, 318, 4577, 284, 787, 329, 4833, 17794, 588, 3234, 3354, 13]
Decoded text:
What is Billet material? I watch a bunch of car videos and they talk about billet blocks and such, i was wondering what billet is and i cant find any easy explanation to what it is Billet is just the name for a chunk of metal. In the automotive world it's often used as a short hand 

### Formatting datasets functions

#### Merging datasets

In [12]:
from datasets import concatenate_datasets
combined_train_dataset = concatenate_datasets([  
    wiki_en,
    stories,
    fineweb_edu,
    owt2,  
])  

combined_finetune_dataset = concatenate_datasets([
    q_a1,
    reddit_instruct,
    alpaca,
])

# Shuffle the combined dataset
train_dataset = combined_train_dataset.shuffle(seed=42)
finetune_dataset = combined_finetune_dataset.shuffle(seed=42)
print(f"Train dataset size: {len(combined_train_dataset)}")
print(f"Finetune dataset size: {len(combined_finetune_dataset)}")

# Exemple 

print("Example entry from train dataset:")
index = random.randint(0, len(train_dataset)-1)
print(train_dataset[index])   

Train dataset size: 26209380
Finetune dataset size: 257745
Example entry from train dataset:


#### Custom Dataset class

Inspired by the dataloader from the "LLMs from scratch" repository. But adapted for multi-row text arrays.

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/dataloader.ipynb

In [13]:
class GPTDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length):
        """
        Args:
            dataset: Dataset of the combined hugginface entries
            tokenizer: the initiatokenizer to process text
            max_length: Context window size
        """
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.input_tokens = []
        self.target_tokens = []

        self.pad_token_id = 50259         # <|pad|>
        self.bos_token_id = 50257    # <|im_start|>
        self.eos_token_id = 50258    # <|im_end|>

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Raw text
        
        # Data format handling
        entry = self.data[idx]
        if 'text' in entry:
            text = entry['text']
        elif 'story' in entry:
            text = entry['story']
        elif 'question' in entry and 'answer' in entry: 
            text = "User: " + entry['question'] + " Assistant:" + entry['answer']
        elif 'post_title' in entry and 'post_text' in entry and 'comment_text' in entry:
            text = "User: " + entry['post_title'] + " Assistant:" + entry['post_text'] + " " + entry['comment_text']
        elif 'instruction' in entry and 'output' in entry:
            text = "User: " + entry['instruction'] + " Assistant:" + entry['output']
        else:
            raise ValueError("Unknown data entry format")
        
        text = str(text) # Ensure text is a string
        #print(text)

        # Adding Start and End tokens
        text = "<|im_start|>" + text + "<|im_end|>" 

        # Tokenization
        tokens = self.tokenizer.encode(text, allowed_special="all")

        # Truncation
        tokens = tokens[:self.max_length] #Data is loost here ( fix later with sliding window )

        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)  # All tokens except last
        labels = torch.tensor(tokens[1:], dtype=torch.long)      # All tokens except first

        


        attention_mask = (input_ids != self.pad_token_id).long() # 1 for real tokens, 0 for padding

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

In [14]:
class GPTstreamingDataset(IterableDataset):
    def __init__(self, dataset, tokenizer, max_length):
        """
        Args:
            dataset: Dataset of the combined hugginface entries
            tokenizer: the initiatokenizer to process text
            max_length: Context window size
        """
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.bos_token_id = 50257    # <|im_start|>
        self.eos_token_id = 50258    # <|im_end|>

    def __iter__(self):
        buffer = []
        for entry in self.data:
            # Data format 
            if entry.get('text') is not None:
                text = entry['text']
            elif entry.get('story') is not None:
                text = entry['story']
            elif entry.get('question') is not None and entry.get('answer') is not None: 
                text = f"User: {entry['question']} Assistant: {entry['answer']}"
            elif entry.get('post_title') is not None:
                title = entry.get('post_title', "")
                post = entry.get('post_text', "")
                comment = entry.get('comment_text', "")
                text = f"User: {title} Assistant: {post} {comment}"
            elif entry.get('instruction') is not None and entry.get('output') is not None:
                text = f"User: {entry['instruction']} Assistant: {entry['output']}"
            if text is None:
                continue
                
            text = str(text) 

            # Start and End tokens
            text = "<|im_start|>" + text + "<|im_end|>" 

            # Tokenization
            tokens = self.tokenizer.encode(text, allowed_special="all")

            buffer.extend(tokens) #pile tokens into buffer


            while len(buffer) >= self.max_length +1:
                chunk = buffer[:self.max_length + 1]
                buffer = buffer[self.max_length+ 1:] # remove used tokens

                input_ids = torch.tensor(chunk[:-1], dtype=torch.long)  
                labels = torch.tensor(chunk[1:], dtype=torch.long)      

                yield {
                    'input_ids': input_ids,
                    'labels': labels
                }

## GPT Model 

### GPT config 

This is the configuration for the GPT model i am going to train. It is a smaller version of the GPT-2 model. 

- Context length: 512 tokens
- Embedding dimension: 512
- Number of attention heads: 8
- Number of layers: 8

In [15]:
GPT_CONFIG = {
    "vocab_size": 50260,
    "context_length": 256, # max i could fit on my gpu
    "emb_dim": 384,
    "number_heads": 6,
    "number_layers": 6,
    "drop_rate": 0.1,
}

##### Test of a entry from dataloader

In [None]:
# Empty cuda cache and memory management
torch.cuda.empty_cache()

train_dataset = GPTstreamingDataset(train_dataset, tokenizer, GPT_CONFIG['context_length'])

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=6, pin_memory=True, prefetch_factor=2, persistent_workers=True)

In [17]:
print("##### Test of a entry from dataloader")
batch = next(iter(train_dataloader))
print(batch)    

##### Test of a entry from dataloader
{'input_ids': tensor([[50257,   818,   257,  ...,  4403,  5866, 10062],
        [  290,  7103, 28717,  ..., 23312, 33408,   341],
        [  262,  4403,  5866,  ..., 24337,  1080,   373],
        ...,
        [  897,   276,  4034,  ...,  7184,   290,   262],
        [  286,  3292,    13,  ...,   991,  9389,  1917],
        [ 4525,  4890,    11,  ...,   284,  2074, 48837]]), 'labels': tensor([[  818,   257,  6016,  ...,  5866, 10062,  2838],
        [ 7103, 28717,  1989,  ..., 33408,   341,   287],
        [ 4403,  5866, 10062,  ...,  1080,   373,   973],
        ...,
        [  276,  4034,    13,  ...,   290,   262,  5236],
        [ 3292,    13,   198,  ...,  9389,  1917,    13],
        [ 4890,    11,  8098,  ...,  2074, 48837, 14317]])}


#### Pytroch model implementation

For this first implementation, i am using the transformer and embedding modules from PyTorch. Later i will try to implement the attention mechanism from scratch for better understanding.

https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [18]:
class GPTModel(nn.Module):
    """
    Gpt model class using transformer library
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Network components 
        ## Embedding layers
        self.embedding = nn.Embedding(config['vocab_size'], config['emb_dim'])
        self.positional_encoding = nn.Embedding(config['context_length'], config['emb_dim'])
        ## Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['emb_dim'],
            nhead=config['number_heads'],
            dim_feedforward=4 * config['emb_dim'],
            dropout=config['drop_rate'],
            activation='gelu',
            batch_first=True,
            norm_first=True # stabilityy 
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['number_layers'])
        ## Output layer
        self.output_layer = nn.Linear(config['emb_dim'], config['vocab_size'], bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)



    def forward(self, input_ids, attention_mask, label_ids=None):   
        batch_size, seq_length = input_ids.shape

        # Embedding
        token_embeddings = self.embedding(input_ids)  
        pos_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.positional_encoding(pos_ids)  # (batch_size, seq_length, emb_dim)

        embeddings = token_embeddings + position_embeddings  # (batch_size, seq_length, emb_dim)

        # Prevent attending to future tokens
        causal_mask = torch.triu(torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device), diagonal=1)

        # voiding to pay attatention padding tokens
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
        
        x = self.transformer(embeddings, mask=causal_mask, src_key_padding_mask=key_padding_mask, is_causal=True)
        logits = self.output_layer(x)

        # Computing loss 
        if label_ids is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.config['vocab_size']), label_ids.view(-1)) # Applies loss to predictions
            return logits, loss

        return logits, None




In [19]:
class GPTModelNoPad(nn.Module):
    """
    Gpt model class using transformer library adapted for streaming dataset without padding mask
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Network components 
        ## Embedding layers
        self.embedding = nn.Embedding(config['vocab_size'], config['emb_dim'])
        self.positional_encoding = nn.Embedding(config['context_length'], config['emb_dim'])
        ## Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['emb_dim'],
            nhead=config['number_heads'],
            dim_feedforward=4 * config['emb_dim'],
            dropout=config['drop_rate'],
            activation='gelu',
            batch_first=True,
            norm_first=True # stabilityy 
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['number_layers'])
        ## Output layer
        self.output_layer = nn.Linear(config['emb_dim'], config['vocab_size'], bias=False)
        # Weight Tying ( input and output embeddings share weights)
        self.output_layer.weight = self.embedding.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, input_ids, label_ids=None):   
        batch_size, seq_length = input_ids.shape

        # Embedding
        token_embeddings = self.embedding(input_ids)  
        pos_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.positional_encoding(pos_ids)  # (batch_size, seq_length, emb_dim)

        embeddings = token_embeddings + position_embeddings  # (batch_size, seq_length, emb_dim)

        # Prevent attending to future tokens
        causal_mask = torch.triu(torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device), diagonal=1)
        
        x = self.transformer(embeddings, mask=causal_mask, is_causal=True)
        logits = self.output_layer(x)

        # Computing loss 
        if label_ids is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config['vocab_size']), label_ids.view(-1)) # Applies loss to predictions
            return logits, loss

        return logits, None 


### Model instantiation 

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModelNoPad(GPT_CONFIG).to(device)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=10000) # Linear learning rate scheduler
scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision



GPTModelNoPad(
  (embedding): Embedding(50260, 384)
  (positional_encoding): Embedding(256, 384)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
        )
        (linear1): Linear(in_features=384, out_features=1536, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1536, out_features=384, bias=True)
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_layer): Linear(in_features=384, out_features=50260, bias=False)
)


  scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision


### Number of parameters calculation

In [21]:
# Calculate number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)    

count_params = count_parameters(model)  
print(f"Number of trainable parameters: {count_params}")

Number of trainable parameters: 30044928


## Training setup

In [22]:
def inference(model, tokenizer, prompt, max_length=256, device='cpu'):
    model.eval()
    input_ids = tokenizer.encode(prompt, allowed_special="all")
    input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

    generated_ids = input_ids
    max_length = max_length - input_ids.shape[1]  # Remaining length for generation
    with torch.no_grad():
        for _ in range(max_length):
            #attention_mask = torch.ones_like(generated_ids)  # All tokens are real (no padding)
            logits, _ = model(generated_ids) # model(generated_ids, attention_mask)
            next_token_logits = logits[:, -1, :]  #next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)  # (1, 1) #
            probs = torch.softmax(next_token_logits, dim=-1)
            #print(next_token_logits)
            next_token_id = torch.multinomial(probs, num_samples=1)  # mult
            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)  # Append to sequence

            if next_token_id.item() == 50258:  # Stop if eot token is generated
                break

    generated_text = tokenizer.decode(generated_ids.squeeze().tolist())
    return generated_text


In [23]:
accumulation_steps= 1  # Number of steps to accumulate gradients

def train_loop(model, dataloader, optimizer, scheduler, device, num_epochs=3, accumulation_steps = 4, question_interval=500, saving_interval=5000):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        step_count = 0

        for i, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            #attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):  # bf16 optimized for ampere archi
                logits, loss = model(input_ids, labels) # model(input_ids, attention_mask, labels)
                loss = loss / accumulation_steps
            scaler.scale(loss).backward()
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            epoch_loss += loss.item()*accumulation_steps
            progress_bar.set_postfix(loss=loss.item() * accumulation_steps)

            # Inference check
            step_count += 1
            if step_count % question_interval == 0:
                model.eval()
                with torch.no_grad():
                    prompt = "AI is technology that enables computers and machines to"
                    generated_text = inference(model, tokenizer, prompt, max_length=50, device=device)
                    print(f"\n[Inference at step {step_count}]")
                    #save generated text to file train_ouput.txt
                    with open("outputs/train_output.txt", "a") as f:
                        if generated_text.strip() != "":
                            f.write(f"\n[Inference at step {step_count}]: {generated_text}\n")  
                            f.write("-"*50 + "\n")
                        else:
                            f.write(f"\n[Inference at step {step_count}]: [No output generated]\n")
                            f.write("-"*50 + "\n")
                model.train()

            if step_count % saving_interval == 0:
                # Save model checkpoint
                date = datetime.now().strftime("%Y%m%d")
                # Create models directory if it doesn't exist
                Path(f"models/{date}").mkdir(parents=True, exist_ok=True)
                checkpoint_path = f"models/{date}/weights_step{step_count}.pt"
                torch.save(model.state_dict(), checkpoint_path)
                print(f"\nModel checkpoint saved at step {step_count} to {checkpoint_path}")

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")


In [24]:
def load_model_weights(model, checkpoint_path, device):
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)
    print(f"Model loaded from {checkpoint_path}")

### First training on combined train dataset

In [25]:
#empty gpu memory 
torch.cuda.empty_cache()
torch.cuda.synchronize()

load_model_weights(model, "models/starter/weights_step10000.pt", device)

train_loop(model, train_dataloader, optimizer, scheduler, device, num_epochs=1)

  model.load_state_dict(torch.load(checkpoint_path, map_location=device))


Model loaded from models/starter/weights_step10000.pt


Epoch 1/1: 15it [00:02,  5.72it/s, loss=5.82]


KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()
torch.cuda.synchronize()

## Sources 

### Principal references: 
- https://arxiv.org/abs/2005.14165 (GPT-3 paper)
- https://arxiv.org/abs/2002.05709 (Attention is all you need paper)
- Build a Large Language Model (from scratch) by Sebastian Raschka