# Projet GPT - Train

This notebook contains the code used to train a small language model using PyTorch from scratch. The model is inspired by the GPT architecture.


#### Hardware
- RTX3060 12GB VRAM
- AMD Ryzen 7 5800X 8-Core
- 32GB RAM
- Ubuntu 22.04 LTS

In [44]:
import os
CACHE_DIR = "/media/rob/RobsDisk/cache_data_llm"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, "datasets")
os.environ['HF_METRICS_CACHE'] = os.path.join(CACHE_DIR, "metrics")
os.environ['HF_MODULES_CACHE'] = os.path.join(CACHE_DIR, "modules")


In [45]:
import torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import random, math

from datasets import load_dataset,concatenate_datasets
import tiktoken
from tqdm import tqdm

## Datasets

### Common knowledge datasets:

##### English Wikipedia crawled dataset

In [46]:
# English Wikipedia crawled dataset
# path to store the dataset cache: /Volumes/RobertsDisk
wiki_en = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', cache_dir=CACHE_DIR) 
print("English Wikipedia dataset loaded.")
print("dataset size in gb:", wiki_en.dataset_size / (1024**3))
print("Number of entries:", len(wiki_en))
print("-"*50)
print("Example entry:")
print(wiki_en[random.randint(0, len(wiki_en)-1)]['text'])


English Wikipedia dataset loaded.
dataset size in gb: 18.812774107791483
Number of entries: 6407814
--------------------------------------------------
Example entry:
Asherton Independent School District was a public school district based in the community of Asherton, Texas (USA).

The community was along the Mexico–United States border, and was southwest of San Antonio.

The district existed until 1999, when it was ordered consolidated with Carrizo Springs schools by the Texas Education Agency (TEA) to form the Carrizo Springs Consolidated Independent School District.  The TEA determined that the Asherton ISD had not followed proper procedure in assessing school property taxes, the basis for its closure. TEA officials stated that "financial instability/insolvency" was the reason behind the district's dissolution.

The Carrizo Springs district moved all students in grades 7-12 into its existing junior high and high schools, but left the Asherton school open to serve grades PreK-6.

Dist

#### Simple stories dataset

In [47]:
# Simple stories dataset
stories = load_dataset("SimpleStories/SimpleStories", split='train', cache_dir=CACHE_DIR)
print("Simple stories dataset loaded.")
print("dataset size in mb:", stories.dataset_size / (1024**2))
print("Number of entries:", len(stories))
print("-"*50)
print("Example entry:")
print(stories[random.randint(0, len(stories)-1)]['story'])

Simple stories dataset loaded.
dataset size in mb: 3030.012650489807
Number of entries: 2115696
--------------------------------------------------
Example entry:
Quietly, the ship sailed through the vast emptiness of space. Samuel sat alone, gazing out at the endless dark. The stars were like diamonds, but they brought no comfort. He felt a chill in the air, a reminder of the warmth he had left behind. The hum of the engine filled the silence, but it felt like a lonely song. 

As he stared out, he caught the smell of old metal mixed with the faint scent of fuel. It made him think of his childhood, of the fresh smell of grass after rain. He longed for the sound of the wind in the trees, for the taste of fruit just picked. Everything felt so far away, hidden among the stars. 

One night, the ship drew near a glowing nebula. The colors swirled like a painting, vibrant and alive. Samuel's heart leapt. He could almost feel the warmth of the light, like a blanket wrapping around him. He want

##### FineWeb-Edu dataset

In [48]:
fineweb_edu = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT",  split='train', cache_dir=CACHE_DIR)

print("FineWeb-Edu is ready.")
print("dataset size in gb:", fineweb_edu.dataset_size / (1024**3))
print("Number of entries:", len(fineweb_edu))
print("-"*50)
print("Example entry:")
print(fineweb_edu[random.randint(0, len(fineweb_edu)-1)]['text']) 


