### Transformer Language Model (based on Adrej Karpathy's nanoGPT tutorial)

In [6]:
import torch
from torch import nn
from torch.nn import functional as F
torch.manual_seed(1234)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
# read the txt entire file into a single string
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [8]:
print(f"Total length of dataset in characters: {len(text)}")

Total length of dataset in characters: 1115394


In [9]:
# get vocabulary of characters
vocab = sorted(set(list(text)))
print("character vocabulary: ", vocab)
vocab_size = len(vocab)

character vocabulary:  ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [10]:
# tokenize the text
ctoi = {vocab[i]:i for i in range(vocab_size)}
itoc = {i:vocab[i] for i in range(vocab_size)}
encode = lambda s: [ctoi[c] for c in s]  # converts a string to integer token sequence
decode = lambda s: [itoc[ix] for ix in s]  # converts an integer token sequence to string of characters

In [11]:
print(encode('Hello world!'))
print(decode(encode('Hello world!')))

[20, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 2]
['H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!']


In [12]:
# tokenize the dataset into integer sequence, convert to torch tensor of type int64

data = torch.tensor(encode(text), dtype=torch.long) 
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [13]:
# train-validation splits (90-10)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

We now split the data into chuncks of size block_size. For each chunk, we create (input,target) pairs for next character prediction, where the input is a context window containing all characters preceding the target character. Note that the context sizes range from 1 up to block size, i.e. there will be block_size number of (input,target) pairs per chunk.

In [14]:
block_size = 8

# example showing the first chunk and all possible (input,target) pairs we can get from it
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"Context: {context} --> target: {target}")

Context: tensor([18]) --> target: 47
Context: tensor([18, 47]) --> target: 56
Context: tensor([18, 47, 56]) --> target: 57
Context: tensor([18, 47, 56, 57]) --> target: 58
Context: tensor([18, 47, 56, 57, 58]) --> target: 1
Context: tensor([18, 47, 56, 57, 58,  1]) --> target: 15
Context: tensor([18, 47, 56, 57, 58,  1, 15]) --> target: 47
Context: tensor([18, 47, 56, 57, 58,  1, 15, 47]) --> target: 58


Now lets create a batch generator which creates a batch of randomly selected blocks/chunks from the data

In [15]:
torch.manual_seed(1223)
batch_size = 4
block_size = 8

# data loader (generates a bvatch of randomly selected blocks)
def get_batch(split='train'):
    data = train_data if split=='train' else val_data

    # sample positions from which to grab blocks
    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])      
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device) # move tensors to gpu 
    return x,y 

xbatch, ybatch = get_batch('train')
print("input batch: ")
print(xbatch.shape)
print(xbatch)
print("target batch: ")
print(ybatch.shape)     
print(ybatch)     
print("")

# context target pairs
print(f"A batch of {batch_size} blocks:")
for b in range(batch_size): # batch dimension
    print(f"\nBlock {b}:")
    for t in range(block_size):  # time dimension
        context = xbatch[b,:t+1]
        target = ybatch[b,t]
        print(f"Context: {context.tolist()} --> target: {target}")
    print("")

input batch: 
torch.Size([4, 8])
tensor([[17, 17, 26,  1, 17, 24, 21, 38],
        [14, 17, 24, 24, 13, 10,  0, 28],
        [63,  1, 47, 52,  1, 56, 43, 55],
        [56, 59, 57, 58,  1, 59, 54, 53]], device='cuda:0')
target batch: 
torch.Size([4, 8])
tensor([[17, 26,  1, 17, 24, 21, 38, 13],
        [17, 24, 24, 13, 10,  0, 28, 50],
        [ 1, 47, 52,  1, 56, 43, 55, 59],
        [59, 57, 58,  1, 59, 54, 53, 52]], device='cuda:0')

A batch of 4 blocks:

