# Projet GPT - Train

This notebook contains the code used to train a small language model using PyTorch from scratch. The model is inspired by the GPT architecture.


#### Hardware
- RTX3060 12GB VRAM
- AMD Ryzen 7 5800X 8-Core
- 32GB RAM
- Ubuntu 22.04 LTS

In [9]:
import os
CACHE_DIR = "/media/rob/RobsDisk/cache_data_llm"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, "datasets")
os.environ['HF_METRICS_CACHE'] = os.path.join(CACHE_DIR, "metrics")
os.environ['HF_MODULES_CACHE'] = os.path.join(CACHE_DIR, "modules")


In [10]:
import torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import random, math

from datasets import load_dataset,concatenate_datasets
import tiktoken
from tqdm import tqdm
from datetime import datetime

## Datasets

### Common knowledge datasets:

##### English Wikipedia crawled dataset

In [11]:
# English Wikipedia crawled dataset
# path to store the dataset cache: /Volumes/RobertsDisk
wiki_en = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', cache_dir=CACHE_DIR) 
print("English Wikipedia dataset loaded.")
print("dataset size in gb:", wiki_en.dataset_size / (1024**3))
print("Number of entries:", len(wiki_en))
print("-"*50)
print("Example entry:")
print(wiki_en[random.randint(0, len(wiki_en)-1)]['text'])


English Wikipedia dataset loaded.
dataset size in gb: 18.812774107791483
Number of entries: 6407814
--------------------------------------------------
Example entry:
is a Japanese anime television series that adapted several Agatha Christie stories about Hercule Poirot and Miss Marple. A new character named Maybelle West, Miss Marple's great-niece, who becomes Poirot's junior assistant, is used to connect the two detectives.

The series was broadcast from 4 July 2004 to 15 May 2005 on NHK, and continues to be shown in re-runs on NHK and other networks in Japan. The series was adapted as manga under the same title, which was released in 2004 and 2005.

Adaptation 

The TV series is a generally faithful adaptation of the original stories given the time constraints (typically one 25-minute episode for a short story, four episodes for a novel). Despite being a modern Japanese adaptation, the original (mainly English) locations and time period are retained. The most obvious story change is 

#### Simple stories dataset

In [12]:
# Simple stories dataset
stories = load_dataset("SimpleStories/SimpleStories", split='train', cache_dir=CACHE_DIR)
print("Simple stories dataset loaded.")
print("dataset size in mb:", stories.dataset_size / (1024**2))
print("Number of entries:", len(stories))
print("-"*50)
print("Example entry:")
print(stories[random.randint(0, len(stories)-1)]['story'])

Simple stories dataset loaded.
dataset size in mb: 3030.012650489807
Number of entries: 2115696
--------------------------------------------------
Example entry:
A radiant star shone brightly in the night sky over a sleepy town. Each night, people would look up and make wishes. One day, a jealous star saw the shining star and decided to cover it with dark clouds. The people became sad, unable to see their favorite star. They gathered and spoke of their loss. 

Then, a brave child shouted, "We will not let one cloud take our joy!" The townsfolk began to sing, their voices rising high. The clouds slowly parted, revealing the star. It sparkled brighter than ever. The jealous star learned that love and hope can break even the darkest clouds. The town celebrated their star, and their wishes flowed freely once again.


##### FineWeb-Edu dataset

In [13]:
fineweb_edu = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT",  split='train', cache_dir=CACHE_DIR)

print("FineWeb-Edu is ready.")
print("dataset size in gb:", fineweb_edu.dataset_size / (1024**3))
print("Number of entries:", len(fineweb_edu))
print("-"*50)
print("Example entry:")
print(fineweb_edu[random.randint(0, len(fineweb_edu)-1)]['text']) 


