# Teach an LLM to do additions
Student: HO Quang Phuoc, Master MVA, ENS Paris-Saclay

The goal of this project is to teach an LLM to do additions, playing only with two parts:
* the tokenizer
* the positional embedding

Both the model and the dataset are fixed.

We are allowed to tune the hyperparameters, but this is not the main goal. Depending on the quality of our tokenizer and positional embedding, we may change the number of bits. The initial value of 3 is very small.

Here I change the number of bits to 9, meaning we are teaching the model to do additions with 9 digit numbers.

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import random
import math
import re
import time

In [99]:
number_bits = 9

dataset_size = 64_000
train_proportion = 0.9

log_interval = 200
batch_size = 64
epochs = 4
learning_rate = 8e-4

## Step 1: Construct a tokenizer

In [100]:
pad_token = "[PAD]"
eos_token = "[EOS]"

### Baseline: character-level tokenizer

In [101]:
class character_level_tokenizer:
    """
    character-level
    """

    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v: k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k: v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"

    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

# Implement your tokenizer here!

You can do anything (as long as you do not compute the addition!).
Some ideas:
* reversing numbers left to right
* arranging by groups (of, 2, 3,...)
* aligning numbers

**Here is my idea for tokenizer**:

1. **Reverse the Numbers:**
   - Reverse the digits of the numbers.
   - Example: 
     - `num1 = 123456789` → `987654321`
     - `num2 = 145` → `541`

2. **Align and Pad the Numbers:**
   - Pad the numbers with zeros to match a fixed length (`number_bits`, 9 in this case).
   - Example: 
     - `num1 = 987654321` (9 digits, no change).
     - `num2 = 541` → `000000541` (padded to 9 digits).

3. **Group Digits by Columns:**
   - Align the numbers in columns and group digits vertically.
   - Example:
     ```
     9 8 7 6 5 4 3 2 1
     0 0 0 0 0 0 5 4 1
     ```
   - Resulting pairs:
     - (5, 9), (4, 8), (1, 7), (0, 6), (0, 5), (0, 4), (0, 3), (0, 2), (0, 1)

