# Projet GPT - Train

This notebook contains the code used to train a small language model using PyTorch from scratch. The model is inspired by the GPT architecture.


#### Hardware
- RTX3060 12GB VRAM
- AMD Ryzen 7 5800X 8-Core
- 32GB RAM
- Ubuntu 22.04 LTS

In [12]:
import torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import random, math

from datasets import load_dataset
import tiktoken

CACHE_DIR = "/Volumes/RobertsDisk/datasets"

## Datasets

### Common knowledge datasets:

##### English Wikipedia crawled dataset

In [3]:
# English Wikipedia crawled dataset
# path to store the dataset cache: /Volumes/RobertsDisk
wiki_en = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', cache_dir=CACHE_DIR) 
print("English Wikipedia dataset loaded.")
print("dataset size in gb:", wiki_en.dataset_size / (1024**3))
print("Number of entries:", len(wiki_en))
print("-"*50)
print("Example entry:")
print(wiki_en[random.randint(0, len(wiki_en)-1)]['text'][:500])


English Wikipedia dataset loaded.
dataset size in gb: 18.812774107791483
Number of entries: 6407814
--------------------------------------------------
Example entry:
Yusuf Mohamed Ismail "Bari-Bari" (, , born 15 July 1958 – died 27 March 2015) was a Somali diplomat and politician. He was the Ambassador of Somalia to the United Nations Human Rights Office in Geneva.

Personal life
Ismail was born in 1960 in Rome, Italy to an aristocratic Somali family. He belonged to the Majeerteen Harti Darod clan. His family originally hailed from Garowe, the administrative capital of the northeastern Puntland regional state of Somalia.

For his tertiary education, Ismail e


#### Simple stories dataset

In [4]:
# Simple stories dataset
stories = load_dataset("SimpleStories/SimpleStories", split='train', cache_dir=CACHE_DIR)
print("Simple stories dataset loaded.")
print("dataset size in mb:", stories.dataset_size / (1024**2))
print("Number of entries:", len(stories))
print("-"*50)
print("Example entry:")
print(stories[random.randint(0, len(stories)-1)]['story'][:500])

Simple stories dataset loaded.
dataset size in mb: 3030.012650489807
Number of entries: 2115696
--------------------------------------------------
Example entry:
A park bench sat under a big oak tree. Lena often came here to write. She loved the quiet and the sounds of nature. One afternoon, while writing her story, she noticed a shadow moving near the tree. Her curiosity sparked, and she looked up. A small bird was hopping on the ground, looking for food.

The bird was bright and cheerful. It danced around, pecking at the grass. Lena smiled, feeling inspired by its joy. She watched as it picked up a small crumb. In that moment, she thought of how simple


##### FineWeb-Edu dataset

In [5]:
fineweb_edu = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT",  split='train', cache_dir=CACHE_DIR)

print("FineWeb-Edu stream ready.")
print("dataset size in gb:", fineweb_edu.dataset_size / (1024**3))
print("Number of entries:", len(fineweb_edu))
print("-"*50)
print("Example entry:")
print(fineweb_edu[random.randint(0, len(fineweb_edu)-1)]['text'][:500]) 


Generating train split: 100%|██████████| 9672101/9672101 [20:58<00:00, 7687.81 examples/s]


FineWeb-Edu stream ready.
dataset size in gb: 45.730818568728864
Number of entries: 9672101
--------------------------------------------------
Example entry:
You Cannot Not Have Conceptual Understanding
In education, we seem to take some delight in shoveling a confused mix of folksy connotations into sciencey-shelled words and phrases. Some of my colleagues would call the result edujargon, though I think that word allows us to feel too smug about our own obtuseness—as though the problem is that the field of education is so darned technical.
Anyway, I’ve been itching to pick on one such phrase lately, conceptual understanding, and I think I’ll start h


#### Some Q&A data to improve the model's ability to answer questions:

In [8]:
q_a1 = load_dataset("agentlans/text-sft-questions-answers-only", split='train', cache_dir=CACHE_DIR)
print("Q&A dataset loaded.")
print("dataset size in mb:", q_a1.dataset_size / (1024**2))
print("Number of entries:", len(q_a1))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(q_a1)-1)
print(q_a1[index]['question'][:500], "\n", q_a1[index]['answer'][:500])

Generating train split: 100%|██████████| 120959/120959 [00:01<00:00, 114568.59 examples/s]
Generating validation split: 100%|██████████| 30240/30240 [00:00<00:00, 116389.13 examples/s]

