# Constants and Setup

In [1]:
path = "./"
root = "../"

SEED = 23

LR = 1e-3
BATCH_SIZE = 16
SEQ_LEN = 128
MAX_ITERS = 50000  # max num batches to train
PRINT_ITERS = 50  # frequency to print train loss
EVAL_ITERS = 500  # frequency to evaluate val loss and generate text from model
EVAL_ITER_COUNT = 100  # number of batches to estimate val loss with
# given a 10% val split, we have 111540 char, so 100 batches * batch size 16 * seq len 128 = roughly 2x num of chars chosen
# EVAL_ITER_COUNT * BATCH_SIZE
SAVE_ITERS = 1000  # frequency to save model and losses
N_EMBD = 128
N_FF = N_EMBD * 4
N_HEAD = 4
N_LAYER = 4

MODEL_NAME = f"vt_{N_LAYER}_LAYERs_{N_HEAD}_HEAD_{N_EMBD}_EMBD_DIM_{SEQ_LEN}_SEQ_LEN"
print("Model Name:", MODEL_NAME)

Model Name: vt_4_LAYERs_4_HEAD_128_EMBD_DIM_128_SEQ_LEN


# Imports

In [2]:
import json
import re
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm

In [3]:
import sys

sys.path.append(root)

from utils import set_seed, device
from models.transformer import Transformer

In [4]:
set_seed(SEED)

# Data

In [5]:
with open(f"{root}/data/tiny-shakespeare.txt", "r") as f:
    text = f.read()

chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)
print(f"Vocab: {chars}")
print(f"Vocab size: {VOCAB_SIZE}")

Vocab: ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Vocab size: 65


In [6]:
# Prepare mappings / tokenizer
# create a mapping from characters to integers
txt2idx = {ch: i for i, ch in enumerate(chars)}
idx2txt = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [txt2idx[c] for c in s]
decode = lambda l: "".join([idx2txt[i] for i in l])

print(encode("tiny-shakespeare is sick"))
print(decode(encode("tiny-shakespeare is sick")))

[58, 47, 52, 63, 7, 57, 46, 39, 49, 43, 57, 54, 43, 39, 56, 43, 1, 47, 57, 1, 57, 47, 41, 49]
tiny-shakespeare is sick


In [7]:
# tokenizer data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))  # 90-10 split
train_data = data[:n]
val_data = data[n:]
print("train_data len:", len(train_data), "val_data len:", len(val_data))


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - SEQ_LEN, (BATCH_SIZE,))
    x = torch.stack([data[i : i + SEQ_LEN] for i in ix])
    y = torch.stack([data[i + 1 : i + SEQ_LEN + 1] for i in ix])
    return x.to(device), y.to(device)

train_data len: 1003854 val_data len: 111540


# Training

