# Projet GPT - Train

This notebook contains the code used to train a small language model using PyTorch from scratch. The model is inspired by the GPT architecture.


#### Hardware
- RTX3060 12GB VRAM
- AMD Ryzen 7 5800X 8-Core
- 32GB RAM
- Ubuntu 22.04 LTS

In [5]:
import os
CACHE_DIR = "/media/rob/RobsDisk/cache_data_llm"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_DATASETS_CACHE'] = os.path.join(CACHE_DIR, "datasets")
os.environ['HF_METRICS_CACHE'] = os.path.join(CACHE_DIR, "metrics")
os.environ['HF_MODULES_CACHE'] = os.path.join(CACHE_DIR, "modules")


In [6]:
import torch, torch.nn as nn, torch.optim as optim
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import random, math

from datasets import load_dataset,concatenate_datasets
import tiktoken
from tqdm import tqdm

## Datasets

### Common knowledge datasets:

##### English Wikipedia crawled dataset

In [7]:
# English Wikipedia crawled dataset
# path to store the dataset cache: /Volumes/RobertsDisk
wiki_en = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', cache_dir=CACHE_DIR) 
print("English Wikipedia dataset loaded.")
print("dataset size in gb:", wiki_en.dataset_size / (1024**3))
print("Number of entries:", len(wiki_en))
print("-"*50)
print("Example entry:")
print(wiki_en[random.randint(0, len(wiki_en)-1)]['text'])


English Wikipedia dataset loaded.
dataset size in gb: 18.812774107791483
Number of entries: 6407814
--------------------------------------------------
Example entry:
Wakefield Scott Stornetta (born June 1959) is an American physicist and scientific researcher. His 1991 paper "How to Time-Stamp a Digital Document”, co-authored with Stuart Haber, won the 1992 Discover Award for Computer Software and is considered to be one of the most important papers in the development of cryptocurrencies.

Stornetta is currently a fellow at the Creative Destruction Lab, a science and technology-based startup accelerator at the Rotman School of Management at the University of Toronto. He is also a founding partner and chief scientist of Yugen Partners, a blockchain-focused venture capital firm that counsels investors on blockchain startup opportunities and governments on blockchain policy, as well as the director of the board of advisors for the American Blockchain PAC.

Career 
In 1989, Stornetta began

#### Simple stories dataset

In [8]:
# Simple stories dataset
stories = load_dataset("SimpleStories/SimpleStories", split='train', cache_dir=CACHE_DIR)
print("Simple stories dataset loaded.")
print("dataset size in mb:", stories.dataset_size / (1024**2))
print("Number of entries:", len(stories))
print("-"*50)
print("Example entry:")
print(stories[random.randint(0, len(stories)-1)]['story'])

Simple stories dataset loaded.
dataset size in mb: 3030.012650489807
Number of entries: 2115696
--------------------------------------------------
Example entry:
Yonder in a bright land, there was a wise old man. He had seen many things and learned many lessons. The village around him was full of simple folk. They often gathered by the big tree to hear him speak. One day, he began a tale of a blue dragon that lived in the mountains. "You see," he said, "this dragon was not what it seemed." 

The blue dragon looked fierce, with shiny scales and sharp teeth. But deep inside, it was kind and gentle. Many villagers feared it. They thought it would eat them if they got too close. But one day, a brave girl decided to visit the dragon. She thought, "Maybe it is not bad." As she walked to the mountain, she felt nervous but also excited. 

When she reached the cave, she found the dragon weeping. The girl asked, "Why do you cry, great dragon?" The dragon replied, "I am lonely. Everyone fears me 

##### FineWeb-Edu dataset

In [9]:
fineweb_edu = load_dataset("HuggingFaceFW/fineweb-edu", "sample-10BT",  split='train', cache_dir=CACHE_DIR)

print("FineWeb-Edu is ready.")
print("dataset size in gb:", fineweb_edu.dataset_size / (1024**3))
print("Number of entries:", len(fineweb_edu))
print("-"*50)
print("Example entry:")
print(fineweb_edu[random.randint(0, len(fineweb_edu)-1)]['text']) 


FineWeb-Edu is ready.
dataset size in gb: 45.730818568728864
Number of entries: 9672101
--------------------------------------------------
Example entry:
Curcuma longa, turmeric in English and haridra in Sanskrit (meaning the special bond with the divine Lord) has been honored for centuries as both a medicine and food, and has recently surged in popularity within the alternative health and holistic nutrition communities in the West.
How it grows
The turmeric plant grows to about three feet in height and produces both a flower and a rhizome, a horizontal stem that grows underground. India has been growing it since ancient times and is currently the largest producer together with other countries like China, Pakistan, Jamaica, Peru, Thailand and Taiwan.
This is thanks to “curcumin”, a compound found in turmeric which is a healing substance known to have the following health benefits:
• Promotes skin health
• Regulates blood sugar
• Reduces internal inflammation
• Balances cholesterol
• Bo