FineWeb-Edu is ready.
dataset size in gb: 45.730818568728864
Number of entries: 9672101
--------------------------------------------------
Example entry:
Over the past 10 years we have been witnessing a significant increase in a number of scientific studies which analyze day-to-day decision-making by using psychological tools and experimental designs. The purpose of this new approach, often called behavioral economics, is to enrich traditional economic conceptions of human decision-making employed by the currently dominant theory of rational choice. This new approach is providing descriptively more fitting accounts of human behavior by showing, for example, that when trying to follow our decisions we, as humans, are constrained by limited will-power.
Most of us want to go to the gym more often, stop eating junk food, quit smoking, reduce drinking, and live “better” in general. And so often we fail because our will power is limited, once we focus on one target, we fail even more in othe

In [49]:
# OpenWebText2 dataset
owt2 = load_dataset("Skylion007/openwebtext", split="train", cache_dir=CACHE_DIR)
print("OpenWebText2 dataset loaded.")
print("Dataset size in GB:", owt2.dataset_size / (1024**3))
print("Number of entries:", len(owt2))
print("-"*50)
print("Example entry:")
print(owt2[random.randint(0, len(owt2)-1)]['text'])

OpenWebText2 dataset loaded.
Dataset size in GB: 37.03822539001703
Number of entries: 8013769
--------------------------------------------------
Example entry:
The view inside the Ultra High Vacuum Scanning Thermal Microscope, which was used to measure temperature fluxes at the nanoscale. Image credit: Joseph XuANN ARBOR—When heat travels between two objects that aren’t touching, it flows differently at the smallest scales—distances on the order of the diameter of DNA, or 1/50,000 of a human hair.

While researchers have been aware of this for decades, they haven’t understood the process. Heat flow often needs to be prevented or harnessed and the lack of an accurate way to predict it represents a bottleneck in nanotechnology development.

Now, in a unique ultra-low vibration lab at the University of Michigan, engineers have measured how heat radiates from one surface to another in a vacuum at distances down to 2 nanometers.

While the thermal energy still flows from the warmer place to

#### Some Q&A data to improve the model's ability to answer questions:

In [50]:
q_a1 = load_dataset("agentlans/text-sft-questions-answers-only", split='train', cache_dir=CACHE_DIR)
print("Q&A dataset loaded.")
print("dataset size in mb:", q_a1.dataset_size / (1024**2))
print("Number of entries:", len(q_a1))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(q_a1)-1)
print(q_a1[index]['question'][:500], "\n", q_a1[index]['answer'])

Q&A dataset loaded.
dataset size in mb: 46.480509757995605
Number of entries: 120959
--------------------------------------------------
Example entry:
How did the emergence of the Entente Canada-Québec (ECQ) program contribute to the improvement of education systems in Quebec? 
 The ECQ program fostered collaboration and knowledge sharing between Francophone and Anglophone colleges in Quebec, focusing on minority and second language education. This program enabled both linguistic groups to explore innovative approaches to teaching and learning, thereby enhancing outcomes.


In [51]:
#euclaise/reddit-instruct
reddit_instruct = load_dataset("euclaise/reddit-instruct", split='train', cache_dir=CACHE_DIR)
# reddit_instruct = load_dataset("Felladrin/ChatML-reddit-instruct-curated", split='train', cache_dir=CACHE_DIR)
print("Reddit Instruct dataset loaded.")
print("dataset size in gb:", reddit_instruct.dataset_size / (1024**3))
print("Number of entries:", len(reddit_instruct))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(reddit_instruct)-1)
print(reddit_instruct[index]['post_title'][:500], reddit_instruct[index]['post_text'][:500]), "\n", reddit_instruct[index]['comment_text'][:500]

Reddit Instruct dataset loaded.
dataset size in gb: 0.09901080373674631
Number of entries: 84784
--------------------------------------------------
Example entry:
ELI5: What causes the color fluctuation in older (90’s) videos? (Link inside) So I was browsing YouTube and came across an older video and noticed the color fluctuating a lot. 

