# Shakespeare GPT 

In this notebook we will explore creating a GPT to generate Shakespeare essays. We will start with a simple bigram model, then build up to a full multi-headed GPT model. We will explore a range of techniques and see quantitatively and qualitatively how they affect our training performance. 

Note the full model in the end should be trained on a GPU to get loss close to 1.4 and reasonable Shakespeare like writings. 

# Prepare data

In [2]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-09-13 17:13:18--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-09-13 17:13:18 (21.2 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [6]:
with open('shakespeare.txt', 'r') as f:
    text = f.read()
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [9]:
chars = sorted(list(set(text)))
char_to_ix = {ch:i for i,ch in enumerate(chars)}
ix_to_char = {i:ch for i,ch in enumerate(chars)}

def encode(text):
    return [char_to_ix[ch] for ch in text]

def decode(vec):
    return ''.join([ix_to_char[ix] for ix in vec])

print(''.join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [12]:
text_vec = encode(text)
print(text_vec[:100])
print(len(text_vec))

[18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1, 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49, 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 37, 53, 59]
1115394


In [14]:
import torch

data = torch.tensor(text_vec, dtype=torch.long)
print(data.shape)
print(data[:100])

torch.Size([1115394])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [26]:
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

In [283]:
torch.manual_seed(1337)

def get_batch(split, batch_size, block_size):
    """
    Each data is multiple examples
    """
    if split == 'train':
        data = train_data
    else:
        # (N x 1)
        data = val_data
    # B x 1
    start = torch.randint(0, len(data) - block_size, (batch_size,))
    # [1, 2, 3, 4, 5], block_size = 3
    # [[1, 2, 3], [2, 3, 4], [3, 4, 5]]
    x = torch.stack([data[s:s+block_size] for s in start])
    # [[2, 3, 4], [3, 4, 5], [4, 5, 6]]
    y = torch.stack([data[s+1:s+block_size+1] for s in start])
    return x, y

x, y = get_batch('train', 4, 8)
# this is 4 * 8 samples 
print(x.shape, y.shape)
print(x)
print(y)

torch.Size([4, 8]) torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


# Bigram model

In this model, each token carries a fixed probability distribution of its next token. 

Bigram is the simplest possible language model. Used as a baseline. 

In [44]:
# simplest possible language model 
class BigramLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_emb = torch.nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x block_size
        # B x block_size x vocab_size
        # each term in the sequence is mapped to an embedding
        # in this case, not an actual embedding, but probabilities for next token
        logits = self.token_emb(idx)  
        # logits: B x block_size x vocab_size
        # target: B x block_size
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        
        if target is not None:
            target = target.view(-1)
            loss = torch.nn.functional.cross_entropy(logits, target)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # T is getting updated
            B, T = idx.shape
            # B*T x C
            logits, _ = self(idx, None)
            logits = logits.view(B, T, -1)
            # only interested in prediction from last token
            # B x C
            logits = logits[:, -1, :].squeeze()
            #  B x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx
        
m = BigramLanguageModel(len(chars))
logits, loss = m(x, y)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(5.0496, grad_fn=<NllLossBackward0>)


In [80]:
# generate examples
sample_x = torch.zeros(1, 1, dtype=torch.long)
preds = m.generate(x, 100)
print([decode(preds[0].tolist())][0])

resent mpaldiver pras'seve k's:
faknd t to'le.
Tesaip t y mivenul
Tovecyonet sofond mis fongu rr fo hoou but


In [54]:
# create optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)

In [285]:
@torch.no_grad()
def estimate_loss(batch_size, block_size, eval_iters=100):
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split, batch_size, block_size)
            logits, loss = m(x, y)
            losses[k] = loss
        out[f'{split}_loss'] = losses.mean().item()
    m.train()
    return out 

In [287]:
# training loop
def train(model, optimizer, batch_size, block_size, num_steps, eval_iters=100):
    for step in range(num_steps):
        x, y = get_batch('train', batch_size, block_size)
        optimizer.zero_grad(set_to_none=True)
        logits, loss = model(x, y)
        loss.backward()
        optimizer.step()
        
        if step % 500 == 0:
            print(estimate_loss(batch_size, block_size, eval_iters))

    return loss

# Bag of words model  

Using item embedding for each token, the simplest way to summarize sentence context is to average all item embeddings together. Then use this sentence embedding to predict next token. 

In [None]:
# Compute BoW: option 1
T = 3
mask = torch.tril(torch.ones(T, T))
mask = mask / mask.sum(-1, keepdim=True)
mask

In [None]:
# compute BoW: option 2
# This is better because softmax allows us to turn arbitrary values into probabilities
T = 3
mask = torch.tril(torch.ones(T, T))
wei = torch.zeros(T, T)
wei = wei.masked_fill(mask == 0, float('-inf'))
wei = torch.nn.functional.softmax(wei, dim=-1)
wei

In [143]:
class BagOfWordsLanguageModel(torch.nn.Module):
    def __init__(self, block_size, vocab_size, emb_size):
        super().__init__()
        self.block_size = block_size
        self.token_emb = torch.nn.Embedding(vocab_size, emb_size)
        self.position_emb = torch.nn.Embedding(block_size, emb_size)
        self.out = torch.nn.Linear(emb_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x block_size
        B, T = idx.shape
        # B x block_size x emb_size
        # each term in the sequence is mapped to an embedding
        token_emb = self.token_emb(idx)  
        # position_idx: block_size
        position_idx = torch.arange(T)
        # position_emb: block_size x emb_size
        position_emb = self.position_emb(position_idx)
        # B x block_size x emb_size
        x = token_emb + position_emb
        # mask: B x block_size x block_size
        mask = torch.tril(torch.ones(T, T))
        wei = torch.zeros(T, T)
        wei = wei.masked_fill(mask == 0, float('-inf'))
        # wei: B x T x T
        wei = torch.nn.functional.softmax(wei, dim=-1)
        # B x block_size x emb_size 
        bow = wei @ x
        # B x block_size x vocab_size
        logits = self.out(bow)
        
        if target is not None:
            # target: B x block_size
            target = target.view(-1)
            # B*block_size x vocab_size
            logits = logits.view(B*T, -1)
            loss = torch.nn.functional.cross_entropy(logits, target)
            # B x block_size x vocab_size
            logits = logits.view(B, T, -1)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # B x T x C
            logits, _ = self(idx[:, -self.block_size:], None)
            # only interested in prediction from last token
            # B x 1 x C
            logits = logits[:, -1, :]
            # B x 1 x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B x 1
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx

In [159]:
m = BagOfWordsLanguageModel(8, len(chars), 32)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)
train(
    model=m,
    optimizer=optimizer, 
    batch_size=16, 
    block_size=16, 
    num_steps=5000, 
    eval_iters=100
)

In [161]:
x = torch.zeros(1, 1, dtype=torch.long)
out = m.generate(x, 100)
print(decode(out[0].tolist()))


YA
eyALt
:B, 'w , iat
htsia cco uChvaooeri ldfeoten u p  aanikfDig,ot-L 
N ho.ien.a

si lAmhhSe ,lte


# Self attention model

Each word in the context can now attend to each other using dedicated linear layer. 

In [176]:
# scale attention down, preserve unit variance, make softmax not saturate
k = torch.randn(8, 10, 16)
q = torch.randn(8, 10, 16)
att = k @ q.transpose(-1, -2)
att_probs = torch.nn.functional.softmax(att, dim=-1)
norm_att = att / 16 ** 0.5
norm_att_probs = torch.nn.functional.softmax(norm_att, dim=-1)

In [174]:
k.var(), q.var(), att.var(), norm_att.var()

(tensor(0.9328), tensor(0.9549), tensor(13.5740), tensor(0.8484))

In [179]:
att_probs[0, 0], norm_att_probs[0, 0]

(tensor([2.0836e-04, 6.9544e-04, 6.3730e-04, 1.0454e-01, 1.1655e-04, 8.3147e-01,
         3.9159e-02, 2.3124e-02, 1.9388e-05, 3.2084e-05]),
 tensor([0.0395, 0.0533, 0.0522, 0.1867, 0.0341, 0.3136, 0.1461, 0.1281, 0.0218,
         0.0247]))

In [199]:
class Head(torch.nn.Module):
    def __init__(self, emb_size, head_size):
        super().__init__()
        self.emb_size = emb_size
        self.k = torch.nn.Linear(emb_size, head_size)
        self.q = torch.nn.Linear(emb_size, head_size)
        self.v = torch.nn.Linear(emb_size, head_size)
    
    def forward(self, x):
        # x: B x T x E
        B, T, E = x.shape
        # B x T x L
        k = self.k(x)
        # B x T x L
        q = self.q(x)
        # B x T x L
        v = self.v(x)
        # B x T x T
        att = k @ q.transpose(-2, -1)
        att = att / (k.shape[-1] ** 0.5)
        # B x T x T
        mask = torch.tril(torch.ones(T, T))
        # B x T x T
        masked_att = att.masked_fill(mask == 0, float('-inf'))
        # B x T x T
        masked_att = torch.nn.functional.softmax(masked_att, dim=-1)
        # B x T x L
        att_bow = masked_att @ v
        return att_bow

In [200]:
class AttentionLanguageModel(torch.nn.Module):
    def __init__(self, block_size, vocab_size, emb_size, linear_size):
        super().__init__()
        self.block_size = block_size
        self.token_emb = torch.nn.Embedding(vocab_size, emb_size)
        self.position_emb = torch.nn.Embedding(block_size, emb_size)
        self.head = Head(emb_size, linear_size)
        self.out = torch.nn.Linear(linear_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x T
        B, T = idx.shape
        # B x T x E
        token_emb = self.token_emb(idx)  
        # position_idx: block_size
        position_idx = torch.arange(T)
        # position_emb: block_size x emb_size
        position_emb = self.position_emb(position_idx)
        # B x T x E
        x = token_emb + position_emb
        # B x T x L
        att_bow = self.head(x)        
        # B x block_size x vocab_size
        logits = self.out(att_bow)
        
        if target is not None:
            # target: B x block_size
            target = target.view(-1)
            # B*block_size x vocab_size
            logits = logits.view(B*T, -1)
            loss = torch.nn.functional.cross_entropy(logits, target)
            # B x block_size x vocab_size
            logits = logits.view(B, T, -1)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # B x T x C
            logits, _ = self(idx[:, -self.block_size:], None)
            # only interested in prediction from last token
            # B x 1 x C
            logits = logits[:, -1, :]
            # B x 1 x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B x 1
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx

In [201]:
m = AttentionLanguageModel(
    block_size=8, 
    vocab_size=len(chars), 
    emb_size=32,
    linear_size=32
)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)
train(
    model=m,
    optimizer=optimizer, 
    batch_size=16, 
    block_size=16, 
    num_steps=5000, 
    eval_iters=100
)

In [210]:
x = torch.zeros(1, 1, dtype=torch.long)
out = m.generate(x, 100)
print(decode(out[0].tolist()))


Ald.
My ti y'sak alsppyu
Way oureens role to;
USice sa yon.

AMNIO:
And wad weat rs thes ree hors as


# Multi-head attention model

Add multi-head to expand the model capacity. 

Add feed forward layer to further reason based on weighted bag of words. 

In [220]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, emb_size, head_size, num_heads):
        super().__init__()
        self.heads = torch.nn.ModuleList([
            Head(emb_size, head_size) for _ in range(num_heads)
        ])
    
    def forward(self, x):
        # x: B x T x E
        B, T, E = x.shape
        # B x T x n x H
        att_bows = torch.stack([head(x) for head in self.heads], dim=2)
        # B x T x n*H
        att_bows = att_bows.view(B, T, -1)
        return att_bows

