In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

torch.manual_seed(1234)
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

#### In this project, we will attempt to build a chacarter-level GPT language model which learnes to add two non-negative integers, i.e. given the input string "a+b=c", the model will be trained to predict the next character following a sliding context window.

#### This is a simple next character prediction task. We will attempt two different versions of this task: 1) The integers of "c" are predicted left-to-right 2) the integers are predicted from right to left (i.e backward) which is typically how humans compute additions. 


In [171]:
# first let's set up the token vocabulary for this problem
# note that we have three special tokens: '<START>' which denotes the beginning of a problem
# sequence, '<END>' denoting end of sequence and a '<PAD>' token which is used for post-padding sequences to ensure fixed length 
vocab = sorted(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '+', '=', '<START>', '<END>', '<PAD>'])
vocab_size = len(vocab)
print(f"Vocabulary: {vocab}")
print(f"vocab_size = {vocab_size}")

# tokenization
ctoi = {vocab[i]:i for i in range(vocab_size)}
itoc = {i:vocab[i] for i in range(vocab_size)}
encode = lambda s: [ctoi[c] for c in s]  # converts a string to integer token sequence
decode = lambda s: [itoc[ix] for ix in s]  # converts an integer token sequence to string of characters
print(ctoi)

Vocabulary: ['+', '0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9', '<END>', '<PAD>', '<START>', '=']
vocab_size = 16
{'+': 0, '0': 1, '1': 2, '10': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '<END>': 12, '<PAD>': 13, '<START>': 14, '=': 15}


#### Now lets implement the data loader which generates a batch of input-target pairs. We will make sure that the context block size will be large enough to see the entire problem string.

In [326]:
import numpy as np
np.random.seed(1223)

# generates input target pairs for a single problem string "a+b=c"
def generate_batch(block_size=26, batch_size=32, max_digits=5, backward=False):

    # make sure block size is big enough to hold the entire problem string
    max_problem_size = 3*max_digits+2
    assert block_size >= max_problem_size, f"block_size needs to be at least {max_problem_size}"

    inputs = []
    targets = []

    for b in range(batch_size):

        # generate two random integers
        a, b = np.random.randint(0,10**max_digits-1,2)
        c = a + b

        prompt = list(str(a).zfill(max_digits) + "+" + str(b).zfill(max_digits)+"=")
        answer = list(str(c).zfill(max_digits+1))
        
        if backward:
            # reverse the digits of "c"
            answer = reversed(answer)

        #print(f"prompt: {prompt}")
        #print(f"answer: {answer}")

        # encolse with special start and end tokens
        prompt = ['<START>'] + prompt
        answer = answer + ['<END>'] 
        problem = prompt+answer
        tot_len = len(problem)

        # post-pad the problem string to make it (block_size+1) long
        problem = problem + ['<PAD>'] * (block_size+1-tot_len)
        #print(f"padded problem: {problem}")

        input = problem[:block_size]
        target = problem[1:block_size+1]
        #print(f"context: {input} -- > target: {target}")

        # tokenized input and target sequences
        input = torch.tensor(encode(input))
        target = torch.tensor(encode(target))
        #print(f"Tokenized context: {input} -- > target: {target}")

        # mask out all character up to and including the '=' in the target (i.e. replace label with mask label: -1)
        target[:2*max_digits+2] = -1
        #print(f"Masked target: {target}")  

        inputs.append(input)
        targets.append(target)

    # create input,target batch tensors
    x = torch.stack(inputs).to(device)
    y = torch.stack(targets).to(device)

    return x, y


In [327]:
# generate an example batch
max_digits = 5 # max number of digits for input integers 'a' and 'b'
batch_size = 1
block_size = 26 

x, y = generate_batch(block_size, batch_size, max_digits)

In [328]:
print(x.shape)
print(x)
print(y)

torch.Size([1, 26])
tensor([[14,  5,  4,  5,  9,  7,  0, 10,  8,  7,  2,  8, 15,  2,  2, 10, 10, 11,
          2, 12, 13, 13, 13, 13, 13, 13]], device='cuda:0')
tensor([[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  2,  2, 10, 10, 11,  2,
         12, 13, 13, 13, 13, 13, 13, 13]], device='cuda:0')


#### Transformer Decoder model

