# NanoChat GPT: Training a Small Chat Model

In this notebook you will:
- Inspect NanoChat's GPT architecture
- Run a forward pass
- Compute masked language modeling loss
- Train a tiny chat model for a few steps

In [34]:
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer

tokenizer = get_tokenizer()

In [35]:
config = GPTConfig(
    vocab_size=tokenizer.get_vocab_size(),
    n_layer=4,
    n_head=4,
    n_kv_head=4,
    n_embd=256,
)
model = GPT(config)

In [36]:
print(f"Nanochat has almost {round(sum(p.numel() for p in model.parameters()) /1e6, 2)} MM parameters")

Nanochat has almost 36.7 MM parameters


<img src = "LLM_Size.png">

In [17]:
conversation = {
    "messages": [
        {"role": "user", "content": "What is a transformer?"},
        {"role": "assistant", "content": "A transformer is a neural network based on attention."}
    ]
}

In [18]:
ids, loss_mask = tokenizer.render_conversation(conversation)

print("Number of tokens:", len(ids))
print("Loss tokens:", sum(loss_mask))

Number of tokens: 20
Loss tokens: 11


In [19]:
import torch

input_ids = torch.tensor(ids).unsqueeze(0)  # (1, T)
logits = model(input_ids)

logits.shape

torch.Size([1, 20, 65536])

In [20]:
loss_mask_t = torch.tensor(loss_mask).unsqueeze(0)

targets = input_ids[:, 1:]
logits = logits[:, :-1, :]

loss_mask_t = loss_mask_t[:, 1:]

log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
target_log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)

loss = -(target_log_probs * loss_mask_t).sum() / loss_mask_t.sum()
loss

tensor(11.4196, grad_fn=<DivBackward0>)

In [21]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for step in range(100):
    optimizer.zero_grad()
    logits = model(input_ids)
    
    logits = logits[:, :-1, :]
    targets = input_ids[:, 1:]
    mask = loss_mask_t
    
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    loss = -(log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) * mask).sum() / mask.sum()
    
    loss.backward()
    optimizer.step()
    
    if step % 10 == 0:
        print(f"step {step} | loss {loss.item():.4f}")

step 0 | loss 11.4196
step 10 | loss 3.2912
step 20 | loss 2.3380
step 30 | loss 1.7654
step 40 | loss 1.2236
step 50 | loss 0.7951
step 60 | loss 0.4974
step 70 | loss 0.3161
step 80 | loss 0.2135
step 90 | loss 0.1551


In [33]:
from nanochat.engine import Engine

engine = Engine(model, tokenizer)

print("Model output:")
print("-" * 40)

input_ids = tokenizer.render_for_completion(conversation)

output = ""
for token_ids, token_mask in engine.generate(
    input_ids,
    max_tokens=50,
    temperature=0.8,
    top_k=40,
):
    text_piece = tokenizer.decode(token_ids)
    print(text_piece, end="", flush=True)
    output += text_piece

print("\n" + "-" * 40)

Model output:
----------------------------------------
A transformer is a neural network based on attention.<|assistant_end|>
----------------------------------------
