# Projet GPT - Train

This notebook contains the code used to train a small language model using PyTorch from scratch. The model is inspired by the GPT architecture.


#### Hardware
- RTX3060 12GB VRAM
- AMD Ryzen 7 5800X 8-Core
- 32GB RAM
- Ubuntu 22.04 LTS

In [4]:
import os
#CACHE_DIR = "/media/rob/RobertsDisk/data_cache"
CACHE_DIR = "/home/rob/projet_gpt/cache_huggingface"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, "datasets")
os.environ['HF_METRICS_CACHE'] = os.path.join(CACHE_DIR, "metrics")
os.environ['HF_MODULES_CACHE'] = os.path.join(CACHE_DIR, "modules")


In [5]:
import torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import random, math

from datasets import load_dataset,concatenate_datasets
import tiktoken
from tqdm import tqdm

## Datasets

### Common knowledge datasets:

##### English Wikipedia crawled dataset

In [6]:
# English Wikipedia crawled dataset
# path to store the dataset cache: /Volumes/RobertsDisk
wiki_en = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', cache_dir=CACHE_DIR) 
print("English Wikipedia dataset loaded.")
print("dataset size in gb:", wiki_en.dataset_size / (1024**3))
print("Number of entries:", len(wiki_en))
print("-"*50)
print("Example entry:")
print(wiki_en[random.randint(0, len(wiki_en)-1)]['text'][:500])


English Wikipedia dataset loaded.
dataset size in gb: 18.812774107791483
Number of entries: 6407814
--------------------------------------------------
Example entry:
Hernán Ismael Galíndez (30 March 1987) is an Argentine-Ecuadorian professional footballer who plays as a goalkeeper for Ecuadorian Serie A club Aucas.

Born in Argentina, he began his career in Argentina with Rosario Central before settling in Ecuador with Universidad Católica, where he made over 300 appearances in a nine-year spell. In 2022, he joined Universidad de Chile but left the club after six months, citing harassment from the club's fanbase. He returned to Ecuador with Aucas, helping th


#### Simple stories dataset

In [7]:
# Simple stories dataset
stories = load_dataset("SimpleStories/SimpleStories", split='train', cache_dir=CACHE_DIR)
print("Simple stories dataset loaded.")
print("dataset size in mb:", stories.dataset_size / (1024**2))
print("Number of entries:", len(stories))
print("-"*50)
print("Example entry:")
print(stories[random.randint(0, len(stories)-1)]['story'][:500])

Generating train split: 100%|██████████| 2115696/2115696 [00:05<00:00, 396051.19 examples/s]
Generating test split: 100%|██████████| 21371/21371 [00:00<00:00, 424350.81 examples/s]


Simple stories dataset loaded.
dataset size in mb: 3030.012650489807
Number of entries: 2115696
--------------------------------------------------
Example entry:
Ovens crackled as Rita prepared for the festival. She had received a letter from her friend, explaining how their town celebrated by baking bread for the community. The letter made Rita think of her own town's traditions. She remembered how food brought people together and decided to host a baking event. 

Rita invited neighbors to join her in making bread. They laughed and shared recipes, blending flavors from their cultures. As the bread baked, the smell wafted through the streets, drawing peo


##### FineWeb-Edu dataset

In [8]:
fineweb_edu = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT",  split='train', cache_dir=CACHE_DIR)

print("FineWeb-Edu is ready.")
print("dataset size in gb:", fineweb_edu.dataset_size / (1024**3))
print("Number of entries:", len(fineweb_edu))
print("-"*50)
print("Example entry:")
print(fineweb_edu[random.randint(0, len(fineweb_edu)-1)]['text'][:500]) 


Generating train split: 100%|██████████| 9672101/9672101 [02:42<00:00, 59365.68 examples/s]


