# The dataset

First we will download the tinyshakespeare dataset and examine its contents.

In [206]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('input.txt') as f:
    text = f.read()

print(text[:100])

--2024-05-25 13:58:03--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-05-25 13:58:03 (18.9 MB/s) - ‘input.txt’ saved [1115394/1115394]

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [207]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print('Vocab size:', vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocab size: 65


# Simple character level tokenizer.

We will create a simple character level tokenizer. Each unique character in the dataset becomes a token mapped to an integer value. In practice character level tokenizers are not used.
Check out BPE tonezizers like SentencePiece and tiktoken.

In [208]:
class Tokenizer:
    def __init__(self, chars):
        self.char_to_ix = {ch: i for i, ch in enumerate(chars)}
        self.ix_to_char = {i: ch for i, ch in enumerate(chars)}
    
    def char_to_index(self, ch):
        return self.char_to_ix[ch]
    
    def index_to_char(self, ix):
        return self.ix_to_char[ix]

    def encode(self, text):
        return [self.char_to_index(ch) for ch in text]
    
    def decode(self, indices):
        return ''.join([self.index_to_char(ix) for ix in indices])


In [209]:
tokenizer = Tokenizer(chars)
tokens = tokenizer.encode("Hello")
print(tokens)
print(tokenizer.decode(tokens))

[20, 43, 50, 50, 53]
Hello


In [211]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device)

print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [212]:
from sklearn.model_selection import train_test_split

train_data, test_data = train_test_split(data, test_size=0.1, shuffle=False)

In [213]:
batch_size = 64
block_size = 256

def get_batch(data):
    indices = torch.randint(0, data.size(0) - block_size, (batch_size,))

    x = torch.stack([data[i:i + block_size] for i in indices])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in indices])

    x = x.to(device)
    y = y.to(device)

    return x, y


In [214]:
xb, yb = get_batch(train_data)

print(xb.shape, yb.shape)
print("X:")
print(tokenizer.decode(xb[0].tolist()))
print("-" * 20)
print("Y:")
print(tokenizer.decode(yb[0].tolist()))

torch.Size([64, 256]) torch.Size([64, 256])
X:
 give me worship and quietness;
I like it better than a dangerous honour.
If Warwick knew in what estate he stands,
'Tis to be doubted he would waken him.

First Watchman:
Unless our halberds did shut up his passage.

Second Watchman:
Ay, wherefore else gu
--------------------
Y:
give me worship and quietness;
I like it better than a dangerous honour.
If Warwick knew in what estate he stands,
'Tis to be doubted he would waken him.

First Watchman:
Unless our halberds did shut up his passage.

Second Watchman:
Ay, wherefore else gua


In [215]:
import torch.nn as nn
import torch.nn.functional as F

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets = None):
        logits = self.embedding(idx) # (batch, seq_len, vocab_size)

        if targets is None:
            loss = None
        else:

            batch_size, seq_len, vocab_size = logits.shape
            logits = logits.view(batch_size * seq_len, vocab_size)
            targets = targets.view(batch_size * seq_len)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens = 25):
        # idx: (batch, seq_len)

        for _ in range(max_new_tokens):
            logits, _ = self(idx)                # (batch, seq_len, vocab_size) for logits

            # this preserves the each sequence in the batch, and the vocab logits, but only the last token in each sequence
            logits = logits[:, -1, :]                           # (batch, vocab_size)

            probs = F.softmax(logits, dim=-1)                   # (batch, vocab_size)

            idx_next = torch.multinomial(probs, num_samples=1)  # (batch, 1)
            idx = torch.cat([idx, idx_next], dim=1)             # (batch, seq_len + 1)
        
        return idx

model = BigramLanguageModel(vocab_size)
model = model.to(device)

logits, loss = model(xb, yb)
print(logits.shape, loss)

idx = torch.zeros((1,1), dtype=torch.long).to(device)
logits = model.generate(idx, max_new_tokens=100)
print(tokenizer.decode(logits[0].tolist()))

torch.Size([16384, 65]) tensor(4.5788, device='cuda:0', grad_fn=<NllLossBackward0>)

mUl.,yeGomxtZXTJMOq-RKfaTyAlrGp'3
AAW?$3SPpppzd ebyrXfJwZOoZLS':AKfrq:$ax:
jXX dirBn,EkvBkPYTbG,YTnp


In [217]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

batch_size = 32
for step in range(10000):
    model.train()
    xb, yb = get_batch(train_data)

    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(f"Loss {loss.item()}")