In [232]:
class FeedForward(torch.nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.linear = torch.nn.Linear(in_size, out_size)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        # x: B x T x in
        # B x T x out
        x = torch.nn.functional.relu(self.linear(x))
        return x

In [233]:
class MultiHeadAttentionLanguageModel(torch.nn.Module):
    def __init__(self, block_size, vocab_size, emb_size, head_size, num_heads, hidden_size):
        super().__init__()
        self.block_size = block_size
        self.token_emb = torch.nn.Embedding(vocab_size, emb_size)
        self.position_emb = torch.nn.Embedding(block_size, emb_size)
        self.multihead = MultiHeadAttention(emb_size, head_size, num_heads)
        self.feedforward = FeedForward(head_size*num_heads, hidden_size)
        self.out = torch.nn.Linear(hidden_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x T
        B, T = idx.shape
        # B x T x E
        token_emb = self.token_emb(idx)  
        # position_idx: block_size
        position_idx = torch.arange(T)
        # position_emb: block_size x emb_size
        position_emb = self.position_emb(position_idx)
        # B x T x E
        x = token_emb + position_emb
        # B x T x H*n
        att_bow = self.multihead(x)     
        # B x T x hidden
        x = self.feedforward(att_bow)   
        # B x T x C
        logits = self.out(x)
        
        if target is not None:
            # target: B x block_size
            target = target.view(-1)
            # B*block_size x vocab_size
            logits = logits.view(B*T, -1)
            loss = torch.nn.functional.cross_entropy(logits, target)
            # B x block_size x vocab_size
            logits = logits.view(B, T, -1)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # B x T x C
            logits, _ = self(idx[:, -self.block_size:], None)
            # only interested in prediction from last token
            # B x 1 x C
            logits = logits[:, -1, :]
            # B x 1 x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B x 1
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx

In [234]:
m = MultiHeadAttentionLanguageModel(
    block_size=8, 
    vocab_size=len(chars), 
    emb_size=32,
    head_size=32,
    num_heads=4,
    hidden_size=32,
)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)
train(
    model=m,
    optimizer=optimizer, 
    batch_size=16, 
    block_size=16, 
    num_steps=5000, 
    eval_iters=100
)

In [236]:
x = torch.zeros(1, 1, dtype=torch.long)
out = m.generate(x, 100)
print(decode(out[0].tolist()))


COMIFFRe, capiis
Whut hif kied!
Wing: I chat I I grotheir; id dear grong of wor wer mires, thempeare


# Multi layered attention model 

In [238]:
class Block(torch.nn.Module):
    def __init__(self, in_size, head_size, num_heads, out_size):
        super().__init__()
        self.multihead = MultiHeadAttention(in_size, head_size, num_heads)
        self.feedforward = FeedForward(head_size*num_heads, out_size)
    
    def forward(self, x):
        # x: B x T x in
        # B x T x n*H
        att_bow = self.multihead(x)
        # B x T x out
        out = self.feedforward(att_bow)
        return out

In [239]:
class MultiBlockAttentionModel(torch.nn.Module):
    def __init__(self, block_size, vocab_size, head_size, num_heads, hidden_size, num_blocks):
        super().__init__()
        self.block_size = block_size
        self.token_emb = torch.nn.Embedding(vocab_size, hidden_size)
        self.position_emb = torch.nn.Embedding(block_size, hidden_size)
        self.blocks = torch.nn.ModuleList([
            Block(hidden_size, head_size, num_heads, hidden_size) for _ in range(num_blocks)
        ])
        self.out = torch.nn.Linear(hidden_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x T
        B, T = idx.shape
        # B x T x E (hidden_size)
        token_emb = self.token_emb(idx)  
        # position_idx: block_size
        position_idx = torch.arange(T)
        # position_emb: block_size x emb_size
        position_emb = self.position_emb(position_idx)
        # B x T x E
        x = token_emb + position_emb
        # B x T x E
        for block in self.blocks:
            x = block(x)   
        # B x T x C
        logits = self.out(x)
        
        if target is not None:
            # target: B x block_size
            target = target.view(-1)
            # B*block_size x vocab_size
            logits = logits.view(B*T, -1)
            loss = torch.nn.functional.cross_entropy(logits, target)
            # B x block_size x vocab_size
            logits = logits.view(B, T, -1)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # B x T x C
            logits, _ = self(idx[:, -self.block_size:], None)
            # only interested in prediction from last token
            # B x 1 x C
            logits = logits[:, -1, :]
            # B x 1 x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B x 1
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx

In [240]:
m = MultiBlockAttentionModel(
    block_size=8, 
    vocab_size=len(chars), 
    head_size=32,
    num_heads=4,
    hidden_size=32,
    num_blocks=4,
)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)
train(
    model=m,
    optimizer=optimizer, 
    batch_size=16, 
    block_size=16, 
    num_steps=5000, 
    eval_iters=100
)

In [242]:
x = torch.zeros(1, 1, dtype=torch.long)
out = m.generate(x, 100)
print(decode(out[0].tolist()))


Can my'; deavy nowns
byou,
A comy
A and soalinde mis lend esest.

FRICCES:
Onof a me punf theentets.


# Residual block

Learn residual to help with deep layered neural network. Validation loss should drop a lot faster comparing to non-residual blocks. 

In [248]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, in_size, head_size, num_heads):
        super().__init__()
        self.heads = torch.nn.ModuleList([
            Head(in_size, head_size) for _ in range(num_heads)
        ])
        # added a new projection layer
        self.proj = torch.nn.Linear(head_size*num_heads, in_size)
    
    def forward(self, x):
        # x: B x T x E
        B, T, E = x.shape
        # B x T x n x H
        att_bows = torch.stack([head(x) for head in self.heads], dim=2)
        # B x T x n*H
        att_bows = att_bows.view(B, T, -1)
        # B x T x E
        out = self.proj(att_bows)
        return out

