In [42]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import random

In [43]:
dataset = set()

def generate_number(num_digits):
  return random.randint(10**(num_digits-1),10**num_digits-1)

num_iterations = int(2e1)
masks_idx = torch.zeros(num_iterations, dtype=torch.long)

with open('addition.txt','w') as f:
  for i in range(num_iterations):
    d1 = random.randint(1,5)
    d2 = random.randint(1,5)
    a = generate_number(d1)
    b = generate_number(d2)
    c = a+b
    masks_idx[i] = d1 + d2
    f.write(f'{a}+{b}={c}%')
    f.write('\n')

with open('addition.txt','r') as f:
    for line in f:
        line = line.strip()
        dataset.add(line)

vocab_size = 13

In [44]:
masks_idx

tensor([5, 8, 7, 4, 7, 7, 6, 8, 2, 7, 4, 4, 9, 5, 9, 4, 4, 6, 2, 2])

In [45]:
encode = lambda s: [10 if c == '+' else 11 if c == '=' else 12 if c == '%' else int(c) for c in s]
decode = lambda l: (['+' if c == 10 else '=' if c == 11 else '%' if c == 12 else c.item() for c in l])
print(encode('12323+12321=34311'))

def real_decode(l):
  return ''.join(map(str,decode(l)))

[1, 2, 3, 2, 3, 10, 1, 2, 3, 2, 1, 11, 3, 4, 3, 1, 1]


In [46]:
max_length = max(len(encode(line)) for line in dataset) 
tensor_np = np.full((len(dataset), max_length), 12, dtype=int)

for i, line in enumerate(dataset):
    encoded_line = encode(line)
    tensor_np[i, :len(encoded_line)] = encoded_line

data = torch.tensor(tensor_np, dtype=torch.long)

print("shape:", data.shape)

shape: torch.Size([19, 17])


In [47]:
data

