In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
stop_char = ";"
pad_char = "_"
chars = list("_0123456789+=;")
vocab_size = len(chars)
ctoi = {c:i for i, c in enumerate(chars)}
itoc = {i:c for i, c in enumerate(chars)}

eq_token = ctoi["="]
stop_token = ctoi[";"]

# return list of integer
def encode(string):
    return [ctoi[c] for c in string] 

def decode(tokens):
    if isinstance(tokens, torch.Tensor):
        return "".join([itoc[t.item()] for t in tokens])
    else:
        return "".join([itoc[t] for t in tokens])

ctoi

{'_': 0,
 '0': 1,
 '1': 2,
 '2': 3,
 '3': 4,
 '4': 5,
 '5': 6,
 '6': 7,
 '7': 8,
 '8': 9,
 '9': 10,
 '+': 11,
 '=': 12,
 ';': 13}

In [3]:
def generate_eq():
    a = torch.randint(0, 1000, (1,))
    b = torch.randint(0, 1000, (1,))
    c = a + b
    cs = f"{c.item()}"[::-1]
    # cs = f"{c.item()}"
    
    return f"{a.item()}+{b.item()}={cs};"


generate_eq()

'559+397=659;'

In [4]:
def pad(ints, length):
    assert length >= len(ints)
    pn = length - len(ints)
    return ints + [0]*pn
        
pad([1,2], 5)

[1, 2, 0, 0, 0]

In [5]:
max_len = 4 + 4 + 5 + 3
padding = 0

# (B, L)
def random_batch(batch_size):
    eq_str = [generate_eq() for _ in range(batch_size)]
    data = [pad(encode(s), max_len) for s in eq_str]
    target = []
    
    for x in data:
        y = list(x)
        i = x.index(eq_token)
        y[0:i+1] = [0]*(i+1)
        target.append(y)
        
    return torch.tensor(data), torch.tensor(target)


x, y = random_batch(2)
x, y

(tensor([[ 8,  3,  1, 11, 10,  3,  8, 12,  8,  5,  7,  2, 13,  0,  0,  0],
         [ 5,  7,  8, 11,  9,  3,  2, 12,  9,  9,  3,  2, 13,  0,  0,  0]]),
 tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  8,  5,  7,  2, 13,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0,  9,  9,  3,  2, 13,  0,  0,  0]]))

In [6]:
emb_size = 64
hidden_size = 128

class Model(nn.Module):
    
    def __init__(self):
        super().__init__();
        self.emb = nn.Embedding(vocab_size, emb_size)
        self.rnn = nn.RNN(emb_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        y = self.emb(x)
        y, h = self.rnn(y, hidden)
        y = self.linear(y)
        return y, h

    
model = Model()
optim = torch.optim.AdamW(model.parameters())

sum([p.numel() for p in model.parameters()])

27534

In [7]:
-torch.tensor(1/vocab_size).log().item()

2.6390573978424072

In [8]:
a = [1, 2, 3]
a[0:-1]

[1, 2]

In [9]:
model.train()

for i in range(50000):
    optim.zero_grad()
    
    x, y = random_batch(32) # (N, L)
    logits, _ = model(x)      # (N, L, vocab)
    
    # N, L = y
    y = y[:,1:].reshape(-1)
    logits = logits[:,:-1,:].reshape(-1, vocab_size)
    
    loss = F.cross_entropy(logits, y, ignore_index=0)
    loss.backward()
    optim.step()
    
    if i % 5000 == 0:
        print(loss.item())

2.647028684616089
0.8864972591400146
0.7910671830177307
0.5939415693283081
0.5654222369194031
0.5836650133132935
0.5539506077766418
0.6139837503433228
0.5237212777137756
0.6793401837348938


In [10]:
def generate_prob():
    a = torch.randint(0, 1000, (1,)).item()
    b = torch.randint(0, 1000, (1,)).item()
    c = a + b
    cs = f"{c}"[::-1]
    return f"{a}+{b}=", c, cs

In [11]:
generate_prob()

('691+503=', 1194, '4911')

In [32]:
model.eval()

for _ in range(10):
    prob, solu, rsolu = generate_prob()
    print(prob + str(solu))
    
    x = torch.tensor(encode(prob)).view(1, -1) # (B, L, C)
    tokens = []
    hidden = None

    for _ in range(100):
        logits, hidden = model(x, hidden)
        last = logits[0,-1]
        prob = F.softmax(last, dim=0)
        ix = torch.multinomial(prob, 1)
        if ix == stop_token:
            break
        tokens.append(ix)
        x = ix.view(1, 1)

    s = "".join([itoc[t.item()] for t in tokens])
    s = s[::-1]
    print(s)
    print("\n")

856+849=1705
1745


558+788=1346
1356


681+536=1217
1257


736+210=946
986


477+149=626
646


53+105=158
178


741+195=936
946


518+669=1187
1197


912+639=1551
1541


859+782=1641
1641