In [249]:
class FeedForward(torch.nn.Module):
    def __init__(self, in_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_size, in_size * 4),
            torch.nn.ReLU(),
            # added one more linear layer 
            torch.nn.Linear(in_size * 4, in_size),
        )

    def forward(self, x):
        return self.net(x)

In [255]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, in_size, head_size, num_heads):
        super().__init__()
        self.multihead = MultiHeadAttention(in_size, head_size, num_heads)
        self.feedforward = FeedForward(in_size)
    
    def forward(self, x):
        # x: B x T x in
        # B x T x in_size
        x = x + self.multihead(x)
        # B x T x in_size
        x = x + self.feedforward(x)
        return x

In [251]:
class MultiResidualBlockAttentionModel(torch.nn.Module):
    def __init__(self, block_size, vocab_size, head_size, num_heads, hidden_size, num_blocks):
        super().__init__()
        self.block_size = block_size
        self.token_emb = torch.nn.Embedding(vocab_size, hidden_size)
        self.position_emb = torch.nn.Embedding(block_size, hidden_size)
        self.blocks = torch.nn.ModuleList([
            ResidualBlock(hidden_size, head_size, num_heads) for _ in range(num_blocks)
        ])
        self.out = torch.nn.Linear(hidden_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x T
        B, T = idx.shape
        # B x T x E (hidden_size)
        token_emb = self.token_emb(idx)  
        # position_idx: block_size
        position_idx = torch.arange(T)
        # position_emb: block_size x emb_size
        position_emb = self.position_emb(position_idx)
        # B x T x E
        x = token_emb + position_emb
        # B x T x E
        for block in self.blocks:
            x = block(x)   
        # B x T x C
        logits = self.out(x)
        
        if target is not None:
            # target: B x block_size
            target = target.view(-1)
            # B*block_size x vocab_size
            logits = logits.view(B*T, -1)
            loss = torch.nn.functional.cross_entropy(logits, target)
            # B x block_size x vocab_size
            logits = logits.view(B, T, -1)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # B x T x C
            logits, _ = self(idx[:, -self.block_size:], None)
            # only interested in prediction from last token
            # B x 1 x C
            logits = logits[:, -1, :]
            # B x 1 x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B x 1
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx

In [252]:
m = MultiResidualBlockAttentionModel(
    block_size=8, 
    vocab_size=len(chars), 
    head_size=32,
    num_heads=4,
    hidden_size=32,
    num_blocks=4,
)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)
train(
    model=m,
    optimizer=optimizer, 
    batch_size=16, 
    block_size=16, 
    num_steps=5000, 
    eval_iters=100
)