Block 0:
Context: [17] --> target: 17
Context: [17, 17] --> target: 26
Context: [17, 17, 26] --> target: 1
Context: [17, 17, 26, 1] --> target: 17
Context: [17, 17, 26, 1, 17] --> target: 24
Context: [17, 17, 26, 1, 17, 24] --> target: 21
Context: [17, 17, 26, 1, 17, 24, 21] --> target: 38
Context: [17, 17, 26, 1, 17, 24, 21, 38] --> target: 13


Block 1:
Context: [14] --> target: 17
Context: [14, 17] --> target: 24
Context: [14, 17, 24] --> target: 24
Context: [14, 17, 24, 24] --> target: 13
Context: [14, 17, 24, 24, 13] --> target

### Now let's create a pytorch-ified Bi-gram language model (will serve as a baseline for comparing the transformer model later on)

In [16]:
# model hyperparameters
batch_size = 64
block_size = 8
max_iters = 6000
learning_rate = 1e-2
eval_interval = 300
eval_iters = 100

In [17]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        '''
        Define model parameters
        '''
        # lookup table for finding logits for the next token (i.e. log of counts for all possible next token given input token)
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # shape: (C,C)


    # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        # get logits for every input token
        logits = self.token_embedding_table(idx) # shape: (B,T,C)
        loss = None
        if targets is not None:
            B,T,C = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,C) # reshaped to (B*T,C)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(idx) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx


In [18]:
# create a bigram language model and test it on the example batch
model = BigramLanguageModel(vocab_size=vocab_size)
# move model to device
m = model.to(device)
logits, loss = m(xbatch, ybatch)
print(logits.shape)
print(loss)

# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=100)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))


torch.Size([32, 65])
tensor(4.3812, device='cuda:0', grad_fn=<NllLossBackward0>)

Generated sequence:
 
zuPdOf$JQmmgsWCjS,yoc.obDjcewb:by
ZzRAKGTx
Xn3YnigR,
T;t:'e$BszJiwljm?REKce'DuIN'-KY?fiwJpAS:bt-? M 


Generated sequence looks like gibberish, because model is untrained. We now train the model using a graident based optimiser.

In [19]:
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [20]:
# evaluating training and validation losses averaged over lots of batches
@torch.no_grad() # disable gradient tracking
def estimate_loss(model):
    out = {}
    model.eval() # swicth to inference mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split) 
            _, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean() 
    model.train() # switch back to training mode
    return out       

In [21]:
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")    


epoch: 0, training loss: 4.598198413848877, validation loss: 4.587023735046387
epoch: 300, training loss: 2.735320806503296, validation loss: 2.7429261207580566
epoch: 600, training loss: 2.524005174636841, validation loss: 2.54231858253479
epoch: 900, training loss: 2.4882287979125977, validation loss: 2.499865770339966
epoch: 1200, training loss: 2.4708216190338135, validation loss: 2.4979774951934814
epoch: 1500, training loss: 2.4736366271972656, validation loss: 2.495673894882202
epoch: 1800, training loss: 2.4659011363983154, validation loss: 2.486417770385742
epoch: 2100, training loss: 2.4783377647399902, validation loss: 2.4822909832000732
epoch: 2400, training loss: 2.4603683948516846, validation loss: 2.4931557178497314
epoch: 2700, training loss: 2.4593422412872314, validation loss: 2.4919941425323486
epoch: 3000, training loss: 2.4662582874298096, validation loss: 2.480480194091797
epoch: 3300, training loss: 2.4537932872772217, validation loss: 2.4837450981140137
epoch: 3

Now let's try generating some text using the trained bigram model.

In [22]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=300)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))


Generated sequence:
 

They?
Wha dre
By'haty.
MERThe may mem ad tees atomy acoucoowe,
My?
An went yollle
Tunis ghest he

NCory s wir theanoreimeca I Theist her peshainds bu? whamou larnd isthy t s, s dony atheeawioure s

Mono at the, wicekeee'lowhe he it wipise bethan horomered.

ARo, ovearannumime
Boumet ho t ator tull 


