# Projet GPT - Train

This notebook contains the code used to train a small language model using PyTorch from scratch. The model is inspired by the GPT architecture.


#### Hardware
- RTX3060 12GB VRAM
- AMD Ryzen 7 5800X 8-Core
- 32GB RAM
- Ubuntu 22.04 LTS

In [1]:
import os
CACHE_DIR = "/media/rob/RobsDisk/cache_data_llm"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, "datasets")
os.environ['HF_METRICS_CACHE'] = os.path.join(CACHE_DIR, "metrics")
os.environ['HF_MODULES_CACHE'] = os.path.join(CACHE_DIR, "modules")


In [2]:
import torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import random, math

from datasets import load_dataset,concatenate_datasets
import tiktoken
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Datasets

### Common knowledge datasets:

##### English Wikipedia crawled dataset

In [3]:
# English Wikipedia crawled dataset
# path to store the dataset cache: /Volumes/RobertsDisk
wiki_en = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', cache_dir=CACHE_DIR) 
print("English Wikipedia dataset loaded.")
print("dataset size in gb:", wiki_en.dataset_size / (1024**3))
print("Number of entries:", len(wiki_en))
print("-"*50)
print("Example entry:")
print(wiki_en[random.randint(0, len(wiki_en)-1)]['text'])




English Wikipedia dataset loaded.
dataset size in gb: 18.812774107791483
Number of entries: 6407814
--------------------------------------------------
Example entry:
T-ara's Best of Best 2009-2012: Korean ver. (stylized as T-ARA's Best of Best 2009-2012 ～KOREAN ver.～) is the first greatest hits album by South Korean girl group T-ara. It was released on October 10, 2012 by EMI Music Japan to commemorate the one-year anniversary of the group's Japanese debut. The album contains all of T-ara's singles from Absolute First Album (2009) up to Funky Town (2012), including their 2010 FIFA World Cup digital single "We Are the One". A limited ultra-deluxe edition of the album, which includes a 72-page photobook and 120-minute documentary of T-ara's "Free Time in Europe" trip, was released on October 17, 2012.

Track listing

Charts

Oricon

Release history

References

2012 greatest hits albums
T-ara albums
EMI Records compilation albums


#### Simple stories dataset

In [4]:
# Simple stories dataset
stories = load_dataset("SimpleStories/SimpleStories", split='train', cache_dir=CACHE_DIR)
print("Simple stories dataset loaded.")
print("dataset size in mb:", stories.dataset_size / (1024**2))
print("Number of entries:", len(stories))
print("-"*50)
print("Example entry:")
print(stories[random.randint(0, len(stories)-1)]['story'])

Simple stories dataset loaded.
dataset size in mb: 3030.012650489807
Number of entries: 2115696
--------------------------------------------------
Example entry:
Beneath the old oak tree, a father sat with his daughter, looking at an ancient map. They had heard tales of a treasure hidden in the woods, and today, they were on a hunt. The daughter, with wide eyes and a big smile, wanted to find it so badly. The father, however, felt a strange weight in his heart. Would they truly find treasure, or would it lead to something else?

They followed the path, which twisted and turned like a snake. Birds chirped above them, and the wind whispered through the leaves. Each step brought them closer to the X on the map, but the father was lost in thought. He remembered his own father, who had taken him on a similar adventure. They had not found gold then, but laughter and stories.

Suddenly, they came across a shiny object buried in the dirt. The daughter rushed over, her heart racing. "Is this it

##### FineWeb-Edu dataset

In [5]:
fineweb_edu = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT",  split='train', cache_dir=CACHE_DIR)

print("FineWeb-Edu is ready.")
print("dataset size in gb:", fineweb_edu.dataset_size / (1024**3))
print("Number of entries:", len(fineweb_edu))
print("-"*50)
print("Example entry:")
print(fineweb_edu[random.randint(0, len(fineweb_edu)-1)]['text']) 