In [254]:
x = torch.zeros(1, 1, dtype=torch.long)
out = m.generate(x, 100)
print(decode(out[0].tolist()))


And you therexing, encle, the conder metin ciat and will of in hadly of me'
The coning,
O, pitead.




# Layer norm

Normalization across layers for gradient flow

In [256]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, in_size, head_size, num_heads):
        super().__init__()
        self.multihead = MultiHeadAttention(in_size, head_size, num_heads)
        self.feedforward = FeedForward(in_size)
        self.layernorm1 = torch.nn.LayerNorm(in_size)
        self.layernorm2 = torch.nn.LayerNorm(in_size)
    
    def forward(self, x):
        # x: B x T x in
        # B x T x in
        x = self.layernorm1(x)
        # B x T x in_size
        x = x + self.multihead(x)
        # B x T x in_size
        x = self.layernorm2(x)
        # B x T x in_size
        x = x + self.feedforward(x)
        return x

In [257]:
class  LayerNormResidualBlockAttentionModel(torch.nn.Module):
    def __init__(self, block_size, vocab_size, head_size, num_heads, hidden_size, num_blocks):
        super().__init__()
        self.block_size = block_size
        self.token_emb = torch.nn.Embedding(vocab_size, hidden_size)
        self.position_emb = torch.nn.Embedding(block_size, hidden_size)
        self.blocks = torch.nn.ModuleList([
                ResidualBlock(hidden_size, head_size, num_heads) for _ in range(num_blocks)
            ] + [
                torch.nn.LayerNorm(hidden_size)
            ]
        )
        self.out = torch.nn.Linear(hidden_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x T
        B, T = idx.shape
        # B x T x E (hidden_size)
        token_emb = self.token_emb(idx)  
        # position_idx: block_size
        position_idx = torch.arange(T)
        # position_emb: block_size x emb_size
        position_emb = self.position_emb(position_idx)
        # B x T x E
        x = token_emb + position_emb
        # B x T x E
        for block in self.blocks:
            x = block(x)   
        # B x T x C
        logits = self.out(x)
        
        if target is not None:
            # target: B x block_size
            target = target.view(-1)
            # B*block_size x vocab_size
            logits = logits.view(B*T, -1)
            loss = torch.nn.functional.cross_entropy(logits, target)
            # B x block_size x vocab_size
            logits = logits.view(B, T, -1)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # B x T x C
            logits, _ = self(idx[:, -self.block_size:], None)
            # only interested in prediction from last token
            # B x 1 x C
            logits = logits[:, -1, :]
            # B x 1 x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B x 1
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx

In [258]:
m = MultiResidualBlockAttentionModel(
    block_size=8, 
    vocab_size=len(chars), 
    head_size=32,
    num_heads=4,
    hidden_size=32,
    num_blocks=4,
)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)
train(
    model=m,
    optimizer=optimizer, 
    batch_size=16, 
    block_size=16, 
    num_steps=5000, 
    eval_iters=100
)