This is better! It has similar syntactic structure as the training text and even has some correct words. The quality is still very bad because the context window is too small, only the previous character is used to predict the next character.

### Self-attention basics

In [23]:
# consider a batch of 4 sequences, with 8 token embeddings in each sequence, with embedding dims=2
B, T, C = 4, 8, 2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

Now for each sequence, we will create context windows randing from size 1 up to T. The context of size t is then computed as the average of the embeddings of all tokens up to position t. This simple averaging gives us a "bag of words" context which has no awareness of relative positions of the tokens.

In [24]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # embedding vectors for all tokens in the context window of size t --> shape: (t,C)  
        # compute bag of words context of size t for the bth sequence
        xbow[b,t] = xprev.mean(dim=0) 

A more efficient way of computing these context vectors is using matrix multiplication.

In [25]:
# consider a single sequence of 3 tokens with embedding dims of 3 --> shape: (3,3)
x = torch.tensor([[1,1], [0,1], [1,0]], dtype=torch.float32) # each row is an embedding vector
print(x)
print("")
# then to get the context of shape: (3,2), in which the the t-th row is the sum of the first t rows in x
# we can simply multiply a lower triangular matrix in which all elements above the diagonal are zero and the rest are 1s
W = torch.tril(torch.ones(3,3))
print(W)   
xbow = W @ x
print(xbow)
print("")

# however, we don't want the sum of embedding vectors, but the mean so instead do the following
W = torch.tril(torch.ones(3,3))
W = W / W.sum(dim=1, keepdims=True)
print(W)
xbow = W @ x
print(xbow)
 

tensor([[1., 1.],
        [0., 1.],
        [1., 0.]])

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[1., 1.],
        [1., 2.],
        [2., 2.]])

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[1.0000, 1.0000],
        [0.5000, 1.0000],
        [0.6667, 0.6667]])


So each row in the weights matrix W gives us the weights for summing up the embedding vectors of all the token. For the ith row, all weights after the ith column are zero,
which means the weighted sum for the ith context only includes tokens up to and including the ith position (in our example, we used uniform weights). So effectively, we're masking out all "future" tokens so that the context only depends on current and past tokens.

In [26]:
# then for a batch of sequences, we can do the following
x = torch.randn(B,T,C)
W = torch.tril(torch.ones(T,T))
W = W / W.sum(dim=1, keepdims=True)
xbow = W @ x # batch matrix multiplication


In [27]:
# another way to generate the lower triangular matrix needed to compute the mean context vectors is as follows
A = torch.tril(torch.ones(T,T))
W = torch.zeros((T,T))
W = W.masked_fill(A == 0, float('-inf')) # masked fill replaces every element in A which equals 0 with -infinity 
print(W)
# then by taking the softmax, we get the desired matrix 
W = F.softmax(W, dim=-1)
print(W)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


### In a self-attention head, these attention weights are not uniform, but are instead computed using the (key, query) vectors of each token in the sequence. Then the output of the attention-head is the weighted sum of value vectors of the tokens.

In [28]:
head_size = 16 # this is the dimensions of the key and query vectors

# the key and query vector are ontained by a linear transform of the embeddings obtained by multiplying with (C, head_size) matrices of learnable weights
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# then given a batch of token sequence embeddings
B, T, C = 4, 8, 2
x = torch.randn(B,T,C) # (B,T,C)

# we can compute the query, key and value vectors for each token as follows
k = key(x) # (B,T,h) where h=16 is the head_size
q = query(x) # (B,T,h)
v = value(x) # (B,T,h)

# then for a sequence of tokens, the (i,j)th attention weight is assigned to be the dot product of the query vector of ith token
# with key vector of jth token, so for the entuire batch we have the following
W = q @ k.transpose(-2,-1)  # we've transposed the key matrix: (B,T,h) --> (B,h,T), the shape of the matrix multiplication result is: (B,T,h) @ (B,h,T) = (B,T,T)

# we also scale the unnormalized weights to have variance of roughly 1
W = W * head_size**(-0.5)