In [8]:
set_seed(SEED)
model = Transformer(
    VOCAB_SIZE,
    SEQ_LEN,
    N_EMBD,
    N_HEAD,
    N_FF,
    N_LAYER,
    device=device,
    switch=False,
    mlp_dropout=0.1,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [9]:
summary(model)

Layer (type:depth-idx)                   Param #
Transformer                              --
├─Embedding: 1-1                         8,320
├─PositionalEncoding: 1-2                --
├─Sequential: 1-3                        --
│    └─Block: 2-1                        --
│    │    └─MultiHeadAttention: 3-1      65,536
│    │    └─MLP: 3-2                     131,712
│    │    └─LayerNorm: 3-3               256
│    │    └─LayerNorm: 3-4               256
│    │    └─Dropout: 3-5                 --
│    │    └─Dropout: 3-6                 --
│    └─Block: 2-2                        --
│    │    └─MultiHeadAttention: 3-7      65,536
│    │    └─MLP: 3-8                     131,712
│    │    └─LayerNorm: 3-9               256
│    │    └─LayerNorm: 3-10              256
│    │    └─Dropout: 3-11                --
│    │    └─Dropout: 3-12                --
│    └─Block: 2-3                        --
│    │    └─MultiHeadAttention: 3-13     65,536
│    │    └─MLP: 3-14                    1

In [10]:
def calc_loss(logits, targets):
    B, S, C = logits.shape
    logits = logits.view(B * S, C)
    targets = targets.view(B * S)
    loss = F.cross_entropy(logits, targets)
    return loss

In [11]:
def train(
    model,
    optimizer,
    device,
    train_loss_list=None,
    val_loss_list=None,
    train_time_list=None,
):

    train_losses = train_loss_list if train_loss_list is not None else []
    val_losses = val_loss_list if val_loss_list is not None else []
    train_times = train_time_list if train_time_list is not None else []

    model.train()
    model.to(device)

    # Set up prompt generation
    generation_file_path = f"{path}/outputs/OUTPUT_{MODEL_NAME}_SEED_{SEED}.txt"
    empty_tokens = torch.zeros((1, 1), dtype=torch.long).to(device)
    cond_prompts = ["KING TERRY: Thou art", "DANIEL: Ay, my dear,"]

    cond_token_list = [encode(prompt) for prompt in cond_prompts]

    for step in range(MAX_ITERS):

        start = time.perf_counter()

        optimizer.zero_grad(set_to_none=True)

        inputs, targets = get_batch("train")
        logits = model(inputs)
        loss = calc_loss(logits, targets)
        train_losses.append(loss.item())

        loss.backward()

        # Monitor gradient norm
        grads = [
            param.grad.detach().flatten()
            for param in model.parameters()
            if param.grad is not None
        ]
        norm = torch.cat(grads).norm()

        train_time = time.perf_counter() - start
        tokens_per_sec = (1 / train_time) * BATCH_SIZE * SEQ_LEN
        train_times.append(tokens_per_sec)

        optimizer.step()

        # print training statistics
        if step % PRINT_ITERS == 0 and step != 0:
            print(
                f"Step {step}/{MAX_ITERS} | Running Avg Train Loss: {np.mean(train_losses):.5f} |",
                f"Grad Norm: {norm:.3f} | Running Avg Tokens/Sec: {np.mean(train_times):.3f}",
            )

        # estimate val loss, generate text and save
        if step % EVAL_ITERS == 0 and step != 0:
            val_losses = estimate_loss(model, val_losses)
            generate(model, generation_file_path, empty_tokens, cond_token_list, step)
            model.train()

        # save model, val losses (not train_losses), train times
        if step % SAVE_ITERS == 0 and step != 0:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                f"{path}/checkpoints/{MODEL_NAME}_STEP_{step}_SEED_{SEED}.pt",
            )

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_losses.json", "w"
            ) as f:
                json.dump(val_losses, f)

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_train_times.json", "w"
            ) as f2:
                json.dump(
                    train_times[EVAL_ITERS::EVAL_ITERS], f2
                )  # match freq of val_losses
                # note this means if you load from checkpoint to continue training you will have a sparser train_times
                # list in computing running avg

In [12]:
@torch.no_grad()
def estimate_loss(model, val_losses):
    model.eval()
    losses = torch.zeros(EVAL_ITER_COUNT)
    for k in range(EVAL_ITER_COUNT):
        inputs, targets = get_batch("test")
        logits = model(inputs)
        losses[k] = calc_loss(logits, targets).item()
    val_loss = losses.mean().item()
    val_losses.append(val_loss)
    # keep model in eval, next call is to .generate() anyway
    print(f"Est. Val Loss: {val_loss:.5f}")
    return val_losses