In [10]:
# OpenWebText2 dataset
owt2 = load_dataset("Skylion007/openwebtext", split="train", cache_dir=CACHE_DIR)
print("OpenWebText2 dataset loaded.")
print("Dataset size in GB:", owt2.dataset_size / (1024**3))
print("Number of entries:", len(owt2))
print("-"*50)
print("Example entry:")
print(owt2[random.randint(0, len(owt2)-1)]['text'])

OpenWebText2 dataset loaded.
Dataset size in GB: 37.03822539001703
Number of entries: 8013769
--------------------------------------------------
Example entry:
Jobs at Adelaide's Coca-Cola factory in doubt as company cancels shifts, union says

Posted

Workers at Adelaide's Coca-Cola production factory have had their shifts cancelled and a meeting has been called for Wednesday morning leaving about 150 jobs in doubt, the union has said.

Staff have been told to attend a meeting at the Port Road business at 8:00am for an announcement, which will coincide with the release of the company's 2016 full-year financial results.

The union, United Voice, said it understood some casual workers at the plant were stood down today.

United Voice spokesperson Carolyn Smart said it had not been advised of any other job losses.

"We would be extremely disappointed if workers were advised of any job losses tomorrow, especially given the consultation clauses contained in their enterprise agreement," Ms 

#### Some Q&A data to improve the model's ability to answer questions:

In [11]:
q_a1 = load_dataset("agentlans/text-sft-questions-answers-only", split='train', cache_dir=CACHE_DIR)
print("Q&A dataset loaded.")
print("dataset size in mb:", q_a1.dataset_size / (1024**2))
print("Number of entries:", len(q_a1))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(q_a1)-1)
print(q_a1[index]['question'][:500], "\n", q_a1[index]['answer'])

Q&A dataset loaded.
dataset size in mb: 46.480509757995605
Number of entries: 120959
--------------------------------------------------
Example entry:
Why is durability a critical factor in product design and marketing? 
 Durability is crucial because it ensures a product can withstand wear and stress, meets user satisfaction criteria, and leads to repeat purchases and a positive brand image.


In [12]:
#euclaise/reddit-instruct
reddit_instruct = load_dataset("euclaise/reddit-instruct", split='train', cache_dir=CACHE_DIR)
# reddit_instruct = load_dataset("Felladrin/ChatML-reddit-instruct-curated", split='train', cache_dir=CACHE_DIR)
print("Reddit Instruct dataset loaded.")
print("dataset size in gb:", reddit_instruct.dataset_size / (1024**3))
print("Number of entries:", len(reddit_instruct))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(reddit_instruct)-1)
print(reddit_instruct[index]['post_title'][:500], reddit_instruct[index]['post_text'][:500]), "\n", reddit_instruct[index]['comment_text'][:500]

Reddit Instruct dataset loaded.
dataset size in gb: 0.09901080373674631
Number of entries: 84784
--------------------------------------------------
Example entry:
[Magic: The Gathering] What exactly is "The Roil" on Zendikar? 


(None,
 '\n',
 "It's a traveling volitality of extreme climate and geographical upheaval that moves across the plane as a result of its unique mana. Weather rapidly changes, land rises and falls, and oceans fill and dry in its wake. It's the primary reason why Zendikar remains so wild and untamed; you can't even reliably map the place for long, let alone settle it in any but the most fleeting sense. \n\nIt's also why Zendikar is such a great prison for the Eldrazi.")

In [13]:
# tatsu-lab/alpaca ( for Q&A fine-tuning )
alpaca = load_dataset("tatsu-lab/alpaca", split='train')
print("Alpaca dataset loaded.")
print("dataset size in mb:", alpaca.dataset_size / (1024**2))
print("Number of entries:", len(alpaca))
print("-"*50)
print("Example entry:")
index = random.randint(0, len(alpaca)-1)
print(alpaca[index]['instruction'][:500], "\n", alpaca[index]['output'][:500])

Alpaca dataset loaded.
dataset size in mb: 44.06797695159912
Number of entries: 52002
--------------------------------------------------
Example entry:
Give me five examples of artificial intelligence technologies and how they are being used in the world today. 
 1. Machine Learning: a branch of AI that enables systems to learn from data and improve their capabilities without being explicitly programmed. Machine Learning is used to analyse data, recognize patterns, and make predictions.