print("\nun-normalized attention weights for first sequence in batch:\n")
print(W[0]) 

# then we apply the "temporal" masking so that the attention weights of future tokens is zero and also normalize so that weights sum to one
A = torch.tril(torch.ones(T,T))
W = W.masked_fill(A == 0, float('-inf')) # masked fill replaces every element in A which equals 0 with -infinity 
print("\nunnormalized masked attention weights:\n")
print(W[0]) 
W = F.softmax(W, dim=-1)
print("\nNormalized masked attention weights:\n")
print(W[0])

# the output of the self-attention head is then the sums of the token embeddings weighted by the attention weights
out = W @ v # (B,T,T) @ (B,T,h) = (B,T,h)


un-normalized attention weights for first sequence in batch:

tensor([[-0.3255, -0.1359,  0.7023,  0.7843,  0.9132, -0.0642,  0.7554, -0.7008],
        [ 0.4533, -0.0343, -0.0129, -0.1061, -0.1898, -0.3820, -0.2144,  0.0173],
        [-1.8406,  0.1965, -0.1939,  0.1792,  0.4945,  1.6717,  0.6568,  0.1746],
        [-1.8140,  0.2287, -0.3421,  0.0223,  0.3180,  1.7213,  0.5163,  0.3221],
        [-1.9375,  0.2729, -0.4890, -0.1024,  0.2012,  1.8989,  0.4442,  0.4667],
        [ 1.1779,  0.0204, -0.5070, -0.7595, -1.0239, -0.7616, -0.9681,  0.5151],
        [-1.4515,  0.2315, -0.4831, -0.1960,  0.0198,  1.4796,  0.2314,  0.4657],
        [ 1.8251, -0.1966,  0.1995, -0.1704, -0.4822, -1.6611, -0.6450, -0.1803]],
       grad_fn=<SelectBackward0>)

unnormalized masked attention weights:

tensor([[-0.3255,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.4533, -0.0343,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.8406,  0.1965, -0.1939,    -i

### Now using the idea of self-attention, we will design a better language model.

In [29]:
# first, we create a single self-attention head module
class Head(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size):
        super().__init__()

        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size

        # define parameters
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)

        # non-parameter tensor of lower triangular ones
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

   
    # define forward pass, input shape: (B,T,C) where B=batch size, T=block_size, C=embedding_dim
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B,T,h) where h is the head_size
        q = self.query(x) # (B,T,h)
        v = self.value(x) # (B,T,h)
        W = q @ k.transpose(-2,-1)  * self.head_size**(-0.5) # (B,T,T)
        W = W.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 
        W = F.softmax(W, dim=-1)
        out = W @ v
        return out


class ImprovedLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # self-attention layer
        self.sa_head = Head(block_size, embedding_dim, head_size) # shape: (T,C,h)
        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # apply self-attention
        x = self.sa_head(x) # (B,T,h)
        # compute output logits
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx

Now let's train this improved model

In [30]:
batch_size = 32
block_size = 8
embedding_dim = 32
head_size = 32
max_iters = 5000
learning_rate = 1e-3
eval_interval = 500
eval_iters = 200

model = ImprovedLanguageModel(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [31]:
# training loop
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")    

epoch: 0, training loss: 4.1714324951171875, validation loss: 4.177326202392578
epoch: 500, training loss: 2.7168636322021484, validation loss: 2.7289960384368896
epoch: 1000, training loss: 2.5374414920806885, validation loss: 2.538370132446289
epoch: 1500, training loss: 2.472914695739746, validation loss: 2.4781994819641113
epoch: 2000, training loss: 2.4376065731048584, validation loss: 2.4514846801757812
epoch: 2500, training loss: 2.4210972785949707, validation loss: 2.4393491744995117
epoch: 3000, training loss: 2.409301996231079, validation loss: 2.4124605655670166
epoch: 3500, training loss: 2.4062507152557373, validation loss: 2.4262123107910156
epoch: 4000, training loss: 2.3829004764556885, validation loss: 2.4174094200134277
epoch: 4500, training loss: 2.382000684738159, validation loss: 2.401224374771118


In [32]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=300)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))


