In [None]:
try:
    import torch, datasets, transformers  # noqa
except Exception as e:
    !pip -q install torch --index-url https://download.pytorch.org/whl/cu121
    !pip -q install datasets transformers
    import torch, datasets, transformers  # noqa

print("PyTorch:", torch.__version__)

  from .autonotebook import tqdm as notebook_tqdm


PyTorch: 2.8.0


In [1]:

# Core imports
import math, os, random
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F

# Special tokens + vocab (byte-level + EOS/CLS)
EOS_ID = 256
CLS_ID = 257
VOCAB_SIZE = 258

class ByteTokenizer:
    def __init__(self):
        self.vocab_size = VOCAB_SIZE
    def encode(self, s: str, add_eos: bool = False):
        ids = list(s.encode('utf-8', errors='ignore'))
        if add_eos:
            ids.append(EOS_ID)
        return ids
    def decode(self, ids):
        byte_vals = [i for i in ids if 0 <= i <= 255]
        try:
            return bytes(byte_vals).decode('utf-8', errors='ignore')
        except Exception:
            return ""

@dataclass
class GPTConfig:
    vocab_size: int = VOCAB_SIZE
    block_size: int = 128
    n_layer: int = 4
    n_head: int = 4
    n_embd: int = 128
    dropout: float = 0.1


In [5]:

class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.n_head = config.n_head
        self.head_dim = config.n_embd // config.n_head
        self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
        mask = torch.tril(torch.ones(config.block_size, config.block_size)).view(1,1,config.block_size,config.block_size)
        self.register_buffer("mask", mask)

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1,2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1,2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1,2)
        att = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
        att = att.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
        att = att.softmax(dim=-1)
        att = self.dropout(att)
        y = att @ v
        y = y.transpose(1,2).contiguous().view(B,T,-1)
        y = self.proj(y)
        y = self.dropout(y)
        return y

class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.n_embd, 4*config.n_embd)
        self.fc2 = nn.Linear(4*config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp  = MLP(config)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
        self.drop = nn.Dropout(config.dropout)
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.clf_head = nn.Linear(config.n_embd, 2)  # binary sentiment
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None: nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, return_hidden=False):
        B, T = idx.size()
        assert T <= self.config.block_size
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        if return_hidden:
            return logits, x
        return logits

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=50, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.block_size:]
            logits = self(idx_cond)[:, -1, :] / max(1e-6, temperature)
            if top_k is not None:
                v, ix = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = -float('inf')
            probs = logits.softmax(dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx


In [6]:

def build_pretrain_corpus():
    passages = [
        "Once upon a time there was a tiny transformer model learning to predict words.",
        "Transfer learning lets a model learn broadly first and specialize later for tasks.",
        "Language models read text and try to guess the next token in the sequence.",
        "Pretraining followed by finetuning changed how NLP systems are built."
    ]
    return "\n".join(passages)

def sample_lm_batch(raw_ids, block_size, batch_size, device):
    x = torch.zeros((batch_size, block_size), dtype=torch.long, device=device)
    y = torch.zeros((batch_size, block_size), dtype=torch.long, device=device)
    for i in range(batch_size):
        start = random.randint(0, max(0, len(raw_ids) - block_size - 2))
        chunk = raw_ids[start : start + block_size + 1]
        x[i] = torch.tensor(chunk[:-1], dtype=torch.long, device=device)
        y[i] = torch.tensor(chunk[1:],  dtype=torch.long, device=device)
    return x, y

def pad_or_truncate(ids, block_size):
    ids = ids[:block_size]
    if len(ids) < block_size:
        ids = ids + [EOS_ID] * (block_size - len(ids))
    return ids


In [7]:

#@title Pre-train (toy corpus) { run: "auto" }
steps = 300 #@param {type:"slider", min:50, max:2000, step:50}
batch_size = 16 #@param {type:"slider", min:4, max:64, step:4}
lr = 3e-4 #@param {type:"number"}
n_layer = 4 #@param {type:"slider", min:2, max:8, step:1}
n_head  = 4 #@param {type:"slider", min:2, max:8, step:1}
n_embd  = 128 #@param {type:"slider", min:64, max:512, step:32}
block_size = 128 #@param {type:"slider", min:64, max:256, step:32}
dropout = 0.1 #@param {type:"number"}

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

tokenizer = ByteTokenizer()
config = GPTConfig(block_size=block_size, n_layer=n_layer, n_head=n_head, n_embd=n_embd, dropout=dropout)
model = GPT(config).to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

raw = ByteTokenizer().encode(build_pretrain_corpus(), add_eos=True) * 50
for step in range(1, steps+1):
    x, y = sample_lm_batch(raw, block_size, batch_size, device)
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, config.vocab_size), y.view(-1))
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    if step % 50 == 0:
        print(f"step {step:4d} | lm_loss {loss.item():.4f}")

torch.save({"model": model.state_dict(), "config": config.__dict__}, "gpt1_tiny_pretrain.pt")
print("Saved: gpt1_tiny_pretrain.pt")