In [13]:
def generate(model, generation_file_path, empty_tokens, cond_token_list, step):

    set_seed(42)

    uncond_res1 = decode(
        model.generate(empty_tokens, method="top-k", k=5, max_new_tokens=500)[
            0
        ].tolist()
    )
    uncond_res2 = decode(
        model.generate(
            empty_tokens, method="nucleus", p_nucleus=0.5, max_new_tokens=500
        )[0].tolist()
    )

    cond_res_list = []
    for prompt in cond_token_list:
        cond_res = decode(
            model.generate(
                torch.tensor(prompt).unsqueeze(0).long().to(device),
                method="top-k",
                k=5,
                max_new_tokens=500,
            )[0].tolist()
        )
        cond_res_list.append(cond_res)

    cond_res_list = "\n\n".join(cond_res_list)

    generation_text = f"""{MODEL_NAME} Output, Step {step}
    UNCONDITIONAL GENERATION:

    Top-k (5) (500 max_tokens):
    {uncond_res1}

    Nucleus (0.5) (500 max_tokens):
    {uncond_res2}

    #####################################################
    CONDITIONAL GENERATION (Top-k (5), 500 max_tokens):
    {cond_res_list}
    -----------------------------------------------------
    """
    with open(generation_file_path, "a") as file:
        file.write(generation_text)
    print(generation_text)

In [25]:
## Driver code
train(model, optimizer, device)

Step 0/50000 | Running Avg Train Loss: 6.24023 | Grad Norm: 20.527 | Running Avg Tokens/Sec: 2993.337
Step 50/50000 | Running Avg Train Loss: 4.21355 | Grad Norm: 1.394 | Running Avg Tokens/Sec: 6741.455
Step 100/50000 | Running Avg Train Loss: 3.77272 | Grad Norm: 2.327 | Running Avg Tokens/Sec: 6705.369
Est. Val Loss: 3.10417
vt_4_LAYERs_4_HEAD_128_EMBD_DIM_128_SEQ_LEN Output, Step 100
    UNCONDITIONAL GENERATION:

    Top-k (5) (500 max_tokens):
    
  ae  t ae e ha t ses sthist sh toue hee th   h at s sh hes ho s he th a  as hi thon enes t t t ha ate h  seat s teatothe  tha hh ta h h s

e shes harot  a he t he s hhh  aresh  thoshe  sh s
h ho t hhhe hhea h sh s h s hhhe hoth  he he

 s
hah ha hos s ahh he s
heh  hhho hahe he  thha t he at   toth tha a hothhar se  h h h s a a th ho h t hh h sh she anhan h s tha  t  the he han

he h s h har a soh ha t ah ha   a ha th tos shor h arhe thohh  tho arorarohe


 hos


ho  thar h at
he thos hhe t


 

    Nucleus (0.5) (500 max_tokens):
   

KeyboardInterrupt: 

# Generation

*After 2250 steps *  16 batch_size, training loss 1.8277:*

In [114]:
print(
    decode(
        model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[
            0
        ].tolist()
    )
)



AUCENTIO:
Which the may sich that nough to the slay'd one so to;
His be knot, I Vistrengs, thy the good our kim in to call:
No, thou man I good, Say, for pmburds tell eack.

HESSend.

MENCIO:
Dever and will'd my Vaing, life to suke in lise,
These'll was of yret, fol smy his no Fear shom gestard:
Retil appoutis commentex e'epon tend her his him buse,
And what ityer am the iends, come; God foll ding:
by appeerk.

LOUCIO:
Petwild, bake you, that I same, what wear;
from in in or my speak as For Jul


In [115]:
input_txt = "TERRY: thou art"
ctx = encode(input_txt)
print(
    decode(
        model.generate(torch.tensor(ctx).unsqueeze(0).long(), max_new_tokens=500)[
            0
        ].tolist()
    )
)

TERRY: thou arte a my she
Which have may and of contain.

DUKE VINCENTIO:
Good as I tall no knrow, for shalt agarnt
And mpo; and Kong a m, not outhpile Mesce.

HENRY VI:
When I will thy lookess, oner the pexstrey
The the the hee voagh gresed livioe.

MENCIO:
My her callis his peaced of to that
We where's by shall bore: as shall myselvea
The plender feuls!

PAPELLANT:
In the the into balby me dods to love,
In but the giving of nyou ase. I tall it-me?'e Goveuling
The theer haught art praver count madeng Camen:
T