FineWeb-Edu is ready.
dataset size in gb: 45.730818568728864
Number of entries: 9672101
--------------------------------------------------
Example entry:
“The typical American diet is too high in calories, saturated fat, sodium, and added sugars, and does not have enough fruits, vegetables, whole grains, calcium, and fiber. Such that contributes to some of the leading causes of death and increases the risk of lots of fatal diseases“– Paul Ebeling
The risks associated with poor eating habits have long been known, and yet a new study revealed just how detrimental the dangers of a poor diet can actually be to our long-term health.
In a recent study 


#### Some Q&A data to improve the model's ability to answer questions:

In [9]:
q_a1 = load_dataset("agentlans/text-sft-questions-answers-only", split='train', cache_dir=CACHE_DIR)
print("Q&A dataset loaded.")
print("dataset size in mb:", q_a1.dataset_size / (1024**2))
print("Number of entries:", len(q_a1))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(q_a1)-1)
print(q_a1[index]['question'][:500], "\n", q_a1[index]['answer'][:500])

Generating train split: 100%|██████████| 120959/120959 [00:00<00:00, 802562.08 examples/s]
Generating validation split: 100%|██████████| 30240/30240 [00:00<00:00, 1534742.85 examples/s]

Q&A dataset loaded.
dataset size in mb: 46.480509757995605
Number of entries: 120959
--------------------------------------------------
Example entry:
What was the first song Marlboro created, and what was its Portuguese equivalent of? 
 Marlboro's first song was 'Melô da Mulher Feia', a Portuguese version of 2 Live Crew's 'Do Wah Diddy', which achieved significant success on the radio.





In [10]:
#euclaise/reddit-instruct
reddit_instruct = load_dataset("euclaise/reddit-instruct", split='train', cache_dir=CACHE_DIR)
# reddit_instruct = load_dataset("Felladrin/ChatML-reddit-instruct-curated", split='train', cache_dir=CACHE_DIR)
print("Reddit Instruct dataset loaded.")
print("dataset size in gb:", reddit_instruct.dataset_size / (1024**3))
print("Number of entries:", len(reddit_instruct))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(reddit_instruct)-1)
print(reddit_instruct[index]['post_title'][:500], reddit_instruct[index]['post_text'][:500]), "\n", reddit_instruct[index]['comment_text'][:500]

Generating train split: 100%|██████████| 84784/84784 [00:00<00:00, 538655.93 examples/s]
Generating test split: 100%|██████████| 2000/2000 [00:00<00:00, 243684.87 examples/s]

Reddit Instruct dataset loaded.
dataset size in gb: 0.09901080373674631
Number of entries: 84784
--------------------------------------------------
Example entry:
If I put two things in the microwave instead of one, are they both going to be heated slower? The scenario is the following: I have a microwave that operates at 1000 W and I have two cups of water.

I noticed that for the same amount of time spent in the microwave, a single cup of water will get warmer if it's in there alone. If I put both cups in there, I also habe to spend more time.

So is the power being "shared" by the cups, and if so, do they share this power based on volume / mass / position?

Do both cups of water receive 500 W each?





(None,
 '\n',
 'A fixed amount of radiation is bouncing around in there. Not all of it will manage to be absorbed by the item. \n\nSo there is some loss and some of it transfers to the heat you want. So multiple items are indeed sharing the same amount of energy. But the exact proportions of heating, how much loss are a complex matter of the shapes of these items, their position the exact design of your microwave and things like that. It probably won’t work out that 2 cups takes exactly twice as long as 1 cup.\n\nB')

In [11]:
# tatsu-lab/alpaca ( for Q&A fine-tuning )
alpaca = load_dataset("tatsu-lab/alpaca", split='train')
print("Alpaca dataset loaded.")
print("dataset size in mb:", alpaca.dataset_size / (1024**2))
print("Number of entries:", len(alpaca))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(alpaca)-1)
print(alpaca[index]['instruction'][:500], "\n", alpaca[index]['output'][:500])

Generating train split: 100%|██████████| 52002/52002 [00:00<00:00, 1103159.06 examples/s]

