GPT Calculator
---

In this notebook, I'll implement a basic calculator using a decoder-only transformer (GPT like) with binary operators (for instance `+`, `-`, `*`, `/`).

In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

from typing import Union

import jax
import jax.numpy as jnp
import optax
import haiku as hk

import einops
from relax.models.gpt import GPT
from relax import Trainer, TrainingConfig

### Create our own naive character based tokenizer

In [2]:
from functools import lru_cache

class Tokenizer:
    def __init__(self):
        self.char_to_token_dict : dict = {}
        self.__eot : str = None
        self.__altered : bool = False
    
    def add_char(self, char: str) -> None:
        self.char_to_token_dict[char] = self.vocab_size
        self.__altered = True
    
    def encode_char(self, char: str) -> int:
        try:
            return self.char_to_token_dict[char]
        except KeyError as error:
            raise ValueError(f"Character `{char}` not contained in the vocabulary.") from error
    
    def decode_indice(self, indice: int) -> str:
        try:
            return self.token_to_char_dict[indice]
        except KeyError as error:
            raise ValueError(f"Indice `{indice}` not contained in the vocabulary.") from error
    
    def encode(self, string: str) -> list:
        return [self.encode_char(c) for c in string]
    
    def decode(self, indices: Union[list, jnp.ndarray], ignore : list = None, stop_after_eot : bool = False) -> list:
        if ignore is None:
            ignore : list = []
        if hasattr(indices, 'tolist'):
            indices = indices.tolist()
        raw_decoded: str = ''.join([self.decode_indice(indice) for indice in indices if indice not in ignore])
        return raw_decoded.split(self.eot)[0] + self.eot if stop_after_eot and self.eot in raw_decoded else raw_decoded
    
    @property
    def vocab_size(self):
        return len(self.char_to_token_dict)
    
    @property
    def token_to_char_dict(self):
        @lru_cache(maxsize=1)
        def lru_inner():
            return {token:char for char, token in self.char_to_token_dict.items()}
        if self.__altered:
            lru_inner.cache_clear()
        return lru_inner()
    
    @property
    def eot(self) -> str:
        return self.__eot
    
    @property
    def eot_token(self) -> int:
        return self.encode_char(self.eot)
    
    @eot.setter
    def eot(self, value: str):
        if not value in self.char_to_token_dict:
            self.add_char(value)
        self.__eot = value
    
    def __str__(self):
        return f"Tokenizer <vocab_size: {self.vocab_size}, eot: {self.eot}>"
    
    def __repr__(self):
        return str(self)
    
        
tokenizer = Tokenizer()
string : str = 'a' * 10
tokenizer.add_char('a')
assert string == tokenizer.decode(tokenizer.encode(string))
print(tokenizer)
del tokenizer, string

Tokenizer <vocab_size: 1, eot: None>


In [3]:
@hk.transform
def model(x):
    return GPT(
        vocab_size = 20,
        block_size = 32,
        n_blocks = 2,
        n_embed = 32,
        n_head = 2,
        dropout_rate = 0.2,
    )(x)

def softmax_cross_entropy(logits, labels, ignore_index=None):
    # vmap the actual cross entropy function per batch input - go brrrrr
    @jax.vmap
    def batch_ce(logits, oh_labels):
        return -jnp.sum(jax.nn.log_softmax(logits) * oh_labels)
        
    # Get the mask in order to filter out the non desired indexes (usually pad tokens for text)
    mask = labels != ignore_index if ignore_index != None else jnp.ones_like(labels, dtype=bool)
    # Turn the labels into one hot format
    one_hot = hk.one_hot(labels, logits.shape[-1])
    # Get the cross entropy per batch entry
    bloss = batch_ce(logits, one_hot)
    # And aggregate all desired entries
    return jnp.mean(bloss, where=mask)

