In [4]:
import torch
from torch.utils.data import Dataset
import random

class ArithmeticDataset(Dataset):
    def __init__(self, max_length, num_samples):
        self.max_length = max_length
        self.num_samples = num_samples
        self.data = self.generate_data()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def generate_number(self, length):
        return random.randint(10**(length-1), 10**length - 1)

    def generate_data(self):
        raise NotImplementedError("Subclasses must implement this method")



In [5]:
class AdditionDataset(ArithmeticDataset):
    def generate_data(self):
        data = []
        samples_per_combination = self.num_samples // (self.max_length ** 2)
        for i in range(1, self.max_length + 1):
            for j in range(1, self.max_length + 1):
                for _ in range(samples_per_combination):
                    num1 = self.generate_number(i)
                    num2 = self.generate_number(j)
                    result = num1 + num2
                    data.append((f"{num1}+{num2}=", str(result)))
        return data

In [6]:
class MultiplicationDataset(ArithmeticDataset):
    def generate_data(self):
        data = []
        samples_per_combination = self.num_samples // (self.max_length ** 2)
        for i in range(1, self.max_length + 1):
            for j in range(1, self.max_length + 1):
                for _ in range(samples_per_combination):
                    num1 = self.generate_number(i)
                    num2 = self.generate_number(j)
                    result = num1 * num2
                    data.append((f"{num1}*{num2}=", str(result)))
        return data

In [7]:
class SortingDataset(ArithmeticDataset):
    def generate_data(self):
        data = []
        samples_per_combination = self.num_samples // (self.max_length ** 2)
        for i in range(1, self.max_length + 1):  # number of integers
            for j in range(1, self.max_length + 1):  # max digit length
                for _ in range(samples_per_combination):
                    numbers = [self.generate_number(random.randint(1, j)) for _ in range(i)]
                    indices = list('abcdefghijklmnopqrstuvwxyz'[:i])
                    input_str = ','.join([f"{idx}:{num}" for idx, num in zip(indices, numbers)])
                    sorted_indices = [idx for _, idx in sorted(zip(numbers, indices))]
                    output_str = ''.join(sorted_indices)
                    data.append((input_str, output_str))
        return data

In [8]:
def create_datasets(dataset_class, max_length, train_samples, test_samples):
    train_dataset = dataset_class(max_length, train_samples)
    test_dataset = dataset_class(max_length, test_samples)
    return train_dataset, test_dataset

In [9]:
# Set parameters
max_length = 20  # maximum length of operands
train_samples = 200_000  # 20 million as mentioned in the paper
test_samples = 1_000  # adjust as needed

# Create datasets
addition_train, addition_test = create_datasets(AdditionDataset, max_length, train_samples, test_samples)
multiplication_train, multiplication_test = create_datasets(MultiplicationDataset, max_length, train_samples, test_samples)
sorting_train, sorting_test = create_datasets(SortingDataset, max_length, train_samples, test_samples)

# Print some samples
print("Addition sample:", addition_train[0])
print("Multiplication sample:", multiplication_train[0])
print("Sorting sample:", sorting_train[0])

Addition sample: ('5+6=', '11')
Multiplication sample: ('6*7=', '42')
Sorting sample: ('a:1', 'a')


In [10]:
import random

def print_samples(dataset, name, num_samples=10):
    print(f"\n{name} Samples:")
    for _ in range(num_samples):
        idx = random.randint(0, len(dataset) - 1)
        sample = dataset[idx]
        print(f"Input: {sample[0]}, Output: {sample[1]}")

# Sample from Addition dataset
print_samples(addition_train, "Addition")

# Sample from Multiplication dataset
print_samples(multiplication_train, "Multiplication")

# Sample from Sorting dataset
print_samples(sorting_train, "Sorting")


Addition Samples:
Input: 4623494+298217012=, Output: 302840506
Input: 3736383673135537194+79578343076=, Output: 3736383752713880270
Input: 92539759590168941+8941235487487=, Output: 92548700825656428
Input: 85822118956+78=, Output: 85822119034
Input: 294909+32712015421647=, Output: 32712015716556
Input: 4127069+70319306423347=, Output: 70319310550416
Input: 80599427074+81006405357393=, Output: 81087004784467
Input: 747220552569195282+8017=, Output: 747220552569203299
Input: 6487202964900457382+5631818=, Output: 6487202964906089200
Input: 55+5421825494=, Output: 5421825549

Multiplication Samples:
Input: 8280107543689690*4172213134716008913=, Output: 34546373450643234240268814946206970
Input: 959762765314616*809983476469=, Output: 777391981235033638217770904
Input: 17995*627325765=, Output: 11288727141175
Input: 85819371937035625104*22282902805=, Output: 1912304723659109414014850016720
Input: 901487158224*1936855207086=, Output: 1746050096527315167975264
Input: 1787*4283074786=, Output: