### Transformer Language Model (based on Adrej Karpathy's nanoGPT tutorial)

In [66]:
import torch
from torch import nn
from torch.nn import functional as F
torch.manual_seed(1234)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [1]:
# read the txt entire file into a single string
with open('tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
print(f"Total length of dataset in characters: {len(text)}")

Total length of dataset in characters: 1115394


In [16]:
# get vocabulary of characters
vocab = sorted(set(list(text)))
print("character vocabulary: ", vocab)
vocab_size = len(vocab)

character vocabulary:  ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [11]:
# tokenize the text
ctoi = {vocab[i]:i for i in range(vocab_size)}
itoc = {i:vocab[i] for i in range(vocab_size)}
encode = lambda s: [ctoi[c] for c in s]  # converts a string to integer token sequence
decode = lambda s: [itoc[ix] for ix in s]  # converts an integer token sequence to string of characters

In [17]:
print(encode('Hello world!'))
print(decode(encode('Hello world!')))

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
['H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!']


In [18]:
# tokenize the dataset into integer sequence, convert to torch tensor of type int64

data = torch.tensor(encode(text), dtype=torch.long) 
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [19]:
# train-validation splits (90-10)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

We now split the data into chuncks of size block_size. For each chunk, we create (input,target) pairs for next character prediction, where the input is a context window containing all characters preceding the target character. Note that the context sizes range from 1 up to block size, i.e. there will be block_size number of (input,target) pairs per chunk.

In [21]:
block_size = 8

# example showing the first chunk and all possible (input,target) pairs we can get from it
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"Context: {context} --> target: {target}")

Context: tensor([18]) --> target: 47
Context: tensor([18, 47]) --> target: 56
Context: tensor([18, 47, 56]) --> target: 57
Context: tensor([18, 47, 56, 57]) --> target: 58
Context: tensor([18, 47, 56, 57, 58]) --> target: 1
Context: tensor([18, 47, 56, 57, 58,  1]) --> target: 15
Context: tensor([18, 47, 56, 57, 58,  1, 15]) --> target: 47
Context: tensor([18, 47, 56, 57, 58,  1, 15, 47]) --> target: 58


Now lets create a batch generator which creates a batch of randomly selected blocks/chunks from the data

In [67]:
torch.manual_seed(1223)
batch_size = 4
block_size = 8

# data loader (generates a bvatch of randomly selected blocks)
def get_batch(split='train'):
    data = train_data if split=='train' else val_data

    # sample positions from which to grab blocks
    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])      
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device) # move tensors to gpu 
    return x,y 

xbatch, ybatch = get_batch('train')
print("input batch: ")
print(xbatch.shape)
print(xbatch)
print("target batch: ")
print(ybatch.shape)     
print(ybatch)     
print("")

# context target pairs
print(f"A batch of {batch_size} blocks:")
for b in range(batch_size): # batch dimension
    print(f"\nBlock {b}:")
    for t in range(block_size):  # time dimension
        context = xbatch[b,:t+1]
        target = ybatch[b,t]
        print(f"Context: {context.tolist()} --> target: {target}")
    print("")

input batch: 
torch.Size([4, 8])
tensor([[17, 17, 26,  1, 17, 24, 21, 38],
        [14, 17, 24, 24, 13, 10,  0, 28],
        [63,  1, 47, 52,  1, 56, 43, 55],
        [56, 59, 57, 58,  1, 59, 54, 53]], device='cuda:0')
target batch: 
torch.Size([4, 8])
tensor([[17, 26,  1, 17, 24, 21, 38, 13],
        [17, 24, 24, 13, 10,  0, 28, 50],
        [ 1, 47, 52,  1, 56, 43, 55, 59],
        [59, 57, 58,  1, 59, 54, 53, 52]], device='cuda:0')

A batch of 4 blocks:

Block 0:
Context: [17] --> target: 17
Context: [17, 17] --> target: 26
Context: [17, 17, 26] --> target: 1
Context: [17, 17, 26, 1] --> target: 17
Context: [17, 17, 26, 1, 17] --> target: 24
Context: [17, 17, 26, 1, 17, 24] --> target: 21
Context: [17, 17, 26, 1, 17, 24, 21] --> target: 38
Context: [17, 17, 26, 1, 17, 24, 21, 38] --> target: 13