@jax.jit
def loss_fn(params, rng, data) -> jnp.ndarray:
    inputs, labels = data
    logits = model.apply(params, rng, inputs)
    # Flatten the batch dim and sequence dim for both the logits and labels
    logits = einops.rearrange(logits, '... d -> (...) d')
    labels = einops.rearrange(labels, '... -> (...)')
    # Get the batch + sequences losses
    bloss = softmax_cross_entropy(logits, labels, ignore_index=0)
    return jnp.mean(bloss)

In [4]:
def get_dataset(tokenizer, main_key, cardinality : int, minimum : int, maximum : int, operations : list[str], sequence_length : int, pad_token : int):
    def get_sample(key):
        ops = {
            '+' : '__add__',
            '*' : '__mul__',
            '-' : '__sub__',
            '/' : '__div__',
            }
        
        # Get the keys for choosing the 2 integers and the operation
        k0, k1 = jax.random.split(key, 2)
        a, b = jax.random.randint(k0, (2,), minimum, maximum)
        op = operations[jax.random.choice(k1, jnp.arange(len(operations)))]
        
        x = f"{a.item()} {op} {b.item()} = "
        result = f"{(getattr(a, ops[op])(b)).item()}"
        
        # Encode both the input and the output that is `input + result + eot`
        input_tokens = tokenizer.encode(x)
        output_tokens = input_tokens + tokenizer.encode(result) + [tokenizer.eot_token]
        
        # Then pad the sequences
        input_tokens += [pad_token] * (sequence_length - len(input_tokens))
        output_tokens += [pad_token] * (sequence_length - len(output_tokens))
        
        return input_tokens, output_tokens
    
    keys = jax.random.split(main_key, cardinality)
    return jnp.array([get_sample(key) for key in keys])

In [5]:
tokenizer = Tokenizer()
chars = ['p'] # Naive padding character, can also use only -1 when creating the dataset
chars += list(range(10)) # The ten integers 
chars += list('+-*/=. ') # The 4 operations, equal sign, the dot for floats and space for nice formatting
for char in chars:
    tokenizer.add_char(str(char))
tokenizer.eot = ';' # Then, use semicolons for the end of text token

print(tokenizer)
tokenizer.char_to_token_dict

Tokenizer <vocab_size: 19, eot: ;>


{'p': 0,
 '0': 1,
 '1': 2,
 '2': 3,
 '3': 4,
 '4': 5,
 '5': 6,
 '6': 7,
 '7': 8,
 '8': 9,
 '9': 10,
 '+': 11,
 '-': 12,
 '*': 13,
 '/': 14,
 '=': 15,
 '.': 16,
 ' ': 17,
 ';': 18}

In [19]:
entries = get_dataset(
    tokenizer,
    jax.random.PRNGKey(42),
    10240, # Get 1024 observations
    1,  
    100,
    list('+-'),
    sequence_length = 32,
    pad_token = tokenizer.encode_char('p')
)

# Get batched dataset
ds = einops.rearrange(entries, '(b bs) xy s -> b xy bs s', bs=64)
ds.shape

(160, 2, 64, 32)

In [20]:
print("Example of calculations : ")
for x, y in einops.rearrange(ds[0, :, :10, :], 'xy b s -> b xy s'):
    print(f"\t{tokenizer.decode(x, ignore=[0])}-> {tokenizer.decode(y, ignore=[0])}")

Example of calculations : 
	5 + 25 = -> 5 + 25 = 30;
	4 - 10 = -> 4 - 10 = -6;
	46 + 94 = -> 46 + 94 = 140;
	89 + 70 = -> 89 + 70 = 159;
	42 + 38 = -> 42 + 38 = 80;
	42 + 73 = -> 42 + 73 = 115;
	95 - 34 = -> 95 - 34 = 61;
	81 - 85 = -> 81 - 85 = -4;
	18 - 96 = -> 18 - 96 = -78;
	9 + 45 = -> 9 + 45 = 54;


In [21]:
config = TrainingConfig(
            epochs=1000,
            )

optimizer = optax.adam(0.001)

trainer = Trainer(model, optimizer, config)