In [329]:
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, embedding_dim, total_head_size, num_heads, dropout_rate):
        super().__init__()

        assert total_head_size % num_heads == 0, "head_size needs to be integer multiple of num_heads"

        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.total_head_size = total_head_size 
        self.head_size = total_head_size // num_heads 
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # define parameters
        self.key = nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.query = nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.value = nn.Linear(embedding_dim, self.total_head_size, bias=False)
        self.attn_dropout = nn.Dropout(dropout_rate)

        # non-parameter tensor of lower triangular ones
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        # we also need to apply a linear projection to make the output residual the same dimension as the input
        self.proj = nn.Linear(total_head_size, embedding_dim) 
        self.output_dropout = nn.Dropout(dropout_rate)


    # define forward pass, input shape: (B,T,C) where B=batch size, T=block_size, C=embedding_dim
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B,T,H) where H is the total_head_size
        q = self.query(x) # (B,T,H)
        v = self.value(x) # (B,T,H)

        # reshape (B,T,H) --> (B,T,n,h), where n=num_heads and h=head_size and H=n*h
        k = k.view(B,T,self.num_heads,self.head_size) 
        q = q.view(B,T,self.num_heads,self.head_size) 
        v = v.view(B,T,self.num_heads,self.head_size) 

        # now we transpose so that the num_heads is the second dimension followed by T,h
        # this allows us to batch matrix mutliply for all heads simulataneously to compute their attention weights
        # (B,T,n,h) --> (B,n,T,h) 
        k = k.transpose(1,2) 
        q = q.transpose(1,2)
        v = v.transpose(1,2)

        # use pytorch built-in function for faster computation of attention scores (set the 'is_causal' parameter for applying causal masking)
        out = F.scaled_dot_product_attention(q,k,v,dropout_p=self.dropout_rate if self.training else 0,is_causal=True)

        # we can transpose the output from (B,n,T,h) --> (B,T,n,h)
        # since the last two dimensions of the transposed tensor are non-contiguous, we apply 
        # contiguous() which return a contiguous tensor
        out = out.transpose(1,2).contiguous()

        # finally we collapse the last two dimensions to get the concatenated output, (B,T,n,h) --> (B,T,n*h) 
        out = out.view(B,T,self.total_head_size)

        # now we project the concatenated output so that it has the same dimensions as the multihead attention layer input
        # (we need to add it with the input because of the residual connection, so need to be same size) 
        out = self.proj(out) # (B,T,C) 

        # apply dropout
        out = self.output_dropout(out)

        return out
    

# a simple mlp 
class FeedForward(nn.Module):
    def __init__(self, embedding_dim, dropout_rate):
        super().__init__()
        # we add extra computations by growing out the feed-forward hidden size by a factor of 4
        # we also add an extra linear layer at the end to project the residual back to same dimensions as input
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4*embedding_dim),  
            nn.ReLU(),
            nn.Linear(4*embedding_dim, embedding_dim), 
            nn.Dropout(dropout_rate)
        )
    
    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        return self.net(x)
    

# transformer block with residual connection and layer norm
class TransformerBlock(nn.Module):
    def __init__(self, block_size, embedding_dim, head_size, num_heads, dropout_rate):
        super().__init__()
        self.sa = MultiHeadAttention(block_size, embedding_dim, head_size, num_heads, dropout_rate) # multi-head attention layer 
        self.ff = FeedForward(embedding_dim, dropout_rate)   # feed-forward layer
        self.ln1 = nn.LayerNorm(embedding_dim) # layer norm at input of multi-head attention
        self.ln2 = nn.LayerNorm(embedding_dim) # layer norm at input of feed-forward

    # in the forward pass, concatenate the outputs from all the attention heads
    def forward(self, x):
        # residual connection between input and multi-head attention output
        x = x + self.sa(self.ln1(x))
        # residual connection between multi-head attention output and feed-forward output
        x = x + self.ff(self.ln2(x)) 
        return x
    