FineWeb-Edu is ready.
dataset size in gb: 45.730818568728864
Number of entries: 9672101
--------------------------------------------------
Example entry:
Art Therapy is the act of using art as a therapuetic tool to explore and express thoughts and feelings, resolve issues and heal past traumas. You do not need to be an artist or have artistic "talent" to benefit from Art Therapy. Art Therapy can be used in conjunction with other forms of therapy and is often used to enhance a subject being explored through other forms of therapy. Some common artistic materials used in Art Therapy are colored pencils, markers, paints, fabric, pastels, collage, and more. Clients of all ages benefit from using Art Therapy.
Eye Movement Desensitization and Reprocessing is a therapy used to heal the negative beliefs that we have about ourselves. EMDR is a powerful therapy often used to treat trauma, but it is equally effective for those who have not been through traumatic events. EMDR uses eye movements (or

In [6]:
# OpenWebText2 dataset
owt2 = load_dataset("Skylion007/openwebtext", split="train", cache_dir=CACHE_DIR)
print("OpenWebText2 dataset loaded.")
print("Dataset size in GB:", owt2.dataset_size / (1024**3))
print("Number of entries:", len(owt2))
print("-"*50)
print("Example entry:")
print(owt2[random.randint(0, len(owt2)-1)]['text'])

OpenWebText2 dataset loaded.
Dataset size in GB: 37.03822539001703
Number of entries: 8013769
--------------------------------------------------
Example entry:
About This Project

This envision project’s goal is to show that the Warehouse District, in spite of its many challenges, can still become a vibrant, 24/7 urban neighborhood and destination through connectivity, increased sense of community and a variety of mixed uses. It uses its location close to sports and entertainment venues not as the sole defining element, but as one of many key pieces that can be leveraged to form a strong community and sense of place as a strong, urban neighborhood that serves at the southern front door to downtown. By creating this strong neighborhood, the foundation for a place that invites visitation will be formed.

The main focus is area is Jackson – Grant, 7th Street – 4th Avenue.

Download the full Warehouse District Envision Project PDF here

PART I: CONNECTIVITY

Currently, The Warehouse Distri

#### Some Q&A data to improve the model's ability to answer questions:

In [7]:
q_a1 = load_dataset("agentlans/text-sft-questions-answers-only", split='train', cache_dir=CACHE_DIR)
print("Q&A dataset loaded.")
print("dataset size in mb:", q_a1.dataset_size / (1024**2))
print("Number of entries:", len(q_a1))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(q_a1)-1)
print(q_a1[index]['question'][:500], "\n", q_a1[index]['answer'])

Q&A dataset loaded.
dataset size in mb: 46.480509757995605
Number of entries: 120959
--------------------------------------------------
Example entry:
Why did the overall volume of purchases in Q3 2012 remain weaker compared to pre-2008 levels? 
 The overall volume of purchases in Q3 2012 remained weaker compared to pre-2008 levels because even with the positive growth in the latest two quarters, it still lags behind the pre-2008 levels, indicating a lingering impact of the 2008 GDP contraction.


In [8]:
#euclaise/reddit-instruct
reddit_instruct = load_dataset("euclaise/reddit-instruct", split='train', cache_dir=CACHE_DIR)
# reddit_instruct = load_dataset("Felladrin/ChatML-reddit-instruct-curated", split='train', cache_dir=CACHE_DIR)
print("Reddit Instruct dataset loaded.")
print("dataset size in gb:", reddit_instruct.dataset_size / (1024**3))
print("Number of entries:", len(reddit_instruct))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(reddit_instruct)-1)
print(reddit_instruct[index]['post_title'][:500], reddit_instruct[index]['post_text'][:500]), "\n", reddit_instruct[index]['comment_text'][:500]

Reddit Instruct dataset loaded.
dataset size in gb: 0.09901080373674631
Number of entries: 84784
--------------------------------------------------
Example entry:
[HALO] My younger brother wanted to know why UNSC Marines don't just take the armor off of Spartan bodies when they die? I was wondering this myself, and a few years ago I think I saw a video that showed what would happen if a Marine wore Mjolnir armor, and it ended up killing him. If anyone has that clip, I'd greatly appreciate that too! (Unless it was just a fever dream)