4. **Treat Pairs as Sets (here I do this by generating pairs in form (min, max)):**
   - Treat each pair as a set (order doesn't matter).
   - Example: (5, 9) is the same as (9, 5).

5. **Create Token Mappings:**
   - Assign each unique pair an ID.
   - Pairs like (5, 9) and (9, 5) will share the same ID.

In [None]:
class CharacterAlignmentTokenizer:
    """
    Tokenizer using character alignment representation.
    """

    def __init__(self, max_digits=9):
        # Define special tokens and max_digits
        self.pad = "[PAD]"
        self.eos = "[EOS]"
        self.max_digits = max_digits

        # Initialize vocabulary with digits, "[PAD]" and "[EOS]"
        self.vocab = [str(x) for x in range(10)] + ["+", "=", self.pad, self.eos]

        # Init token-to-id and id-to-token mappings
        self.token_to_id = {v: i for i, v in enumerate(self.vocab)}
        self.id_to_token = {i: v for v, i in self.token_to_id.items()}

        # Define pair-to-id and id-to-pair mappings
        self.pair_to_id = {}
        self.id_to_pair = {}
        current_id = len(self.vocab)  # Start after single-character tokens

        for i in range(10):
            for j in range(i, 10):  # Ensures (min, max) order
                pair_str = f"({i},{j})"
                self.pair_to_id[pair_str] = current_id
                self.id_to_pair[current_id] = pair_str
                current_id += 1

        # Merge pair mappings into token dictionaries
        self.token_to_id.update(self.pair_to_id)
        self.id_to_token.update(self.id_to_pair)

        # Total number of tokens
        self.ntokens = len(self.token_to_id)

        # Regex to remove invalid characters
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"

    def clean(self, text):
        """
        Removes all characters not in the vocabulary.
        """
        return re.sub(self.pattern, "", text)

    def align_numbers(self, num1, num2):
        """
        Aligns two numbers for column-wise processing.
        After alignment, the pairs of digits are normalized to (min, max)
        so that (3,6) and (6,3) are treated the same.
        """
        num1, num2 = str(num1)[::-1], str(num2)[::-1]  # Reverse for right alignment
        num1 = num1.ljust(self.max_digits, "0")  # Zero-padding
        num2 = num2.ljust(self.max_digits, "0")

        # Create normalized (min, max) digit pairs
        return [
            f"({min(int(d1), int(d2))},{max(int(d1), int(d2))})"
            for d1, d2 in zip(num1, num2)
        ]

    def pre_tokenization(self, text):
        """
        Tokenizes numbers and symbols in a column-wise fashion.
        If '+' is found, it aligns the two numbers and converts them to token IDs.
        """
        if "+" in text:
            parts = text.split("+")
            if len(parts) != 2:
                raise ValueError("Invalid input format. Expected 'num1 + num2 ='")

            num1, num2 = parts[0].strip(), parts[1].split("=")[0].strip()
            aligned_numbers = self.align_numbers(num1, num2)

            tokens = [
                self.token_to_id.get(pair, self.token_to_id[self.pad])
                for pair in aligned_numbers
            ]
            tokens.append(self.token_to_id["+"])
            tokens.append(self.token_to_id["="])
            return tokens

        # Handle single number sequences
        # If c is out of vocab, return "[PAD]"
        return [self.token_to_id.get(c, self.token_to_id[self.pad]) for c in text]

    def encode(self, text):
        """
        Encodes the input text into a list of token IDs.
        """
        return self.pre_tokenization(self.clean(text))

    def decode(self, token_list):
        """
        Decodes a list of token IDs back into a string.
        """
        return "".join(self.id_to_token.get(x, self.pad) for x in token_list)

Here I try with some examples

In [None]:
tokenizer = CharacterAlignmentTokenizer(number_bits)
ntokens = tokenizer.ntokens
print(f"number of tokens in my tokenizer = {ntokens}")

prompt = "123456789 + 145 ="
input_encoded = tokenizer.encode(prompt)
output_decoded = tokenizer.decode(input_encoded)
print(input_encoded)
print(output_decoded)

number of tokens in my tokenizer = 69
[58, 52, 30, 20, 19, 18, 17, 16, 15, 10, 11]
(5,9)(4,8)(1,7)(0,6)(0,5)(0,4)(0,3)(0,2)(0,1)+=


In [162]:
prompt = "1265 + 54323 ="
input_encoded = tokenizer.encode(prompt)
output_decoded = tokenizer.decode(input_encoded)
print(input_encoded)
print(output_decoded)

[43, 37, 34, 27, 19, 14, 14, 14, 14, 10, 11]
(3,5)(2,6)(2,3)(1,4)(0,5)(0,0)(0,0)(0,0)(0,0)+=


## Step 2: Create a dataset for arithmetic operations

In [164]:
def sample_datapoint(number_bits=number_bits):
    """
    returns a string containing two random numbers on `number_bits` many bits and their sum.
    """
    a_list = [random.randint(0, 9) for _ in range(number_bits)]
    b_list = [random.randint(0, 9) for _ in range(number_bits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    sum_int = a_int + b_int
    return (str(a_int) + "+" + str(b_int) + "=", str(sum_int))


sample_datapoint(number_bits)

('294119799+791156163=', '1085275962')

In [165]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(number_bits))
data[:10]

[('65075665+358240227=', '423315892'),
 ('100825935+600741977=', '701567912'),
 ('763384340+390144737=', '1153529077'),
 ('207062109+600163507=', '807225616'),
 ('967088871+53531711=', '1020620582'),
 ('49086839+553676836=', '602763675'),
 ('609142974+483183143=', '1092326117'),
 ('917377039+557345635=', '1474722674'),
 ('516485983+856480403=', '1372966386'),
 ('72464960+6116180=', '78581140')]

In [166]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size) :]