Alpaca dataset loaded.
dataset size in mb: 44.06797695159912
Number of entries: 52002
--------------------------------------------------
Example entry:
Generate an HTML code for a 3Cols table that also has a header. Output the code directly. 
 <table>
  <tr>
    <th>Header 1</th>
    <th>Header 2</th>
    <th>Header 3</th>
  </tr>
  <tr>
   <td>Column 1</td>
   <td>Column 2</td>
   <td>Column 3</td>
  </tr>
</table>





## Data Preprocessing

#### Tokenizer setup

For this project i use tiktoken for the tokenizer, as it is the same tokenizer used by OpenAI for their models.

I use the "gpt2" encoding which is a byte pair encoding (BPE) tokenizer.

In [18]:
tokenizer_base = tiktoken.get_encoding("gpt2")

tokenizer = tiktoken.Encoding(
    name="rob-tokenizer",
    pat_str=tokenizer_base._pat_str,
    mergeable_ranks=tokenizer_base._mergeable_ranks,
    special_tokens={
        **tokenizer_base._special_tokens,
        "<|im_start|>": 100264,
        "<|im_end|>": 100265,
        "<|pad|>": 0,
    }
)

#### Test of the byte pair encoding tokenizer 

In [13]:
# test of tokenizer on reddit_instruct
sample_text = reddit_instruct[0]['post_title'] + " " + reddit_instruct[0]['post_text'] + " " + reddit_instruct[0]['comment_text']
tokens = tokenizer.encode(sample_text)
print(tokens)
print("Decoded text:")
print(tokenizer.decode(tokens)) 
print(f"Sample text length in characters: {len(sample_text)}")
print(f"Sample text length in tokens: {len(tokens)}")   

[2061, 318, 24207, 1616, 2587, 30, 314, 2342, 257, 7684, 286, 1097, 5861, 290, 484, 1561, 546, 275, 32512, 7021, 290, 884, 11, 1312, 373, 11263, 644, 275, 32512, 318, 290, 1312, 18548, 1064, 597, 2562, 7468, 284, 644, 340, 318, 24207, 1616, 318, 655, 262, 1438, 329, 257, 16058, 286, 6147, 13, 554, 262, 29393, 995, 340, 338, 1690, 973, 355, 257, 1790, 1021, 329, 3354, 326, 547, 3235, 1389, 503, 286, 257, 1263, 2512, 286, 2587, 11, 355, 6886, 284, 11721, 3350, 654, 810, 44030, 6147, 318, 19036, 656, 257, 15936, 12070, 503, 286, 9629, 6147, 13, 7080, 3191, 318, 517, 5789, 329, 1588, 17794, 475, 340, 460, 779, 1365, 3081, 286, 21782, 290, 318, 4577, 284, 787, 329, 4833, 17794, 588, 3234, 3354, 13]
Decoded text:
What is Billet material? I watch a bunch of car videos and they talk about billet blocks and such, i was wondering what billet is and i cant find any easy explanation to what it is Billet is just the name for a chunk of metal. In the automotive world it's often used as a short hand 

### Formatting datasets functions

In [None]:
def format_plain_text(batch):
    # Handle both 'text' and 'story' 
    keys = batch.keys()
    if 'text' in keys:
        sources = batch['text']
    elif 'story' in keys:
        sources = batch['story']
    else:
        # Fallback if neither exists 
        return {'text': [''] * len(list(batch.values())[0])}
        
    return {'text': [f"<|im_start|>\n{t}\n<|im_end|>" for t in sources]}

def format_qa(batch):
    return {'text': [f"<|im_start|>\nQuestion: {q}\nAnswer: {a}\n<|im_end|>" 
                     for q, a in zip(batch['question'], batch['answer'])]}

def format_reddit(batch):
    return {'text': [f"<|im_start|>\nQuestion: {pt}\nAnswer: {ptext}\nAnswer: {ct}\n<|im_end|>" 
                     for pt, ptext, ct in zip(batch['post_title'], batch['post_text'], batch['comment_text'])]}