Block 1:
Context: [14] --> target: 17
Context: [14, 17] --> target: 24
Context: [14, 17, 24] --> target: 24
Context: [14, 17, 24, 24] --> target: 13
Context: [14, 17, 24, 24, 13] --> target

### Now let's create a pytorch-ified Bi-gram language model (will serve as a baseline for comparing the transformer model later on)

In [82]:
# model hyperparameters
batch_size = 64
block_size = 8
max_iters = 6000
learning_rate = 1e-2
eval_interval = 300
eval_iters = 100

In [83]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        '''
        Define model parameters
        '''
        # lookup table for finding logits for the next token (i.e. log of counts for all possible next token given input token)
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # shape: (C,C)


    # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        # get logits for every input token
        logits = self.token_embedding_table(idx) # shape: (B,T,C)
        loss = None
        if targets is not None:
            B,T,C = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,C) # reshaped to (B*T,C)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(idx) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx


In [84]:
# create a bigram language model and test it on the example batch
model = BigramLanguageModel(vocab_size=vocab_size)
# move model to device
m = model.to(device)
logits, loss = m(xbatch, ybatch)
print(logits.shape)
print(loss)

# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=100)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))


torch.Size([32, 65])
tensor(4.5195, device='cuda:0', grad_fn=<NllLossBackward0>)

Generated sequence:
 
CPryCqOOqL 'xTaAvrDiqnGa$PxT
?k3Jy$vjOpus.ca$KsY?u?s$-qwc: !!rzuniPGJ'qvs3VMkxHUWY qeqmefwjYG,LLZ
xM


Generated sequence looks like gibberish, because model is untrained. We now train the model using a graident based optimiser.

In [85]:
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [154]:
# evaluating training and validation losses averaged over lots of batches
@torch.no_grad() # disable gradient tracking
def estimate_loss(model):
    out = {}
    model.eval() # swicth to inference mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split) 
            _, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean() 
    model.train() # switch back to training mode
    return out       

In [87]:
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")    


epoch: 0, training loss: 4.615415573120117, validation loss: 4.633676052093506
epoch: 300, training loss: 2.7299535274505615, validation loss: 2.7564690113067627
epoch: 600, training loss: 2.519435167312622, validation loss: 2.5423777103424072
epoch: 900, training loss: 2.488534927368164, validation loss: 2.5016705989837646
epoch: 1200, training loss: 2.4763269424438477, validation loss: 2.4994447231292725
epoch: 1500, training loss: 2.4710662364959717, validation loss: 2.4984192848205566
epoch: 1800, training loss: 2.466684103012085, validation loss: 2.4925339221954346
epoch: 2100, training loss: 2.467583656311035, validation loss: 2.4871132373809814
epoch: 2400, training loss: 2.4666929244995117, validation loss: 2.4932312965393066
epoch: 2700, training loss: 2.459779739379883, validation loss: 2.485511064529419
epoch: 3000, training loss: 2.454580783843994, validation loss: 2.482884168624878
epoch: 3300, training loss: 2.4595024585723877, validation loss: 2.4840993881225586
epoch: 3

Now let's try generating some text using the trained bigram model.

In [90]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=300)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))


Generated sequence:
 
TI owind,

Y:
A:
F awind fllaverat heroshilds CEdd n!
ISarext me s merd; w klishe t in
Thes
TE k s t thaifor than OLo, s th t ng, ffeey
AS:
Foeers ordld ann d itou makenche thir in. pr.
Hot h po cceplcy aue as YCE t! the
Th--


Thinongamarearr thes ci.
fekesmy tot eritond e;
Whyoke aknthin; y w,-h; 


This is better! It has similar syntactic structure as the training text and even has some correct words. The quality is still very bad because the context window is too small, only the previous character is used to predict the next character.

### Self-attention basics

In [94]:
# consider a batch of 4 sequences, with 8 token embeddings in each sequence, with embedding dims=2
B, T, C = 4, 8, 2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

Now for each sequence, we will create context windows randing from size 1 up to T. The context of size t is then computed as the average of the embeddings of all tokens up to position t. This simple averaging gives us a "bag of words" context which has no awareness of relative positions of the tokens.

In [98]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # embedding vectors for all tokens in the context window of size t --> shape: (t,C)  
        # compute bag of words context of size t for the bth sequence
        xbow[b,t] = xprev.mean(dim=0) 