len(data_train), len(data_test)

(57600, 6400)

## Step 3: Construct a model

### Basline: the classical Positional Embedding

In [None]:
class PositionalEmbedding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEmbedder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEmbedder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        """
        d_model: the embedding dimension
        """
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)  # positional encoding matrix
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(
            1
        )  # tensor of shape (max_len, 1) representing positions 0, 1, ..., max_len-1

        # torch.arrange(0, d_model, 2) = [0, 2, 4, ..., d_model-2]
        # -math.log(10000.0) / d_model = scaling factor
        # torch.exp = apply the exponential func
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(
            position * div_term
        )  # apply sine function to even indies
        pe[:, 1::2] = torch.cos(
            position * div_term
        )  # apply cosine function to odd indices

        pe = pe.unsqueeze(0).transpose(
            0, 1
        )  # add batch dimension, pe shape: (1, max_len, d_model)
        self.register_buffer(
            "pe", pe
        )  # pe as buffer, it will be save but won't be updated

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[: x.size(0), :]
        return self.dropout(x)

# Implement your positional embedding here!

You can do anything. Some ideas:
* RoPE
* (randomised) FIRE
* Abacus

**!!! IMPORTANT !!!** This model of Transformers is "input first", meaning that an input is a tensor with shape
(length_prompts, batch_size)

**Here I implement positional embedding using Rotary Position Embedding (RoPE) Algorithm**
1. Precompute Rotation Frequencies

$$
\theta_i = 10000^{-\frac{2i}{d}}
$$

where:
- $d$ is the embedding dimension (must be even).
- $i$ is the index over the half-dimensional space $i \in \left[0, \frac{d}{2} \right]$.

From this, we compute the rotation terms:

$$
\cos(\theta p), \quad \sin(\theta p)
$$

where $p$ is the position index.


2. Each embedding vector $x$ of shape $(\text{batch}, \text{seq\_len}, d)$ is split into two equal halves:

$$
(x_1, x_2) = x[:, :, :\frac{d}{2}], \quad x[:, :, \frac{d}{2}:]
$$

where:
- $x_1$ and $x_2$ each have shape $(\text{batch}, \text{seq\_len}, d/2)$.
- The first half $x_1$ represents even-indexed dimensions.
- The second half $x_2$ represents odd-indexed dimensions.


3. We now rotate the vector components using the precomputed rotation frequencies:

$$
x_1^{\text{new}} = x_1 \cos(\theta p) - x_2 \sin(\theta p)
$$

$$
x_2^{\text{new}} = x_1 \sin(\theta p) + x_2 \cos(\theta p)
$$


4. Merging the Rotated Components

$$
x^{\text{new}} = \left[ x_1^{\text{new}}, x_2^{\text{new}} \right]
$$

which restores the original embedding shape $(\text{batch}, \text{seq\_len}, d)$.