Loss 2.4549293518066406


In [238]:
def generate_text(model, tokenizer, seed_text, max_new_tokens):
    model.eval()

    idx = torch.tensor([tokenizer.encode(seed_text)], dtype=torch.long, device=device)
    idx = model.generate(idx, max_new_tokens)

    return tokenizer.decode(idx[0].tolist())

print(generate_text(model, tokenizer, "I say", 250))

I say, l hthelisonedr y o w!

WING thtre I thes h thanga owaice t bl; ththe.
HARAR:

inearaikinishig LUSTouccasingopilomy f f thoreryif okefo; y.
m owiny wh

FLIVIUCalis, ar the th ar ashaverhareceardy purr o mancow LON:
Th,
Cid; mamyok.
Lore we t'ledofer


# Improved Loss Estimation

The previous training loop gave a noisy loss metric and only considered the training loss. Here we are using an average loss over 100 evaluations for both training and validation data.

In [13]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()

    for split in ['train', 'val']:
        losses = []

        for step in range(100):
            X, Y = get_batch(train_data if split == 'train' else test_data)
            _, loss = model(X, Y)
            losses.append(loss.item())

        out[split] = sum(losses) / len(losses)

    model.train()
    return out

# Revised Training Loop

In [14]:
def train(model, optimizer, train_data, test_data, n_steps=1000):
    for step in range(n_steps):
        model.train()
        X, Y = get_batch(train_data)

        logits, loss = model(X, Y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 1000 == 0:
            print(estimate_loss())

In [15]:
train(model, optimizer, train_data, test_data, n_steps=10000)

{'train': 2.4531653833389284, 'val': 2.483633105754852}
{'train': 2.4555484461784363, 'val': 2.4888607978820803}
{'train': 2.451585259437561, 'val': 2.485650296211243}
{'train': 2.4584910559654234, 'val': 2.4844886326789855}
{'train': 2.449352834224701, 'val': 2.4828339052200317}
{'train': 2.4509687209129334, 'val': 2.4893119311332703}
{'train': 2.4535897898674013, 'val': 2.490015721321106}
{'train': 2.4513447952270506, 'val': 2.4801024127006532}
{'train': 2.4534989309310915, 'val': 2.4852124357223513}
{'train': 2.461052176952362, 'val': 2.482593240737915}


# Self-attention

We will start with a simple trick. We want every token to be "aware" of the context of the tokens that have occured before it (But not the tokens that follow).

**Idea** - We could make each token an average of the tokens that preceded it.

In [16]:
weights = torch.tril(torch.ones(block_size, block_size)).to(device)
weights = weights / weights.sum(1, keepdim=True) 

weights

tensor([[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0333, 0.0333, 0.0333,  ..., 0.0333, 0.0000, 0.0000],
        [0.0323, 0.0323, 0.0323,  ..., 0.0323, 0.0323, 0.0000],
        [0.0312, 0.0312, 0.0312,  ..., 0.0312, 0.0312, 0.0312]],
       device='cuda:0')

We now have a set of weights that makes each token an average of the tokens that preceeded it. We can actually do the same thing with Softmax and tril (lower triangle)

In [17]:
tril = torch.tril(torch.ones(block_size, block_size)).to(device)
weights = torch.zeros(block_size, block_size).to(device)
weights = weights.masked_fill(tril == 0, float('-inf'))

weights = F.softmax(weights, dim=1)

weights

tensor([[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0333, 0.0333, 0.0333,  ..., 0.0333, 0.0000, 0.0000],
        [0.0323, 0.0323, 0.0323,  ..., 0.0323, 0.0323, 0.0000],
        [0.0312, 0.0312, 0.0312,  ..., 0.0312, 0.0312, 0.0312]],
       device='cuda:0')

# Data-dependant Averaging for Self-Attention

The information that is relevant to a given token when predicting future tokens is based only upon previous tokens. However, not all previously seen tokens carry equal importance. To that end, we parameterize this unequal importance using learned key and query vectors for each token. In this way, each token now is represented as a weighted average of all previous tokens, with more attention placed on some tokens vs others.

In [239]:
batch_size = 8
seq_len = 32
x = torch.randn(batch_size, seq_len, vocab_size).to(device)

head_size = 16
key = nn.Linear(vocab_size, head_size, bias=False).to(device)
query = nn.Linear(vocab_size, head_size, bias=False).to(device)
value = nn.Linear(vocab_size, head_size, bias=False).to(device)

K = key(x)  # (batch, seq_len, head_size)
Q = query(x)  # (batch, seq_len, head_size)
V = value(x)  # (batch, seq_len, head_size)

weights = Q @ K.transpose(-2, -1)  # (batch, seq_len, head_size) @ (batch, head_size, seq_len) -> (batch, seq_len, seq_len)

tril = torch.tril(torch.ones(seq_len, seq_len)).to(device)
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

output = weights @ V  # (batch, seq_len, seq_len) @ (batch, seq_len, head_size) -> (batch, seq_len, head_size)
output.shape

torch.Size([8, 32, 16])

# Scaling the attention.

In the orignal *Attention is All You Need* paper, an additional scaling is applied, dividing by the square root of the dimensionality of the key embeddings.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

This is done as a normlization step. If the key and query values are guassian with unit variance and mean 0 (are they?), multiplying the weighs by the values results in variance on the order of the head size. (Why?)

In [240]:
k = torch.randn(batch_size, seq_len, head_size).to(device)
q = torch.randn(batch_size, seq_len, head_size).to(device)
weights = q @ k.transpose(-2, -1)

print("K variance:", k.var().item())
print("Q variance:", q.var().item())

print("Weights variance:", weights.var().item())

weights = weights / (head_size ** 0.5)

print("Weights variance after scaling:", weights.var().item())

K variance: 0.9985864758491516
Q variance: 0.9695006012916565
Weights variance: 15.501200675964355
Weights variance after scaling: 0.9688250422477722


# Why does this matter?

The reason is the Softmax function. If the input to a Softmax function has very high values and very low values it will actually converge to one-hot vectors

(Can I find a paper about this?)

In [241]:
diffuse_weights = torch.Tensor([0.1, -0.1, 1.2, -0.8]).to(device)

print("Softmax of diffuse weights:", F.softmax(diffuse_weights, dim=0))
print("Softmax of high variance weights:", F.softmax(diffuse_weights * 100, dim=0))

Softmax of diffuse weights: tensor([0.1912, 0.1566, 0.5745, 0.0777], device='cuda:0')
Softmax of high variance weights: tensor([0., 0., 1., 0.], device='cuda:0')


In [243]:
# dropout = 0.1

class AttentionHead(nn.Module):
    MAX_SEQ_LEN = 1024

    def __init__(self, embedding_size, head_size):
        super().__init__()

        self.head_size = head_size

        self.key = nn.Linear(embedding_size, head_size, bias=False)
        self.query = nn.Linear(embedding_size, head_size, bias=False)
        self.value = nn.Linear(embedding_size, head_size, bias=False)

        # self.dropout = nn.Dropout(dropout)

        self.register_buffer('tril', torch.tril(torch.ones(AttentionHead.MAX_SEQ_LEN, AttentionHead.MAX_SEQ_LEN)).to(device))
    
    def forward(self, x):
        batch, seq_len, vocab_size = x.shape

        K = self.key(x)     # (batch, seq_len, embedding_size) -> (batch, seq_len, head_size)
        Q = self.query(x)   # (batch, seq_len, embedding_size) -> (batch, seq_len, head_size)
        V = self.value(x)   # (batch, seq_len, embedding_size) -> (batch, seq_len, head_size)

        weights = Q @ K.transpose(-2, -1)                               # (batch, seq_len, head_size) @ (batch, head_size, seq_len) -> (batch, seq_len, seq_len)
        
        weights = weights.masked_fill(self.tril[:seq_len, :seq_len] == 0, float('-inf'))
        weights = F.softmax(weights / (self.head_size ** 0.5), dim=-1)

        # weights = self.dropout(weights)

        return weights @ V


In [253]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()

        embedding_size = num_heads * head_size

        self.heads = nn.ModuleList([AttentionHead(embedding_size, head_size) for _ in range(num_heads)])
        self.projection = nn.Linear(num_heads * head_size, embedding_size)
        # self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.projection(x)
        # x = self.dropout(x)
        
        return x

In [254]:
class FeedForward(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size),
            # nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

In [255]:
class LayerNorm(nn.Module):
    def __init__(self, embed_size, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embed_size))
        self.beta = nn.Parameter(torch.zeros(embed_size))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

