# Cap√≠tulo 05 ‚Äî Pr√©-Treinamento e Gera√ß√£o de Texto

Este notebook acompanha o Cap√≠tulo 05 da s√©rie **Fazendo um LLM do Zero**.

Neste notebook vamos ensinar o GPTMini a aprender linguagem.

üéØ **Objetivos deste notebook:**
- Como calcular loss probabil√≠stica
- Como funciona o loop de treinamento
- Como monitorar aprendizado
- Como gerar texto com diferentes estrat√©gias (Greedy, Temperature, Top-K, Top-P)
- Como salvar e carregar modelos


## 1. Setup e Configura√ß√£o

In [None]:
# ============================================================
# Setup do reposit√≥rio no Colab
# ============================================================
import os
import sys

REPO_URL = "https://github.com/vongrossi/fazendo-um-llm-do-zero.git"
REPO_DIR = "fazendo-um-llm-do-zero"

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL}

os.chdir(REPO_DIR)
sys.path.append(os.getcwd())
print("Diret√≥rio atual:", os.getcwd())


In [None]:
!pip -q install -r 05-pre-treinamento/requirements.txt

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
import random
import numpy as np
from lib.gptmini import GPTConfig, GPTMini

device = "cuda" if torch.cuda.is_available() else "cpu"
print("‚úÖ Device:", device)

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


## 2. Prepara√ß√£o do Dataset

Vamos usar um dataset de texto simples para observar o modelo aprendendo as transi√ß√µes de palavras.

In [None]:
text = """
o gato subiu no telhado
o cachorro subiu no sofa
o gato dormiu no sofa
o cachorro dormiu no tapete
o gato pulou no muro
""".strip().lower()

tokens = text.split()
vocab = sorted(set(tokens))

stoi = {t:i for i,t in enumerate(vocab)}
itos = {i:t for t,i in stoi.items()}

encoded = [stoi[t] for t in tokens]

def build_dataset(token_ids, context_size):
    X, Y = [], []
    for i in range(len(token_ids) - context_size):
        x = token_ids[i : i + context_size]
        y = token_ids[i + 1 : i + context_size + 1]
        X.append(x)
        Y.append(y)
    return torch.tensor(X, dtype=torch.long), torch.tensor(Y, dtype=torch.long)

context_size = 5
X, Y = build_dataset(encoded, context_size)

# Split Treino / Valida√ß√£o
N = X.size(0)
perm = torch.randperm(N)
split = int(0.85 * N)
train_idx = perm[:split]
val_idx = perm[split:]

X_train, Y_train = X[train_idx].to(device), Y[train_idx].to(device)
X_val, Y_val = X[val_idx].to(device), Y[val_idx].to(device)


## 3. Otimiza√ß√£o e Treinamento

Agora vamos instanciar o modelo e treinar usando a **Cross Entropy** em toda a sequ√™ncia.

In [None]:
config = GPTConfig(
    vocab_size=len(vocab),
    context_size=context_size,
    d_model=64,
    n_heads=4,
    n_layers=2
)