Q&A dataset loaded.
dataset size in mb: 46.480509757995605
Number of entries: 120959
--------------------------------------------------
Example entry:
What are the primary goals of diagnostic radiology in healthcare? 
 Diagnostic radiology helps healthcare professionals visualize structures inside the body to diagnose the cause of symptoms, monitor treatment responses, and screen for different illnesses, such as cancer or heart disease.





In [None]:
#euclaise/reddit-instruct
reddit_instruct = load_dataset("euclaise/reddit-instruct", split='train', cache_dir=CACHE_DIR)
# reddit_instruct = load_dataset("Felladrin/ChatML-reddit-instruct-curated", split='train', cache_dir=CACHE_DIR)
print("Reddit Instruct dataset loaded.")
print("dataset size in gb:", reddit_instruct.dataset_size / (1024**3))
print("Number of entries:", len(reddit_instruct))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(reddit_instruct)-1)
print(reddit_instruct[index]['post_title'][:500], reddit_instruct[index]['post_text'][:500]), "\n", reddit_instruct[index]['comment_text'][:500]

Generating train split: 100%|██████████| 84784/84784 [00:02<00:00, 33580.29 examples/s]
Generating test split: 100%|██████████| 2000/2000 [00:00<00:00, 33807.86 examples/s]

Reddit Instruct dataset loaded.
dataset size in gb: 0.09901080373674631
Number of entries: 84784
--------------------------------------------------
Example entry:
ELI5: Why are North American Christians in particular so against evolution, while peers in Europe and elsewhere are not? 





(None,
 '\n',
 "Because schools here teach biology and evolution like any other subject - that's how it works, deal with it.\n\nMaking debates about creationism and allowing people to choose if their kids are lied to, or taught the truth is counterproductive - it legitimizes ignorance.\n\nNow if we could deal with homeopathy the same way.")

In [9]:
# tatsu-lab/alpaca ( for Q&A fine-tuning )
alpaca = load_dataset("tatsu-lab/alpaca", split='train')
print("Alpaca dataset loaded.")
print("dataset size in mb:", alpaca.dataset_size / (1024**2))
print("Number of entries:", len(alpaca))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(alpaca)-1)
print(alpaca[index]['instruction'][:500], "\n", alpaca[index]['output'][:500])

Generating train split: 100%|██████████| 52002/52002 [00:00<00:00, 478765.59 examples/s]

Alpaca dataset loaded.
dataset size in mb: 44.06797695159912
Number of entries: 52002
--------------------------------------------------
Example entry:
Create a shopping list for someone who wants to make a cheesecake. 
 Shopping list for cheesecake:
- Cream cheese
- Graham crackers
- Heavy cream
- Unsalted butter 
- Brown sugar
- White sugar
- Salt 
- Vanilla extract 
- Eggs
- Sour cream





In [11]:
tokenizer = tiktoken.get_encoding("gpt2")

#### Test of the byte pair encoding tokenizer 

In [15]:
# test of tokenizer on reddit_instruct
sample_text = reddit_instruct[0]['post_title'] + " " + reddit_instruct[0]['post_text'] + " " + reddit_instruct[0]['comment_text']
tokens = tokenizer.encode(sample_text)
print(tokens)
print("Decoded text:")
print(tokenizer.decode(tokens)) 
print(f"Sample text length in characters: {len(sample_text)}")
print(f"Sample text length in tokens: {len(tokens)}")   

[2061, 318, 24207, 1616, 2587, 30, 314, 2342, 257, 7684, 286, 1097, 5861, 290, 484, 1561, 546, 275, 32512, 7021, 290, 884, 11, 1312, 373, 11263, 644, 275, 32512, 318, 290, 1312, 18548, 1064, 597, 2562, 7468, 284, 644, 340, 318, 24207, 1616, 318, 655, 262, 1438, 329, 257, 16058, 286, 6147, 13, 554, 262, 29393, 995, 340, 338, 1690, 973, 355, 257, 1790, 1021, 329, 3354, 326, 547, 3235, 1389, 503, 286, 257, 1263, 2512, 286, 2587, 11, 355, 6886, 284, 11721, 3350, 654, 810, 44030, 6147, 318, 19036, 656, 257, 15936, 12070, 503, 286, 9629, 6147, 13, 7080, 3191, 318, 517, 5789, 329, 1588, 17794, 475, 340, 460, 779, 1365, 3081, 286, 21782, 290, 318, 4577, 284, 787, 329, 4833, 17794, 588, 3234, 3354, 13]
Decoded text:
What is Billet material? I watch a bunch of car videos and they talk about billet blocks and such, i was wondering what billet is and i cant find any easy explanation to what it is Billet is just the name for a chunk of metal. In the automotive world it's often used as a short hand 