In [256]:
class Block(nn.Module):
    def __init__(self, embedding_size, num_heads):
        super().__init__()
        head_size = embedding_size // num_heads
        self.attn = MultiHeadAttention(head_size, num_heads)
        self.ff = FeedForward(embedding_size)

        self.ln1 = LayerNorm(embedding_size)
        self.ln2 = LayerNorm(embedding_size)
    
    def forward(self, x):
        x = self.attn(self.ln1(x)) + x
        x = self.ff(self.ln2(x)) + x
        return x

In [262]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = nn.Embedding(block_size, embedding_dim)

        self.blocks = nn.Sequential(*[Block(embedding_dim, num_heads=8) for _ in range(6)])

        self.ln = LayerNorm(embedding_dim)

        self.lm_head = nn.Linear(embedding_dim, vocab_size)
    
    def forward(self, idx, targets = None):
        batch_size, seq_len = idx.shape

        embeddings = self.embedding(idx) # (batch, seq_len, embedding_dim)
        positions = torch.arange(seq_len).to(device)

        x = embeddings + self.position_embedding(positions) # (batch, seq_len, embedding_dim)

        x = self.blocks(x) # (batch, seq_len, embedding_dim)
        
        x = self.ln(x) # (batch, seq_len, embedding_dim)

        logits = self.lm_head(x) # (batch, seq_len, vocab_size)

        if targets is None:
            loss = None
        else:

            batch_size, seq_len, vocab_size = logits.shape
            logits = logits.view(batch_size * seq_len, vocab_size)
            targets = targets.view(batch_size * seq_len)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens = 25):
        # idx: (batch, seq_len)

        print(idx.shape)
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]     # (batch, block_size)

            logits, _ = self(idx_cond) 
            print(idx.shape)               # (batch, seq_len, vocab_size) for logits

            # this preserves the each sequence in the batch, and the vocab logits, but only the last token in each sequence
            logits = logits[:, -1, :]                           # (batch, vocab_size)

            probs = F.softmax(logits, dim=-1)                   # (batch, vocab_size)

            idx_next = torch.multinomial(probs, num_samples=1)  # (batch, 1)
            idx = torch.cat([idx, idx_next], dim=1)             # (batch, seq_len + 1)
        
        return idx
    