FineWeb-Edu is ready.
dataset size in gb: 45.730818568728864
Number of entries: 9672101
--------------------------------------------------
Example entry:
June 4, 2013 — The chemical secrets of a concrete Roman breakwater that has spent the last 2,000 years submerged in the Mediterranean Sea have been uncovered by an international team of researchers led by Paulo Monteiro of the U.S. Department of Energy's Lawrence Berkeley National Laboratory (Berkeley Lab), a professor of civil and environmental engineering at the University of California, Berkeley.
Analysis of samples provided by team member Marie Jackson pinpointed why the best Roman concrete was superior to most modern concrete in durability, why its manufacture was less environmentally damaging -- and how these improvements could be adopted in the modern world.
"It's not that modern concrete isn't good -- it's so good we use 19 billion tons of it a year," says Monteiro. "The problem is that manufacturing Portland cement accounts f

In [14]:
# OpenWebText2 dataset
owt2 = load_dataset("Skylion007/openwebtext", split="train", cache_dir=CACHE_DIR)
print("OpenWebText2 dataset loaded.")
print("Dataset size in GB:", owt2.dataset_size / (1024**3))
print("Number of entries:", len(owt2))
print("-"*50)
print("Example entry:")
print(owt2[random.randint(0, len(owt2)-1)]['text'])

OpenWebText2 dataset loaded.
Dataset size in GB: 37.03822539001703
Number of entries: 8013769
--------------------------------------------------
Example entry:
THE skipper and navigation officer of the cargo ship Rena, which smashed into a reef, causing New Zealand's worst maritime environmental disaster, have both been jailed for seven months.

Captain Mauro Balomaga, 44, and navigation officer Leonil Relon, 37, both Filipinos, were today sentenced in Tauranga District Court, in the North Island, on a raft of charges laid after the 236-metre, 37,000-tonne cargo ship struck Astrolabe Reef off Tauranga in the early hours of October 5 last year.

It spilled about 360 tonnes of heavy fuel oil into the sea, which washed up on local beaches and killed wildlife.

Containers were washed overboard and clean-up crews are still picking up debris from the wreck, which broke apart in January after being pounded by heavy seas.

The disaster, which sparked a massive anti-pollution response, ruined t

#### Some Q&A data to improve the model's ability to answer questions:

In [15]:
q_a1 = load_dataset("agentlans/text-sft-questions-answers-only", split='train', cache_dir=CACHE_DIR)
print("Q&A dataset loaded.")
print("dataset size in mb:", q_a1.dataset_size / (1024**2))
print("Number of entries:", len(q_a1))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(q_a1)-1)
print(q_a1[index]['question'][:500], "\n", q_a1[index]['answer'])

Q&A dataset loaded.
dataset size in mb: 46.480509757995605
Number of entries: 120959
--------------------------------------------------
Example entry:
How can organizations benchmark themselves against the practices of CESAI? 
 Organizations can benchmark themselves against the practices of CESAI by analyzing their marketing and business strategies, their word of mouth advertising strategy, and their customer-focused culture. This allows them to identify areas for improvement and alignment with the guidelines and philosophies of CESAI.


In [16]:
#euclaise/reddit-instruct
reddit_instruct = load_dataset("euclaise/reddit-instruct", split='train', cache_dir=CACHE_DIR)
# reddit_instruct = load_dataset("Felladrin/ChatML-reddit-instruct-curated", split='train', cache_dir=CACHE_DIR)
print("Reddit Instruct dataset loaded.")
print("dataset size in gb:", reddit_instruct.dataset_size / (1024**3))
print("Number of entries:", len(reddit_instruct))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(reddit_instruct)-1)
print(reddit_instruct[index]['post_title'][:500], reddit_instruct[index]['post_text'][:500]), "\n", reddit_instruct[index]['comment_text'][:500]

Reddit Instruct dataset loaded.
dataset size in gb: 0.09901080373674631
Number of entries: 84784
--------------------------------------------------
Example entry:
[40K] So, Horrus stabs the Emperor, the Emperor kills him.. What are the events immediately proceeding this? He's standing there in a battlefield bleeding and what not I guess, Horrus is dead, so his men just sorta... stop fighting? So who suddenly took charge and got him back to the throne (which I assume was nowhere near the battlefield, being at the very bottom of a palace the size of Europe...) Who started setting up contingency plans in the case that at that moment he *died*.




