# Preliminaries

Load dataset

In [1]:
with open("shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [2]:
print(f"Number of chars in dataset: {len(text)}")

Number of chars in dataset: 1115394


In [3]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


Build vocabulary

In [18]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"{vocab_size=}")
"".join(chars)

vocab_size=65


"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

Create "char to index" and "index to char" mappings

In [9]:
stoi = { s:i for i, s in enumerate(chars) }
itos = { i:s for s, i in stoi.items() }

In [10]:
stoi["&"]

4

In [11]:
itos[4]

'&'

In [108]:
encode = lambda s: [stoi[c] for c in s] # encodes a string
decode = lambda e: "".join([itos[i] for i in e]) # decodes an encoding

In [109]:
encode("Hello world")

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]

In [110]:
decode([20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42])

'Hello world'

Tokenize dataset

In [16]:
import torch

data = torch.tensor(encode(text), dtype=torch.long)
data.shape

torch.Size([1115394])

In [17]:
data[:100]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])

Split dataset into train and validation

In [19]:
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]

print(train_data.shape)
print(val_data.shape)

torch.Size([1003854])
torch.Size([111540])


Experimenting with context length

In [32]:
context_len = 8
sample = train_data[:context_len+1].tolist()
sample

[18, 47, 56, 57, 58, 1, 15, 47, 58]

In [34]:
print("Input ==> Target")
print("----------------")

for i in range(context_len):
    x = sample[:i+1]
    y = sample[i+1]
    print(f"{x} ==> {y}")

Input ==> Target
----------------
[18] ==> 47
[18, 47] ==> 56
[18, 47, 56] ==> 57
[18, 47, 56, 57] ==> 58
[18, 47, 56, 57, 58] ==> 1
[18, 47, 56, 57, 58, 1] ==> 15
[18, 47, 56, 57, 58, 1, 15] ==> 47
[18, 47, 56, 57, 58, 1, 15, 47] ==> 58


In [36]:
print("Input (with padding) ==> Target")
print("-------------------------------")

for i in range(context_len):
    x = [0] * (context_len-(i+1)) + sample[:i+1]
    y = sample[i+1]
    print(f"{x} ==> {y}")

Input (with padding) ==> Target
-------------------------------
[0, 0, 0, 0, 0, 0, 0, 18] ==> 47
[0, 0, 0, 0, 0, 0, 18, 47] ==> 56
[0, 0, 0, 0, 0, 18, 47, 56] ==> 57
[0, 0, 0, 0, 18, 47, 56, 57] ==> 58
[0, 0, 0, 18, 47, 56, 57, 58] ==> 1
[0, 0, 18, 47, 56, 57, 58, 1] ==> 15
[0, 18, 47, 56, 57, 58, 1, 15] ==> 47
[18, 47, 56, 57, 58, 1, 15, 47] ==> 58


Function to get a batch

In [92]:
torch.manual_seed(1337)

# params
batch_size = 4
context_len = 8

def get_batch(split):
    data = train_data if split == "train" else val_data
    # get "batch_size" number of random indices
    ixs = torch.randint(low=0, high=len(data)-context_len, size=(batch_size,))
    # get inputs
    x = torch.stack([data[i:i+context_len] for i in ixs])
    # get labels
    y = torch.stack([data[i+1:i+context_len+1] for i in ixs])
    return x, y

# get a sample batch
xb, yb = get_batch("train")

print("------")
print("inputs")
print(xb.shape)
print(xb)
print("-------")
print("targets")
print(yb.shape)
print(yb)

------
inputs
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
-------
targets
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


# BigramLanguageModel

What is the expected negative log likelihood of a completely uniform model?


In [93]:
-torch.log(torch.tensor(1/vocab_size))

tensor(4.1744)

Define the model

In [155]:
import torch.nn as nn
from torch.nn import functional as F