In [260]:
x = torch.zeros(1, 1, dtype=torch.long)
out = m.generate(x, 100)
print(decode(out[0].tolist()))


lood, Loly:
And melg,
Fore thered:
I am essence? I waty will and you Edward tougar; manine, cicke he


# Dropout 

Randomly trains the sub networks, for regularization

In [265]:
class FeedForward(torch.nn.Module):
    def __init__(self, in_size, dropout):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_size, in_size * 4),
            torch.nn.ReLU(),
            torch.nn.Linear(in_size * 4, in_size),
            # add dropout
            torch.nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [266]:
class Head(torch.nn.Module):
    def __init__(self, emb_size, head_size, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.k = torch.nn.Linear(emb_size, head_size)
        self.q = torch.nn.Linear(emb_size, head_size)
        self.v = torch.nn.Linear(emb_size, head_size)
        # add dropout
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, x):
        # x: B x T x E
        B, T, E = x.shape
        # B x T x L
        k = self.k(x)
        # B x T x L
        q = self.q(x)
        # B x T x L
        v = self.v(x)
        # B x T x T
        att = k @ q.transpose(-2, -1)
        att = att / (k.shape[-1] ** 0.5)
        # B x T x T
        mask = torch.tril(torch.ones(T, T))
        # B x T x T
        masked_att = att.masked_fill(mask == 0, float('-inf'))
        # B x T x T
        masked_att = torch.nn.functional.softmax(masked_att, dim=-1)
        masked_att = self.dropout(masked_att)
        # B x T x L
        att_bow = masked_att @ v
        return att_bow