2. Natural Language Processing: a branch of AI that focuses on understanding and generating human languages. NLP is used by many businesses to automate customer service and automated chat bots.

3. Computer Vision: a branch of AI which enables systems to interpret, analy


## Data Preprocessing

#### Tokenizer setup

For this project i use tiktoken for the tokenizer, as it is the same tokenizer used by OpenAI for their models.

I use the "gpt2" encoding which is a byte pair encoding (BPE) tokenizer.

In [14]:
tokenizer_base = tiktoken.get_encoding("gpt2")

tokenizer = tiktoken.Encoding(
    name="rob-tokenizer",
    pat_str=tokenizer_base._pat_str,
    mergeable_ranks=tokenizer_base._mergeable_ranks,
    special_tokens={
        **tokenizer_base._special_tokens,
        "<|im_start|>": 50257,
        "<|im_end|>": 50258,
        "<|pad|>": 50259,
    }
)

#### Test of the byte pair encoding tokenizer 

In [15]:
# test of tokenizer on reddit_instruct
sample_text = reddit_instruct[0]['post_title'] + " " + reddit_instruct[0]['post_text'] + " " + reddit_instruct[0]['comment_text']
tokens = tokenizer.encode(sample_text)
print(tokens)
print("Decoded text:")
print(tokenizer.decode(tokens)) 
print(f"Sample text length in characters: {len(sample_text)}")
print(f"Sample text length in tokens: {len(tokens)}")   

[2061, 318, 24207, 1616, 2587, 30, 314, 2342, 257, 7684, 286, 1097, 5861, 290, 484, 1561, 546, 275, 32512, 7021, 290, 884, 11, 1312, 373, 11263, 644, 275, 32512, 318, 290, 1312, 18548, 1064, 597, 2562, 7468, 284, 644, 340, 318, 24207, 1616, 318, 655, 262, 1438, 329, 257, 16058, 286, 6147, 13, 554, 262, 29393, 995, 340, 338, 1690, 973, 355, 257, 1790, 1021, 329, 3354, 326, 547, 3235, 1389, 503, 286, 257, 1263, 2512, 286, 2587, 11, 355, 6886, 284, 11721, 3350, 654, 810, 44030, 6147, 318, 19036, 656, 257, 15936, 12070, 503, 286, 9629, 6147, 13, 7080, 3191, 318, 517, 5789, 329, 1588, 17794, 475, 340, 460, 779, 1365, 3081, 286, 21782, 290, 318, 4577, 284, 787, 329, 4833, 17794, 588, 3234, 3354, 13]
Decoded text:
What is Billet material? I watch a bunch of car videos and they talk about billet blocks and such, i was wondering what billet is and i cant find any easy explanation to what it is Billet is just the name for a chunk of metal. In the automotive world it's often used as a short hand 

### Formatting datasets functions

#### Merging datasets

In [16]:
from datasets import concatenate_datasets
combined_train_dataset = concatenate_datasets([  
    wiki_en,
    stories,
    fineweb_edu,
    owt2,  
])  

combined_finetune_dataset = concatenate_datasets([
    q_a1,
    reddit_instruct,
    alpaca,
])

# Shuffle the combined dataset
train_dataset = combined_train_dataset.shuffle(seed=42)
finetune_dataset = combined_finetune_dataset.shuffle(seed=42)
print(f"Train dataset size: {len(combined_train_dataset)}")
print(f"Finetune dataset size: {len(combined_finetune_dataset)}")

# Exemple 

print("Example entry from train dataset:")
index = random.randint(0, len(train_dataset)-1)
print(train_dataset[index])   