In [None]:
class RoPE(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(RoPE, self).__init__()
        assert d_model % 2 == 0, "Embedding size (d_model) must be even for RoPE"

        self.d_model = d_model
        self.d_half = d_model // 2  # RoPE applies to half the embedding size

        # Compute the frequency terms for rotation
        position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(
            -math.log(10000.0) * torch.arange(0, self.d_half, 1).float() / self.d_half
        )

        # Precompute cosine and sine values
        freqs = position * div_term  # (max_len, d_half)
        self.register_buffer("cos", freqs.cos())  # (max_len, d_half)
        self.register_buffer("sin", freqs.sin())  # (max_len, d_half)

    def _apply_rope(self, x, cos, sin):
        """
        Applies rotary position embeddings to the input tensor x.
        Args:
            x: (batch, seq_len, d_model)
            cos: (seq_len, d_half)
            sin: (seq_len, d_half)
        Returns:
            RoPE-applied tensor of the same shape.
        """
        batch, seq_len, d_model = x.shape
        assert (
            d_model == self.d_model
        ), "Input embedding size does not match model embedding size"

        # Split into two halves
        x1, x2 = x[..., : self.d_half], x[..., self.d_half :]

        # Apply rotary transformation
        x1_new = x1 * cos - x2 * sin
        x2_new = x1 * sin + x2 * cos

        # Concatenate back along the last dimension
        return torch.cat([x1_new, x2_new], dim=-1)

    def forward(self, x):
        """
        Args:
            x: (seq_len, batch, d_model)
        Returns:
            RoPE-applied tensor of the same shape.
        """
        seq_len, batch, d_model = x.shape
        assert (
            d_model == self.d_model
        ), "Input embedding size does not match RoPE model embedding size"

        # Transpose to (batch, seq_len, d_model) for processing
        x = x.transpose(0, 1)

        # Select cos and sin based on seq_len
        cos, sin = self.cos[:seq_len], self.sin[:seq_len]  # (seq_len, d_half)

        # Apply RoPE and transpose back
        x_rotated = self._apply_rope(x, cos.unsqueeze(0), sin.unsqueeze(0))
        return x_rotated.transpose(0, 1)  # (seq_len, batch, d_model)

In [None]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        """
        ntoken = number of unique tokens = vocab size
        ninp = embedding size = dim of input embedding
        nhead = number of attention heads
        nhid = hidden layer size in the feedforward network
        nlayers = number of transformer encoder layers
        dropout = dropout rate for regularization
        """
        super(TransformerModel, self).__init__(
            d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers
        )
        # embedding layer to map token indices to vectors of size ninp
        self.input_emb = nn.Embedding(ntoken, ninp)

        # apply positional embedding
        # self.pos_encoder = PositionalEmbedding(ninp, dropout)
        self.pos_encoder = RoPE(d_model=ninp)

        #  map transformer outputs (size ninp) back to vocabulary tokens
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz, sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        # convert token IDs to embedding + scaling
        src = self.input_emb(src) * math.sqrt(self.ninp)

        # apply positional embedding
        src = self.pos_encoder(src)

        # feeds the encoded input
        output_enc = self.encoder(src, mask=self.src_mask)

        # pass encoded ouput through decoder layer
        output_dec = self.decoder(output_enc)

        """
        Return:
        - log_softmax(output_dec) = prob distribution over tokens
        - output_enc = encoder representation
        """

        return F.log_softmax(output_dec, dim=-1), output_enc

In [172]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


Please do not change these parameters!

In [173]:
model = TransformerModel(ntoken=ntokens, ninp=128, nhead=16, nhid=64, nlayers=8)
model.to(device)



TransformerModel(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=69, bias=True)
  (input_emb): Embedding(69, 128)
  (pos_encoder): RoPE()
)

In [None]:
def generate(model, prompts, new_tokens=number_bits + 2):
    input_tensor = prompts  # (length_prompts, batch_size)
    input_tensor = input_tensor.to(device)
    for _ in range(new_tokens):
        output, _ = model(input_tensor)  # (length_prompts, batch_size, ntokens)
        last_output = output[-1, :, :]  # (batch_size, ntokens)
        token = torch.argmax(last_output, -1).view((1, -1))  # (1, batch_size)
        input_tensor = torch.cat((input_tensor, token), 0)
    return input_tensor

In [175]:
model.eval()

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1, 1))
output = generate(model, prompt_tensor).view((1, -1))
output, tokenizer.decode(output.tolist()[0])

(tensor([[34, 14, 14, 14, 14, 14, 14, 14, 14, 10, 11, 52, 52, 52, 52, 52, 52, 52,
           2,  2,  2,  2]], device='cuda:0'),
 '(2,3)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)+=(4,8)(4,8)(4,8)(4,8)(4,8)(4,8)(4,8)2222')

In [176]:
output = generate(model, prompt_tensor).view((1, -1))
output, tokenizer.decode(output.tolist()[0])