(None,
 '\n',
 "There are lots of reasons:\n\n- Spartans are much taller than average, over 7 ft, so it'd be unlikely to fit.\n\n- Mjolnir connects directly to Spartans brains through neural implants, which low ranking soldiers wouldn't normally have.\n\n- Spartans' bodies are stronger and faster then a normal human's, which allows them to use the suit without being crushed. (Like that video your thinking of.)\n\n- And finally, Mjolnir actually take a whole team of people to put on and take off (or a specialized robot ")

In [9]:
# tatsu-lab/alpaca ( for Q&A fine-tuning )
alpaca = load_dataset("tatsu-lab/alpaca", split='train')
print("Alpaca dataset loaded.")
print("dataset size in mb:", alpaca.dataset_size / (1024**2))
print("Number of entries:", len(alpaca))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(alpaca)-1)
print(alpaca[index]['instruction'][:500], "\n", alpaca[index]['output'][:500])

Alpaca dataset loaded.
dataset size in mb: 44.06797695159912
Number of entries: 52002
--------------------------------------------------
Example entry:
Describe three tips for investing in a low-risk asset class. 
 1. Consider your investment horizon – longer-term investments tend to be less risky than short-term ones.
2. Understand and diversify your portfolio – diversify across asset classes such as stocks, bonds, and cash to lower your risk.
3. Utilize asset allocation – allocating your portfolio across different asset classes helps better manage risk and optimize returns over different time frames.


## Data Preprocessing

#### Tokenizer setup

For this project i use tiktoken for the tokenizer, as it is the same tokenizer used by OpenAI for their models.

I use the "gpt2" encoding which is a byte pair encoding (BPE) tokenizer.

In [10]:
tokenizer_base = tiktoken.get_encoding("gpt2")

tokenizer = tiktoken.Encoding(
    name="rob-tokenizer",
    pat_str=tokenizer_base._pat_str,
    mergeable_ranks=tokenizer_base._mergeable_ranks,
    special_tokens={
        **tokenizer_base._special_tokens,
        "<|im_start|>": 50257,
        "<|im_end|>": 50258,
        "<|pad|>": 50259,
    }
)

#### Test of the byte pair encoding tokenizer 

In [11]:
# test of tokenizer on reddit_instruct
sample_text = reddit_instruct[0]['post_title'] + " " + reddit_instruct[0]['post_text'] + " " + reddit_instruct[0]['comment_text']
tokens = tokenizer.encode(sample_text)
print(tokens)
print("Decoded text:")
print(tokenizer.decode(tokens)) 
print(f"Sample text length in characters: {len(sample_text)}")
print(f"Sample text length in tokens: {len(tokens)}")   

[2061, 318, 24207, 1616, 2587, 30, 314, 2342, 257, 7684, 286, 1097, 5861, 290, 484, 1561, 546, 275, 32512, 7021, 290, 884, 11, 1312, 373, 11263, 644, 275, 32512, 318, 290, 1312, 18548, 1064, 597, 2562, 7468, 284, 644, 340, 318, 24207, 1616, 318, 655, 262, 1438, 329, 257, 16058, 286, 6147, 13, 554, 262, 29393, 995, 340, 338, 1690, 973, 355, 257, 1790, 1021, 329, 3354, 326, 547, 3235, 1389, 503, 286, 257, 1263, 2512, 286, 2587, 11, 355, 6886, 284, 11721, 3350, 654, 810, 44030, 6147, 318, 19036, 656, 257, 15936, 12070, 503, 286, 9629, 6147, 13, 7080, 3191, 318, 517, 5789, 329, 1588, 17794, 475, 340, 460, 779, 1365, 3081, 286, 21782, 290, 318, 4577, 284, 787, 329, 4833, 17794, 588, 3234, 3354, 13]
Decoded text:
What is Billet material? I watch a bunch of car videos and they talk about billet blocks and such, i was wondering what billet is and i cant find any easy explanation to what it is Billet is just the name for a chunk of metal. In the automotive world it's often used as a short hand 

