In [248]:
import random

import torch
from torch.utils.data import DataLoader, Dataset

In [139]:
OPERAND = ["+", "//", "*", "%"]

In [140]:
def data_gen(max_pow, operand: list[str]):
    i = random.randint(0, len(operand) - 1)

    operand = operand[i]

    left = str(random.randint(0, 10**max_pow))
    right = str(random.randint(0, 10**max_pow))

    code = left + operand + right

    return code, str(eval(code))

In [141]:
data_gen(3, OPERAND)

('745//63', '11')

In [285]:
class MathTokenizer:
    max_digit = 10

    def __init__(self, operand: list[str]):
        self.vocab = {str(i): i for i in range(self.max_digit)}
        for i, op in enumerate(operand):
            self.vocab[op] = self.max_digit + i

        self.vocab["/"] = self.vocab["//"]
        del self.vocab["//"]

        self.vocab["P"] = self.max_digit + len(operand)
        
        self.anti_vocab = {value: key for key, value in self.vocab.items()}
        
        self.vocab_size = len(self.vocab) 

    def encode(self, x: str) -> list[int]:
        x.replace("//", "/")
        return list(map(lambda x: self.vocab[x], list(x)))

    def decode(self, x: list[int]) -> list[str]:
        decoded = list(map(lambda x: self.anti_vocab[x], list(x)))
        return "".join(decoded)

In [286]:
tokenizer = MathTokenizer(OPERAND)

In [287]:
tokenizer.decode(tokenizer.encode(data_gen(3, OPERAND)[0]))

'563*88'

In [288]:
class MathDataset(Dataset):
    def __init__(self, lenght, max_pow):
        self.n = lenght
        self.max_pow = max_pow

        self.tokenizer = MathTokenizer(OPERAND)

    def __getitem__(self, index):
        code, result = data_gen(self.max_pow, OPERAND)
        return self.tokenizer.encode(code), self.tokenizer.encode(result)

    def __len__(self):
        return self.n

In [289]:
def collate_fn(batch, pad_token):
    # Example: Sum the data from the batch    
    print(batch)
    return batch

In [290]:
dataset = MathDataset(10, 3)

In [291]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [292]:
for data in dataloader:
    pass

[([5, 3, 8, 13, 4, 7, 1], [6, 7]), ([2, 1, 8, 12, 2, 4, 7], [5, 3, 8, 4, 6]), ([5, 6, 7, 10, 3, 0, 9], [8, 7, 6]), ([3, 9, 2, 10, 4, 0], [4, 3, 2])]
[([6, 2, 12, 6, 1, 2], [3, 7, 9, 4, 4]), ([1, 0, 2, 11, 11, 3, 0, 0], [0]), ([7, 0, 4, 12, 3, 8, 1], [2, 6, 8, 2, 2, 4]), ([4, 5, 3, 11, 11, 9, 7, 7], [0])]
[([7, 4, 7, 10, 3, 7, 9], [1, 1, 2, 6]), ([3, 9, 10, 6, 1, 0], [6, 4, 9])]