# language model with multiple transformer blocks
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, embedding_dim, head_size, num_heads, num_blocks, dropout_rate=0.2):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.embedding_dim = embedding_dim
        self.head_size = head_size
        self.hum_heads = num_heads
        self.num_blocks = num_blocks

        '''
        Define model parameters
        '''
        # token embedding layer 
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim) # shape: (vocab_size,C)
        # position embedding layer
        self.pos_embedding = nn.Embedding(block_size, embedding_dim) # shape: (T,C)

        # stack of transformer blocks
        self.blocks = nn.Sequential(*[TransformerBlock(block_size, embedding_dim, head_size, num_heads, dropout_rate) for _ in range(num_blocks)])

        # we also add a layer norm before the final output layer
        self.ln_f = nn.LayerNorm(embedding_dim)

        # output layer logits
        self.lm_head = nn.Linear(head_size, vocab_size) # shape: (h,vocab_size)


        # forward pass takes in a batch of input token sequences of shape (B,T) and corresponding targets of shape (B,T)
    def forward(self, idx, targets=None):
        B, T =idx.shape
        # get token embeddings
        token_embeds = self.token_embedding(idx) # (B,T,C)
        # add positional encoding
        pos_embeds = self.pos_embedding(torch.arange(T, device=device)) # (T,C) 
        x = token_embeds + pos_embeds # (B,T,C)
        # pass through transformer blocks
        x = self.blocks(x) # (B,T,C)
        # apply layer norm
        x = self.ln_f(x)  # (B,T,C)
        # compute output logits 
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B,T,vocab_size = logits.shape
            # reshape the logits and targets such that batch of input sequences are flattened into a single big input sequence
            # i.e. (B,T) --> (B*T)
            logits = logits.view(B*T,vocab_size) # reshaped to (B*T,vocab_size)
            targets = targets.view(B*T) # reshaped to (B*T)
            # compute cross entropy loss (i.e. average negative log likelihood)
            # ignore the masked target labels (i.e. the labels which are -1)
            loss = F.cross_entropy(logits, targets, ignore_index=-1)
        return logits, loss
    

    # generates new sequences continuing from a given batch of context tokens
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        self.eval() # swicth to inference mode
        # batch of contexts, idx has shape (B,T)
        for _ in range(max_new_tokens):
            # since we're using positional encoding, we need to crop idx if input sequence length exceeds block size (keep last block_size tokens)
            idx_crop = idx[:,-self.block_size:] 
            # get predictions
            logits, _ = self(idx_crop) # shape: (B,T,C)
            # for each context sequence (in the batch), compute the probability of the next token using the logits of the last token in the context sequence
            logits = logits[:,-1,:] # shape: (B,C)
            probs = F.softmax(logits, dim=-1) 
            # sample from the probability distribution to get next token
            idx_next = torch.multinomial(probs, num_samples=1) # shape: (B,1)
            # append to the current context
            idx = torch.cat((idx, idx_next), dim=1) # shape: (B,T+1)

        self.train() # swicth to train mode

        return idx
    

# evaluating training and validation losses averaged over lots of batches
@torch.no_grad() # disable gradient tracking
def estimate_loss(model, eval_iters, block_size, batch_size):
    out = {}
    model.eval() # swicth to inference mode
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = generate_batch(block_size=block_size,batch_size=batch_size) 
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean() 
    model.train() # switch back to training mode
    return out        

In [330]:
def test_gpt_adder(a, b, max_digits):
    # generate a single sequences using the model with start token 
    input = ['<START>'] + list(str(a).zfill(max_digits)+"+"+str(b).zfill(max_digits)+"=")
    input = [encode(input)]
    idx = torch.tensor(input,dtype=torch.long, device=device)
    #print(idx)
    generated_seq = m.generate(idx, max_new_tokens=30)[0].tolist()
    # Decode integer tokens into characters
    generated_seq = decode(generated_seq)
    # remove pad tokens
    generated_seq = list(filter(lambda c: c!='<PAD>', generated_seq))
    print("\nGenerated sequence:\n","".join(generated_seq)) 


#### Now lets train the model

In [331]:
batch_size = 32
block_size = 26
embedding_dim = 128
head_size = embedding_dim
num_heads = 4
num_blocks = 3
dropout_rate = 0.2
max_iters = 20000
learning_rate = 5e-4
eval_interval = 200
eval_iters = 100

model = TransformerLanguageModel(vocab_size=vocab_size, block_size=block_size, embedding_dim=embedding_dim, head_size=head_size, num_heads=num_heads, num_blocks=num_blocks, dropout_rate=dropout_rate)
# move model to device
m = model.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

num_params = sum(p.numel() for p in m.parameters())
print(f"Total number of parameters in transformer network: {num_params/1e6} M")

Total number of parameters in transformer network: 0.60136 M


In [332]:
# training loop
train_loss = None
val_loss = None

pbar = tqdm(range(max_iters), desc="Epochs")
for epoch in pbar:
    # sample a batch of trainin data
    xb, yb = generate_batch(block_size=block_size, batch_size=batch_size)
    # evaluate the loss
    _, loss = m(xb, yb)
    # reset parameter gradients
    optimizer.zero_grad(set_to_none=True) 
    # backward pass
    loss.backward()
    # optimizer step
    optimizer.step()

    if epoch % eval_interval == 0:
        losses = estimate_loss(m, eval_iters,block_size, batch_size)
        #print(f"epoch: {epoch}, training loss: {losses['train'].item()}, validation loss: {losses['val'].item()}")     
        train_loss = losses['train'].item()
        val_loss = losses['val'].item()

    pbar.set_description(f"Epoch {epoch + 1}, Train Loss: {train_loss:.3f}, Val Loss: {val_loss:.3f}")   

Epoch 20000, Train Loss: 0.000, Val Loss: 0.001: 100%|██████████| 20000/20000 [08:29<00:00, 39.23it/s]


In [339]:
test_gpt_adder(82292, 1013, max_digits)


Generated sequence:
 <START>82292+01013=083305<END>


In [335]:
test_gpt_adder(9912, 543, max_digits)


Generated sequence:
 <START>09912+00543=010455<END>