What causes this? [Link](https://youtu.be/VpLQlUa2G0s)


(None,
 '\n',
 'Image is stored in VHS tapes not as a bunch of red green and blue pixels, but as analog luminance and chrominance. By now you must have read somewhere about how there are other ways to represent color. Chrominance would be Hue and Saturation, luminance would be the Value or Brightness in a HSV system.\n\nSince everything in a VHS tape is analog, the VHS player needs analog circuitry to amplify the signals. Amplifiers such as that have a parameter called gain, which refers to by how much they incre')

In [52]:
# tatsu-lab/alpaca ( for Q&A fine-tuning )
alpaca = load_dataset("tatsu-lab/alpaca", split='train')
print("Alpaca dataset loaded.")
print("dataset size in mb:", alpaca.dataset_size / (1024**2))
print("Number of entries:", len(alpaca))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(alpaca)-1)
print(alpaca[index]['instruction'][:500], "\n", alpaca[index]['output'][:500])

Alpaca dataset loaded.
dataset size in mb: 44.06797695159912
Number of entries: 52002
--------------------------------------------------
Example entry:
Name at least 3 European countries. 
 Germany, France, and the United Kingdom.


## Data Preprocessing

#### Tokenizer setup

For this project i use tiktoken for the tokenizer, as it is the same tokenizer used by OpenAI for their models.

I use the "gpt2" encoding which is a byte pair encoding (BPE) tokenizer.

In [53]:
tokenizer_base = tiktoken.get_encoding("gpt2")

tokenizer = tiktoken.Encoding(
    name="rob-tokenizer",
    pat_str=tokenizer_base._pat_str,
    mergeable_ranks=tokenizer_base._mergeable_ranks,
    special_tokens={
        **tokenizer_base._special_tokens,
        "<|im_start|>": 50257,
        "<|im_end|>": 50258,
        "<|pad|>": 50259,
    }
)

#### Test of the byte pair encoding tokenizer 

In [54]:
# test of tokenizer on reddit_instruct
sample_text = reddit_instruct[0]['post_title'] + " " + reddit_instruct[0]['post_text'] + " " + reddit_instruct[0]['comment_text']
tokens = tokenizer.encode(sample_text)
print(tokens)
print("Decoded text:")
print(tokenizer.decode(tokens)) 
print(f"Sample text length in characters: {len(sample_text)}")
print(f"Sample text length in tokens: {len(tokens)}")   

[2061, 318, 24207, 1616, 2587, 30, 314, 2342, 257, 7684, 286, 1097, 5861, 290, 484, 1561, 546, 275, 32512, 7021, 290, 884, 11, 1312, 373, 11263, 644, 275, 32512, 318, 290, 1312, 18548, 1064, 597, 2562, 7468, 284, 644, 340, 318, 24207, 1616, 318, 655, 262, 1438, 329, 257, 16058, 286, 6147, 13, 554, 262, 29393, 995, 340, 338, 1690, 973, 355, 257, 1790, 1021, 329, 3354, 326, 547, 3235, 1389, 503, 286, 257, 1263, 2512, 286, 2587, 11, 355, 6886, 284, 11721, 3350, 654, 810, 44030, 6147, 318, 19036, 656, 257, 15936, 12070, 503, 286, 9629, 6147, 13, 7080, 3191, 318, 517, 5789, 329, 1588, 17794, 475, 340, 460, 779, 1365, 3081, 286, 21782, 290, 318, 4577, 284, 787, 329, 4833, 17794, 588, 3234, 3354, 13]
Decoded text:
What is Billet material? I watch a bunch of car videos and they talk about billet blocks and such, i was wondering what billet is and i cant find any easy explanation to what it is Billet is just the name for a chunk of metal. In the automotive world it's often used as a short hand 

### Formatting datasets functions

#### Merging datasets

In [55]:
from datasets import concatenate_datasets
combined_train_dataset = concatenate_datasets([  
    wiki_en,
    stories,
    fineweb_edu,
    owt2,  
])  

combined_finetune_dataset = concatenate_datasets([
    q_a1,
    reddit_instruct,
    alpaca,
])

# Shuffle the combined dataset
train_dataset = combined_train_dataset.shuffle(seed=42)
finetune_dataset = combined_finetune_dataset.shuffle(seed=42)
print(f"Train dataset size: {len(combined_train_dataset)}")
print(f"Finetune dataset size: {len(combined_finetune_dataset)}")

# Exemple 

print("Example entry from train dataset:")
index = random.randint(0, len(train_dataset)-1)
print(train_dataset[index])   