A more efficient way of computing these context vectors is using matrix multiplication.

In [126]:
# consider a single sequence of 3 tokens with embedding dims of 3 --> shape: (3,3)
x = torch.tensor([[1,1], [0,1], [1,0]], dtype=torch.float32) # each row is an embedding vector
print(x)
print("")
# then to get the context of shape: (3,2), in which the the t-th row is the sum of the first t rows in x
# we can simply multiply a lower triangular matrix in which all elements above the diagonal are zero and the rest are 1s
W = torch.tril(torch.ones(3,3))
print(W)   
xbow = W @ x
print(xbow)
print("")

# however, we don't want the sum of embedding vectors, but the mean so instead do the following
W = torch.tril(torch.ones(3,3))
W = W / W.sum(dim=1, keepdims=True)
print(W)
xbow = W @ x
print(xbow)
 

tensor([[1., 1.],
        [0., 1.],
        [1., 0.]])

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[1., 1.],
        [1., 2.],
        [2., 2.]])

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[1.0000, 1.0000],
        [0.5000, 1.0000],
        [0.6667, 0.6667]])


So each row in the weights matrix W gives us the weights for summing up the embedding vectors of all the token. For the ith row, all weights after the ith column are zero,
which means the weighted sum for the ith context only includes tokens up to and including the ith position (in our example, we used uniform weights). So effectively, we're masking out all "future" tokens so that the context only depends on current and past tokens.

In [127]:
# then for a batch of sequences, we can do the following
x = torch.randn(B,T,C)
W = torch.tril(torch.ones(T,T))
W = W / W.sum(dim=1, keepdims=True)
xbow = W @ x # batch matrix multiplication


In [131]:
# another way to generate the lower triangular matrix needed to compute the mean context vectors is as follows
A = torch.tril(torch.ones(T,T))
W = torch.zeros((T,T))
W = W.masked_fill(A == 0, float('-inf')) # masked fill replaces every element in A which equals 0 with -infinity 
print(W)
# then by taking the softmax, we get the desired matrix 
W = F.softmax(W, dim=-1)
print(W)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


### In a self-attention head, these attention weights are not uniform, but are instead computed using the (key, query) vectors of each token in the sequence. Then the output of the attention-head is the weighted sum of value vectors of the tokens.

In [148]:
head_size = 16 # this is the dimensions of the key and query vectors

# the key and query vector are ontained by a linear transform of the embeddings obtained by multiplying with (C, head_size) matrices of learnable weights
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# then given a batch of token sequence embeddings
B, T, C = 4, 8, 2
x = torch.randn(B,T,C) # (B,T,C)

# we can compute the query, key and value vectors for each token as follows
k = key(x) # (B,T,h) where h=16 is the head_size
q = query(x) # (B,T,h)
v = value(x) # (B,T,h)

# then for a sequence of tokens, the (i,j)th attention weight is assigned to be the dot product of the query vector of ith token
# with key vector of jth token, so for the entuire batch we have the following
W = q @ k.transpose(-2,-1)  # we've transposed the key matrix: (B,T,h) --> (B,h,T), the shape of the matrix multiplication result is: (B,T,h) @ (B,h,T) = (B,T,T)

# we also scale the unnormalized weights to have variance of roughly 1
W = W * head_size**(-0.5)


print("\nun-normalized attention weights for first sequence in batch:\n")
print(W[0]) 

# then we apply the "temporal" masking so that the attention weights of future tokens is zero and also normalize so that weights sum to one
A = torch.tril(torch.ones(T,T))
W = W.masked_fill(A == 0, float('-inf')) # masked fill replaces every element in A which equals 0 with -infinity 
print("\nunnormalized masked attention weights:\n")
print(W[0]) 
W = F.softmax(W, dim=-1)
print("\nNormalized masked attention weights:\n")
print(W[0])

# the output of the self-attention head is then the sums of the token embeddings weighted by the attention weights
out = W @ v # (B,T,T) @ (B,T,h) = (B,T,h)


un-normalized attention weights for first sequence in batch:

tensor([[ 0.0545, -0.0593, -0.4230, -0.2394, -0.2138, -0.0163,  0.3961, -0.2642],
        [-0.0224,  0.0344,  0.2661,  0.1355,  0.1128,  0.0098, -0.2368,  0.1565],
        [-0.0832,  0.1832,  1.4977,  0.7088,  0.5580,  0.0534, -1.2894,  0.8469],
        [-0.1025,  0.1488,  1.1382,  0.5882,  0.4951,  0.0421, -1.0202,  0.6751],
        [-0.1214,  0.1572,  1.1731,  0.6261,  0.5387,  0.0440, -1.0676,  0.7085],
        [-0.0049,  0.0084,  0.0666,  0.0330,  0.0270,  0.0024, -0.0585,  0.0386],
        [ 0.1230, -0.2082, -1.6385, -0.8157, -0.6681, -0.0596,  1.4433, -0.9521],
        [-0.0872,  0.1431,  1.1201,  0.5616,  0.4624,  0.0409, -0.9899,  0.6534]],
       grad_fn=<SelectBackward0>)

unnormalized masked attention weights:

tensor([[ 0.0545,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0224,  0.0344,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0832,  0.1832,  1.4977,    -i

### Now using the idea of self-attention, we will design a better language model.

In [156]:
# first, we create a single self-attention head module
class Head(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size):
        super().__init__()

        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size

        # define parameters
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)

        # non-parameter tensor of lower triangular ones
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

   
    # define forward pass, input shape: (B,T,C) where B=batch size, T=block_size, C=embedding_dim
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B,T,h) where h is the head_size
        q = self.query(x) # (B,T,h)
        v = self.value(x) # (B,T,h)
        W = q @ k.transpose(-2,-1)  * self.head_size**(-0.5) # (B,T,T)
        W = W.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 
        W = F.softmax(W, dim=-1)
        out = W @ v
        return out


class ImprovedLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # self-attention layer
        self.sa_head = Head(block_size, embedding_dim, head_size) # shape: (T,C,h)
        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # apply self-attention
        x = self.sa_head(x) # (B,T,h)
        # compute output logits
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx

Now let's train this improved model

In [157]:
batch_size = 32
block_size = 8
embedding_dim = 32
head_size = 32
max_iters = 5000
learning_rate = 1e-3
eval_interval = 500
eval_iters = 200

model = ImprovedLanguageModel(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [159]:
# training loop
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")    

epoch: 0, training loss: 2.394552230834961, validation loss: 2.399357318878174
epoch: 500, training loss: 2.379080295562744, validation loss: 2.4083101749420166
epoch: 1000, training loss: 2.3789455890655518, validation loss: 2.3955671787261963
epoch: 1500, training loss: 2.3779523372650146, validation loss: 2.398219585418701
epoch: 2000, training loss: 2.379781723022461, validation loss: 2.384589433670044
epoch: 2500, training loss: 2.3580307960510254, validation loss: 2.386016607284546
epoch: 3000, training loss: 2.371300458908081, validation loss: 2.394611120223999
epoch: 3500, training loss: 2.366600275039673, validation loss: 2.3787920475006104
epoch: 4000, training loss: 2.3598947525024414, validation loss: 2.3750083446502686
epoch: 4500, training loss: 2.3642451763153076, validation loss: 2.3858792781829834


In [160]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=300)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))


Generated sequence:
 
Fe dis spe shan thess sco win not wigrs,
Fithangts In pisis; pluson lpleon ighake y ealf y!
Whatierathimy fus gatwou le hay sthis,
Ato' omod, a it wiss out wheave louce ake I st I syok yow oris terwathy ut dat
FR ES:
The dares
O RID:
Thhe!
ANIDUSCHe seden schapreche!
Yome, ow IO:
we ds
Wiseifof arse


Note that the improved model acheives a slightly lower loss and generate slightly better sequences.

#### To acheive even better performance, we can use multiple attention heads in parallel and concatenate their outputs. This is called "multi-head attention"