In [267]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, in_size, head_size, num_heads, dropout):
        super().__init__()
        self.heads = torch.nn.ModuleList([
            Head(in_size, head_size, dropout) for _ in range(num_heads)
        ])
        self.proj = torch.nn.Linear(head_size*num_heads, in_size)
        # add dropout
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, x):
        # x: B x T x E
        B, T, E = x.shape
        # B x T x n x H
        att_bows = torch.stack([head(x) for head in self.heads], dim=2)
        # B x T x n*H
        att_bows = att_bows.view(B, T, -1)
        # B x T x E
        out = self.proj(att_bows)
        out = self.dropout(out)
        return out

In [268]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, in_size, head_size, num_heads, dropout):
        super().__init__()
        self.multihead = MultiHeadAttention(in_size, head_size, num_heads, dropout)
        self.feedforward = FeedForward(in_size, dropout)
        self.layernorm1 = torch.nn.LayerNorm(in_size)
        self.layernorm2 = torch.nn.LayerNorm(in_size)
    
    def forward(self, x):
        # x: B x T x in
        # B x T x in
        x = self.layernorm1(x)
        # B x T x in_size
        x = x + self.multihead(x)
        # B x T x in_size
        x = self.layernorm2(x)
        # B x T x in_size
        x = x + self.feedforward(x)
        return x

