## DEFINING TRAINING PARAMETERS

In [19]:
from CatGPT_model import GPT, GPTConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizers import ByteLevelBPETokenizer
from time import time
from dataclasses import dataclass
from math import cos, pi
import os

In [20]:
# Define checkpoint path
checkpoint_path = "../models/checkpoint.pth"

# Function to save checkpoint
def save_checkpoint(model, optimizer, step, dataloader, checkpoint_path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'step': step,
        'dataloader_state': {
            'current_position': dataloader.file_pointer.tell(),
            'tokens': dataloader.tokens
        }
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at step {step}")

# Function to load checkpoint
def load_checkpoint(model, optimizer, dataloader, checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    dataloader.current_position = checkpoint['dataloader_state']['current_position']
    dataloader.tokens = checkpoint['dataloader_state']['tokens']
    dataloader.file_pointer.seek(dataloader.current_position)
    step = checkpoint['step']
    print(f"Checkpoint loaded from step {step}")
    return step

In [21]:
# Create Data Loadet class

class DataLoaderLite:
    def __init__(self, file, B, T, buffer_size=100000000):
        self.file = file
        self.B = B
        self.T = T
        self.buffer_size = buffer_size
        self.current_position = 0
        self.tokenizer = ByteLevelBPETokenizer(
            '../tokenizer/vocab.json',
            '../tokenizer/merges.txt'
        )
        self.tokens = torch.tensor([], dtype=torch.long)
        self.file_pointer = open(self.file, 'r')

    def _load_tokens(self):
        text = self.file_pointer.read(self.buffer_size)
        if not text:
            print(f"TOTAL EPOCH OF TEXT FINISHED")
            self.file_pointer.seek(0)  # Reset to the beginning if end is reached
            text = self.file_pointer.read(self.buffer_size)
        encoded = self.tokenizer.encode(text).ids
        return torch.tensor(encoded, dtype=torch.long)

    def next_batch(self):
        B, T = self.B, self.T
        while len(self.tokens) <= B * T:
            self.tokens = torch.cat((self.tokens, self._load_tokens()), dim=0)

        buf = self.tokens[:B * T + 1]
        self.tokens = self.tokens[B * T + 1:]  # Discard used tokens

        x = buf[:-1].view(B, T)  # inputs
        y = buf[1:].view(B, T)   # targets
        return x, y

    def save_state(self, path):
        state = {
            'current_position': self.file_pointer.tell(),
            'tokens': self.tokens
        }
        torch.save(state, path)
        print(f"DataLoaderLite state saved at {path}")

    def load_state(self, path):
        state = torch.load(path)
        self.current_position = state['current_position']
        self.tokens = state['tokens']
        self.file_pointer.seek(self.current_position)
        print(f"DataLoaderLite state loaded from {path}")

    def close(self):
        self.file_pointer.close()

@dataclass
class CatGPT_training_config:
    B = 2
    T = 1024
    total_batch_size = 524288
    float_matmul_precision = 'high'
    vocab_size = 32768
    max_lr = 6e-4
    min_lr = max_lr * 0.1
    warmup_steps = 35
    steps = 10000
    weight_decay = 0.1
    betas = (0.9, 0.95)
    eps = 1e-8
    compile_model = True
    use_gpu = False

CatGPT_basic_config = CatGPT_training_config()

assert (CatGPT_basic_config.total_batch_size % (CatGPT_basic_config.B * CatGPT_basic_config.T)) == 0, "make sure total_batch_size is divisible by B * T"
grad_accum_steps = CatGPT_basic_config.total_batch_size // (CatGPT_basic_config.B * CatGPT_basic_config.T)
print(f"total desired batch size: {CatGPT_basic_config.total_batch_size}")
print(f"=> calculated gradient accumulation steps: {grad_accum_steps}")


device = "cpu"

if CatGPT_training_config.use_gpu:
    if torch.cuda.is_available():
        device = "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
    print(f"Using device: {device}")


# Create DataLoader
train_loader = DataLoaderLite("../data/tiny_corpus.txt", B=CatGPT_training_config.B, T=CatGPT_training_config.T)

# Set matmul precision to lower

torch.set_float32_matmul_precision(CatGPT_training_config.float_matmul_precision)

# Create model and optimizer
model = GPT(GPTConfig(vocab_size=CatGPT_training_config.vocab_size))
model.to(device)

if CatGPT_training_config.compile_model:
    model = torch.compile(model)

# Define optimizer
optimizer = model.configure_optimizers(weight_decay=CatGPT_training_config.weight_decay, learning_rate=CatGPT_training_config.max_lr, device=device)

# Load checkpoint if exists
start_step = 0
if os.path.exists(checkpoint_path):
    start_step = load_checkpoint(model, optimizer, train_loader, checkpoint_path, device)

# Warmup + cosine decay learning rate schedule

def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < CatGPT_basic_config.warmup_steps:
        return CatGPT_basic_config.max_lr * (it + 1) / CatGPT_basic_config.warmup_steps
    # 2) if it > lr_decay_iters, return min learning rate
    if it > CatGPT_basic_config.steps:
        return CatGPT_basic_config.min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - CatGPT_basic_config.warmup_steps) / (CatGPT_basic_config.steps - CatGPT_basic_config.warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + cos(pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return CatGPT_basic_config.min_lr + coeff * (CatGPT_basic_config.max_lr - CatGPT_basic_config.min_lr)

total desired batch size: 524288
=> calculated gradient accumulation steps: 256
num decayed parameter tensors: 50, with 110886912 parameters
num non-decayed parameter tensors: 98, with 121344 parameters
using fused AdamW: False
Checkpoint loaded from step 1470


## TRAINING THE MODEL

In [None]:
for i in range(start_step, CatGPT_basic_config.steps):
    initial_time = time()
    optimizer.zero_grad()
    loss_accum = 0.0
    for micro_step in range(grad_accum_steps):
        print(micro_step)
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)

        if device == "cuda":
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                logits, loss = model(x, y)
        else:
            logits, loss = model(x, y)
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()
        loss.backward()
    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # Update the learning rate
    lr = get_lr(i)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    optimizer.step()
    dt = time() - initial_time
    tokens_processed = train_loader.B * train_loader.T * grad_accum_steps
    tokens_per_second = tokens_processed / dt
    print(f"Step {i} | Loss: {loss_accum.item()} | Time: {dt} | Tokens/s: {tokens_per_second} | LR: {lr}")

    # Save checkpoint periodically
    if (i + 1) % 250 == 0:
        save_checkpoint(model, optimizer, i + 1, train_loader, checkpoint_path)

# EVALUATING THE MODEL

In [48]:
def generate_text(input_text = 'La intel·ligència artificial tindrà la capactiat de', num_return_sequences = 1, max_length = 100):

    enc = ByteLevelBPETokenizer(
        '../tokenizer/vocab.json',
        '../tokenizer/merges.txt'
    )

    # Encode the input text
    tokens = enc.encode(input_text).ids
    tokens = torch.tensor(tokens, dtype=torch.long)  # (8,)

    if len(tokens) > max_length:
        max_length = len(tokens) + 25
        print(f"Max length set to {max_length} as input text is longer than the previous max length")
    tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)  # (B, 8)
    x = tokens.to(device)

    # Set manual seed for reproducibility
    torch.manual_seed(42)

    # Generate sequences
    with torch.no_grad():
        for _ in range(max_length - x.size(1)):
            logits, _ = model(x)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            x = torch.cat((x, next_token), dim=1)

    # Decode and print the generated sequences
    for i in range(num_return_sequences):
        tokens = x[i].tolist()
        decoded = enc.decode(tokens)
        print(f"Sample {i+1}: {decoded}")

In [59]:
generate_text("Estic fart de la situació actual de", num_return_sequences=5, max_length=2)

Max length set to 33 as input text is longer than the previous max length
Sample 1: Estic fart de la situació actual de la nostra gran societat. No paro d'Uniy Afo, encara que no trobes verd mana i decidim anar cap
Sample 2: Estic fart de la situació actual de la nostra feina de seguretat i salut com més seguirem. Aquesta professora em dec el massatge amb la resta d'ingredients i
Sample 3: Estic fart de la situació actual de l'habitatge se't deixarà seduir seguint les instruccions senzilles. Aquestes instruccions són precises pels centres d'informació i catalogació que permeten
Sample 4: Estic fart de la situació actual de la topota d'aquesta escola. • Noorientar les bones sensacions amb els amics del centre, potenciar el vincle amb la
Sample 5: Estic fart de la situació actual de contaminació provocada pel medi ambient. Avui hem arranjat el blog cap a les consumicions de combustibles fòssils que hem vist. Bona