# bigram model doesn't really care about context length
# it simply predicts the next character based on the prev character

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # interesting how the output dimension is vocab_size
        self.embedding_table = nn.Embedding(vocab_size, vocab_size)
        # this is a very simple model
        # the embedding can be directly interpreted as the logits (predictions for next token)

    def forward(self, x, y=None):
        # x: (B, T)
        # y: (B, T)
        # embedding table essentially replaces each index with its corresponding embedding
        logits = self.embedding_table(x) # logits: (B, T, C)
        # calculate loss:
        if y is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            y = y.view(-1)
            loss = F.cross_entropy(logits, y)
        return logits, loss
    
    def generate(self, x, max_new_tokens):
        """
        quick note about this method:
        this is a simple bigram model, so it only needs the immediate previous
        character to predict the next token
        however this method's implementation feeds the entire previous context,
        and then we just extract the last prediction
        this is obviously inefficient, as we could simply pass the most recent token,
        to predict the next one
        however this method's implementation will scale to more complex architectures,
        which actually care about context length :)        
        """
        # x: (B, T)
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(x) # logits: (B, C, T)
            # for each prediction, get last timestep prediction
            logits = logits[:, -1, :] # (B, C)
            # calculate probs
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1, replacement=True)
            # append sampled index to the running sequence
            x = torch.cat((x, idx_next), dim=1) # (B, T+1)
        return x

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([256, 65])
tensor(4.5040, grad_fn=<NllLossBackward0>)


Sample generation before training

In [158]:
inputs = torch.zeros((1, 1), dtype=torch.long)
inputs

tensor([[0]])

In [159]:
generation = m.generate(inputs, max_new_tokens=100)
generation

tensor([[ 0, 19, 60, 19, 56, 46, 55, 43,  1, 37, 11, 41, 47, 41, 49, 24, 41, 37,
         63, 40, 22, 36, 44,  5,  4, 34, 64, 48, 46, 16, 40, 35, 27,  8,  8,  3,
         49, 25, 46, 63, 16, 61, 12, 18, 32, 41, 29, 52, 18, 15, 51, 54, 42, 16,
         13, 51, 43, 58, 41, 18,  4, 24, 22, 46, 46, 10, 53, 26, 48, 36,  9,  5,
         56, 62, 56, 25, 21, 12, 17, 32, 21,  0, 63, 14, 61,  6, 48, 42, 26, 36,
         42,  1,  1, 13, 35, 20, 27, 39, 47, 61, 27]])

In [160]:
print(decode(generation.squeeze().tolist()))


GvGrhqe Y;cickLcYybJXf'&VzjhDbWO..$kMhyDw?FTcQnFCmpdDAmetcF&LJhh:oNjX3'rxrMI?ETI
yBw,jdNXd  AWHOaiwO


Training

In [161]:
# initialize Adam optimizer with bigram model params
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [162]:
batch_size = 32

for steps in range(1, 10001):
    # get sample batch
    xb, yb = get_batch("train")
    # calculate loss
    logits, loss = m(xb, yb)
    # set gradients to zero
    optimizer.zero_grad(set_to_none=True)
    # calculate gradients
    loss.backward()
    # update parameters
    optimizer.step()

    if steps % 1000 == 0 or steps == 1:
        print(loss.item())

4.610204219818115
3.6542091369628906
3.1350882053375244
2.7605016231536865
2.66096830368042
2.5715842247009277
2.520346164703369
2.5460145473480225
2.6075916290283203
2.6352434158325195
2.654296636581421


Generation after training

In [163]:
inputs = torch.zeros((1, 1), dtype=torch.long)
generation = m.generate(inputs, max_new_tokens=200)
print("".join(decode(generation.squeeze().tolist())))


US:
Sesve her yos illaut!
Th y thallin

Micthe o, wn handssthisthoff t mat ictten w:
sul aresit ofo'Thinjesous and tyom
Tou:p-
inis semid!

AMay se falalsh O:
Thavorme RKEniro afrse Mua f Widrte icr,