In [271]:
class  DropoutAttentionModel(torch.nn.Module):
    def __init__(self, block_size, vocab_size, num_heads, hidden_size, num_blocks, dropout):
        super().__init__()
        self.block_size = block_size
        self.token_emb = torch.nn.Embedding(vocab_size, hidden_size)
        self.position_emb = torch.nn.Embedding(block_size, hidden_size)
        self.blocks = torch.nn.ModuleList([
                ResidualBlock(hidden_size, hidden_size // num_heads, num_heads, dropout) for _ in range(num_blocks)
            ] + [
                torch.nn.LayerNorm(hidden_size)
            ]
        )
        self.out = torch.nn.Linear(hidden_size, vocab_size)
        
    def forward(self, idx, target):
        # idx: B x T
        B, T = idx.shape
        # B x T x E (hidden_size)
        token_emb = self.token_emb(idx)  
        # position_idx: block_size
        position_idx = torch.arange(T)
        # position_emb: block_size x emb_size
        position_emb = self.position_emb(position_idx)
        # B x T x E
        x = token_emb + position_emb
        # B x T x E
        for block in self.blocks:
            x = block(x)   
        # B x T x C
        logits = self.out(x)
        
        if target is not None:
            # target: B x block_size
            target = target.view(-1)
            # B*block_size x vocab_size
            logits = logits.view(B*T, -1)
            loss = torch.nn.functional.cross_entropy(logits, target)
            # B x block_size x vocab_size
            logits = logits.view(B, T, -1)
        else:
            loss = None
        return logits, loss 

    def generate(self, idx, max_new_tokens):
        # idx: B x T
        for _ in range(max_new_tokens):
            # B x T x C
            logits, _ = self(idx[:, -self.block_size:], None)
            # only interested in prediction from last token
            # B x 1 x C
            logits = logits[:, -1, :]
            # B x 1 x C
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # B x 1
            n = torch.multinomial(probs, num_samples=1)
            # B x T+1
            idx = torch.cat([idx, n], dim=1)
        return idx

In [272]:
m = DropoutAttentionModel(
    block_size=8, 
    vocab_size=len(chars), 
    head_size=32,
    num_heads=4,
    hidden_size=32,
    num_blocks=4,
    dropout=0.1,
)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)
train(
    model=m,
    optimizer=optimizer, 
    batch_size=16, 
    block_size=16, 
    num_steps=5000, 
    eval_iters=100
)