In [176]:
class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads):
        super().__init__()
        assert head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"
        self.heads = nn.ModuleList([Head(block_size, embedding_dim, head_size//num_heads) for _ in range(num_heads)])

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)


In [177]:
# improved language model with multi-head self attention
class ImprovedLanguageModelMultiHead(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # self-attention layer
        self.sa_heads = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads) # shape: (T,C,h)
        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # apply self-attention
        x = self.sa_heads(x) # (B,T,h)
        # compute output logits
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx

In [180]:
batch_size = 32
block_size = 8
embedding_dim = 32
head_size = 32
num_heads = 4
max_iters = 5000
learning_rate = 1e-3
eval_interval = 500
eval_iters = 200

model = ImprovedLanguageModelMultiHead(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [181]:
# training loop
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")     

epoch: 0, training loss: 4.176211357116699, validation loss: 4.173470973968506
epoch: 500, training loss: 2.6718404293060303, validation loss: 2.6823761463165283
epoch: 1000, training loss: 2.520519495010376, validation loss: 2.509432077407837
epoch: 1500, training loss: 2.4418094158172607, validation loss: 2.4534318447113037
epoch: 2000, training loss: 2.3901007175445557, validation loss: 2.402662992477417
epoch: 2500, training loss: 2.3557698726654053, validation loss: 2.3734166622161865
epoch: 3000, training loss: 2.3254342079162598, validation loss: 2.3540947437286377
epoch: 3500, training loss: 2.3081467151641846, validation loss: 2.324794054031372
epoch: 4000, training loss: 2.288052797317505, validation loss: 2.3057618141174316
epoch: 4500, training loss: 2.2802255153656006, validation loss: 2.296482801437378


In [182]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=800)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))      


Generated sequence:
 
Bure,
I gaisck yres thas wo his a omeptilosce;
Ne sor sear:
Ther my nerch hat thougm; obelalivedcorusowno an we, sey don aesndsof nod, Sot hofnucourand py thawke','tber.

QUEREIOn woweabe
Wour ind may gan pat am tis fors.

KICANUS:
RTume wen.
YOMjumy athe to dime fan thoul sold noule shives
Wo's Go sar thoust hat of kerifarn,
ame don tharteat te hor,
Mcam, both' thed aven hat detway by; ad, se hais wow I isnd to to dos ne;
yourdent hese by wert blo's thend of you to wir, a tou's buds us this cereryou sarcald, ut ET:
Sang ande;
Nem Paceim
And sopor picomee ath cher nownIm'D ETHETH!
Tha mernoo nobst sole I inse may
Cour
Milot a magish esilor,-e loon hit Cy:
Sat le, of ing,
Eugant bos!

Theld me theat dos's yousthearis
Tond of,-
The as ot noon, id domavend preght,
Per, ye-fis men, tit moncit 


#### Instead of calculating the output logits directly from the attention head output, we can put a feed forward (multi-layer perceptron) layer in between the attention head and output layer. This allows us to pack in extra computations and extract more meaningful representations from the attention output. This brings us to the Transformer (Decoder) Block. Each transformer block consists of a multihead self-attemtion layer followed by a feed-forward layer. Transformer blocks are also designed to be stacked up (similar to stacked CNNs and stacked RNNs) and therefore also incorporate residual conections and layer normalization to ensure that the gradients can backpropagate without any difficulty as the stack of transformer blocks become deeper.  

In [184]:
# a simple mlp 
class FeedForward(nn.Module):
    def __init__(self, head_size, num_heads):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(head_size, head_size),
            nn.ReLU()
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)

# a transformer block consisting of a multihead attention-layer followed by a feed-forward layer
class TransformerBlock(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads)
        self.ff = FeedForward(head_size, num_heads)
        
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        x = self.sa(x)
        x = self.ff(x)
        return x


In [190]:
# language model with multiple transfop4mer blocks
class ImprovedLanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # stack of 3 transformer blocks
        self.blocks = nn.Sequential(
            TransformerBlock(block_size, embedding_dim, head_size, num_heads),
            TransformerBlock(block_size, embedding_dim, head_size, num_heads),
            TransformerBlock(block_size, embedding_dim, head_size, num_heads),
        )

        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # pass through transformer blocks
        x = self.blocks(x) # (B,T,h)
        # compute output logits
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx

In [186]:
batch_size = 32
block_size = 8
embedding_dim = 32
head_size = 32
num_heads = 4
max_iters = 5000
learning_rate = 1e-3
eval_interval = 500
eval_iters = 200