Train dataset size: 26209380
Finetune dataset size: 257745
Example entry from train dataset:
{'id': '<urn:uuid:c543181b-bab4-40b9-b0ae-027eebc9705f>', 'url': 'https://codedocs.org/what-is/sqlite', 'title': None, 'text': '|Developer(s)||D. Richard Hipp|\n|Initial release||17 August 2000|\n|Stable release||3.35.5 (19 April 2021 )|\n.sqlite3, .sqlite, .db\n|Internet media type|\n|Open format?||yes (Public Domain)|\nSQLite (//,//) is a relational database management system (RDBMS) contained in a C library. In contrast to many other database management systems, SQLite is not a client–server database engine. Rather, it is embedded into the end program.\nSQLite generally follows PostgreSQL syntax. SQLite uses a dynamically and weakly typed SQL syntax that does not guarantee the domain integrity. This means that one can, for example, insert a string into a column defined as an integer. SQLite will attempt to convert data between formats where appropriate, the string "123" into an integer in th

#### Custom Dataset class

Inspired by the dataloader from the "LLMs from scratch" repository. But adapted for multi-row text arrays.

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/dataloader.ipynb

In [56]:
class GPTDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length):
        """
        Args:
            dataset: Dataset of the combined hugginface entries
            tokenizer: the initiatokenizer to process text
            max_length: Context window size
        """
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.input_tokens = []
        self.target_tokens = []

        self.pad_token_id = 50259         # <|pad|>
        self.bos_token_id = 50257    # <|im_start|>
        self.eos_token_id = 50258    # <|im_end|>

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Raw text
        
        # Data format handling
        entry = self.data[idx]
        if 'text' in entry:
            text = entry['text']
        elif 'story' in entry:
            text = entry['story']
        elif 'question' in entry and 'answer' in entry: 
            text = "User: " + entry['question'] + " Assistant:" + entry['answer']
        elif 'post_title' in entry and 'post_text' in entry and 'comment_text' in entry:
            text = "User: " + entry['post_title'] + " Assistant:" + entry['post_text'] + " " + entry['comment_text']
        elif 'instruction' in entry and 'output' in entry:
            text = "User: " + entry['instruction'] + " Assistant:" + entry['output']
        else:
            raise ValueError("Unknown data entry format")
        
        text = str(text) # Ensure text is a string
        #print(text)

        # Adding Start and End tokens
        text = "<|im_start|>" + text + "<|im_end|>" 

        # Tokenization
        tokens = self.tokenizer.encode(text, allowed_special="all")

        # Truncation
        tokens = tokens[:self.max_length] #Data is loost here ( fix later with sliding window )

        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)  # All tokens except last
        labels = torch.tensor(tokens[1:], dtype=torch.long)      # All tokens except first


        #Padding 
        padding_length = self.max_length - len(tokens)
        if padding_length > 0:
            input_ids = torch.cat([input_ids, torch.full((padding_length,), self.pad_token_id)])
            labels = torch.cat([labels, torch.full((padding_length,), -100)])


        attention_mask = (input_ids != self.pad_token_id).long() # 1 for real tokens, 0 for padding

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

## GPT Model 

### GPT config 

This is the configuration for the GPT model i am going to train. It is a smaller version of the GPT-2 model. 

- Context length: 512 tokens
- Embedding dimension: 512
- Number of attention heads: 8
- Number of layers: 8

In [57]:
GPT_CONFIG = {
    "vocab_size": 50260,
    "context_length": 512, # max i could fit on my gpu
    "emb_dim": 512,
    "number_heads": 8,
    "number_layers": 8,
    "drop_rate": 0.1,
}

##### Test of a entry from dataloader

In [58]:
# Empty cuda cache and memory management
torch.cuda.empty_cache()

train_dataset = GPTDataset(combined_train_dataset, tokenizer, max_length=GPT_CONFIG["context_length"])
print(f"Train dataset size: {len(train_dataset)}")

batch_size = 12
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2, persistent_workers=True)

Train dataset size: 26209380


In [59]:
print("##### Test of a entry from dataloader")
batch = next(iter(train_dataloader))
print(batch)    

