## Get all libraries

In [None]:
!pip install --upgrade accelerate -q
!pip install transformers evaluate wandb -q

[0m

In [None]:
import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import IterableDataset, Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import wandb
import random
import os

def seed_everything(seed: int):   
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

## Model
I'l try to use most basic BERT

In [None]:
from transformers import BertTokenizer, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Create dataset
So in here we have tokenized input (output) - `2+2=4` and masked input - `2+2=[MASK]`

In [None]:
from itertools import product


class NumbersDataset(Dataset):
    def __init__(self, left_len, right_len, tokenizer):
        super(NumbersDataset, self).__init__()
        self.left_len = left_len
        self.right_len = right_len
        self.tokenizer = tokenizer
        self._build()
    
    def __len__(self):
        return len(self.inputs.input_ids)
    
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.inputs.items()}

    def _build(self):
        self.inputs = []
        self.masked = []
        left_range = range(10**self.left_len)
        right_range = range(10**self.right_len)
        for i in product(left_range, right_range):
            self.inputs.append(f'{i[0]}+{i[1]}={i[0]+i[1]}') # I'm really sorry for this mess
            self.masked.append(f'{i[0]}+{i[1]}=[MASK]') # And this too
        self.inputs = self.tokenizer(self.inputs)
        self.masked = self.tokenizer(self.masked)
        self.inputs['labels'] = self.inputs.input_ids.detach().clone()
        self.inputs['input_ids'] = self.masked.input_ids.detach().clone()

You can tweak max_length, I've tried 512 and it kinda worked, but was too slow

In [None]:
def tok(x):
    return tokenizer(x, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
dataset = NumbersDataset(2, 2, tok)
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

Quick testing

In [None]:
tokenizer.decode(dataset[0]['input_ids'])

'[CLS] 0 + 0 = [MASK] [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [None]:
tokenizer.decode(dataset[0]['labels'])

'[CLS] 0 + 0 = 0 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

## Preparations

Set some arguments

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir='out',
    per_device_train_batch_size=64,
    num_train_epochs=2,
    logging_steps=50,
    evaluation_strategy='steps'
)

Define some metrics

In [None]:
import evaluate
mae_metric = evaluate.load("mae")

## And train!

In [None]:
from transformers import Trainer

wandb.init(
    # set the wandb project where this run will be logged
    project="llmcalc",
    
    # track hyperparameters and run metadata
    config={
        "group":"e",
        "model":"bert",
        "name":"bert"
    }
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=mae_metric.compute,
    # preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

In [None]:
bt = trainer.train()



Step,Training Loss
50,0.0235
100,0.021
150,0.0191
200,0.0184
250,0.0183


## Inference
Mostly it's off by one

In [None]:
class EvalSet(Dataset):
    def __init__(self, inputs):
        self.inputs = inputs
    
    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in tok(self.inputs).items()}

tokenizer.decode(
    trainer.predict(
            EvalSet([
                    "41+25=[MASK]"
            ])
    ).predictions.argmax(-1)[0]
)

'[CLS] 41 + 25 = 67 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'