(None,
 '\n',
 "They're on board Horus's battleship, the Emperor has just wiped Horus's soul from existence and is lying, nearly dead, in the midst of horrible carnage. The bodies of Ollanius Pius and Sanguinus lie near him, and the Emperor weeps for instant of sanity he saw in his son's eyes before he killed him.\n\nRogal Dorn enters, having split away from the Emperor shortly after they teleported aboard the battleship. The Emperor, realizing that he is dying, and knowing that he needs to be alive to guide the ")

In [17]:
# tatsu-lab/alpaca ( for Q&A fine-tuning )
alpaca = load_dataset("tatsu-lab/alpaca", split='train')
print("Alpaca dataset loaded.")
print("dataset size in mb:", alpaca.dataset_size / (1024**2))
print("Number of entries:", len(alpaca))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(alpaca)-1)
print(alpaca[index]['instruction'][:500], "\n", alpaca[index]['output'][:500])

Alpaca dataset loaded.
dataset size in mb: 44.06797695159912
Number of entries: 52002
--------------------------------------------------
Example entry:
Summarize an article titled "Advantages of Owning a Pet" 
 Owning a pet can bring a multitude of benefits to an individual, both physically and mentally. From physical health benefits such as reduced cholesterol levels to mental wellbeing improvements such as increased levels of happiness, the advantages of owning a pet are significant.


## Data Preprocessing

#### Tokenizer setup

For this project i use tiktoken for the tokenizer, as it is the same tokenizer used by OpenAI for their models.

I use the "gpt2" encoding which is a byte pair encoding (BPE) tokenizer.

In [18]:
tokenizer_base = tiktoken.get_encoding("gpt2")

tokenizer = tiktoken.Encoding(
    name="rob-tokenizer",
    pat_str=tokenizer_base._pat_str,
    mergeable_ranks=tokenizer_base._mergeable_ranks,
    special_tokens={
        **tokenizer_base._special_tokens,
        "<|im_start|>": 50257,
        "<|im_end|>": 50258,
        "<|pad|>": 50259,
    }
)

#### Test of the byte pair encoding tokenizer 

In [19]:
# test of tokenizer on reddit_instruct
sample_text = reddit_instruct[0]['post_title'] + " " + reddit_instruct[0]['post_text'] + " " + reddit_instruct[0]['comment_text']
tokens = tokenizer.encode(sample_text)
print(tokens)
print("Decoded text:")
print(tokenizer.decode(tokens)) 
print(f"Sample text length in characters: {len(sample_text)}")
print(f"Sample text length in tokens: {len(tokens)}")   

[2061, 318, 24207, 1616, 2587, 30, 314, 2342, 257, 7684, 286, 1097, 5861, 290, 484, 1561, 546, 275, 32512, 7021, 290, 884, 11, 1312, 373, 11263, 644, 275, 32512, 318, 290, 1312, 18548, 1064, 597, 2562, 7468, 284, 644, 340, 318, 24207, 1616, 318, 655, 262, 1438, 329, 257, 16058, 286, 6147, 13, 554, 262, 29393, 995, 340, 338, 1690, 973, 355, 257, 1790, 1021, 329, 3354, 326, 547, 3235, 1389, 503, 286, 257, 1263, 2512, 286, 2587, 11, 355, 6886, 284, 11721, 3350, 654, 810, 44030, 6147, 318, 19036, 656, 257, 15936, 12070, 503, 286, 9629, 6147, 13, 7080, 3191, 318, 517, 5789, 329, 1588, 17794, 475, 340, 460, 779, 1365, 3081, 286, 21782, 290, 318, 4577, 284, 787, 329, 4833, 17794, 588, 3234, 3354, 13]
Decoded text:
What is Billet material? I watch a bunch of car videos and they talk about billet blocks and such, i was wondering what billet is and i cant find any easy explanation to what it is Billet is just the name for a chunk of metal. In the automotive world it's often used as a short hand 

### Formatting datasets functions

#### Merging datasets

In [20]:
from datasets import concatenate_datasets
combined_train_dataset = concatenate_datasets([  
    wiki_en,
    stories,
    fineweb_edu,
    owt2,  
])  

combined_finetune_dataset = concatenate_datasets([
    q_a1,
    reddit_instruct,
    alpaca,
])

