In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# for reproducibility
torch.manual_seed(1337)

<torch._C.Generator at 0x1d34c0cd1b0>

In [2]:
# open dataset
with open('tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
    shakespeare = f.read()

In [3]:
# read part of dataset
print(shakespeare[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [4]:
# get characters used in dataset (and vocabulary size)
chars = sorted(list(set(shakespeare)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
# encode and decode characters as integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [6]:
# testing char-int encoding and decoding
print(encode("Hello world!"))
print(decode(encode("Hello world!")))

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
Hello world!


In [7]:
# encode dataset in int-tensor
data = torch.tensor(encode(shakespeare), dtype=torch.long)

In [8]:
# train test split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
print(len(train_data))
print(len(val_data))

1003854
111540


In [9]:
# looking at feature/target blocks
block_size = 9
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'when input is {context} the target is: {target}')

when input is tensor([18]) the target is: 47
when input is tensor([18, 47]) the target is: 56
when input is tensor([18, 47, 56]) the target is: 57
when input is tensor([18, 47, 56, 57]) the target is: 58
when input is tensor([18, 47, 56, 57, 58]) the target is: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58]) the target is: 47


We will be training the transformer on all of the above examples. We want the model to be able to recognize and make predictions for as little as a single word, and as much as an entire block size.

This is not at the cost of computation, since these chunks will be processed in parallel in the GPU.

In [10]:
batch_size = 4 # sequences to process in parallel
block_size = 8 # maximum sequence length to process

def get_batch(data):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [11]:
xb, yb = get_batch(train_data)

In [12]:
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)
print('-----')
for b in range(2):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target is: {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-----
when input is [24] the target is: 43
when input is [24, 43] the target is: 58
when input is [24, 43, 58] the target is: 5
when input is [24, 43, 58, 5] the target is: 57
when input is [24, 43, 58, 5, 57] the target is: 1
when input is [24, 43, 58, 5, 57, 1] the target is: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target is: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target is: 39
when input is [44] the target is: 53
when input is [44, 53] the target is: 56
when input is [44, 53, 56] the target is: 1
when input is [44, 53, 56, 1] the target is: 58
when input is [44, 53, 56, 1, 58]

The Bigram language model uses an embedding table of size `(vocab_size, vocab_size)`, where passing the index pulls the corresponding row from the table.

PyTorch arranges the bigram table into a `(B, T, C)` table, where: <br>
B = Batch (batch_size) <br>
T = Time (block_size) <br>
C = Channel (vocab_size) <br>

The Bigram language model does NOT consider sequences, only a single previous character (block_size == 1).

In [13]:
class BigramLanguageModel(nn.Module):

    # initialize (embed)
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    # forward pass
    def forward(self, idx, targets=None):
        # idx and targets are (B, T) of table
        logits = self.token_embedding_table(idx) # (B,T,C)
        # generation
        if targets is None:
            loss = None
        else: # training
            # cross_entropy takes (B*T,C) so logits and targets need to be reshaped
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    # generate next set of tokens up to max_new_tokens
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [14]:
m = BigramLanguageModel(vocab_size)
logits, loss= m(xb, yb)
print(logits.shape)
print(loss)

# start sequence with newline character
idx = torch.zeros((1, 1), dtype=torch.long)
# sequence and generate (untrained model, random / garbage text)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(5.0364, grad_fn=<NllLossBackward0>)

l-QYjt'CL?jLDuQcLzy'RIo;'KdhpV
vLixa,nswYZwLEPS'ptIZqOZJ$CA$zy-QTkeMk x.gQSFCLg!iW3fO!3DGXAqTsq3pdgq


In [15]:
# optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [16]:
# training
batch_size = 32
total_steps = 15_000
for steps in range(total_steps):
    # sample batch of data
    xb, yb = get_batch(train_data)
    # evaluate loss
    logits, loss = m(xb, yb)
    # zero gradients and backpropagate
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % 1_000==0:
        print(f'{steps} / {total_steps} | loss: {round(loss.item(), 4)}')

0 / 15000 | loss: 4.6583
1000 / 15000 | loss: 3.7861
2000 / 15000 | loss: 3.1343
3000 / 15000 | loss: 2.6584
4000 / 15000 | loss: 2.6051
5000 / 15000 | loss: 2.5792
6000 / 15000 | loss: 2.5388
7000 / 15000 | loss: 2.4331
8000 / 15000 | loss: 2.4081
9000 / 15000 | loss: 2.4967
10000 / 15000 | loss: 2.4376
11000 / 15000 | loss: 2.4759
12000 / 15000 | loss: 2.4248
13000 / 15000 | loss: 2.4868
14000 / 15000 | loss: 2.3344


In [17]:
# prints n_tokens of generated text
def generate_text(n_tokens):
    print(decode(
        m.generate(
            torch.zeros((1, 1), dtype=torch.long),
            max_new_tokens=n_tokens
        )[0].tolist()
    ))

In [18]:
generate_text(300)


TEY ishwarod, se ttha's; I ppry memitth we ieeelo me,
Fait Cloog se tt afre ce tim wary stuklel d lofran VID g
OLI'tof areris nde imowlmandise wineatingiomanh y Mave,
NCHegr.
Trs y thaurymeresththonglast ffomofo, thiles's Ble t ireowor-mito shigee mer
Wank
Myo I t Wh wiakn I ad y, LICophouplefarouk 


The following is a math trick in self-attention.

In [19]:
# matrix dimensions
B,T,C = 4,8,2
# random (weights or values)
x = torch.randn(B,T,C)
print(x.shape)
# "bag of words" (average of words)
# average of all previous words & current word
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow[b,t] = torch.mean(xprev, 0)
print('values')
print(x[0])
print('averages')
print(xbow[0])

torch.Size([4, 8, 2])
values
tensor([[ 1.8236, -1.7576],
        [-0.5178, -1.2025],
        [ 0.6484,  1.6954],
        [-0.8710, -1.2504],
        [ 0.9018, -0.0196],
        [-0.4445,  0.5347],
        [ 1.3271,  0.0285],
        [-0.3768, -0.5363]])
averages
tensor([[ 1.8236, -1.7576],
        [ 0.6529, -1.4801],
        [ 0.6514, -0.4216],
        [ 0.2708, -0.6288],
        [ 0.3970, -0.5069],
        [ 0.2568, -0.3333],
        [ 0.4097, -0.2816],
        [ 0.3113, -0.3135]])


The for loop above is inefficient. The 'mathematical trick' is a form of matrix multiplication.

In [20]:
torch.manual_seed(42)
# triangular 3x3 ones
a = torch.tril(torch.ones(3, 3))
# random 3x2 of integers
b = torch.randint(0,10,(3,2)).float()
# multiply
c = a @ b

print('a=')
print(a)
print('-----')
print('b=')
print(b)
print('-----')
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
-----
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
-----
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


The above is for summing values, but we can do the same for averages by averaging the triangular matrix.

In [21]:
torch.manual_seed(42)
# triangular 3x3 ones
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
# random 3x2 of integers
b = torch.randint(0,10,(3,2)).float()
# multiply
c = a @ b

print('a=')
print(a)
print('-----')
print('b=')
print(b)
print('-----')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
-----
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
-----
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Returning to the toy example...

In [22]:
# triangular average matrix
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
print(wei)
# multiplying by toy matrix
# (T, T) @ (B, T, C) --> broadcast
# (B, T, T) @ (B, T, C)
# = (B, T, C)
xbow2 = wei @ x
# confirm equality of nested for loop and matrix mult
print(torch.allclose(xbow, xbow2))

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
True


There is another way to do this: use softmax.

In [23]:
# lower triangular ones
tril = torch.tril(torch.ones(T, T))
print(tril)
# all zeros
wei = torch.zeros((T, T))
# lower triangular zeros, upper triangular -inf
wei = wei.masked_fill(tril == 0, float('-inf'))
print(wei)
# triangular average matrix
wei = F.softmax(wei, dim=-1)
print(wei)
# matmul
xbow3 = wei @ x
print(xbow3[0])
# confirm equality
print(torch.allclose(xbow, xbow3))

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000,

Now let's implement self-attention. This is no longer a simple average of the previous tokens; rather, we can have data-driven connections between the final token and any of the previous tokens in the sequence.

We can implement this with the bread and butter of self-attention: the query and key vectors. Each token has a query vector (what am I looking for?) and a key vector (what do I contain?). The query vectors dot product with the key vectors for each token, so tokens that have a stronger interaction have a larger dot product. 

$$ \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

where $d_k$ is `head_size`.

In [25]:
B,T,C = 4,8,32
x = torch.randn(B,T,C)

head_size = 16 # H
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,H)
q = query(x) # (B,T,H)
# query / key dot product
wei = q @ k.transpose(-2, -1) # (B,T,H) @ (B,H,T) --> (B,T,T)
wei *= head_size**-0.5

# softmax
tril = torch.tril(torch.ones((T,T)))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
# communication between query / key and value
out = wei @ v # (B,T,H)
print(out.shape)

torch.Size([4, 8, 16])