Generated sequence:
 
Ma gi.
SOF noy:
LI I at by;
Wun thir bathr I igrithangesarde thald sat therin
Sho oms wizew cay my,
Thaze.

AULAfave cere ns tmy EOnd tpral the yoro picetr.

An Isome, kutinif chan ouse, tche my houch yidnt whe, se whame gowamendene, leare
GMa-tho we ther hasindem hin cmbigse!n?

A'F:
D, boural.
Hy,


Note that the improved model acheives a slightly lower loss and generate slightly better sequences.

#### To acheive even better performance, we can use multiple attention heads in parallel and concatenate their outputs. This is called "multi-head attention"

In [33]:
class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads):
        super().__init__()
        assert head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"
        self.heads = nn.ModuleList([Head(block_size, embedding_dim, head_size//num_heads) for _ in range(num_heads)])

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)


In [34]:
# improved language model with multi-head self attention
class ImprovedLanguageModelMultiHead(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # self-attention layer
        self.sa_heads = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads) # shape: (T,C,h)
        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # apply self-attention
        x = self.sa_heads(x) # (B,T,h)
        # compute output logits
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx

In [35]:
batch_size = 32
block_size = 8
embedding_dim = 32
head_size = 32
num_heads = 4
max_iters = 5000
learning_rate = 1e-3
eval_interval = 500
eval_iters = 200

model = ImprovedLanguageModelMultiHead(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [36]:
# training loop
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")     

epoch: 0, training loss: 4.205431938171387, validation loss: 4.203817844390869
epoch: 500, training loss: 2.7132978439331055, validation loss: 2.7140204906463623
epoch: 1000, training loss: 2.540811538696289, validation loss: 2.5597493648529053
epoch: 1500, training loss: 2.448634386062622, validation loss: 2.465970993041992
epoch: 2000, training loss: 2.4019908905029297, validation loss: 2.4068872928619385
epoch: 2500, training loss: 2.3573694229125977, validation loss: 2.3764777183532715
epoch: 3000, training loss: 2.317030429840088, validation loss: 2.3318843841552734
epoch: 3500, training loss: 2.302166700363159, validation loss: 2.3186120986938477
epoch: 4000, training loss: 2.2924306392669678, validation loss: 2.3139798641204834
epoch: 4500, training loss: 2.263468027114868, validation loss: 2.293119430541992


In [37]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=800)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))      


Generated sequence:
 
Whertent
't the rer tis ther maly Rovere.

CSoo are prorow oul ibe comd ye hakre cod eend you, thilesfladew And thie sourlak forust
Marce
Anl all iqeay Os! soron ange, will ge? helll doleant gio your, I cove I aris the awithtellaverat heros hads ford ton ralexr ge sime Reck kliker tain
Thes
The xorr tha's?
LURTAROLOD ELIO MAMIK:
I Cyourr
Aned soord
Malfn thit yomd In he im:
Wilirsk.

Thavepo diestinhater shin Bin the
The wis be but mallas, Ead his.
fe lim:
QUCferither?
No hyoul ankimingfore, houth dis samos writ Couf co win not wighst, tall; wy In pis so pluson lxandie one my se livy!
What awathim fe jo'lllouele hist lon,
Hat,
And He a in wiss out my the loue arop I shily you youldeis ter Bucy At dat
lalat ofd.

MOENBENRICHK The fay tour eise, misch presterin mow she old, a, ingsefer torse


#### Instead of calculating the output logits directly from the attention head output, we can put a feed forward (multi-layer perceptron) layer in between the attention head and output layer. This allows us to pack in extra computations and extract more meaningful representations from the attention output. This brings us to the Transformer (Decoder) Block. Each transformer block consists of a multihead self-attemtion layer followed by a feed-forward layer. Transformer blocks are also designed to be stacked up (similar to stacked CNNs and stacked RNNs) and therefore also incorporate residual conections and layer normalization to ensure that the gradients can backpropagate without any difficulty as the stack of transformer blocks become deeper.  