### Formatting datasets functions

#### Merging datasets

In [12]:
from datasets import concatenate_datasets
combined_train_dataset = concatenate_datasets([  
    wiki_en,
    stories,
    fineweb_edu,
    owt2,  
])  

combined_finetune_dataset = concatenate_datasets([
    q_a1,
    reddit_instruct,
    alpaca,
])

# Shuffle the combined dataset
train_dataset = combined_train_dataset.shuffle(seed=42)
finetune_dataset = combined_finetune_dataset.shuffle(seed=42)
print(f"Train dataset size: {len(combined_train_dataset)}")
print(f"Finetune dataset size: {len(combined_finetune_dataset)}")

# Exemple 

print("Example entry from train dataset:")
index = random.randint(0, len(train_dataset)-1)
print(train_dataset[index])   

Train dataset size: 26209380
Finetune dataset size: 257745
Example entry from train dataset:
{'id': '11451537', 'url': 'https://en.wikipedia.org/wiki/Galley%20Down%20Wood', 'title': 'Galley Down Wood', 'text': "Galley Down Wood is a  biological Site of Special Scientific Interest north-east of Bishop's Waltham in Hampshire.\n\nThis wood, which was planted with beech trees in around 1930, has a well developed beech flora. Flowering plants include bird's-nest orchid, white helleborine, greater butterfly-orchid, common spotted orchid and the nationally rare long-leaved helleborine.\n\nReferences\n\nSites of Special Scientific Interest in Hampshire\nBishop's Waltham", 'story': None, 'topic': None, 'theme': None, 'style': None, 'feature': None, 'grammar': None, 'persona': None, 'initial_word_type': None, 'initial_letter': None, 'word_count': None, 'character_count': None, 'num_paragraphs': None, 'avg_word_length': None, 'avg_sentence_length': None, 'flesch_reading_ease': None, 'flesch_kinca

#### Custom Dataset class

Inspired by the dataloader from the "LLMs from scratch" repository. But adapted for multi-row text arrays.

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/dataloader.ipynb

In [13]:
class GPTDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length):
        """
        Args:
            dataset: Dataset of the combined hugginface entries
            tokenizer: the initiatokenizer to process text
            max_length: Context window size
        """
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.input_tokens = []
        self.target_tokens = []

        self.pad_token_id = 50259         # <|pad|>
        self.bos_token_id = 50257    # <|im_start|>
        self.eos_token_id = 50258    # <|im_end|>

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Raw text
        
        # Data format handling
        entry = self.data[idx]
        if 'text' in entry:
            text = entry['text']
        elif 'story' in entry:
            text = entry['story']
        elif 'question' in entry and 'answer' in entry: 
            text = "User: " + entry['question'] + " Assistant:" + entry['answer']
        elif 'post_title' in entry and 'post_text' in entry and 'comment_text' in entry:
            text = "User: " + entry['post_title'] + " Assistant:" + entry['post_text'] + " " + entry['comment_text']
        elif 'instruction' in entry and 'output' in entry:
            text = "User: " + entry['instruction'] + " Assistant:" + entry['output']
        else:
            raise ValueError("Unknown data entry format")
        
        text = str(text) # Ensure text is a string
        #print(text)

        # Adding Start and End tokens
        text = "<|im_start|>" + text + "<|im_end|>" 

        # Tokenization
        tokens = self.tokenizer.encode(text, allowed_special="all")

        # Truncation
        tokens = tokens[:self.max_length] #Data is loost here ( fix later with sliding window )

        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)  # All tokens except last
        labels = torch.tensor(tokens[1:], dtype=torch.long)      # All tokens except first


        #Padding 
        padding_length = self.max_length - len(tokens)
        if padding_length > 0:
            input_ids = torch.cat([input_ids, torch.full((padding_length,), self.pad_token_id)])
            labels = torch.cat([labels, torch.full((padding_length,), -100)])


        attention_mask = (input_ids != self.pad_token_id).long() # 1 for real tokens, 0 for padding

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