Train dataset size: 26209380
Finetune dataset size: 257745
Example entry from train dataset:
{'id': None, 'url': None, 'title': None, 'text': "btcx\n\nSr. Member\n\nOffline\n\nActivity: 302\n\nMerit: 253\n\nVIPSr. MemberActivity: 302Merit: 253 [ANN] OGRR.COM adds 20,000 users, HIRING Full-Time PHP Developer, Designer November 13, 2012, 10:55:03 PM\n\nLast edit: November 17, 2012, 03:52:42 AM by btcx #1\n\nCurrently, we're in search of a full-time Senior Web Developer to continue expanding and refining the features on the forum, and help grow it to one of the largest communities on the Internet.\n\nRequirements:\n\nLinux, Apache, MySQL, PHP\n\nJavascript, Node.js\n\nCSS3, HTML5\n\n~$50/hr negotiable, equity possibilities\n\nAlso, we're looking for someone reliable with some CSS skills to create a few alternative themes for the forums.\n\nIf this sounds like something you might be interested in, PM me here or over at Ogrr.com has recently completed its merger with mmoexchange.net, adding

#### Custom Dataset class

Inspired by the dataloader from the "LLMs from scratch" repository. But adapted for multi-row text arrays.

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/dataloader.ipynb

In [17]:
class GPTDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=512):
        """
        Args:
            dataset: Dataset of the combined hugginface entries
            tokenizer: the initiatokenizer to process text
            max_length: Context window size
        """
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.input_tokens = []
        self.target_tokens = []

        self.pad_token_id = 50259         # <|pad|>
        self.bos_token_id = 50257    # <|im_start|>
        self.eos_token_id = 50258    # <|im_end|>

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Raw text
        
        # Data format handling
        entry = self.data[idx]
        if 'text' in entry:
            text = entry['text']
        elif 'story' in entry:
            text = entry['story']
        elif 'question' in entry and 'answer' in entry: 
            text = "User: " + entry['question'] + " Assistant:" + entry['answer']
        elif 'post_title' in entry and 'post_text' in entry and 'comment_text' in entry:
            text = "User: " + entry['post_title'] + " Assistant:" + entry['post_text'] + " " + entry['comment_text']
        elif 'instruction' in entry and 'output' in entry:
            text = "User: " + entry['instruction'] + " Assistant:" + entry['output']
        else:
            raise ValueError("Unknown data entry format")
        
        text = str(text) # Ensure text is a string
        #print(text)

        # Adding Start and End tokens
        text = "<|im_start|>" + text + "<|im_end|>" 

        # Tokenization
        tokens = self.tokenizer.encode(text, allowed_special="all")

        # Truncation
        tokens = tokens[:self.max_length] #Data is loost here ( fix later with sliding window )

        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)  # All tokens except last
        labels = torch.tensor(tokens[1:], dtype=torch.long)      # All tokens except first


        #Padding 
        padding_length = self.max_length - len(tokens)
        if padding_length > 0:
            input_ids = torch.cat([input_ids, torch.full((padding_length,), self.pad_token_id)])
            labels = torch.cat([labels, torch.full((padding_length,), -100)])


        attention_mask = (input_ids != self.pad_token_id).long() # 1 for real tokens, 0 for padding

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

### DataLoader setup

In [42]:
# Empty cuda cache and memory management
torch.cuda.empty_cache()

train_dataset = GPTDataset(combined_train_dataset, tokenizer, max_length=512)
print(f"Train dataset size: {len(train_dataset)}")

batch_size = 10
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Train dataset size: 26209380


##### Test of a entry from dataloader

In [19]:
print("##### Test of a entry from dataloader")
batch = next(iter(train_dataloader))
print(batch)    

##### Test of a entry from dataloader
{'input_ids': tensor([[50257,  1525, 17704,  ...,  1099,   290, 21825],
        [50257, 38202, 50173,  ..., 50259, 50259, 50259],
        [50257, 14202, 50259,  ..., 50259, 50259, 50259],
        ...,
        [50257,  5159,  6634,  ..., 50259, 50259, 50259],
        [50257,  5886,  2063,  ...,  1751,  2299, 16856],
        [50257, 31160, 18024,  ...,  1440,  5852,    13]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 0,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[ 1525, 17704,   367,  ...,   290, 21825,   416],
        [38202, 50173,  2634,  ...,  -100,  -100,  -100],
        [14202, 50258,  -100,  ...,  -100,  -100,  -100],
        ...,
        [ 5159,  6634,  3203,  ...,  -100,  -100,  -100],
        [ 5886,  2063,   286,  ...,  2299, 16856,  1243],
        [31160, 18024,   349,  ...,

## GPT Model 

### GPT config 

This is the configuration for the GPT model i am going to train. It is a smaller version of the GPT-2 model. 

- Context length: 512 tokens
- Embedding dimension: 512
- Number of attention heads: 8
- Number of layers: 8

In [20]:
GPT_CONFIG = {
    "vocab_size": 50260,
    "context_length": 512,
    "emb_dim": 512,
    "number_heads": 8,
    "number_layers": 8,
    "drop_rate": 0.1,
}

#### Pytroch model implementation

For this first implementation, i am using the transformer and embedding modules from PyTorch. Later i will try to implement the attention mechanism from scratch for better understanding.

https://docs.pytorch.org/docs/stable/generated/torch.nn.Transformer.html

In [27]:
class GPTModel(nn.Module):
    """
    Gpt model class using transformer library
    """
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Network components 
        ## Embedding layers
        self.embedding = nn.Embedding(config['vocab_size'], config['emb_dim'])
        self.positional_encoding = nn.Embedding(config['context_length'], config['emb_dim'])
        ## Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config['emb_dim'],
            nhead=config['number_heads'],
            dim_feedforward=4 * config['emb_dim'],
            dropout=config['drop_rate'],
            activation='gelu',
            batch_first=True,
            norm_first=True # stabilityy 
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=config['number_layers'])
        ## Output layer
        self.output_layer = nn.Linear(config['emb_dim'], config['vocab_size'], bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, input_ids, attention_mask, label_ids=None):   
        batch_size, seq_length = input_ids.shape

        # Embedding
        token_embeddings = self.embedding(input_ids)  
        pos_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)

        position_embeddings = self.positional_encoding(pos_ids)  # (batch_size, seq_length, emb_dim)

        embeddings = token_embeddings + position_embeddings  # (batch_size, seq_length, emb_dim)

        # Prevent attending to future tokens
        causal_mask = torch.triu(torch.full((seq_length, seq_length), float('-inf'), device=input_ids.device), diagonal=1)

        # voiding to pay attatention padding tokens
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
        
        x = self.transformer(embeddings, mask=causal_mask, src_key_padding_mask=key_padding_mask, is_causal=True)
        logits = self.output_layer(x)

        # Computing loss 
        if label_ids is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.config['vocab_size']), label_ids.view(-1)) # Applies loss to predictions
            return logits, loss

        return logits, None