# Shuffle the combined dataset
train_dataset = combined_train_dataset.shuffle(seed=42)
finetune_dataset = combined_finetune_dataset.shuffle(seed=42)
print(f"Train dataset size: {len(combined_train_dataset)}")
print(f"Finetune dataset size: {len(combined_finetune_dataset)}")

# Exemple 

print("Example entry from train dataset:")
index = random.randint(0, len(train_dataset)-1)
print(train_dataset[index])   

Train dataset size: 26209380
Finetune dataset size: 257745
Example entry from train dataset:
{'id': None, 'url': None, 'title': None, 'text': 'Something about garbage trucks are simply beastly. Well, maybe a few things — their size, the stench, and the way they totally crush that garbage. So, electric garbage trucks bring some nice beastly power to the electric vehicle community. And I hope it explains why I got so excited when I saw news about such garbage trucks (and keep hearing their crushing power in my head).\n\nWord on the street is that Motiv Power Systems will soon provide the first U.S. all-electric Class 8 refuse truck. And it has just announced some of its vendor partners. “The garbage truck body will be a Loadmaster 20 cubic yard Excel-S series, provided by RNOW Inc under municipal contract with the City of Chicago. The chassis, manufactured by Crane Carrier, will be furnished by Cumberland Service Center another City of Chicago contracted dealer,” a press release regardin

#### Custom Dataset class

Inspired by the dataloader from the "LLMs from scratch" repository. But adapted for multi-row text arrays.

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/dataloader.ipynb

In [21]:
class GPTDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length):
        """
        Args:
            dataset: Dataset of the combined hugginface entries
            tokenizer: the initiatokenizer to process text
            max_length: Context window size
        """
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.input_tokens = []
        self.target_tokens = []

        self.pad_token_id = 50259         # <|pad|>
        self.bos_token_id = 50257    # <|im_start|>
        self.eos_token_id = 50258    # <|im_end|>

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Raw text
        
        # Data format handling
        entry = self.data[idx]
        if 'text' in entry:
            text = entry['text']
        elif 'story' in entry:
            text = entry['story']
        elif 'question' in entry and 'answer' in entry: 
            text = "User: " + entry['question'] + " Assistant:" + entry['answer']
        elif 'post_title' in entry and 'post_text' in entry and 'comment_text' in entry:
            text = "User: " + entry['post_title'] + " Assistant:" + entry['post_text'] + " " + entry['comment_text']
        elif 'instruction' in entry and 'output' in entry:
            text = "User: " + entry['instruction'] + " Assistant:" + entry['output']
        else:
            raise ValueError("Unknown data entry format")
        
        text = str(text) # Ensure text is a string
        #print(text)

        # Adding Start and End tokens
        text = "<|im_start|>" + text + "<|im_end|>" 

        # Tokenization
        tokens = self.tokenizer.encode(text, allowed_special="all")

        # Truncation
        tokens = tokens[:self.max_length] #Data is loost here ( fix later with sliding window )

        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)  # All tokens except last
        labels = torch.tensor(tokens[1:], dtype=torch.long)      # All tokens except first


        #Padding 
        padding_length = self.max_length - len(tokens)
        if padding_length > 0:
            input_ids = torch.cat([input_ids, torch.full((padding_length,), self.pad_token_id)])
            labels = torch.cat([labels, torch.full((padding_length,), -100)])


        attention_mask = (input_ids != self.pad_token_id).long() # 1 for real tokens, 0 for padding

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

## GPT Model 

### GPT config 

This is the configuration for the GPT model i am going to train. It is a smaller version of the GPT-2 model. 

- Context length: 512 tokens
- Embedding dimension: 512
- Number of attention heads: 8
- Number of layers: 8

In [22]:
GPT_CONFIG = {
    "vocab_size": 50260,
    "context_length": 512, # max i could fit on my gpu
    "emb_dim": 512,
    "number_heads": 8,
    "number_layers": 8,
    "drop_rate": 0.1,
}

##### Test of a entry from dataloader

In [23]:
# Empty cuda cache and memory management
torch.cuda.empty_cache()

train_dataset = GPTDataset(combined_train_dataset, tokenizer, max_length=GPT_CONFIG["context_length"])
print(f"Train dataset size: {len(train_dataset)}")