model = GPTMini(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
vocab_size = len(vocab)

train_loss_history = []
val_loss_history = []

@torch.no_grad()
def eval_val_loss():
    model.eval()
    logits, _ = model(X_val)
    B, T, V = logits.shape
    loss = F.cross_entropy(logits.view(-1, V), Y_val.view(-1))
    model.train()
    return loss.item()

print("üöÄ Iniciando Treinamento...")
model.train()
for step in range(601):
    idx = torch.randint(0, X_train.size(0), (16,), device=device)
    xb, yb = X_train[idx], Y_train[idx]

    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_loss_history.append(loss.item())

    if step % 50 == 0:
        vloss = eval_val_loss()
        val_loss_history.append((step, vloss))
        print(f"Step {step:03d} | Train Loss: {loss.item():.4f} | Val Loss: {vloss:.4f}")


## 4. Visualiza√ß√£o de Performance

Vamos observar como a Loss caiu durante o treinamento.

In [None]:
# Plot do training loss
plt.figure(figsize=(10, 4))
plt.plot(train_loss_history, label="Train Loss", alpha=0.6)
if val_loss_history:
    steps_v, losses_v = zip(*val_loss_history)
    plt.scatter(steps_v, losses_v, color='red', label="Val Loss")
plt.title("Evolu√ß√£o do Treinamento")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.legend()
plt.show()


## 5. Estrat√©gias de Gera√ß√£o (Decoding)

Aqui definimos as fun√ß√µes que permitem ao modelo escolher os pr√≥ximos tokens.

In [None]:
def encode_text(s):
    return [stoi[t] for t in s.lower().split() if t in stoi]

def decode(ids):
    return " ".join(itos[int(i)] for i in ids)

@torch.no_grad()
def generate_greedy(start_tokens, max_new_tokens=10):
    model.eval()
    idx = torch.tensor(start_tokens, dtype=torch.long, device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits, _ = model(idx[:, -context_size:])
        next_id = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        idx = torch.cat([idx, next_id], dim=1)
    return idx.squeeze(0).tolist()

@torch.no_grad()
def generate_temperature(start_tokens, max_new_tokens=10, temperature=1.0):
    model.eval()
    idx = torch.tensor(start_tokens, dtype=torch.long, device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits, _ = model(idx[:, -context_size:])
        logits = logits[:, -1, :] / max(temperature, 1e-6)
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)
    return idx.squeeze(0).tolist()

@torch.no_grad()
def generate_top_k(start_tokens, max_new_tokens=10, temperature=1.0, k=5):
    model.eval()
    idx = torch.tensor(start_tokens, dtype=torch.long, device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits, _ = model(idx[:, -context_size:])
        logits = logits[:, -1, :] / max(temperature, 1e-6)
        v, _ = torch.topk(logits, min(k, logits.size(-1)))
        logits[logits < v[:, [-1]]] = -float('Inf')
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)
    return idx.squeeze(0).tolist()

@torch.no_grad()
def generate_top_p(start_tokens, max_new_tokens=10, temperature=1.0, p=0.9):
    model.eval()
    idx = torch.tensor(start_tokens, dtype=torch.long, device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits, _ = model(idx[:, -context_size:])
        logits = logits[:, -1, :] / max(temperature, 1e-6)
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > p
        sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
        sorted_indices_to_remove[:, 0] = False
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float('Inf')
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)
    return idx.squeeze(0).tolist()


## 6. Teste de Gera√ß√£o Comparativa

Vamos ver como cada estrat√©gia se comporta com a mesma entrada.

In [None]:
start = encode_text("o gato")
print("Entrada:", decode(start))
print("-" * 30)
print("Greedy        :", decode(generate_greedy(start, max_new_tokens=8)))
print("Temperature 0.8:", decode(generate_temperature(start, max_new_tokens=8, temperature=0.8)))
print("Top-k (k=5)    :", decode(generate_top_k(start, max_new_tokens=8, temperature=1.0, k=5)))
print("Top-p (p=0.9)  :", decode(generate_top_p(start, max_new_tokens=8, temperature=1.0, p=0.9)))


## 7. Persist√™ncia: Checkpoints

Salvar o modelo permite que voc√™ o utilize em outros notebooks (como os de Fine-tuning).

In [None]:
ckpt = {
    "state_dict": model.state_dict(),
    "stoi": stoi,
    "itos": itos,
    "config": config
}
torch.save(ckpt, "gpt_checkpoint.pt")
print("‚úÖ Checkpoint salvo com sucesso!")


## 8. Conclus√£o

Voc√™ acabou de ensinar um GPT a aprender linguagem.

Voc√™ viu:
- Como calcular cross entropy em sequ√™ncias
- Como funciona o loop de treinamento iterativo
- Como monitorar o aprendizado com gr√°ficos de Loss
- Como controlar a criatividade da gera√ß√£o com Temperature e Nucleus Sampling

No pr√≥ximo cap√≠tulo, vamos levar este modelo para a 