(tensor([[34, 14, 14, 14, 14, 14, 14, 14, 14, 10, 11, 52, 52, 52, 52, 52, 52, 52,
           2,  2,  2,  2]], device='cuda:0'),
 '(2,3)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)+=(4,8)(4,8)(4,8)(4,8)(4,8)(4,8)(4,8)2222')

In [177]:
def pad(token_list, type_list="prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
        if type_list == "answers":
            out.append(
                x
                + [tokenizer.token_to_id[eos_token]]
                + [tokenizer.token_to_id[pad_token]] * (max_length - len(x))
            )
    return out, max_length

In [None]:
prompts = [
    tokenizer.encode("1+1="),
    tokenizer.encode("21+35="),
    tokenizer.encode("549865423+932543354="),
]
answers = [
    tokenizer.encode("2"),
    tokenizer.encode("56"),
    tokenizer.encode("1482408777"),
]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
padded_prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [
    tokenizer.decode(p) for p in padded_answers
]

(['(1,1)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)+=',
  '(1,5)(2,3)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)(0,0)+=',
  '(3,4)(2,5)(3,4)(3,5)(4,6)(5,8)(2,9)(3,4)(5,9)+='],
 ['2[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]',
  '56[EOS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]',
  '1482408777[EOS]'])

In [179]:
def get_batch(split, i):
    data = data_train if split == "train" else data_test
    prompts = [tokenizer.encode(data[i][0]) for i in range(i, i + batch_size)]
    padded_prompts, length_prompts = pad(prompts, "prompts")
    answers = [tokenizer.encode(data[i][1]) for i in range(i, i + batch_size)]
    padded_answers, length_answers = pad(answers, "answers")
    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, length_prompts, length_answers

In [180]:
X, Y, length_prompts, length_answers = get_batch("train", 243)
X.shape, Y.shape, length_prompts, length_answers

(torch.Size([11, 64]), torch.Size([11, 64]), 11, 10)

## Step 4: Evaluate

In [181]:
def evaluate():
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.0
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, length_prompts, length_answers = get_batch(
                "test", i
            )
            prompts = prompts.to(device)  # (length_prompts, batch_size)
            target_answers = target_answers.to(
                device
            )  # (length_answers + 1, batch_size)
            output = generate(
                model, prompts, length_answers + 1
            )  # (length_prompts + length_answers + 1, batch_size)
            answers_tokens = output[
                length_prompts:, :
            ]  # (length_answers + 1, batch_size), contains tokens
            equality_test = (
                answers_tokens == target_answers
            )  # (length_answers + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [182]:
evaluate()

0.0

## Step 4: Train the model

In [None]:
def train_epoch():
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    total_loss = 0.0
    start_time = time.time()
    for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
        prompts, target_answers, length_prompts, length_answers = get_batch("train", i)
        # print(f"prompts: {length_prompts}, answers: {length_answers}")

        prompts = prompts.to(device)  # (length_prompts, batch_size)
        target_answers = target_answers.to(device)  # (length_answers, batch_size)

        input_tensor = torch.cat(
            (prompts, target_answers), 0
        )  # (length_prompts + length_answers, batch_size)

        model.zero_grad()
        output, _ = model(
            input_tensor
        )  # (length_prompts + length_answers, batch_size, ntokens)
        output_answers = output[length_prompts - 1 : -1, :, :].reshape(
            -1, ntokens
        )  # (length_answers * batch_size, ntokens)
        target_answers = target_answers.view(-1)
        loss = F.cross_entropy(output_answers, target_answers)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print(
                "| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}".format(
                    batch,
                    len(data_train) // batch_size,
                    elapsed * 1000 / log_interval,
                    cur_loss,
                    math.exp(cur_loss),
                )
            )
            total_loss = 0
            start_time = time.time()


def train():
    best_test_accuracy = None
    test_accuracy = evaluate()
    print("-" * 89)
    print("| initialisation | test accuracy {:5.2f}".format(test_accuracy))
    print("-" * 89)
    for epoch in range(1, epochs + 1):
        # for epoch in range(1, 20+1):
        epoch_start_time = time.time()
        train_epoch()
        test_accuracy = evaluate()
        print("-" * 89)
        print(
            "| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}".format(
                epoch, (time.time() - epoch_start_time), test_accuracy
            )
        )
        print("-" * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", "wb") as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [184]:
train()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   200/  900 batches | ms/batch 22.54 | loss  1.97 | perplexity     7.16
|   400/  900 batches | ms/batch 21.94 | loss  1.27 | perplexity     3.57
|   600/  900 batches | ms/batch 22.27 | loss  0.50 | perplexity     1.66
|   800/  900 batches | ms/batch 22.38 | loss  0.14 | perplexity     1.15
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 25.66s | test accuracy  0.99
-----------------------------------------------------------------------------------------
|   200/  900 batches | ms/batch 22.39 | loss  0.06 | perplexity     1.06
|   400/  900 batches | ms/batch 22.21 | loss  0.04 | perplexity     1.04
|   600/  900 batches | ms/batch 22.06 | loss  0.02 | perplexity     1.02
|   800/  900 batches | ms/

In [None]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1, 1))
    output = generate(model, prompt_tensor, len(answers)).view((1, -1))
    print(
        f"prompt: {prompt:<20} | actual answer: {answers:<13} | decode: {tokenizer.decode(output.tolist()[0]):<20}"
    )