def format_alpaca(batch):
    return {'text': [f"<|im_start|>\nQuestion: {inst}\nAnswer: {out}\n<|im_end|>" 
                     for inst, out in zip(batch['instruction'], batch['output'])]}

In [15]:
print("Formatting Wiki...")
wiki_formatted = wiki_en.map(format_plain_text, remove_columns=wiki_en.column_names, batched=True)

print("Formatting Stories...")
stories_formatted = stories.map(format_plain_text, remove_columns=stories.column_names, batched=True)

print("Formatting FineWeb...")
fineweb_formatted = fineweb_edu.map(format_plain_text, remove_columns=fineweb_edu.column_names, batched=True)

print("Formatting QA...")
qa_formatted = q_a1.map(format_qa, remove_columns=q_a1.column_names, batched=True)

print("Formatting Reddit...")
reddit_formatted = reddit_instruct.map(format_reddit, remove_columns=reddit_instruct.column_names, batched=True)

print("Formatting Alpaca...")
alpaca_formatted = alpaca.map(format_alpaca, remove_columns=alpaca.column_names, batched=True)

Formatting Wiki...


Map: 100%|██████████| 6407814/6407814 [01:23<00:00, 76374.26 examples/s] 


Formatting Stories...


Map: 100%|██████████| 2115696/2115696 [00:09<00:00, 212877.96 examples/s]


Formatting FineWeb...


Map: 100%|██████████| 9672101/9672101 [03:17<00:00, 48906.32 examples/s]


Formatting QA...


Map: 100%|██████████| 120959/120959 [00:00<00:00, 461960.44 examples/s]


Formatting Reddit...


Map: 100%|██████████| 84784/84784 [00:00<00:00, 197938.66 examples/s]


Formatting Alpaca...


Map: 100%|██████████| 52002/52002 [00:00<00:00, 319196.68 examples/s]


#### Merging datasets

In [16]:
from datasets import concatenate_datasets
combined_hf_dataset = concatenate_datasets([
    wiki_formatted, 
    stories_formatted, 
    fineweb_formatted, 
    qa_formatted, 
    reddit_formatted, 
    alpaca_formatted
])

# Shuffle the combined dataset
combined_hf_dataset = combined_hf_dataset.shuffle(seed=42)

print(f"Combined dataset size: {len(combined_hf_dataset)}")

Combined dataset size: 18453356


#### Custom Dataset class

In [25]:
class TextDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length=512):
        """
        Args:
            hf_dataset: The Hugging Face dataset object (from Part 1)
            tokenizer: The tokenizer to process text
            max_length: Context window size
        """
        self.data = hf_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.pad_token_id = 0       # <|pad|>
        self.eos_token_id = 100265  # <|im_end|>

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Raw text
        text = self.data[idx]['text']
        
        # Tokenization
        tokens = self.tokenizer.encode(text, allowed_special="all")

        # Truncation and padding
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
            tokens[-1] = self.eos_token_id  # ensure last token is eos



        input_ids = torch.tensor(tokens, dtype=torch.long)
        attention_mask = torch.ones_like(input_ids)


        #Padding 
        padding_length = self.max_length - input_ids.size(0)
        if padding_length > 0:
            # Create padding tensors
            pad_ids = torch.full((padding_length,), self.pad_token_id, dtype=torch.long)
            pad_mask = torch.zeros((padding_length,), dtype=torch.long) # 0 = ignore
            
            # Concatenate
            input_ids = torch.cat([input_ids, pad_ids])
            attention_mask = torch.cat([attention_mask, pad_mask])


        labels = input_ids.clone()

        #adding the ignore index for padding tokens
        if padding_length > 0:
            # We know the padding is at the end
            labels[-padding_length:] = -100

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

## Sources 

### Principal references: 
- https://arxiv.org/abs/2005.14165 (GPT-3 paper)
- https://arxiv.org/abs/2002.05709 (Attention is all you need paper)
- Build a Large Language Model (from scratch) by Sebastian Raschka

### About Padding tokens in Language Modeling
- https://arxiv.org/html/2510.01238v1 