model = ImprovedLanguageModelTransformer(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [187]:
# training loop
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")     

epoch: 0, training loss: 4.1873459815979, validation loss: 4.184993743896484
epoch: 500, training loss: 3.0473477840423584, validation loss: 3.0452191829681396
epoch: 1000, training loss: 2.646038770675659, validation loss: 2.638247489929199
epoch: 1500, training loss: 2.5332300662994385, validation loss: 2.51843523979187
epoch: 2000, training loss: 2.46379017829895, validation loss: 2.4540207386016846
epoch: 2500, training loss: 2.4229698181152344, validation loss: 2.4220693111419678
epoch: 3000, training loss: 2.373483657836914, validation loss: 2.3759841918945312
epoch: 3500, training loss: 2.3541784286499023, validation loss: 2.3471546173095703
epoch: 4000, training loss: 2.3195488452911377, validation loss: 2.320119619369507
epoch: 4500, training loss: 2.293874740600586, validation loss: 2.301326274871826


#### Note that the validation loss has not decreased any further, in fact having the single multi-head attention layer seems to have been better. This because we need to add residual connections in the transformer block to improve gradient backpropagation and also include layer norms.

In [192]:
class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads):
        super().__init__()
        assert head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"
        self.heads = nn.ModuleList([Head(block_size, embedding_dim, head_size//num_heads) for _ in range(num_heads)])

        # we also need to apply a linear projection to make the output residual the same dimension as the input
        self.proj = nn.Linear(head_size, embedding_dim) 

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

# a simple mlp 
class FeedForward(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        # we add extra computations by growing out the feed-forward hidden size by a factor of 4
        # we also add an extra linear layer at the end to project the residual back to same dimensions as input
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4*embedding_dim),  
            nn.ReLU(),
            nn.Linear(4*embedding_dim, embedding_dim), 
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)
    

# transformer block with residual connection 
class TransformerBlock(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads)
        self.ff = FeedForward(embedding_dim)
        
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        # residual connection between input and multi-head attention output
        x = x + self.sa(x)
        # residual connection between multi-head attention output and feed-forward output
        x = x + self.ff(x) 
        return x

In [193]:
batch_size = 32
block_size = 8
embedding_dim = 32
head_size = 32
num_heads = 4
max_iters = 5000
learning_rate = 1e-3
eval_interval = 500
eval_iters = 200

model = ImprovedLanguageModelTransformer(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

# training loop
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")     

epoch: 0, training loss: 4.509718418121338, validation loss: 4.5064544677734375
epoch: 500, training loss: 2.395660638809204, validation loss: 2.4006667137145996
epoch: 1000, training loss: 2.260805606842041, validation loss: 2.271179437637329
epoch: 1500, training loss: 2.1748201847076416, validation loss: 2.2145626544952393
epoch: 2000, training loss: 2.126512050628662, validation loss: 2.1818671226501465
epoch: 2500, training loss: 2.0978622436523438, validation loss: 2.145966053009033
epoch: 3000, training loss: 2.0685031414031982, validation loss: 2.136467218399048
epoch: 3500, training loss: 2.052429437637329, validation loss: 2.1092963218688965
epoch: 4000, training loss: 2.033585548400879, validation loss: 2.1026153564453125
epoch: 4500, training loss: 2.010476589202881, validation loss: 2.103306531906128


In [194]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=800)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))  


Generated sequence:
 
That u not and maughter tayer,
But thears, I she you so ard hill peotild lovey.

KING ILANES:

ALANIO:
Site you,
Awhat and shines thour down!
 thre you his undip ably that up gry hereford min tandy he good landiouch he royad
as they to conthis for knath sper groyal give,
Noos it doppet, sech a const as this for the tifusord expoth boinint can troinses frest
Ansud well it this fent,
Thilesss: giaw?

The fale wouncegayou fir you shall
And this gas no fead Ethers, agaiass?

PALUCIO:
Now, of surhan of the secon:
Thy bloves, sher hind, tord's hithy with yours spleaced
fing, pluck,
Turns
Eateed we semast'd blovesandayst foble doth her his resarw goe for for clath geady span spelf'd is agrian semen;
And her;
here his as
ar king towers the seal!

KING HENRWOLA:
Fir msight
I low---me's to's as so c


#### We can see that adding residual conmnections inside the transformer block has significantly improved the performance. The quality of generated text has also improved dramatically.

#### The final icing on the cake is Layer Normalization, which serves a similar purpose as batch normalization. Layer normalization forces the output neurons of a linear layer to all have zero mean and unit variance. This ensures that activations and gradients are more stable and well-behaved during training. 