prompt: 951112863+218386430= | actual answer: 1169499293    | decode: (0,3)(3,6)(4,8)(2,6)(1,8)(1,3)(1,8)(1,5)(2,9)+=1169499293
prompt: 79644146+575846641=  | actual answer: 655490787     | decode: (1,6)(4,4)(1,6)(4,6)(4,4)(6,8)(5,9)(7,7)(0,5)+=655490787
prompt: 549047276+473313306= | actual answer: 1022360582    | decode: (6,6)(0,7)(2,3)(3,7)(1,4)(0,3)(3,9)(4,7)(4,5)+=1022360582
prompt: 336246734+123438005= | actual answer: 459684739     | decode: (4,5)(0,3)(0,7)(6,8)(3,4)(2,4)(3,6)(2,3)(1,3)+=459684739
prompt: 278783775+934186881= | actual answer: 1212970656    | decode: (1,5)(7,8)(7,8)(3,6)(8,8)(1,7)(4,8)(3,7)(2,9)+=1212970656
prompt: 860539585+621775351= | actual answer: 1482314936    | decode: (1,5)(5,8)(3,5)(5,9)(3,7)(5,7)(0,1)(2,6)(6,8)+=1482314936
prompt: 96041807+617987978=  | actual answer: 714029785     | decode: (7,8)(0,7)(8,9)(1,7)(4,8)(0,9)(6,7)(1,9)(0,6)+=714029785
prompt: 516626546+667387943= | actual answer: 1184014489    | decode: (3,6)(4,4)(5,9)(6,7)(2,8)(3,6)(6,7)(1

## Probing

This is just for fun...

In [158]:
import numpy as np

train_size = 1000
test_size = 100

model.eval()


def data_probing(size):
    X = []
    y = np.zeros(size)
    for i in range(size):
        input = torch.tensor(tokenizer.encode(data[i][0])).view((-1, 1)).to(device)
        _, output = model(input)
        output = output[-1, :, :].flatten()
        # determine whether there was a carry in the result:
        carry = len(data[i][1]) > len(data[i][0]) / 2
        X.append(output.cpu().detach().numpy())
        y[i] = carry
    return np.array(X), y

In [159]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

X_train, y_train = data_probing(train_size)
X_test, y_test = data_probing(test_size)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.fit_transform(X_test)

reg = LogisticRegression()
reg.fit(X_train, y_train)
reg.score(X_test, y_test)

0.99