### Model instantiation 

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(GPT_CONFIG).to(device)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) # AdamW optimizer
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # Cosine annealing learning rate scheduler

GPTModel(
  (embedding): Embedding(50260, 512)
  (positional_encoding): Embedding(512, 512)
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_layer): Linear(in_features=512, out_features=50260, bias=False)
)


## Training setup

In [23]:
def inference(model, tokenizer, prompt, max_length=512, device='cpu'):
    model.eval()
    input_ids = tokenizer.encode(prompt, allowed_special="all")
    input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)  # (1, seq_length)

    generated_ids = input_ids
    max_length = max_length - input_ids.shape[1]  # Remaining length for generation
    with torch.no_grad():
        for _ in range(max_length):
            attention_mask = torch.ones_like(generated_ids)  # All tokens are real (no padding)
            logits, _ = model(generated_ids, attention_mask)
            next_token_logits = logits[:, -1, :]  # (1, vocab_size)
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)  # (1, 1)

            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)  # Append to sequence

            if next_token_id.item() == 100265:  # Stop if <|im_end|> token is generated
                break

    generated_text = tokenizer.decode(generated_ids.squeeze().tolist())
    return generated_text


In [None]:
def train_loop(model, dataloader, optimizer, scheduler, device, num_epochs=3, question_interval=500):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        step_count = 0

        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()
            logits, loss = model(input_ids, attention_mask, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()  # Step scheduler every batch

            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

            # Inference check
            step_count += 1
            if step_count % question_interval == 0:
                model.eval()
                with torch.no_grad():
                    prompt = "You are an AI being trained. How are you doing? Please answer briefly."
                    generated_text = inference(model, tokenizer, prompt, max_length=50, device=device)
                    print(f"\n[Inference at step {step_count}]")
                model.train()

        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")


### First training on combined train dataset

In [38]:
#empty gpu memory 
torch.cuda.empty_cache()
torch.cuda.synchronize()
 

train_loop(model, train_dataloader, optimizer, scheduler, device, num_epochs=3)

Epoch 1/3:   0%|          | 500/3276173 [02:47<340:48:45,  2.67it/s, loss=6.61]


[Inference at step 500]: You are an AI being trained. How are you doing? Please answer briefly.





































Epoch 1/3:   0%|          | 1000/3276173 [05:34<343:31:09,  2.65it/s, loss=6.58]


[Inference at step 1000]: You are an AI being trained. How are you doing? Please answer briefly.





































Epoch 1/3:   0%|          | 1396/3276173 [07:46<303:56:14,  2.99it/s, loss=6.4] 


KeyboardInterrupt: 

In [41]:
torch.cuda.empty_cache()
torch.cuda.synchronize()

## Sources 

### Principal references: 
- https://arxiv.org/abs/2005.14165 (GPT-3 paper)
- https://arxiv.org/abs/2002.05709 (Attention is all you need paper)
- Build a Large Language Model (from scratch) by Sebastian Raschka

### About Padding tokens in Language Modeling
- https://arxiv.org/html/2510.01238v1 