rng = jax.random.PRNGKey(42)
fake_input = jnp.zeros((1, 32), dtype=int)
init_state = trainer.init(rng, fake_input)

In [22]:
trained_state = trainer.train(init_state, loss_fn, ds, jit_update_step=True)

Training:   0%|          | 0/1000 [00:00<?, ?epoch/s]

In [23]:
x, y = ds[0]
key = jax.random.PRNGKey(0)
logits = model.apply(trained_state.params, key, x)
y_hat = jnp.argmax(jax.nn.log_softmax(logits, -1), -1)

print("Example of outputs:")
print("x -> y_hat (y)")
for (x_, y_), y_hat_ in zip(einops.rearrange(ds[0, :, :10, :], 'xy b s -> b xy s'), y_hat):
    print(f"\t{tokenizer.decode(x_, ignore=[0])}-> {tokenizer.decode(y_hat_, ignore=[0], stop_after_eot=True)} ({tokenizer.decode(y_, ignore=[0])})")

Example of outputs:
x -> y_hat (y)
	5 + 25 = -> 5 + 25 = 30; (5 + 25 = 30;)
	4 - 10 = -> 4 - 10 = -1; (4 - 10 = -6;)
	46 + 94 = -> 46 + 94 = 140; (46 + 94 = 140;)
	89 + 70 = -> 89 + 70 = 159; (89 + 70 = 159;)
	42 + 38 = -> 42 + 38 = 80; (42 + 38 = 80;)
	42 + 73 = -> 42 + 73 = 115; (42 + 73 = 115;)
	95 - 34 = -> 95 - 34 = 61; (95 - 34 = 61;)
	81 - 85 = -> 81 - 85 = -4; (81 - 85 = -4;)
	18 - 96 = -> 18 - 96 = -78; (18 - 96 = -78;)
	9 + 45 = -> 9 + 45 = 64; (9 + 45 = 54;)


In [24]:
x, y = einops.rearrange(ds, 'b xy bs s -> xy (b bs) s')
loss = loss_fn(trained_state.params, jax.random.PRNGKey(0), (x, y))
logits = model.apply(trained_state.params, jax.random.PRNGKey(0), x)
y_hat = jnp.argmax(jax.nn.log_softmax(logits, -1), -1)
y_str : list[str] = [tokenizer.decode(indices, ignore=[0], stop_after_eot=True) for indices in y]
y_hat_str : list[str] = [tokenizer.decode(indices, ignore=[0], stop_after_eot=True) for indices in y_hat]
print(f"Crossentropy loss : {loss}")
print(f"Accuracy on train set : {sum([a == b for a, b in zip(y_str, y_hat_str)]) / len(y_str)}")

Crossentropy loss : 0.07728936523199081
Accuracy on train set : 0.69208984375


#### Try it on unseen data

In [25]:
val_ds = get_dataset(
    tokenizer,
    jax.random.PRNGKey(43),
    128, # Get 128 observations
    1,  
    100,
    list('+-'),
    sequence_length = 32,
    pad_token = tokenizer.encode_char('p')
)
x, y = val_ds = einops.rearrange(val_ds, 'a b ... -> b a ...')

In [27]:
loss = loss_fn(trained_state.params, jax.random.PRNGKey(0), (x, y))
logits = model.apply(trained_state.params, jax.random.PRNGKey(0), x)
y_hat = jnp.argmax(jax.nn.log_softmax(logits, -1), -1)
y_str : list[str] = [tokenizer.decode(indices, ignore=[0], stop_after_eot=True) for indices in y]
y_hat_str : list[str] = [tokenizer.decode(indices, ignore=[0], stop_after_eot=True) for indices in y_hat]
print(f"Crossentropy loss : {loss}")
print(f"Accuracy on val set : {sum([a == b for a, b in zip(y_str, y_hat_str)]) / len(y_str)}")

Crossentropy loss : 0.09523070603609085
Accuracy on val set : 0.6484375
