# toyGPT - a character level GPT model trained on all of Shakespeare's works

## First let's explore the dataset

In [1]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text=f.read()

In [2]:
print("length of characters in text: ", len(text))

length of characters in text:  1115394


In [3]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


### We can now construct our vocabulary from the dataset

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("length of unique characters in text: ", vocab_size)

length of unique characters in text:  65


Let's now create the encoding and decoding dictionaries and functions that will be used to encode and decode the text. When using language models, our input text needs to be converted into tokens that can be used as input to the model. In our case we are using a very basic encoding in which each character is encoded as our vocabulary set is not too large. In larger vocabularies, a more complex encoding is required which uses subword units. This level of encoding is beyond the scope of this notebook.

In [5]:
ctoi = {c:i for i,c in enumerate(chars)}
itoc = {i:c for i,c in enumerate(chars)}

encode = lambda str: [ctoi[char] for char in str]
decode = lambda lst: ''.join([itoc[i] for i in lst])

Now we can divide our dataset into training and validation sets. We will use 90% of the data for training and 10% for validation.

In [6]:
import torch

data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)

torch.Size([1115394])


In [7]:
split_idx = int(len(data)*0.9)
train_data = data[:split_idx]
val_data = data[split_idx:]

When we feed our data into the model we don't feed it all the data at once. That would be too computationally expensive. Instead we feed it in chunks. The size of these chunks will be called block size. Now, the important thing to note here is that when we take a chunk of the data, this chunk translates to multiple training examples. This is possible because we can look at a chunk in smaller parts and each of these parts is a training example. For example if we get a chunk that looks like this:

```
Hello World!
```

It can be split into the following training examples:

```
H -> e
He -> l
Hel -> l
Hell -> o
Hello ->
Hello -> W
Hello W -> o
Hello Wo -> r
Hello Wor -> l
Hello Worl -> d
Hello World -> !
```

In our case each character acts as a training example or an addition to it. So if we have a chunk of size 10, we will have 10 training examples. This is why we need to divide our dataset into chunks of size block_size.

In [8]:
block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size+1]

for i in range(block_size):
    context = x[:i+1]
    target = y[i]
    print(f"When the input is: {decode(context.tolist())}({context.tolist()}) the target is: {decode([target.item()])}({target})")

When the input is: F([18]) the target is: i(47)
When the input is: Fi([18, 47]) the target is: r(56)
When the input is: Fir([18, 47, 56]) the target is: s(57)
When the input is: Firs([18, 47, 56, 57]) the target is: t(58)
When the input is: First([18, 47, 56, 57, 58]) the target is:  (1)
When the input is: First ([18, 47, 56, 57, 58, 1]) the target is: C(15)
When the input is: First C([18, 47, 56, 57, 58, 1, 15]) the target is: i(47)
When the input is: First Ci([18, 47, 56, 57, 58, 1, 15, 47]) the target is: t(58)


Now that we have our chunks, we want to pass in batches of these chunks to the model. This is where batch_size comes in. We will take batch_size chunks and pass them to the model. This means that we will have batch_size * block_size training examples per batch.

In [9]:
torch.manual_seed(42)

batch_size = 4

def get_batch(split):
    dataset = train_data if split == 'train' else val_data
    start_idx = torch.randint(len(dataset) - block_size, (batch_size,))
    x = torch.stack([dataset[idx:idx+block_size] for idx in start_idx])
    y = torch.stack([dataset[idx+1:idx+block_size+1] for idx in start_idx])
    return x, y

xb, yb = get_batch('train')

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When the input is: {decode(context.tolist())}({context.tolist()}) the target is: {decode([target.item()])}({target})")

When the input is: s([57]) the target is:  (1)
When the input is: s ([57, 1]) the target is: h(46)
When the input is: s h([57, 1, 46]) the target is: i(47)
When the input is: s hi([57, 1, 46, 47]) the target is: s(57)
When the input is: s his([57, 1, 46, 47, 57]) the target is:  (1)
When the input is: s his ([57, 1, 46, 47, 57, 1]) the target is: l(50)
When the input is: s his l([57, 1, 46, 47, 57, 1, 50]) the target is: o(53)
When the input is: s his lo([57, 1, 46, 47, 57, 1, 50, 53]) the target is: v(60)
When the input is:  ([1]) the target is: t(58)
When the input is:  t([1, 58]) the target is: h(46)
When the input is:  th([1, 58, 46]) the target is: e(43)
When the input is:  the([1, 58, 46, 43]) the target is: r(56)
When the input is:  ther([1, 58, 46, 43, 56]) the target is: e(43)
When the input is:  there([1, 58, 46, 43, 56, 43]) the target is:  (1)
When the input is:  there ([1, 58, 46, 43, 56, 43, 1]) the target is: c(41)
When the input is:  there c([1, 58, 46, 43, 56, 43, 1, 4

Now let's start by making the most basic language model (Bigram model) and see how it performs on our dataset. We will have a token embedding table of size vocab_size by vocab_size that will encode each character. The forward pass will use the token embedding table to encode the input and calculate the loss using the cross entropy loss function. 

We can also use the model to generate text. We will start by giving it a prompt and then we will let it generate the rest of the text.

In [10]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(42)

class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, generate_tokens):
        for i in range(generate_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], 1)

        return idx
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(f"loss: {loss.item()}") # ideally should be -ln(1/65) = 4.174
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), 100)[0].tolist()))