In [38]:
# a simple mlp 
class FeedForward(nn.Module):
    def __init__(self, head_size, num_heads):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(head_size, head_size),
            nn.ReLU()
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)

# a transformer block consisting of a multihead attention-layer followed by a feed-forward layer
class TransformerBlock(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads)
        self.ff = FeedForward(head_size, num_heads)
        
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        x = self.sa(x)
        x = self.ff(x)
        return x


In [46]:
# language model with multiple transfop4mer blocks
class ImprovedLanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # stack of 3 transformer blocks
        self.blocks = nn.Sequential(
            TransformerBlock(block_size, embedding_dim, head_size, num_heads),
            TransformerBlock(block_size, embedding_dim, head_size, num_heads),
            TransformerBlock(block_size, embedding_dim, head_size, num_heads),
        )

        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # pass through transformer blocks
        x = self.blocks(x) # (B,T,h)
        # compute output logits
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx

In [40]:
batch_size = 32
block_size = 8
embedding_dim = 32
head_size = 32
num_heads = 4
max_iters = 5000
learning_rate = 1e-3
eval_interval = 500
eval_iters = 200

model = ImprovedLanguageModelTransformer(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [41]:
# training loop
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")     

epoch: 0, training loss: 4.143747806549072, validation loss: 4.144392013549805
epoch: 500, training loss: 3.0773422718048096, validation loss: 3.0846164226531982
epoch: 1000, training loss: 2.732046127319336, validation loss: 2.716034173965454
epoch: 1500, training loss: 2.5512638092041016, validation loss: 2.5352306365966797
epoch: 2000, training loss: 2.455603837966919, validation loss: 2.442411422729492
epoch: 2500, training loss: 2.409034490585327, validation loss: 2.4058051109313965
epoch: 3000, training loss: 2.3722264766693115, validation loss: 2.3759915828704834
epoch: 3500, training loss: 2.341911554336548, validation loss: 2.34486722946167
epoch: 4000, training loss: 2.3220529556274414, validation loss: 2.3295435905456543
epoch: 4500, training loss: 2.298309326171875, validation loss: 2.3365747928619385


#### Note that the validation loss has not decreased any further, in fact having the single multi-head attention layer seems to have been better. This because we need to add residual connections in the transformer block to improve gradient backpropagation and also include layer norms.

In [42]:
class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads):
        super().__init__()
        assert head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"
        self.heads = nn.ModuleList([Head(block_size, embedding_dim, head_size//num_heads) for _ in range(num_heads)])

        # we also need to apply a linear projection to make the output residual the same dimension as the input
        self.proj = nn.Linear(head_size, embedding_dim) 

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

# a simple mlp 
class FeedForward(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        # we add extra computations by growing out the feed-forward hidden size by a factor of 4
        # we also add an extra linear layer at the end to project the residual back to same dimensions as input
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4*embedding_dim),  
            nn.ReLU(),
            nn.Linear(4*embedding_dim, embedding_dim), 
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)
    

# transformer block with residual connection 
class TransformerBlock(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads)
        self.ff = FeedForward(embedding_dim)
        
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        # residual connection between input and multi-head attention output
        x = x + self.sa(x)
        # residual connection between multi-head attention output and feed-forward output
        x = x + self.ff(x) 
        return x

In [43]:
batch_size = 32
block_size = 8
embedding_dim = 32
head_size = 32
num_heads = 4
max_iters = 5000
learning_rate = 1e-3
eval_interval = 500
eval_iters = 200

model = ImprovedLanguageModelTransformer(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

# training loop
for epoch in range(max_iters):
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")     

epoch: 0, training loss: 4.703917026519775, validation loss: 4.697531700134277
epoch: 500, training loss: 2.385143518447876, validation loss: 2.3879919052124023
epoch: 1000, training loss: 2.242703914642334, validation loss: 2.270765781402588
epoch: 1500, training loss: 2.184208393096924, validation loss: 2.206484317779541
epoch: 2000, training loss: 2.1248888969421387, validation loss: 2.173213005065918
epoch: 2500, training loss: 2.095299005508423, validation loss: 2.1506059169769287
epoch: 3000, training loss: 2.064594030380249, validation loss: 2.1176111698150635
epoch: 3500, training loss: 2.0370447635650635, validation loss: 2.1078944206237793
epoch: 4000, training loss: 2.025120973587036, validation loss: 2.0856645107269287
epoch: 4500, training loss: 2.007169723510742, validation loss: 2.0710246562957764


In [44]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=800)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))  


