In [584]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [585]:
stop_char = ";"
pad_char = "_"
chars = list("_0123456789+=;")
vocab_size = len(chars)
ctoi = {c:i for i, c in enumerate(chars)}
itoc = {i:c for i, c in enumerate(chars)}

eq_token = ctoi["="]
stop_token = ctoi[";"]

# return list of integer
def encode(string):
    return [ctoi[c] for c in string] 

def decode(tokens):
    if isinstance(tokens, torch.Tensor):
        return "".join([itoc[t.item()] for t in tokens])
    else:
        return "".join([itoc[t] for t in tokens])

ctoi

{'_': 0,
 '0': 1,
 '1': 2,
 '2': 3,
 '3': 4,
 '4': 5,
 '5': 6,
 '6': 7,
 '7': 8,
 '8': 9,
 '9': 10,
 '+': 11,
 '=': 12,
 ';': 13}

In [586]:
def generate_eq():
    a = torch.randint(0, 100, (1,))
    b = torch.randint(0, 100, (1,))
    c = a + b
    cs = f"{c.item()}"[::-1]
    
    return f"{a.item()}+{b.item()}={cs};"


generate_eq()

'39+30=96;'

In [587]:
def pad(ints, length):
    assert length >= len(ints)
    pn = length - len(ints)
    return ints + [0]*pn
        
pad([1,2], 5)

[1, 2, 0, 0, 0]

In [588]:
max_len = 12
padding = 0

# (B, L)
def random_batch(batch_size):
    eq_str = [generate_eq() for _ in range(batch_size)]
    data = [pad(encode(s), max_len) for s in eq_str]
    target = []
    
    for x in data:
        y = list(x)
        i = x.index(eq_token)
        y[0:i+1] = [0]*(i+1)
        target.append(y)
        
    return torch.tensor(data), torch.tensor(target)


x, y = random_batch(2)
x, y

(tensor([[ 5,  5, 11,  9,  8, 12,  2,  4,  2, 13,  0,  0],
         [ 4, 10, 11, 10,  8, 12,  7,  4,  2, 13,  0,  0]]),
 tensor([[ 0,  0,  0,  0,  0,  0,  2,  4,  2, 13,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  7,  4,  2, 13,  0,  0]]))

In [589]:
emb_size = 64
hidden_size = 128

class Model(nn.Module):
    
    def __init__(self):
        super().__init__();
        self.emb = nn.Embedding(vocab_size, emb_size)
        self.rnn = nn.RNN(emb_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None):
        y = self.emb(x)
        y, h = self.rnn(y, hidden)
        y = self.linear(y)
        return y, h

    
model = Model()
optim = torch.optim.AdamW(model.parameters())

sum([p.numel() for p in model.parameters()])

# x, y = random_batch(1) # (N, L)
# logits = model(x) # (N, L, vocab)

# y = y[:,1:].view(-1)
# logits = logits[:,1:,:].view(-1, vocab_size)
# loss = F.cross_entropy(logits, y)
# loss

27534

In [590]:
-torch.tensor(1/vocab_size).log().item()

2.6390573978424072

In [591]:
a = [1, 2, 3]
a[0:-1]

[1, 2]

In [640]:
model.train()

for i in range(10000):
    optim.zero_grad()
    
    x, y = random_batch(32) # (N, L)
    logits, _ = model(x)      # (N, L, vocab)
    
    # N, L = y
    y = y[:,1:].reshape(-1)
    logits = logits[:,:-1,:].reshape(-1, vocab_size)
    
    loss = F.cross_entropy(logits, y, ignore_index=0)
    loss.backward()
    optim.step()
    
    if i % 1000 == 0:
        print(loss.item())

0.0005985383759252727
0.0019992964807897806
0.001008582883514464
0.0011705962242558599
0.0008264639182016253
0.000811861187685281
0.0008374485187232494
0.00032646741601638496
0.0019043784122914076
0.0009828548645600677


In [608]:
def generate_prob():
    a = torch.randint(0, 100, (1,)).item()
    b = torch.randint(0, 100, (1,)).item()
    c = a + b
    cs = f"{c}"[::-1]
    return f"{a}+{b}=", c, cs

In [609]:
generate_prob()

('47+0=', 47, '74')

In [645]:
model.eval()

for _ in range(10):
    prob, solu, rsolu = generate_prob()
    print(prob, solu, rsolu)
    
    x = torch.tensor(encode(prob)).view(1, -1)
    tokens = []
    hidden = None

    for _ in range(10):
        logits, hidden = model(x, hidden)
        last = logits[0,-1]
        prob = F.softmax(last, dim=0)
        ix = torch.multinomial(prob, 1)
        if ix == stop_token:
            break
        tokens.append(ix)
        x = ix.view(1, 1)

    print("".join([itoc[t.item()] for t in tokens]))
    print("\n")

8+8= 16 61
61


56+35= 91 19
19


25+28= 53 35
35


49+19= 68 86
86


74+10= 84 48
48


52+36= 88 88
88


85+97= 182 281
281


79+26= 105 501
501


49+30= 79 97
97


35+75= 110 011
011