torch.Size([32, 65])
loss: 4.886534690856934

o$,q&IWqW&xtCjaB?ij&bYRGkF?b; f ,CbwhtERCIfuWr,DzJERjhLlVaF&EjffPHDFcNoGIG'&$qXisWTkJPw
 ,b Xgx?D3sj


We see that without any training the model generates text that looks nothing like what we would expect. So let's train it and see how it performs.

In [11]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [12]:
batch_size=32
for _ in range(10000):
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(f"loss: {loss.item()}")

loss: 2.5233194828033447


In [13]:
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), 1000)[0].tolist()))


QUDUThe chas.
F lisen tabr:
LI mus nk,
A: al l ayo cenghe's therinvar,
TEsen ithawaneit at islinerainy atsomo clour pad d wikn h,
HYy my Tholes:
it GBy ke m vilou xthazinderand llo chee lond Cld this lisesule wars, tirofof wnofan
Rou cthe p.

By hat celis ire m, aksthethe aur withAR wotoot.
Toy:me, of Ithed; bo r:
DWAy celowinoourne,
WIDYoukesu t I:f fowhilong bert irw:
I m;
ADWhit hor hy t I nd, billexve, war t, s
When re llyong thm ithinde!
Whem mire ow
MIAPet mad, trd br hay
ANG w t we illlaisthe:
CESk ewhaiowaue e;I'OND:
t m; br
Fergho br rosoulin rfe.

Lnoof by, bald woande: ay,
LABRKitirit t, ken,
Whisppal

And, r ar st
Blalist su s the,

AUC; har anorg mellban, ll w ny hand.

A s
I thitherol at ceres, sticco:
TI st ngere, t t!
IZABurnosoreivet at cay iendet ch ds frthanan g ilr INGrorsto itopllver hequleat anehmoqus t cthabyowoveal Bushanean orusun,
CO:
!

SWAUESow Sore t'SThomasth cor:
FO:
GOxt
Wherieatrerpethalfll t. fit RL:
I inondvedat ir'd icere.
Ben olan, te ENToullo ford

After 10,000 steps of optmization we see how much better the model has gotten and the text it generates is much more of what we would expect from Shakespeare. However, due to the simplicity of the model, it still has a lot of trouble with generating words and sentences.

##### Mathematical trick for self-attention

We want to move beyond using just the previous character to inform the next character. A simple improvement would be to use the average of all previous embeddings in our context. One way to perform this is to calculate the average using a for loop as shown below.

In [28]:
B, T, C = 4, 8, 2
torch.manual_seed(1337)
x = torch.randn((B,T,C))
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xbow[b, t] = torch.mean(x[b, :t+1], 0)

print((x[0,0]+x[0,1])/2)
print(xbow[0,1])

tensor([-0.0894, -0.4926])
tensor([-0.0894, -0.4926])


So we can see how the for loop can be used to calculate the average until time t. However, this is not very efficient and we can use a mathematical trick to improve the efficiency. We will be using matrix multiplication which will vectorize the operation and make it much faster.

In [31]:
weights = torch.tril(torch.ones((T,T)))
weights /= weights.sum(1, keepdim=True)
xbow2 = weights @ x
print(xbow[0])
print(xbow2[0])
print(torch.allclose(xbow, xbow2))

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
True


The tril method gives us a lower triangular matrix from the input matrix. And when we divide each row in it by the sum of the row, we get weights that will sum to 1 and can therefore be use to calculate the weighted average of the embeddings. By getting the lower triangular matrix we are making sure that we are only using the embeddings until time t and not the ones after it.

Now we can also obtain the weights matrix another way that may seem more intuitive. Let's implement that and talk about the intuition behind it.

In [33]:
tril = torch.tril(torch.ones((T,T)))
weights = torch.zeros((T,T))
weights = torch.masked_fill(weights, tril == 0, float('-inf'))
weights = F.softmax(weights, 1)
xbow3 = weights @ x
print(xbow[0])
print(xbow3[0])
print(torch.allclose(xbow, xbow3))

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
True


Why this approach makes more intuitive sense than the previous implementation is because we can view the initialization of the weights as being the connections between the current embedding and the previous embeddings. So initially we set the weights to all previous embeddings to be 0 and to all future embeddings to be -inf. Then as the model learns, it can adjust the weights to be more appropriate. This is why we use the softmax function to get the weights as it will make sure that the weights sum to 1.