##### Test of a entry from dataloader
{'input_ids': tensor([[50257, 18683, 11064,  ...,  2077,  1337,   286],
        [50257,  5956,  1285,  ...,   286,  2346,   960],
        [50257, 24857, 15428,  ..., 11004,  7401,   329],
        ...,
        [50257,   464,  1584,  ...,  1232,    13, 24690],
        [50257,  2061,   921,  ..., 50259, 50259, 50259],
        [50257,  7841,  9057,  ...,   198,   198,   464]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[18683, 11064,  1012,  ...,  1337,   286,   351],
        [ 5956,  1285,    11,  ...,  2346,   960,   392],
        [24857, 15428,  2626,  ...,  7401,   329,  6079],
        ...,
        [  464,  1584,  1906,  ...,    13, 24690,   422],
        [ 2061,   921, 10664,  ...,  -100,  -100,  -100],
        [ 7841,  9057,   308,  ...,

#### Pytroch model implementation

For this first implementation, i am using the transformer and embedding modules from PyTorch. Later i will try to implement the attention mechanism from scratch for better understanding.

https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [60]:
class GPTModel(nn.Module):
    """
    Gpt model class using transformer library
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Network components 
        ## Embedding layers
        self.embedding = nn.Embedding(config['vocab_size'], config['emb_dim'])
        self.positional_encoding = nn.Embedding(config['context_length'], config['emb_dim'])
        ## Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['emb_dim'],
            nhead=config['number_heads'],
            dim_feedforward=4 * config['emb_dim'],
            dropout=config['drop_rate'],
            activation='gelu',
            batch_first=True,
            norm_first=True # stabilityy 
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['number_layers'])
        ## Output layer
        self.output_layer = nn.Linear(config['emb_dim'], config['vocab_size'], bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, input_ids, attention_mask, label_ids=None):   
        batch_size, seq_length = input_ids.shape

        # Embedding
        token_embeddings = self.embedding(input_ids)  
        pos_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.positional_encoding(pos_ids)  # (batch_size, seq_length, emb_dim)

        embeddings = token_embeddings + position_embeddings  # (batch_size, seq_length, emb_dim)

        # Prevent attending to future tokens
        causal_mask = torch.triu(torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device), diagonal=1)

        # voiding to pay attatention padding tokens
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
        
        x = self.transformer(embeddings, mask=causal_mask, src_key_padding_mask=key_padding_mask, is_causal=True)
        logits = self.output_layer(x)

        # Computing loss 
        if label_ids is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.config['vocab_size']), label_ids.view(-1)) # Applies loss to predictions
            return logits, loss

        return logits, None




### Model instantiation 

In [61]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(GPT_CONFIG).to(device)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3) # AdamW optimizer
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # Cosine annealing learning rate scheduler
scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision

GPTModel(
  (embedding): Embedding(50260, 512)
  (positional_encoding): Embedding(512, 512)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_layer): Linear(in_features=512, out_features=50260, bias=False)
)


  scaler = torch.cuda.amp.GradScaler()  # Gradient scaler for mixed precision


## Training setup

In [62]:
def inference(model, tokenizer, prompt, max_length=256, device='cpu'):
    model.eval()
    input_ids = tokenizer.encode(prompt, allowed_special="all")
    input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

    generated_ids = input_ids
    max_length = max_length - input_ids.shape[1]  # Remaining length for generation
    with torch.no_grad():
        for _ in range(max_length):
            attention_mask = torch.ones_like(generated_ids)  # All tokens are real (no padding)
            logits, _ = model(generated_ids, attention_mask)
            next_token_logits = logits[:, -1, :]  # (1, vocab_size)
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)  # (1, 1)

            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)  # Append to sequence

            if next_token_id.item() == 50258:  # Stop if eot token is generated
                break

    generated_text = tokenizer.decode(generated_ids.squeeze().tolist())
    return generated_text


In [63]:
accumulation_steps= 2  # Number of steps to accumulate gradients

def train_loop(model, dataloader, optimizer, scheduler, device, num_epochs=3, accumulation_steps = 4, question_interval=500):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        step_count = 0

        for i, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):  # bf16 optimized for ampere archi
                logits, loss = model(input_ids, attention_mask, labels)
                loss = loss / accumulation_steps
            scaler.scale(loss).backward()
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            epoch_loss += loss.item()*accumulation_steps
            progress_bar.set_postfix(loss=loss.item() * accumulation_steps)

            # Inference check
            step_count += 1
            if step_count % question_interval == 0:
                model.eval()
                with torch.no_grad():
                    prompt = "You are an AI being trained. How are you doing?"
                    generated_text = inference(model, tokenizer, prompt, max_length=50, device=device)
                    print(f"\n[Inference at step {step_count}]")
                    #save generated text to file train_ouput.txt
                    with open("train_output.txt", "a") as f:
                        if generated_text.strip() != "":
                            f.write(f"\n[Inference at step {step_count}]: {generated_text}\n")  
                            f.write("-"*50 + "\n")
                        else:
                            f.write(f"\n[Inference at step {step_count}]: [No output generated]\n")
                            f.write("-"*50 + "\n")
                model.train()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")


### First training on combined train dataset

In [None]:
#empty gpu memory 
torch.cuda.empty_cache()
torch.cuda.synchronize()
 

train_loop(model, train_dataloader, optimizer, scheduler, device, num_epochs=3)

Epoch 1/3:   0%|          | 500/2184115 [01:43<144:48:06,  4.19it/s, loss=18.4]


[Inference at step 500]


Epoch 1/3:   0%|          | 1000/2184115 [03:27<142:25:13,  4.26it/s, loss=10.1]


[Inference at step 1000]


Epoch 1/3:   0%|          | 1500/2184115 [05:10<142:45:47,  4.25it/s, loss=16.3]


[Inference at step 1500]


Epoch 1/3:   0%|          | 2000/2184115 [06:54<144:04:37,  4.21it/s, loss=7.99]


[Inference at step 2000]


Epoch 1/3:   0%|          | 2501/2184115 [08:36<127:34:31,  4.75it/s, loss=7.77]


[Inference at step 2500]


Epoch 1/3:   0%|          | 3001/2184115 [10:12<127:22:04,  4.76it/s, loss=7.77]


[Inference at step 3000]


Epoch 1/3:   0%|          | 3501/2184115 [11:48<127:28:13,  4.75it/s, loss=7.44]


[Inference at step 3500]


Epoch 1/3:   0%|          | 4001/2184115 [13:23<127:43:00,  4.74it/s, loss=7.56]


[Inference at step 4000]


Epoch 1/3:   0%|          | 4501/2184115 [14:59<127:29:02,  4.75it/s, loss=7.42]


[Inference at step 4500]


Epoch 1/3:   0%|          | 5001/2184115 [16:36<135:30:41,  4.47it/s, loss=7.39]


[Inference at step 5000]


Epoch 1/3:   0%|          | 5501/2184115 [18:18<135:10:24,  4.48it/s, loss=7.69]


[Inference at step 5500]


Epoch 1/3:   0%|          | 6000/2184115 [20:00<140:38:47,  4.30it/s, loss=7.43]


[Inference at step 6000]


Epoch 1/3:   0%|          | 6501/2184115 [21:38<134:37:24,  4.49it/s, loss=7.32]


[Inference at step 6500]


Epoch 1/3:   0%|          | 7000/2184115 [23:17<144:23:53,  4.19it/s, loss=7.38]


[Inference at step 7000]


Epoch 1/3:   0%|          | 7501/2184115 [24:59<134:29:07,  4.50it/s, loss=7.11]


[Inference at step 7500]


Epoch 1/3:   0%|          | 7833/2184115 [26:05<123:31:57,  4.89it/s, loss=6.98]

In [None]:
torch.cuda.empty_cache()
torch.cuda.synchronize()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## Sources 

### Principal references: 
- https://arxiv.org/abs/2005.14165 (GPT-3 paper)
- https://arxiv.org/abs/2002.05709 (Attention is all you need paper)
- Build a Large Language Model (from scratch) by Sebastian Raschka

### About Padding tokens in Language Modeling
- https://arxiv.org/html/2510.01238v1 