In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

torch.manual_seed(1234)
device = 'cuda' if torch.cuda.is_available() else 'cpu' 

#### In this project, we will attempt to build a chacarter-level GPT language model which learnes to add two non-negative integers, i.e. given the input string "a+b=c", the model will be trained to predict the next character following a sliding context window.

#### This is a simple next character prediction task. We will attempt two different versions of this task: 1) The integers of "c" are predicted left-to-right 2) the integers are predicted from right to left (i.e backward) which is typically how humans compute additions. 


In [67]:
# first let's set up the token vocabulary for this problem
# note that we have two special tokens '<*>' which denotes the beginning or end of a 
# sequence and the '<PAD>' token which is used for pre-padding sequences to ensure fixed length 
vocab = sorted(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '+', '=', '<END>', '<PAD>'])
vocab_size = len(vocab)
print(f"Vocabulary: {vocab}")
print(f"vocab_size = {vocab_size}")

# tokenization
ctoi = {vocab[i]:i for i in range(vocab_size)}
itoc = {i:vocab[i] for i in range(vocab_size)}
encode = lambda s: [ctoi[c] for c in s]  # converts a string to integer token sequence
decode = lambda s: [itoc[ix] for ix in s]  # converts an integer token sequence to string of characters
print(ctoi)

Vocabulary: ['+', '0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9', '<*>', '<PAD>', '=']
vocab_size = 15
{'+': 0, '0': 1, '1': 2, '10': 3, '2': 4, '3': 5, '4': 6, '5': 7, '6': 8, '7': 9, '8': 10, '9': 11, '<*>': 12, '<PAD>': 13, '=': 14}


#### Now lets implement the data loader which generates a batch of input-target pairs. We will make sure that the context block size will be large enough to see the entire problem string over multiple sliding windows.

In [142]:
import random
random.seed(1223)

# generates input target pairs for a single problem string "a+b=c"
def generate_batch(max_digits, block_size, batch_size, backward=False):

    # make sure block size is big enough to hold the entire problem string
    max_problem_size = 3*max_digits+2
    assert block_size >= max_problem_size, f"block_size needs to be at least {max_problem_size}"

    inputs = []
    targets = []

    for b in range(batch_size):

        # randomly generate two integers
        a = random.randint(0,10**max_digits-1)
        b = random.randint(0,10**max_digits-1)
        c = a + b

        prompt = list(f"{a}+{b}=")
        answer = list(f"{c}")
        
        if backward:
            # reverse the digits of "c"
            answer = reversed(answer)

        #print(f"prompt: {prompt}")
        #print(f"answer: {answer}")

        # encolse with special token
        prompt = ['<*>'] + prompt
        answer = answer + ['<*>'] 
        problem = prompt+answer
        tot_len = len(problem)

        # post-pad the problem string to make it (block_size+1) long
        problem = problem + ['<PAD>'] * (block_size+1-tot_len)
        #print(f"padded problem: {problem}")

        # tokenized input and target sequences        
        input = torch.tensor(encode(problem[:block_size]))
        target = torch.tensor(encode(problem[1:block_size+1]))
        inputs.append(input)
        targets.append(target)

        #print(f"context: {input} -- > target: {target}")

    # create input,target batch tensors
    x = torch.stack(inputs).to(device)
    y = torch.stack(targets).to(device)

    return x, y


In [143]:
max_digits = 5 # max number of digits for input integers 'a' and 'b'
batch_size = 4
block_size = 26 

x, y = generate_batch(max_digits, block_size, batch_size)

In [146]:
print(x.shape)
print(x)

torch.Size([4, 26])
tensor([[12, 11,  9,  7, 11,  2,  0, 10,  7,  4,  1, 10, 14,  2, 10,  4,  9, 11,
         11, 12, 13, 13, 13, 13, 13, 13],
        [12,  7,  9,  6,  7,  9,  0,  8,  9,  4,  6,  7, 14,  2,  4,  6,  9,  1,
          4, 12, 13, 13, 13, 13, 13, 13],
        [12,  9,  7,  2, 11, 11,  0,  8,  5,  6,  2,  9, 14,  2,  5, 10,  8,  2,
          8, 12, 13, 13, 13, 13, 13, 13],
        [12,  4,  7,  9, 11,  8,  0,  6,  2,  5,  1,  8, 14,  8,  9,  2,  1,  4,
         12, 13, 13, 13, 13, 13, 13, 13]], device='cuda:0')