batch_size = 12
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2, persistent_workers=True)

Train dataset size: 26209380


In [None]:
print("##### Test of a entry from dataloader")
batch = next(iter(train_dataloader))
print(batch)    

##### Test of a entry from dataloader
{'input_ids': tensor([[50257,  2484,   789,  ...,   340,    13, 24006],
        [50257, 14202, 50259,  ..., 50259, 50259, 50259],
        [50257,    12, 36829,  ..., 50259, 50259, 50259],
        ...,
        [50257,  2061,   717,  ...,   428,   198,  5235],
        [50257, 14202, 50259,  ..., 50259, 50259, 50259],
        [50257, 13603,   272,  ..., 42116,   422,   262]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[ 2484,   789, 35375,  ...,    13, 24006,   329],
        [14202, 50258,  -100,  ...,  -100,  -100,  -100],
        [   12, 36829,  6763,  ...,  -100,  -100,  -100],
        ...,
        [ 2061,   717,  3181,  ...,   198,  5235, 16877],
        [14202, 50258,  -100,  ...,  -100,  -100,  -100],
        [13603,   272, 10322,  ...,

#### Pytroch model implementation

For this first implementation, i am using the transformer and embedding modules from PyTorch. Later i will try to implement the attention mechanism from scratch for better understanding.

https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [24]:
class GPTModel(nn.Module):
    """
    Gpt model class using transformer library
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Network components 
        ## Embedding layers
        self.embedding = nn.Embedding(config['vocab_size'], config['emb_dim'])
        self.positional_encoding = nn.Embedding(config['context_length'], config['emb_dim'])
        ## Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['emb_dim'],
            nhead=config['number_heads'],
            dim_feedforward=4 * config['emb_dim'],
            dropout=config['drop_rate'],
            activation='gelu',
            batch_first=True,
            norm_first=True # stabilityy 
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['number_layers'])
        ## Output layer
        self.output_layer = nn.Linear(config['emb_dim'], config['vocab_size'], bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, input_ids, attention_mask, label_ids=None):   
        batch_size, seq_length = input_ids.shape

        # Embedding
        token_embeddings = self.embedding(input_ids)  
        pos_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.positional_encoding(pos_ids)  # (batch_size, seq_length, emb_dim)

        embeddings = token_embeddings + position_embeddings  # (batch_size, seq_length, emb_dim)

        # Prevent attending to future tokens
        causal_mask = torch.triu(torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device), diagonal=1)

        # voiding to pay attatention padding tokens
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
        
        x = self.transformer(embeddings, mask=causal_mask, src_key_padding_mask=key_padding_mask, is_causal=True)
        logits = self.output_layer(x)

        # Computing loss 
        if label_ids is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.config['vocab_size']), label_ids.view(-1)) # Applies loss to predictions
            return logits, loss

        return logits, None




### Model instantiation 

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(GPT_CONFIG).to(device)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # Cosine annealing learning rate scheduler
scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision



GPTModel(
  (embedding): Embedding(50260, 512)
  (positional_encoding): Embedding(512, 512)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_layer): Linear(in_features=512, out_features=50260, bias=False)
)


  scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision


### Number of parameters calculation

In [26]:
# Calculate number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)    

count_params = count_parameters(model)  
print(f"Number of trainable parameters: {count_params}")

Number of trainable parameters: 76947456


## Training setup

In [19]:
def inference(model, tokenizer, prompt, max_length=256, device='cpu'):
    model.eval()
    input_ids = tokenizer.encode(prompt, allowed_special="all")
    input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

    generated_ids = input_ids
    max_length = max_length - input_ids.shape[1]  # Remaining length for generation
    with torch.no_grad():
        for _ in range(max_length):
            attention_mask = torch.ones_like(generated_ids)  # All tokens are real (no padding)
            logits, _ = model(generated_ids, attention_mask)
            next_token_logits = logits[:, -1, :]  # (1, vocab_size)
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)  # (1, 1)

            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)  # Append to sequence

            if next_token_id.item() == 50258:  # Stop if eot token is generated
                break

    generated_text = tokenizer.decode(generated_ids.squeeze().tolist())
    return generated_text


In [None]:
accumulation_steps= 1  # Number of steps to accumulate gradients

def train_loop(model, dataloader, optimizer, scheduler, device, num_epochs=3, accumulation_steps = 4, question_interval=500, saving_interval=5000):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        step_count = 0

        for i, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):  # bf16 optimized for ampere archi
                logits, loss = model(input_ids, attention_mask, labels)
                loss = loss / accumulation_steps
            scaler.scale(loss).backward()
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            epoch_loss += loss.item()*accumulation_steps
            progress_bar.set_postfix(loss=loss.item() * accumulation_steps)

            # Inference check
            step_count += 1
            if step_count % question_interval == 0:
                model.eval()
                with torch.no_grad():
                    prompt = "You are an AI being trained. How are you doing?"
                    generated_text = inference(model, tokenizer, prompt, max_length=50, device=device)
                    print(f"\n[Inference at step {step_count}]")
                    #save generated text to file train_ouput.txt
                    with open("outputs/train_output.txt", "a") as f:
                        if generated_text.strip() != "":
                            f.write(f"\n[Inference at step {step_count}]: {generated_text}\n")  
                            f.write("-"*50 + "\n")
                        else:
                            f.write(f"\n[Inference at step {step_count}]: [No output generated]\n")
                            f.write("-"*50 + "\n")
                model.train()

            if step_count % saving_interval == 0:
                # Save model checkpoint
                date_hour = datetime.now().strftime("%Y%m%d_%H%M%S")
                # Create models directory if it doesn't exist
                Path(f"models/{date_hour}").mkdir(parents=True, exist_ok=True)
                checkpoint_path = f"models/{date_hour}/weights_step{step_count}.pt"
                torch.save(model.state_dict(), checkpoint_path)
                print(f"\nModel checkpoint saved at step {step_count} to {checkpoint_path}")

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")


### First training on combined train dataset

In [21]:
#empty gpu memory 
torch.cuda.empty_cache()
torch.cuda.synchronize()
 

train_loop(model, train_dataloader, optimizer, scheduler, device, num_epochs=1)

Epoch 1/1:   0%|          | 501/2184115 [01:45<146:10:01,  4.15it/s, loss=7.26]


[Inference at step 500]


Epoch 1/1:   0%|          | 1000/2184115 [03:25<138:05:52,  4.39it/s, loss=6.77]


[Inference at step 1000]


Epoch 1/1:   0%|          | 1501/2184115 [05:08<136:23:39,  4.45it/s, loss=6.76]


[Inference at step 1500]


Epoch 1/1:   0%|          | 2000/2184115 [06:52<144:55:01,  4.18it/s, loss=6.53]


[Inference at step 2000]


Epoch 1/1:   0%|          | 2500/2184115 [08:36<144:40:57,  4.19it/s, loss=6.65]


[Inference at step 2500]


Epoch 1/1:   0%|          | 3000/2184115 [10:20<143:54:17,  4.21it/s, loss=6.37]


[Inference at step 3000]


Epoch 1/1:   0%|          | 3501/2184115 [12:03<127:32:01,  4.75it/s, loss=6.43]


[Inference at step 3500]


Epoch 1/1:   0%|          | 4001/2184115 [13:39<126:52:30,  4.77it/s, loss=6.39]


[Inference at step 4000]


Epoch 1/1:   0%|          | 4501/2184115 [15:14<127:12:28,  4.76it/s, loss=6.19]


[Inference at step 4500]


Epoch 1/1:   0%|          | 4999/2184115 [16:49<122:17:21,  4.95it/s, loss=6.27]



[Inference at step 5000]


RuntimeError: Parent directory models/20260202_181015 does not exist.

In [None]:
torch.cuda.empty_cache()
torch.cuda.synchronize()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## Sources 

### Principal references: 
- https://arxiv.org/abs/2005.14165 (GPT-3 paper)
- https://arxiv.org/abs/2002.05709 (Attention is all you need paper)
- Build a Large Language Model (from scratch) by Sebastian Raschka

### About Padding tokens in Language Modeling
- https://arxiv.org/html/2510.01238v1 