model = BigramLanguageModel(vocab_size, embedding_dim=64)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
train(model, optimizer, train_data, test_data, n_steps=10000)

{'train': 4.187445931434631, 'val': 4.204956769943237}


KeyboardInterrupt: 

In [263]:
idx = torch.zeros((1,1), dtype=torch.long).to(device)
logits = model.generate(idx, max_new_tokens=250)
print(tokenizer.decode(logits[0].tolist()))

torch.Size([1, 1])
torch.Size([1, 1])
torch.Size([1, 2])
torch.Size([1, 3])
torch.Size([1, 4])
torch.Size([1, 5])
torch.Size([1, 6])
torch.Size([1, 7])
torch.Size([1, 8])
torch.Size([1, 9])
torch.Size([1, 10])
torch.Size([1, 11])
torch.Size([1, 12])
torch.Size([1, 13])
torch.Size([1, 14])
torch.Size([1, 15])
torch.Size([1, 16])
torch.Size([1, 17])
torch.Size([1, 18])
torch.Size([1, 19])
torch.Size([1, 20])
torch.Size([1, 21])
torch.Size([1, 22])
torch.Size([1, 23])
torch.Size([1, 24])
torch.Size([1, 25])
torch.Size([1, 26])
torch.Size([1, 27])
torch.Size([1, 28])
torch.Size([1, 29])
torch.Size([1, 30])
torch.Size([1, 31])
torch.Size([1, 32])
torch.Size([1, 33])
torch.Size([1, 34])
torch.Size([1, 35])
torch.Size([1, 36])
torch.Size([1, 37])
torch.Size([1, 38])
torch.Size([1, 39])
torch.Size([1, 40])
torch.Size([1, 41])
torch.Size([1, 42])
torch.Size([1, 43])
torch.Size([1, 44])
torch.Size([1, 45])
torch.Size([1, 46])
torch.Size([1, 47])
torch.Size([1, 48])
torch.Size([1, 49])
torch.Size

In [199]:
torch.save(model.state_dict(), 'model.pth')

In [205]:
# Load the model

model2 = BigramLanguageModel(vocab_size, embedding_dim=384).to(device)
model2.load_state_dict(torch.load('model.pth'))

idx = torch.zeros((1,1), dtype=torch.long).to(device)
logits = model2.generate(idx, max_new_tokens=250)
print(tokenizer.decode(logits[0].tolist()))

torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 1, 1]) torch.Size([1, 1, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size([1, 2, 1])
torch.Size([1, 2, 1]) torch.Size