In [274]:
x = torch.zeros(1, 1, dtype=torch.long)
out = m.generate(x, 100)
print(decode(out[0].tolist()))


Wall
Selt men wash litter awwrongecey in and nain,
Afur hed, what
Turull's of 'tannr. I S Madmen I p


# Final model

Push the model and see how good the model can get. 

In [288]:
block_size = 256
m = DropoutAttentionModel(
    block_size=block_size, 
    vocab_size=len(chars), 
    head_size=32,
    num_heads=6,
    hidden_size=384,
    num_blocks=6,
    dropout=0.2,
)
optimizer = torch.optim.AdamW(m.parameters(), lr=0.0003)
train(
    model=m, 
    optimizer=optimizer, 
    batch_size=64,  
    block_size=block_size,
    num_steps=5000,
    eval_iters=500,
)

{'train_loss': 3.638077735900879, 'val_loss': 3.6654458045959473}


KeyboardInterrupt: 

In [290]:
x = torch.zeros(1, 1, dtype=torch.long)
out = m.generate(x, 1000)
print(decode(out[0].tolist()))


RUMA'st theas aro nod ond e sheru
GRWesea he thoul feafr, 'st:
Sthe lof lo, ishon h!
He wnk Romime o he al haird,
Asovevery,
Ay shiace nor t he O we be oouthtotrlothe ivarepr the
Whor.
PFr; s hatherendis
RDWht o s;
NRo CARICKES:
Hoootit s:
K:


CLI lldo henacar y s pbrise youlf:
IIURESARDithe:
Ay e mpen sins wo hithowes,
AURTeced inonoiny hererdof ar histhicary un um h IO:
Ia gais ur, we:
ARI ou attsout TEve on n:
IOLARLelavayou herot I:

DUSe adl y hareasu ESe! y gouthe s pe wse d, oflapurince t m had learean arinore in
Pr qrou ano theind at:
Ce rago hue sigrk;
Wiceror y.
Yo, inoouce's IA Vhaist premayotel sigu ee--ongour be, desh mie. bror's:
:
OYUCEORou Fothisaker,
TUS:
Toustroe INA buthille whantheu hblee y is,
A pre. a I harldest COnerd gi wisees abel, wilo my?
YIEN K:
NEurt o,
TARESaner, t mm? nd, he mouryoroiomietrthingomeryond atef onour 'deduin LO:
Ayenkeghentu pank; med DO:
LAD:
Hetreseat yoverd f gorsith, g he'lithat cave bydeath. pelodare IV: huro;
Ar du.
Whandsu y, urco f