Generated sequence:
 

DUCHENTBULERCENT:
QUENEN IIZABHARD IV:
Thereringates the want say.

CORIOMENT:
AgUS:
The af headmen most hews, lokegise I with cartless sovon's and your, Weeith hath a it to should my look their to there him hald in Marnted ett
gent and it real'd-woord.

ICABELL:
I wath to thy the criend fathery maid all ade nos I my lyond Yor Wow my and the the affal did fill af my earton that I'll mark you abther;
Lecab the terefer and the frather knecce.

RICHARETIO:
Nrike Rotake a atoo.

FICKINGSOUCESTES:
Non I countlees from thane, and dess of excons, he dood 'the khes?
For Rest mere.

COMINay of sumpidiisbe mein, whell Ratioud I knot but you may let wilfnstren; if and me your death; no me of epere wor am'd thou dath with at?
ClALAND:
Hy, gunk most we them not weet Hech all Hore, funt my poor to twam


#### We can see that adding residual conmnections inside the transformer block has significantly improved the performance. The quality of generated text has also improved dramatically.

#### The final icing on the cake is Layer Normalization, which serves a similar purpose as batch normalization. Layer normalization forces the output neurons of a linear layer to all have zero mean and unit variance. This ensures that activations and gradients are more stable and well-behaved during training. We will use pre-layer norms, i.e. layer norm will be applied at the input of the multi-head attention layer and the input of the feed-forward layer. 

#### We can also add some dropout layers in front of the multi-head attention and feed forward layers and also to the attention weights. Then all together, we have the following:

In [61]:
class Head(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, dropout_rate):
        super().__init__()

        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size

        # define parameters
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)

        # non-parameter tensor of lower triangular ones
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout_rate)

   
    # define forward pass, input shape: (B,T,C) where B=batch size, T=block_size, C=embedding_dim
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B,T,h) where h is the head_size
        q = self.query(x) # (B,T,h)
        v = self.value(x) # (B,T,h)
        W = q @ k.transpose(-2,-1)  * self.head_size**(-0.5) # (B,T,T)
        W = W.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 
        W = F.softmax(W, dim=-1)
        # apply dropout to attention weights
        W = self.dropout(W)
        out = W @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads, dropout_rate):
        super().__init__()
        assert head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"
        self.heads = nn.ModuleList([Head(block_size, embedding_dim, head_size//num_heads, dropout_rate) for _ in range(num_heads)])

        # we also need to apply a linear projection to make the output residual the same dimension as the input
        self.proj = nn.Linear(head_size, embedding_dim) 
        self.dropout = nn.Dropout(dropout_rate)

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        x = self.dropout(x)
        return x

# a simple mlp 
class FeedForward(nn.Module):
    def __init__(self, embedding_dim, dropout_rate):
        super().__init__()
        # we add extra computations by growing out the feed-forward hidden size by a factor of 4
        # we also add an extra linear layer at the end to project the residual back to same dimensions as input
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4*embedding_dim),  
            nn.ReLU(),
            nn.Linear(4*embedding_dim, embedding_dim), 
            nn.Dropout(dropout_rate)
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)
    

# transformer block with residual connection and layer norm
class TransformerBlock(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads, dropout_rate):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads, dropout_rate) # multi-head attention layer 
        self.ff = FeedForward(embedding_dim, dropout_rate)   # feed-forward layer
        self.ln1 = nn.LayerNorm(embedding_dim) # layer norm at input of multi-head attention
        self.ln2 = nn.LayerNorm(embedding_dim) # layer norm at input of feed-forward

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        # residual connection between input and multi-head attention output
        x = x + self.sa(self.ln1(x))
        # residual connection between multi-head attention output and feed-forward output
        x = x + self.ff(self.ln2(x)) 
        return x
    