tensor([[ 2,  7,  6, 10,  3, 11,  2,  7,  9, 12, 12, 12, 12, 12, 12, 12, 12],
        [ 4, 10,  6, 11,  1,  0, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],
        [ 8,  5,  8, 10,  9,  3,  8,  7,  0, 11,  9,  4,  7,  2,  8, 12, 12],
        [ 1,  3,  9,  5, 10,  8,  5,  3,  1,  8, 11,  8,  6,  7,  1,  3, 12],
        [ 1,  8,  9, 10,  9,  1,  7,  1, 11,  9,  3,  6,  0, 12, 12, 12, 12],
        [ 5, 10,  5, 11,  1,  0, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],
        [ 4,  2,  8,  5, 10,  8,  6,  0,  5,  1, 11,  9,  0,  3,  3,  6, 12],
        [ 2,  5,  3, 10,  2,  2,  8, 11,  4,  8,  1, 12, 12, 12, 12, 12, 12],
        [ 7, 10,  6,  3,  6, 11,  6,  4,  3, 12, 12, 12, 12, 12, 12, 12, 12],
        [ 3,  2,  4,  4, 10,  6,  5, 11,  3,  3,  0,  9, 12, 12, 12, 12, 12],
        [ 1,  9,  5, 10,  9, 11,  2,  0,  4, 12, 12, 12, 12, 12, 12, 12, 12],
        [ 8,  0,  2, 10,  7,  2,  0,  2,  8, 11,  7,  2,  8,  3,  0, 12, 12],
        [ 2,  9,  7, 10,  7,  5,  6,  4, 11,  7,  8,  6,  1, 12,

In [48]:
masks_idx = torch.where(data == 11)[1]

In [49]:
masks_idx

tensor([ 5,  3,  9, 10,  8,  3, 10,  7,  5,  7,  5,  9,  8,  8,  6,  5,  8,  6,
         5])

In [41]:
ix = torch.randint(len(data),(5,))
logits = torch.stack([data[i][3:-1] for i in ix])
targets = torch.stack([data[i][4:] for i in ix])
mask = torch.ones(5,15)

# Assuming masks_idx is a tensor of shape (5,) containing the values used in the loop
masks_idx = masks_idx[ix]  # Select the relevant indices

# Create a range tensor
range_tensor = torch.arange(15).unsqueeze(0).expand(5, -1)

# Create the mask in one step
mask = (range_tensor < (masks_idx - 4 - 1).unsqueeze(1)).float()

print(ix)
#print(mask)
print(logits)
print(targets)
print(torch.stack([data[i] for i in ix]))

tensor([10, 14, 14,  0, 19])
tensor([[ 0,  4, 10,  1,  0,  1,  3, 11,  1,  0,  0,  2,  1,  7],
        [ 5,  0,  8,  4, 11,  5,  1,  7,  9, 12, 12, 12, 12, 12],
        [ 5,  0,  8,  4, 11,  5,  1,  7,  9, 12, 12, 12, 12, 12],
        [ 7,  6,  3,  3, 11,  7,  6,  7,  2, 12, 12, 12, 12, 12],
        [10,  4, 11,  2,  7,  7, 12, 12, 12, 12, 12, 12, 12, 12]])
tensor([[ 4, 10,  1,  0,  1,  3, 11,  1,  0,  0,  2,  1,  7, 12],
        [ 0,  8,  4, 11,  5,  1,  7,  9, 12, 12, 12, 12, 12, 12],
        [ 0,  8,  4, 11,  5,  1,  7,  9, 12, 12, 12, 12, 12, 12],
        [ 6,  3,  3, 11,  7,  6,  7,  2, 12, 12, 12, 12, 12, 12],
        [ 4, 11,  2,  7,  7, 12, 12, 12, 12, 12, 12, 12, 12, 12]])
tensor([[ 9,  9,  2,  0,  4, 10,  1,  0,  1,  3, 11,  1,  0,  0,  2,  1,  7, 12],
        [ 9,  5, 10,  5,  0,  8,  4, 11,  5,  1,  7,  9, 12, 12, 12, 12, 12, 12],
        [ 9,  5, 10,  5,  0,  8,  4, 11,  5,  1,  7,  9, 12, 12, 12, 12, 12, 12],
        [ 3,  9, 10,  7,  6,  3,  3, 11,  7,  6,  7,  2, 12, 12

In [18]:
data[15]

tensor([ 1, 10,  5,  3,  4,  4,  6, 11,  5,  3,  4,  4,  7, 12, 12, 12, 12, 12])

In [22]:
masks_idx

tensor([6, 8, 6, 8, 9, 6, 9, 8, 6, 3, 9, 2, 9, 4, 4, 6, 5, 7, 7, 8])

In [23]:
data

tensor([[ 3,  9, 10,  7,  6,  3,  3, 11,  7,  6,  7,  2, 12, 12, 12, 12, 12, 12],
        [ 4,  0, 10,  8,  2,  4,  8, 11,  8,  2,  8,  8, 12, 12, 12, 12, 12, 12],
        [ 8, 10,  8,  0,  0,  0,  9, 11,  8,  0,  0,  1,  7, 12, 12, 12, 12, 12],
        [ 5, 10,  3, 11,  8, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],
        [ 2,  1,  9,  8, 10,  6,  2,  5,  2, 11,  8,  4,  5,  0, 12, 12, 12, 12],
        [ 6,  2,  3,  5, 10,  8,  0,  2,  7, 11,  1,  4,  2,  6,  2, 12, 12, 12],
        [ 7,  6, 10,  7,  9,  9,  4,  7, 11,  8,  0,  0,  2,  3, 12, 12, 12, 12],
        [ 9,  4,  6, 10,  3,  5,  5,  2,  6, 11,  3,  6,  4,  7,  2, 12, 12, 12],
        [ 2, 10,  2,  4,  0,  2, 11,  2,  4,  0,  4, 12, 12, 12, 12, 12, 12, 12],
        [ 8,  3,  4,  2, 10,  2,  6,  3,  9,  4, 11,  3,  4,  7,  3,  6, 12, 12],
        [ 9,  9,  2,  0,  4, 10,  1,  0,  1,  3, 11,  1,  0,  0,  2,  1,  7, 12],
        [ 4,  0,  5,  3, 10,  4,  3,  2,  4,  5, 11,  4,  7,  2,  9,  8, 12, 12],
        [ 4, 10,