Device: cpu
step   50 | lm_loss 2.8102
step  100 | lm_loss 1.9693
step  150 | lm_loss 1.4381
step  200 | lm_loss 1.0412
step  250 | lm_loss 0.7282
step  300 | lm_loss 0.5440
Saved: gpt1_tiny_pretrain.pt


In [8]:

#@title Fine‑tune (choose toy/SST‑2/IMDB)
dataset_choice = "toy" #@param ["toy", "sst2", "imdb"]
epochs = 5 #@param {type:"slider", min:1, max:10, step:1}
batch_size = 16 #@param {type:"slider", min:4, max:64, step:4}
lr = 3e-4 #@param {type:"number"}

from datasets import load_dataset

# Load model
ckpt = torch.load("gpt1_tiny_pretrain.pt", map_location=device)
config = GPTConfig(**ckpt["config"])
model = GPT(config).to(device)
model.load_state_dict(ckpt["model"])
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
tok = ByteTokenizer()

def build_toy_sentiment():
    pos = ["I loved this movie so much!",
           "What a fantastic, uplifting experience.",
           "Absolutely wonderful and inspiring.",
           "Great cast, smart writing, had a blast."]
    neg = ["I hated every minute of it.",
           "This was boring and predictable.",
           "Terrible script and weak acting.",
           "A disappointing and messy film."]
    texts = pos + neg
    labels = [1]*len(pos) + [0]*len(neg)
    return texts, labels

def to_batches(texts, labels, bs):
    for i in range(0, len(texts), bs):
        yield texts[i:i+bs], labels[i:i+bs]

def build_cls_batch(texts, labels):
    idxs, ys = [], []
    for t, y in zip(texts, labels):
        ids = [CLS_ID] + tok.encode(t, add_eos=True)
        ids = ids[:config.block_size] + [EOS_ID] * max(0, config.block_size - len(ids))
        idxs.append(ids); ys.append(y)
    x = torch.tensor(idxs, dtype=torch.long, device=device)
    y = torch.tensor(ys,   dtype=torch.long, device=device)
    return x, y

# Prepare data
if dataset_choice == "toy":
    train_texts, train_labels = build_toy_sentiment()
elif dataset_choice == "sst2":
    ds = load_dataset("glue", "sst2")
    train_texts = [ex["sentence"] for ex in ds["train"]]
    train_labels = [ex["label"]   for ex in ds["train"]]
elif dataset_choice == "imdb":
    ds = load_dataset("imdb")
    train_texts = [ex["text"] for ex in ds["train"]]
    train_labels = [1 if ex["label"]==1 else 0 for ex in ds["train"]]
else:
    raise ValueError("Unknown dataset_choice")

# Simple training
for epoch in range(1, epochs+1):
    # shuffle
    perm = torch.randperm(len(train_texts))
    train_texts = [train_texts[i] for i in perm]
    train_labels = [train_labels[i] for i in perm]
    total, n = 0.0, 0
    for bt, bl in to_batches(train_texts, train_labels, batch_size):
        x, y = build_cls_batch(bt, bl)
        logits_lm, hidden = model(x, return_hidden=True)
        cls_h = hidden[:, 0, :]
        logits = model.clf_head(cls_h)
        loss = F.cross_entropy(logits, y)
        optimizer.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total += loss.item(); n += 1
    print(f"epoch {epoch:2d} | clf_loss {total/max(1,n):.4f}")

torch.save({"model": model.state_dict(), "config": config.__dict__}, "gpt1_tiny_finetune.pt")
print("Saved: gpt1_tiny_finetune.pt")


  from .autonotebook import tqdm as notebook_tqdm


epoch  1 | clf_loss 0.6875
epoch  2 | clf_loss 0.6737
epoch  3 | clf_loss 0.7239
epoch  4 | clf_loss 0.6471
epoch  5 | clf_loss 0.7394
Saved: gpt1_tiny_finetune.pt


In [9]:

#@title Generate from a prompt
prompt = "Once upon a time" #@param {type:"string"}
max_new_tokens = 80 #@param {type:"slider", min:10, max:200, step:10}
temperature = 1.0 #@param {type:"number"}
top_k = 50 #@param {type:"number"}

# load any checkpoint
path = "gpt1_tiny_finetune.pt" if os.path.exists("gpt1_tiny_finetune.pt") else "gpt1_tiny_pretrain.pt"
ckpt = torch.load(path, map_location=device)
config = GPTConfig(**ckpt["config"])
model = GPT(config).to(device)
model.load_state_dict(ckpt["model"])
model.eval()

tok = ByteTokenizer()
ids = tok.encode(prompt, add_eos=False)
if not ids: ids = [EOS_ID]
idx = torch.tensor([ids[-config.block_size:]], dtype=torch.long, device=device)
out = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
print(tok.decode(out[0].tolist()))


Once upon a time the s tdan  therext se token tin pran pre wordg fera predict xt worasks.
Transk
