# Fine-tuning gpt2

In [None]:
import os, time, math, pickle, random
import numpy as np, pandas as pd
from contextlib import nullcontext
import torch
from model import GPT

In [None]:
from_scratch = True
task = "commongen"
prompt_vocab_size = 20
classification_task = "classification" in task
always_save_checkpoint = False
eval_only = False

In [None]:
DATA_DIR = "data/"
MODEL_DIR = "best_models/"
IN_CHECKPOINT = "gpt.pt"
OUT_CHECKPOINT = "gpt.pt"

In [None]:
print("loading dataset for task:", task)
train_data = pd.read_csv(
    DATA_DIR + task + "_train.bin", header=None, sep="@"
).values.tolist()
train_data = [
    [[_.strip() for _ in entry[0].split("#")], entry[1].strip()] for entry in train_data
]
val_data = pd.read_csv(
    DATA_DIR + task + "_val.bin", header=None, sep="@"
).values.tolist()
val_data = [
    [[_.strip() for _ in entry[0].split("#")], entry[1].strip()] for entry in val_data
]

In [None]:
device = ["mps", "cpu", "cuda"][0]
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device =", device)

In [None]:
compile = False
if device == "cuda":
    compile = True
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    if torch.cuda.is_bf16_supported():
        ctx = torch.amp.autocast(device_type=device, dtype=torch.bfloat16)
        scaler = torch.cuda.amp.GradScaler(enabled=False)
    else:
        ctx = torch.amp.autocast(device_type=device, dtype=torch.float16)
        scaler = torch.cuda.amp.GradScaler(enabled=True)
else:
    ctx = nullcontext()
    scaler = torch.cuda.amp.GradScaler(enabled=False)

In [None]:
import tiktoken

gpt2 = tiktoken.get_encoding("gpt2")

end_text_token = 50256
start_input_token = 50257
end_input_token = 50258
concept_delimiter_token = 50259
pad_token = 50260
enc = tiktoken.Encoding(
    name="gpt_modified",
    pat_str=gpt2._pat_str,
    mergeable_ranks=gpt2._mergeable_ranks,
    special_tokens={
        **gpt2._special_tokens,
        "<|start_of_input|>": start_input_token,
        "<|end_of_input|>": end_input_token,
        "<|concept_delimiter|>": concept_delimiter_token,
        "<|padding|>": pad_token,
    },
)

In [None]:
if not from_scratch:
    print("loading model from checkpoint")
    checkpoint = torch.load(MODEL_DIR + IN_CHECKPOINT, map_location=device)
    config = checkpoint["config"]
    model = GPT(config)
    state_dict = checkpoint["model"]
    unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    iter_num = checkpoint["iter_num"]
    best_val_loss = checkpoint["best_val_loss"]
else:
    print("building model from scratch")
    config = dict(dropout=0.2, prompt_vocab_size=prompt_vocab_size)
    model = GPT.from_pretrained("gpt2", config)
    model.crop_block_size(128)
    model.extend_vocab(n_added_tokens=4, pad_token=pad_token)
    config = model.config
    print(config)
    iter_num = 0
    best_val_loss = 1e9
model = model.to(device)
if compile:
    print("compiling the model... (takes a ~minute)")
    model = torch.compile(model)

In [None]:
model

In [None]:
print("--- learnable parameters ---")
for pn, p in model.named_parameters():
    if p.requires_grad:
        print(pn)

In [None]:
batch_size = 8
gradient_accumulation_steps = 2
tokens_per_iter = gradient_accumulation_steps * batch_size * config["block_size"]
print(f"tokens per iteration will be: {tokens_per_iter:,}")

In [None]:
learning_rate = 1e-3
max_iters = 5000
lr_decay_iters = 2500
min_lr = 1e-5
weight_decay = 1e-2
beta1 = 0.9
beta2 = 0.99
warmup_iters = 200
grad_clip = 10.0
decay_lr = False

In [None]:
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2))
if not from_scratch:
    print("loading optimizer from checkpoint")
    optimizer.load_state_dict(checkpoint["optimizer"])

In [None]:
eval_interval = 100
do_log = False
log_interval = 1
eval_iters = 50

In [None]:
train_data[:10]

In [None]:
def get_batch(split, batch_index=None):
    data = train_data if split == "train" else val_data
    if batch_index is not None:
        max_batches = len(data) // batch_size
        batch_index = batch_index % max_batches
        ix = torch.arange(batch_index * batch_size, (batch_index + 1) * batch_size)
    else:
        ix = torch.randint(len(data), (batch_size,))

    x = (
        torch.ones(
            batch_size, config["block_size"] - prompt_vocab_size, dtype=torch.long
        )
        * pad_token
    )
    y = torch.ones(batch_size, config["block_size"], dtype=torch.long) * pad_token
    for i, index in enumerate(ix):
        concepts = data[index][0]
        random.shuffle(concepts)
        # concepts=concepts[0]
        concepts = "<|concept_delimiter|>".join(concepts)
        scene = data[index][1]
        concepts_encoded = enc.encode(
            concepts, allowed_special={"<|concept_delimiter|>"}
        )
        scene_encoded = enc.encode_ordinary(scene)
        encoded = (
            [start_input_token]
            + concepts_encoded
            + [end_input_token]
            + scene_encoded
            + [end_text_token]
        )
        x[i][: len(encoded)] = torch.Tensor(encoded)
        y[i][
            prompt_vocab_size
            + 1
            + len(concepts_encoded) : prompt_vocab_size
            + 1
            + len(concepts_encoded)
            + len(scene_encoded)
            + 1
        ] = torch.Tensor(scene_encoded + [end_text_token])
    x, y = x.to(device), y.to(device)
    if prompt_vocab_size > 0:
        prompt = torch.arange(prompt_vocab_size)
        prompts = prompt.repeat(batch_size, 1)
        prompts = prompts.to(device)
    else:
        prompts = None
    return x, y, prompts

In [None]:
x, y, prompts = get_batch("train")
ip = x[1].tolist()
op = y[1].tolist()
for i in range(prompt_vocab_size):
    print("prompt", i, "-", enc.decode([op[i]]))
for i in range(len(ip)):
    print(enc.decode([ip[i]]), "-", enc.decode([op[i + prompt_vocab_size]]))

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            if split == "val":
                X, Y, Prompts = get_batch(split, batch_index=k)
            else:
                X, Y, Prompts = get_batch(split)
            with ctx:
                logits, loss = model(X, Y, prompts=Prompts)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [None]:
X, Y, Prompts = get_batch("train")  # fetch the very first batch
t0 = time.time()
while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0:
        losses = estimate_loss()
        print(
            f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
        )
        if not eval_only and (losses["val"] < best_val_loss or always_save_checkpoint):
            best_val_loss = losses["val"]
            if iter_num > 0:
                checkpoint = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "iter_num": iter_num,
                    "best_val_loss": best_val_loss,
                    "config": config,
                }
                print(f"saving checkpoint to {MODEL_DIR+OUT_CHECKPOINT}")
                torch.save(checkpoint, MODEL_DIR + OUT_CHECKPOINT)
    if eval_only:
        break

    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits, loss = model(X, Y, prompts=Prompts)
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
        X, Y, Prompts = get_batch("train")
        scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if do_log and iter_num % log_interval == 0:
        lossf = loss.item() * gradient_accumulation_steps
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
    iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

In [None]:
round(best_val_loss.item(), 2)