# language model with multiple transformer blocks
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads, num_blocks, dropout_rate=0.2):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads
        self.num_blocks = num_blocks

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # stack of transformer blocks
        self.blocks = nn.Sequential(*[TransformerBlock(block_size, embedding_dim, head_size, num_heads, dropout_rate) for _ in range(num_blocks)])

        # we also add a layer norm before the final output layer
        self.ln_f = nn.LayerNorm(embedding_dim)

        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # pass through transformer blocks
        x = self.blocks(x) # (B,T,C)
        # apply layer norm
        x = self.ln_f(x)  # (B,T,C)
        # compute output logits 
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    def generate(self, idx, max_new_tokens):
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)
        return idx



#### Training with the improved transformer block which has both residual connections and pre-layer norms and scaling up the network

In [66]:
from tqdm import tqdm

batch_size = 64
block_size = 256
embedding_dim = 384
head_size = embedding_dim
num_heads = 6
num_blocks = 6
dropout_rate = 0.2
max_iters = 5000
learning_rate = 5e-4
eval_interval = 500
eval_iters = 200

model = TransformerLanguageModel(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads, num_blocks=num_blocks, dropout_rate=dropout_rate)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [69]:
num_params = sum(p.numel() for p in m.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")

Total number of parameters in transformer network: 10.788929 M


In [67]:
# training loop
train_loss = None
val_loss = None

pbar = tqdm(range(max_iters), desc="Epochs")
for epoch in pbar:
    # sample a batch of trainin data
    xb, yb = get_batch('train')
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m)
        #print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")     
        train_loss = losses['train'].item()
        val_loss = losses['val'].item()

    pbar.set_description(
        f"Epoch {epoch + 1}, Train Loss: {train_loss:.3f}, Val Loss: {val_loss:.3f}"
    )   

Epoch 5000, Train Loss: 0.984, Val Loss: 1.500: 100%|██████████| 5000/5000 [27:27<00:00,  3.76it/s]   


In [68]:
# generate a single sequences using the model with start token 0
idx = torch.zeros((1,1), dtype=torch.long, device=device)
generated_seq = m.generate(idx, max_new_tokens=2000)[0].tolist()
# Decode integer tokens into characters
generated_seq = decode(generated_seq)
print("\nGenerated sequence:\n","".join(generated_seq))  


Generated sequence:
 
That have his act much tongue
And hurlined upon their own desires,
Are how my prince be sweeted, now strongly are
As musters, and have met our attempted body,
Been witness to thy husband's growthoods;
But of their deaths, province should be plaing on agon,
Put for cordserves would be consul to say.
I'll afide you tell your true hearts' torms!
I'll not swift.

HASTINGS:
I would advise the man to cheque,
I to do that were sense thou suspectures?
Like may be a sweet son in earth; devery wedding;
For away, I sing, a potson! Your which that's next
With immitted with musiciate be broke, a sea,
And thousan Montague: before so quick to his holy
Assoleme. A pluck on my look friend, and it
Is good a tent, or as a whoo's entertain, I
knowing both forget, which I show'd it.

LEONTES:

OLANTES:
Ha! why shall I will have herefore the hour
To wash here stands, whispering heaven stretch'd at the Lord Stanles
To go but with me: alas! I pluck the love
That man I am commanded. Prov