# Simple Transformer

In [1]:
import sys
import time
import torch
import torch.nn as nn
import numpy as np

from datetime import datetime
from torch.nn import functional as F

In [2]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7f64144f9670>

## Constants

In [3]:
batch_size = 24
block_size = 1024
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
loss_est_n_batch = 50

embedding_size = 384
num_heads = 6
head_size = embedding_size
n_layer = 6
dropout = 0.3

## Hardware

In [4]:
has_cuda = torch.cuda.is_available() and torch.backends.cuda.is_built()
has_mps = torch.backends.mps.is_available() and torch.backends.mps.is_built()
print(f'CUDA available: {has_cuda}')
if has_cuda:
    print(f'  * GPU count: {torch.cuda.device_count()}')
print(f'mps available: {has_mps}')

CUDA available: True
  * GPU count: 1
mps available: False


In [5]:
device = "cpu"
if has_cuda:
    device = "cuda"
elif has_mps:
    device = "mps"
torch.set_default_device(device)

In [6]:
device

'cuda'

In [7]:
should_compile = True
if sys.platform == "win32":
    should_compile = False
print(f"Should compile: {should_compile}")

Should compile: True


## Training Data

In [8]:
with open('data/tiny_shakespeare.txt', 'r', encoding='utf-8') as input:
    text = input.read()

In [9]:
print(f'Number of characters in text: {len(text)}')

Number of characters in text: 1115394


In [10]:
print(f'First 100 chars:\n\n```\n{text[:100]}\n```')

First 100 chars:

```
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
```


In [11]:
words = text.split(" ")

In [12]:
len(words)

169893

In [13]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [14]:
vocab_size

65

In [15]:
''.join(chars)

"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [16]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

In [17]:
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [18]:
encoded_test = encode("hello world!")
encoded_test

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]

In [19]:
decode(encoded_test)

'hello world!'

In [20]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [21]:
train_data[:block_size + 1]

tensor([18, 47, 56,  ..., 63, 53, 59], device='cuda:0')

In [22]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

## Generate Batch Data

In [23]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [24]:
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb[:5])
print('targets:')
print(yb.shape)
print(yb[:5])

inputs:
torch.Size([24, 1024])
tensor([[17, 26, 17,  ...,  1, 40, 43],
        [43, 47, 64,  ..., 41, 53, 52],
        [52, 43,  0,  ..., 58, 11,  1],
        [53, 59, 50,  ..., 59, 57, 10],
        [53, 52, 43,  ...,  0, 20, 53]], device='cuda:0')
targets:
torch.Size([24, 1024])
tensor([[26, 17, 31,  ..., 40, 43, 58],
        [47, 64, 43,  ..., 53, 52, 42],
        [43,  0, 32,  ..., 11,  1, 41],
        [59, 50, 57,  ..., 57, 10,  0],
        [52, 43,  1,  ..., 20, 53, 61]], device='cuda:0')


In [25]:
for t in range(8):
    context = xb[0, :t+1]
    target = yb[0, t]
    print(f'input: {context}, target: {target}')

input: tensor([17], device='cuda:0'), target: 26
input: tensor([17, 26], device='cuda:0'), target: 17
input: tensor([17, 26, 17], device='cuda:0'), target: 31
input: tensor([17, 26, 17, 31], device='cuda:0'), target: 10
input: tensor([17, 26, 17, 31, 10], device='cuda:0'), target: 0
input: tensor([17, 26, 17, 31, 10,  0], device='cuda:0'), target: 13
input: tensor([17, 26, 17, 31, 10,  0, 13], device='cuda:0'), target: 1
input: tensor([17, 26, 17, 31, 10,  0, 13,  1], device='cuda:0'), target: 57


## Neural Net

In [26]:
scalar = None
if has_cuda:
    scalar = torch.cuda.amp.GradScaler()

In [27]:
@torch.no_grad()
def estimate_loss(model, eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            if scalar is None:
                _, loss = model(X, Y)
            else:
                with torch.cuda.amp.autocast():
                    _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

In [28]:
class Head(nn.Module):
    def __init__(self, n_embd: int, head_size: int):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Input: (B, T, C)
        Ouput: (B, T, head_size)
        """
        B, T, C = x.shape

        k = self.key(x) # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)

        wei = q @ k.transpose(-2, -1) * (C**-0.5) # (B, T, head_size) @ (B, head_size, T) = (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        
        out = wei @ v # (B, T, T) @ (B, T, head_size) = (B, T, head_size)
        return out

In [29]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd: int, n_heads: int, head_size: int):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embd, head_size//n_heads) for _ in range(n_heads)])
        self.proj = nn.Linear(head_size, head_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Input: (B, T, C)
        Ouput: (B, T, head_size)
        """
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

In [30]:
class FeedForward(nn.Module):
    def __init__(self, head_size: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(head_size, head_size * 4),
            nn.ReLU(),
            nn.Linear(head_size * 4, head_size),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        """
        Input: (B, T, head_size)
        Ouput: (B, T, head_size)
        """
        return self.net(x)

In [31]:
class Block(nn.Module):
    def __init__(self, n_embd: int, n_head: int, head_size: int):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, head_size)
        self.ffwd = FeedForward(head_size)
        self.ln1 = nn.LayerNorm(head_size)
        self.ln2 = nn.LayerNorm(head_size)

    def forward(self, x):
        """
        Input: (B, T, C)
        Ouput: (B, T, head_size)
        """
        x = self.ln1(x)
        x = x + self.sa(x) # (B, T, C) ---> (B, T, head_size)
        x = self.ln2(x)
        x = x + self.ffwd(x) # (B, T, head_size) ---> (B, T, head_size)
        return x

In [32]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size: int, n_embd: int, n_head: int, head_size: int, n_layer: int) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.pos_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, head_size) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(head_size)
        self.lm_head = nn.Linear(head_size, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) batch, time, channel
        pos_emb = self.pos_embedding_table(torch.arange(T)) # (T, C)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)

        loss = None
        if targets is not None:
            B, T, C = logits.shape
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits.view(B*T, C), targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_conx = idx[:, -block_size:]
            if scalar is not None:
                logits, _ = self(idx_conx)
            else:
                with torch.cuda.amp.autocast():
                    logits, _ = self(idx_conx)
            logits = logits[:, -1, :] # (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [33]:
m = BigramLanguageModel(vocab_size, embedding_size, num_heads, head_size, n_layer)

In [34]:
torch.set_float32_matmul_precision('medium')
if should_compile:
    m = torch.compile(m)

In [35]:
num_params = sum(p.numel() for p in m.parameters())
print("{num_params:.2f}M parameters".format(num_params=num_params/1e6))

11.08M parameters


In [36]:
logits, loss = m(xb, yb)
print(f"logits shape: {logits.shape}")
print(f"loss: {loss}")

logits shape: torch.Size([24, 1024, 65])
loss: 4.404087066650391


In [37]:
idx = torch.zeros((1, 1), dtype=torch.long)
generated = m.generate(idx, max_new_tokens=100)
print(decode(generated[0].tolist()))


QHATp'EVfB3olZ3K;xvcx
HNHWBvBUcD.&-EskAlN?WSNuwAaBs;qfcAHq$B?CxcUc,MYJkM&XhL3!3v,:mXGOlD
fjKNf;xOp3j


In [38]:
estimate_loss(m, loss_est_n_batch)

{'train': 4.41819429397583, 'val': 4.421832084655762}

## Training

In [39]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [40]:
train_step = 0

In [41]:
def save_model(model, train_step, should_print=False, save_half=False):
    dropout_text = "{dropout:.1f}".format(dropout=dropout).replace(".", "_")
    ts = datetime.now().strftime("%Y%m%dT%H%M%S")
    model_save_path = f"model/{block_size}_{n_layer}x{num_heads}x{embedding_size}_{dropout_text}_fp32_iter{train_step}_{ts}.pt"
    if save_half:
        model_save_path_half = model_save_path.replace("fp32", "fp16")
        torch.save(m.half().state_dict(), model_save_path_half)
        if should_print:
            print(f"Saved model to \"{model_save_path_half}\"")
    else:
        torch.save(m.state_dict(), model_save_path)
        if should_print:
            print(f"Saved model to \"{model_save_path}\"")

In [42]:
start = time.time()
for _ in range(max_iters):
    optimizer.zero_grad(set_to_none=True)

    train_step += 1
    xb, yb = get_batch('train')
    if scalar is None:
        logits, loss = m(xb, yb)
        loss.backward()
        optimizer.step()
    else:
        with torch.cuda.amp.autocast():
            logits, loss = m(xb, yb)
        scalar.scale(loss).backward()
        scalar.step(optimizer)
        scalar.update()
    
    if train_step <= 1 or train_step % eval_interval == 0:
        batch_loss = estimate_loss(m, loss_est_n_batch)
        print("Step {step}: train_loss={train:.5f}, val_loss={val:.5f}, dur={dur:.2f}s".format(step=train_step, train=batch_loss['train'], val=batch_loss['val'], dur=time.time() - start))
        if train_step > 1 and train_step % max_iters > 0:
            save_model(m, train_step, should_print=False, save_half=False)

Step 1: train_loss=5.55992, val_loss=5.60891, dur=42.68s
Step 500: train_loss=2.45181, val_loss=2.48498, dur=94.76s
Step 1000: train_loss=1.83304, val_loss=1.98096, dur=147.28s
Step 1500: train_loss=1.48769, val_loss=1.68838, dur=199.85s
Step 2000: train_loss=1.34856, val_loss=1.61017, dur=253.96s
Step 2500: train_loss=1.28370, val_loss=1.58126, dur=307.56s
Step 3000: train_loss=1.22707, val_loss=1.54274, dur=360.33s
Step 3500: train_loss=1.19192, val_loss=1.53189, dur=413.16s
Step 4000: train_loss=1.14671, val_loss=1.50873, dur=465.89s
Step 4500: train_loss=1.11858, val_loss=1.51996, dur=518.67s
Step 5000: train_loss=1.09290, val_loss=1.51635, dur=571.28s


In [43]:
save_model(m, train_step, should_print=True, save_half=False)
save_model(m, train_step, should_print=True, save_half=True)

Saved model to "model/1024_6x6x384_0_3_fp32_iter5000_20230814T202102.pt"
Saved model to "model/1024_6x6x384_0_3_fp16_iter5000_20230814T202102.pt"


In [46]:
start_text = "SHUYANG FROM NEW YORK:"
idx = torch.tensor(np.array(encode(start_text)), dtype=torch.long).view(1, len(start_text))
m.eval()
generated = m.half().generate(idx, max_new_tokens=500)
m.train()
print(decode(generated[0].tolist()))

SHUYANG FROM NEW YORK:
Ay, ay, of with 'once times, now give me from thee.

BUCKINGHAM:
Nor news into my lord. But if it be done,
Poisoner of a deed, nor a-flattering it;
But in the bout of Edward, who, and writed me;
Nor strong Edward that time to relious mother;
Nor I doth father a princious subjects
Which true Edward at thy grandation,
Else doth three stateful Edward's sake,
And tell thee storm on, or rush the rest.
More afterward to wish the depant arms,
Or Rivers, sir France, they like unto my shame,
Nor ingorou