## GPT Model 

### GPT config 

This is the configuration for the GPT model i am going to train. It is a smaller version of the GPT-2 model. 

- Context length: 512 tokens
- Embedding dimension: 512
- Number of attention heads: 8
- Number of layers: 8

In [14]:
GPT_CONFIG = {
    "vocab_size": 50260,
    "context_length": 512, # max i could fit on my gpu
    "emb_dim": 512,
    "number_heads": 8,
    "number_layers": 8,
    "drop_rate": 0.1,
}

##### Test of a entry from dataloader

In [15]:
# Empty cuda cache and memory management
torch.cuda.empty_cache()

train_dataset = GPTDataset(combined_train_dataset, tokenizer, max_length=GPT_CONFIG["context_length"])
print(f"Train dataset size: {len(train_dataset)}")

batch_size = 12
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2, persistent_workers=True)

Train dataset size: 26209380


In [16]:
print("##### Test of a entry from dataloader")
batch = next(iter(train_dataloader))
print(batch)    

##### Test of a entry from dataloader
{'input_ids': tensor([[50257, 13414,  1229,  ...,  1119,   423,   587],
        [50257, 20191, 26113,  ..., 50259, 50259, 50259],
        [50257,   464, 16431,  ..., 50259, 50259, 50259],
        ...,
        [50257, 22697,   311,  ..., 50259, 50259, 50259],
        [50257,    32, 11287,  ..., 50259, 50259, 50259],
        [50257,    40,  1842,  ..., 50259, 50259, 50259]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[13414,  1229, 40900,  ...,   423,   587,  5000],
        [20191, 26113,  3811,  ...,  -100,  -100,  -100],
        [  464, 16431, 26004,  ...,  -100,  -100,  -100],
        ...,
        [22697,   311,    13,  ...,  -100,  -100,  -100],
        [   32, 11287,  9002,  ...,  -100,  -100,  -100],
        [   40,  1842, 17252,  ...,

#### Pytroch model implementation

For this first implementation, i am using the transformer and embedding modules from PyTorch. Later i will try to implement the attention mechanism from scratch for better understanding.

https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [17]:
class GPTModel(nn.Module):
    """
    Gpt model class using transformer library
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Network components 
        ## Embedding layers
        self.embedding = nn.Embedding(config['vocab_size'], config['emb_dim'])
        self.positional_encoding = nn.Embedding(config['context_length'], config['emb_dim'])
        ## Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['emb_dim'],
            nhead=config['number_heads'],
            dim_feedforward=4 * config['emb_dim'],
            dropout=config['drop_rate'],
            activation='gelu',
            batch_first=True,
            norm_first=True # stabilityy 
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['number_layers'])
        ## Output layer
        self.output_layer = nn.Linear(config['emb_dim'], config['vocab_size'], bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, input_ids, attention_mask, label_ids=None):   
        batch_size, seq_length = input_ids.shape

        # Embedding
        token_embeddings = self.embedding(input_ids)  
        pos_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.positional_encoding(pos_ids)  # (batch_size, seq_length, emb_dim)

        embeddings = token_embeddings + position_embeddings  # (batch_size, seq_length, emb_dim)

        # Prevent attending to future tokens
        causal_mask = torch.triu(torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device), diagonal=1)

        # voiding to pay attatention padding tokens
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
        
        x = self.transformer(embeddings, mask=causal_mask, src_key_padding_mask=key_padding_mask, is_causal=True)
        logits = self.output_layer(x)

        # Computing loss 
        if label_ids is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.config['vocab_size']), label_ids.view(-1)) # Applies loss to predictions
            return logits, loss

        return logits, None




### Model instantiation 

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(GPT_CONFIG).to(device)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3) # AdamW optimizer
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # Cosine annealing learning rate scheduler
scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision



GPTModel(
  (embedding): Embedding(50260, 512)
  (positional_encoding): Embedding(512, 512)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_layer): Linear(in_features=512, out_features=50260, bias=False)
)


  scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision


## Training setup

In [19]:
def inference(model, tokenizer, prompt, max_length=256, device='cpu'):
    model.eval()
    input_ids = tokenizer.encode(prompt, allowed_special="all")
    input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

    generated_ids = input_ids
    max_length = max_length - input_ids.shape[1]  # Remaining length for generation
    with torch.no_grad():
        for _ in range(max_length):
            attention_mask = torch.ones_like(generated_ids)  # All tokens are real (no padding)
            logits, _ = model(generated_ids, attention_mask)
            next_token_logits = logits[:, -1, :]  # (1, vocab_size)
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)  # (1, 1)

            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)  # Append to sequence

            if next_token_id.item() == 50258:  # Stop if eot token is generated
                break

    generated_text = tokenizer.decode(generated_ids.squeeze().tolist())
    return generated_text


In [20]:
accumulation_steps= 1  # Number of steps to accumulate gradients

def train_loop(model, dataloader, optimizer, scheduler, device, num_epochs=3, accumulation_steps = 4, question_interval=500):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        step_count = 0

        for i, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):  # bf16 optimized for ampere archi
                logits, loss = model(input_ids, attention_mask, labels)
                loss = loss / accumulation_steps
            scaler.scale(loss).backward()
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            epoch_loss += loss.item()*accumulation_steps
            progress_bar.set_postfix(loss=loss.item() * accumulation_steps)

            # Inference check
            step_count += 1
            if step_count % question_interval == 0:
                model.eval()
                with torch.no_grad():
                    prompt = "You are an AI being trained. How are you doing?"
                    generated_text = inference(model, tokenizer, prompt, max_length=50, device=device)
                    print(f"\n[Inference at step {step_count}]")
                    #save generated text to file train_ouput.txt
                    with open("train_output.txt", "a") as f:
                        if generated_text.strip() != "":
                            f.write(f"\n[Inference at step {step_count}]: {generated_text}\n")  
                            f.write("-"*50 + "\n")
                        else:
                            f.write(f"\n[Inference at step {step_count}]: [No output generated]\n")
                            f.write("-"*50 + "\n")
                model.train()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")


### First training on combined train dataset

In [21]:
#empty gpu memory 
torch.cuda.empty_cache()
torch.cuda.synchronize()
 

train_loop(model, train_dataloader, optimizer, scheduler, device, num_epochs=3)

Epoch 1/3:   0%|          | 501/2184115 [01:45<150:17:45,  4.04it/s, loss=10]  


[Inference at step 500]


Epoch 1/3:   0%|          | 1001/2184115 [03:23<131:18:23,  4.62it/s, loss=7.17]


[Inference at step 1000]


Epoch 1/3:   0%|          | 1501/2184115 [05:04<132:51:42,  4.56it/s, loss=6.89]


[Inference at step 1500]


Epoch 1/3:   0%|          | 2001/2184115 [06:44<132:15:37,  4.58it/s, loss=6.56]


[Inference at step 2000]


Epoch 1/3:   0%|          | 2501/2184115 [08:25<133:02:42,  4.55it/s, loss=6.68]


[Inference at step 2500]


Epoch 1/3:   0%|          | 3001/2184115 [10:04<135:25:05,  4.47it/s, loss=6.54]


[Inference at step 3000]


Epoch 1/3:   0%|          | 3501/2184115 [11:48<134:11:55,  4.51it/s, loss=6.36]


[Inference at step 3500]


Epoch 1/3:   0%|          | 4000/2184115 [13:30<143:06:32,  4.23it/s, loss=6.31]


[Inference at step 4000]


Epoch 1/3:   0%|          | 4501/2184115 [15:13<135:27:44,  4.47it/s, loss=6.2] 


[Inference at step 4500]


Epoch 1/3:   0%|          | 5001/2184115 [16:55<134:39:50,  4.49it/s, loss=6.37]


[Inference at step 5000]


Epoch 1/3:   0%|          | 5501/2184115 [18:34<126:46:08,  4.77it/s, loss=6.06]


[Inference at step 5500]


Epoch 1/3:   0%|          | 6001/2184115 [20:09<127:09:01,  4.76it/s, loss=5.98]


[Inference at step 6000]


Epoch 1/3:   0%|          | 6501/2184115 [21:45<126:50:17,  4.77it/s, loss=6.03]


[Inference at step 6500]


Epoch 1/3:   0%|          | 7001/2184115 [23:20<127:30:53,  4.74it/s, loss=5.89]


[Inference at step 7000]


Epoch 1/3:   0%|          | 7501/2184115 [24:55<126:53:08,  4.77it/s, loss=6.06]


[Inference at step 7500]


Epoch 1/3:   0%|          | 8001/2184115 [26:31<126:57:38,  4.76it/s, loss=5.98]


[Inference at step 8000]


Epoch 1/3:   0%|          | 8501/2184115 [28:06<126:43:52,  4.77it/s, loss=5.6] 


[Inference at step 8500]


Epoch 1/3:   0%|          | 9001/2184115 [29:41<126:52:54,  4.76it/s, loss=5.87]


[Inference at step 9000]


Epoch 1/3:   0%|          | 9501/2184115 [31:17<126:38:08,  4.77it/s, loss=5.9] 


[Inference at step 9500]


Epoch 1/3:   0%|          | 10001/2184115 [32:52<126:41:35,  4.77it/s, loss=5.88]


[Inference at step 10000]


Epoch 1/3:   0%|          | 10501/2184115 [34:28<126:41:28,  4.77it/s, loss=5.94]


[Inference at step 10500]


Epoch 1/3:   1%|          | 11001/2184115 [36:03<126:29:25,  4.77it/s, loss=5.73]


[Inference at step 11000]


Epoch 1/3:   1%|          | 11501/2184115 [37:38<126:33:47,  4.77it/s, loss=5.82]


[Inference at step 11500]


Epoch 1/3:   1%|          | 12001/2184115 [39:14<126:27:02,  4.77it/s, loss=5.74]


[Inference at step 12000]


Epoch 1/3:   1%|          | 12501/2184115 [40:49<126:24:16,  4.77it/s, loss=5.43]


[Inference at step 12500]


Epoch 1/3:   1%|          | 13001/2184115 [42:24<126:30:21,  4.77it/s, loss=5.41]


[Inference at step 13000]


Epoch 1/3:   1%|          | 13501/2184115 [44:00<126:24:12,  4.77it/s, loss=5.73]


[Inference at step 13500]


Epoch 1/3:   1%|          | 14001/2184115 [45:38<126:36:59,  4.76it/s, loss=5.67]


[Inference at step 14000]


Epoch 1/3:   1%|          | 14501/2184115 [47:14<126:24:08,  4.77it/s, loss=5.62]


[Inference at step 14500]


Epoch 1/3:   1%|          | 15001/2184115 [48:49<126:24:53,  4.77it/s, loss=5.66]


[Inference at step 15000]


Epoch 1/3:   1%|          | 15501/2184115 [50:24<126:12:21,  4.77it/s, loss=5.66]


[Inference at step 15500]


Epoch 1/3:   1%|          | 16001/2184115 [52:00<126:14:13,  4.77it/s, loss=5.33]


[Inference at step 16000]


Epoch 1/3:   1%|          | 16501/2184115 [53:35<126:17:04,  4.77it/s, loss=5.41]


[Inference at step 16500]


Epoch 1/3:   1%|          | 17001/2184115 [55:11<126:15:14,  4.77it/s, loss=5.36]


[Inference at step 17000]


Epoch 1/3:   1%|          | 17501/2184115 [56:46<126:02:04,  4.78it/s, loss=5.49]


[Inference at step 17500]


Epoch 1/3:   1%|          | 18001/2184115 [58:23<129:01:10,  4.66it/s, loss=5.31]


[Inference at step 18000]


Epoch 1/3:   1%|          | 18500/2184115 [1:00:05<144:45:02,  4.16it/s, loss=5.46]


[Inference at step 18500]


Epoch 1/3:   1%|          | 19000/2184115 [1:01:50<144:32:21,  4.16it/s, loss=5.22]


[Inference at step 19000]


Epoch 1/3:   1%|          | 19500/2184115 [1:03:38<146:58:44,  4.09it/s, loss=5.18]


[Inference at step 19500]


Epoch 1/3:   1%|          | 20001/2184115 [1:05:23<129:29:07,  4.64it/s, loss=5.41]


[Inference at step 20000]


Epoch 1/3:   1%|          | 20500/2184115 [1:07:08<149:10:08,  4.03it/s, loss=5.44]


[Inference at step 20500]


Epoch 1/3:   1%|          | 21001/2184115 [1:08:51<127:03:21,  4.73it/s, loss=5.48]


[Inference at step 21000]


Epoch 1/3:   1%|          | 21501/2184115 [1:10:27<127:07:59,  4.73it/s, loss=5.3] 


[Inference at step 21500]


Epoch 1/3:   1%|          | 22001/2184115 [1:12:03<134:24:20,  4.47it/s, loss=5.43]


[Inference at step 22000]


Epoch 1/3:   1%|          | 22501/2184115 [1:13:45<131:01:20,  4.58it/s, loss=5.38]


[Inference at step 22500]


Epoch 1/3:   1%|          | 23000/2184115 [1:15:24<143:26:51,  4.18it/s, loss=5.32]


[Inference at step 23000]


Epoch 1/3:   1%|          | 23501/2184115 [1:17:07<133:57:31,  4.48it/s, loss=5.7] 


[Inference at step 23500]


Epoch 1/3:   1%|          | 24001/2184115 [1:18:48<129:56:40,  4.62it/s, loss=5.43]


[Inference at step 24000]


Epoch 1/3:   1%|          | 24501/2184115 [1:20:25<126:03:15,  4.76it/s, loss=5.3] 


[Inference at step 24500]


Epoch 1/3:   1%|          | 25001/2184115 [1:22:04<132:06:27,  4.54it/s, loss=5.19]


[Inference at step 25000]


Epoch 1/3:   1%|          | 25501/2184115 [1:23:43<126:17:02,  4.75it/s, loss=115]    


[Inference at step 25500]


Epoch 1/3:   1%|          | 26000/2184115 [1:25:25<148:15:32,  4.04it/s, loss=22.3]


[Inference at step 26000]


Epoch 1/3:   1%|          | 26501/2184115 [1:27:07<131:12:49,  4.57it/s, loss=9.73]


[Inference at step 26500]


Epoch 1/3:   1%|          | 27000/2184115 [1:28:51<144:36:48,  4.14it/s, loss=12]  


[Inference at step 27000]


Epoch 1/3:   1%|▏         | 27501/2184115 [1:30:36<128:10:06,  4.67it/s, loss=7.85]


[Inference at step 27500]


Epoch 1/3:   1%|▏         | 28000/2184115 [1:32:21<139:55:45,  4.28it/s, loss=7.82]


[Inference at step 28000]


Epoch 1/3:   1%|▏         | 28500/2184115 [1:34:04<143:12:45,  4.18it/s, loss=7.3] 


[Inference at step 28500]


Epoch 1/3:   1%|▏         | 28512/2184115 [1:34:07<118:35:44,  5.05it/s, loss=7.39]


KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()
torch.cuda.synchronize()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## Sources 

### Principal references: 
- https://arxiv.org/abs/2005.14165 (GPT-3 paper)
- https://arxiv.org/abs/2002.05709 (Attention is all you need paper)
- Build a Large Language Model (from scratch) by Sebastian Raschka

### About Padding tokens in Language Modeling
- https://arxiv.org/html/2510.01238v1 