In [2]:
 %load_ext nb_mypy

Version 1.0.4


In [3]:
import torch
import torch.nn as nn
from typing import Optional

torch.manual_seed(1337)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x125e50270>

In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
batch_size = 64
block_size = 256
embed_size = 384
num_heads = 6
num_layers = 6
dropout = 0.2

In [6]:
charToIndex : dict[str, int] = { ch:i for i,ch in enumerate(chars) }
indexToChar : dict[int, str] = { i:ch for i,ch in enumerate(chars) }

def encode(text: str) -> list[int]: 
    return [charToIndex[c] for c in text]

def decode(values: list[int]) -> str:
    return ''.join([indexToChar[value] for value in values])

In [7]:
data = torch.tensor(encode(text), dtype=torch.long)

In [8]:
n = int(0.9 * len(data))

train_data = data[:n]
val_data  = data[n:]

In [9]:
def get_batch(train = True):
    data = train_data if train else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [10]:
class Head(nn.Module):
    def __init__(self, block_size, embed_size, head_size):
        super().__init__()
        self.key = torch.nn.Linear(embed_size, head_size, bias=False)
        self.query = torch.nn.Linear(embed_size, head_size, bias=False)
        self.value = torch.nn.Linear(embed_size, head_size, bias=False)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):
        B, T, C = x.shape # (B=batch_size, T=block_size, C=embed_size)
        k = self.key(x) # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)
        v = self.value(x) # (B, T, head_size)

        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, T)
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf')) # future tokens have -inf affinities
        wei = torch.nn.functional.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        return wei @ v # (B, T, head_size)

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, block_size, embed_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(block_size, embed_size, embed_size / num_heads) for _ in range(num_heads)])
        self.projection = nn.Linear(embed_size, embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.projection(out)
        out = self.dropout(out)
        return out

In [12]:
class FeedForward(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(embed_size, 4 * embed_size),
            torch.nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

In [13]:
class Block(nn.Module):
    def __init__(self, num_heads, block_size, embed_size):
        super().__init__()
        self.sa_heads = MultiHeadAttention(num_heads, block_size, embed_size)
        self.feed_forward = FeedForward(embed_size)
        self.ln1 = torch.nn.LayerNorm(embed_size)
        self.ln2 = torch.nn.LayerNorm(embed_size)

    def forward(self, x):
        x = x + self.sa_heads(self.ln1(x)) # (B, T, C)
        x = x + self.feed_forward(self.ln2(x)) # (B, T, C)
        return x

In [21]:
class TransformerLanguageModel(nn.Module):
    def __init__(self, block_size, embed_size, vocab_size, num_heads, num_layers):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        self.position_embedding_table = nn.Embedding(block_size, embed_size)
        
        print(num_layers)
        Block(num_heads, block_size, embed_size)

        self.blocks = nn.Sequential(*[Block(num_heads, block_size, embed_size) for _ in range(num_layers)])
        self.ln = torch.nn.LayerNorm(embed_size)
        self.lm_head = nn.Linear(embed_size, vocab_size)
    
    # idx and targets are both (batch_size, block_size) tensors of integers
    def forward(self, idx, targets = None) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        B, T = idx.shape

        token_embeddings = self.token_embedding_table(idx) # (B=batch_size, T=block_size, C=embed_size)
        position_embeddings = self.position_embedding_table(torch.arange(T))
        x = token_embeddings + position_embeddings # (B, T, C)
        x = self.blocks(x)
        x = self.ln(x)
                
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            loss = nn.functional.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        
        return logits, loss

    def generate(self, idx, max_new_tokens) -> torch.Tensor:
        # idx is (B, T) array of indexes
        for _ in range(max_new_tokens):
            # crop idx to the block_size
            idx_crop = idx[:, -block_size:]
            logits, _ = self.forward(idx_crop)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = nn.functional.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

In [20]:
import torch

torch.manual_seed(1337)

m = TransformerLanguageModel(block_size, embed_size, vocab_size, num_heads, num_layers)

xb, yb = get_batch()
logits, loss = m.forward(xb, yb)

print(logits.shape)
print(loss)

print(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=100))
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of SymInts size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32

for steps in range(1000):
    xb, yb = get_batch(True)
    logits, loss = m.forward(xb, yb)

    if loss is not None:    
        optimizer.zero_grad(True)
        loss.backward()
        optimizer.step()

if loss is not None:            
    print(loss.item())

2.163740873336792


In [None]:
eval_iters = 1000

@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in [True, False]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xs, ys = get_batch(split)
            logits, loss = m.forward(xs, ys)
            if loss is not None:
                losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

In [None]:
max_iters = 5000
eval_interval = 300

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses[True]:.4f}, val loss {losses[False]:.4f}")
    
    xb, yb = get_batch(True)
    logits, loss = m.forward(xb, yb)

    if loss is not None:    
        optimizer.zero_grad(True)
        loss.backward()
        optimizer.step()

step 0: train loss 2.2625, val loss 2.2693
step 300: train loss 2.2184, val loss 2.2401
step 600: train loss 2.1882, val loss 2.2063
step 900: train loss 2.1533, val loss 2.1950
step 1200: train loss 2.1354, val loss 2.1782
step 1500: train loss 2.1057, val loss 2.1553
step 1800: train loss 2.0824, val loss 2.1346
step 2100: train loss 2.0788, val loss 2.1438
step 2400: train loss 2.0586, val loss 2.1257
step 2700: train loss 2.0411, val loss 2.1132
step 3000: train loss 2.0389, val loss 2.1109
step 3300: train loss 2.0183, val loss 2.0982
step 3600: train loss 2.0180, val loss 2.1076
step 3900: train loss 1.9986, val loss 2.0825
step 4200: train loss 1.9936, val loss 2.0832
step 4500: train loss 1.9931, val loss 2.0988
step 4800: train loss 1.9818, val loss 2.0861


In [None]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


KISTINGSBWollodt holiend than thy vadswas shode,
And fairen our he.
The pose what na. Ins rown;
Thou arrain should for be avery to welliel but reges of Donemn tway ditelvond had beinsold,
Farb nog this in herechild chan.

RUKE VINCENTIO:
Heast!

BOLOUS:
So have your kind in son.

Fands,
A slome frower do Mrinclow.

METINGETI:
Fane their and him sabock, with so be wifell Nordys kes I world.
What?

EDWARAyULANNE:
Whening we
Enchis sound-uman Compellousimment my hand fale
with in?
The hast yet tow.
