In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass

In [None]:
stop_char = ";"
pad_char = "_"
chars = list("_0123456789+=;")
vocab_size = len(chars)
ctoi = {c:i for i, c in enumerate(chars)}
itoc = {i:c for i, c in enumerate(chars)}

eq_token = ctoi["="]
stop_token = ctoi[";"]

# return list of integer
def encode(string):
    return [ctoi[c] for c in string] 

def decode(tokens):
    if isinstance(tokens, torch.Tensor):
        return "".join([itoc[t.item()] for t in tokens])
    else:
        return "".join([itoc[t] for t in tokens])

ctoi

In [122]:
def generate_eq(hi_range=100):
    a = torch.randint(0, hi_range, (1,))
    b = torch.randint(0, hi_range, (1,))
    c = a + b
    cs = f"{c.item()}"[::-1]
    return f"{a.item()}+{b.item()}={cs};"

generate_eq()

'39+31=07;'

In [123]:
def pad(ints, length):
    assert length >= len(ints)
    pn = length - len(ints)
    return ints + [0]*pn
        
pad([1,2], 5)

[1, 2, 0, 0, 0]

In [133]:
block_size = 32
padding = 0

# (B, L)
def random_batch(batch_size, hi_range=100):
    eq_str = [generate_eq(hi_range) for _ in range(batch_size)]
    data = [pad(encode(s), block_size) for s in eq_str]
    target = []
    
    for x in data:
        y = list(x)
        i = x.index(eq_token)
        y[0:i+1] = [0]*(i+1)
        target.append(y)
        
    return torch.tensor(data), torch.tensor(target)


x, y = random_batch(2)
x, y

(tensor([[ 4,  4, 11, 10,  2, 12,  5,  3,  2, 13,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2,  2, 11,  9,  7, 12,  8, 10, 13,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]),
 tensor([[ 0,  0,  0,  0,  0,  0,  5,  3,  2, 13,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  8, 10, 13,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]))

In [134]:
@dataclass
class Config:
    vocab_size: int
    block_size: int
    emb_size: int
    head_num: int
    head_size: int
    layer_num: int
    ctoi: dict
    dropout: float


class MultiHeadAttension(nn.Module):

    def __init__(self, c: Config):
        super().__init__()
        assert c.emb_size / c.head_size == c.head_num

        self.head_size = c.head_size
        self.head_num = c.head_num
        self.attn = nn.Linear(
            c.emb_size, 3 * c.head_num * c.head_size, bias=False)
        self.ffn = nn.Linear(c.head_num * c.head_size, c.emb_size, bias=False)

        self.attn_dropout = nn.Dropout(c.dropout)
        self.resid_dropout = nn.Dropout(c.dropout)

    # x: (B, L, C)
    # return: (B, L, C)
    def forward(self, x):
        B, L, C = x.shape

        z = self.attn(x)  # (B, L, 3 * hn * hs)
        k, q, v = torch.split(
            z, self.head_num * self.head_size, dim=2)  # (B, L, hn * hs)

        k = k.view(B, L, self.head_num, self.head_size).permute(
            0, 2, 1, 3)  # (B, hn, L, hs)
        q = q.view(B, L, self.head_num, self.head_size).permute(0, 2, 1, 3)
        v = v.view(B, L, self.head_num, self.head_size).permute(0, 2, 1, 3)

        q = q.permute(0, 1, 3, 2)  # (B, hn, hs, L)
        attn = (k @ q) / self.head_size**0.5  # (B, hn, L, L)
        mask = torch.tril(torch.ones(L, L)) == 0
        mask = mask.to(x.device)
        attn = attn.masked_fill(mask, -float('inf'))  # (B, hn, L, L)
        attn = F.softmax(attn, dim=3)
        attn = self.attn_dropout(attn)

        y = attn @ v  # (B, hn, L, hs)
        y = y.permute(0, 2, 1, 3)  # (B, L, hn, hs)
        y = y.contiguous().view(B, L, -1)  # (B, L, hn * hs)
        y = self.ffn(y)  # (B, L, C)
        y = self.resid_dropout(y)

        return y


class FeedForward(nn.Module):

    def __init__(self, c: Config):
        super().__init__()
        self.linear1 = nn.Linear(c.emb_size, 2 * c.emb_size)
        self.linear2 = nn.Linear(2 * c.emb_size, c.emb_size)
        self.dropout = nn.Dropout(c.dropout)

    # (B, L, C)
    def forward(self, x):
        y = self.linear1(x)
        y = torch.relu(y)
        y = self.linear2(y)
        y = self.dropout(y)

        return y


class Block(nn.Module):

    def __init__(self, c: Config):
        super().__init__()

        assert c.emb_size % c.head_size == 0
        assert c.emb_size / c.head_size == c.head_num

        self.mha = MultiHeadAttension(c)
        self.lnorm1 = nn.LayerNorm(c.emb_size)
        self.lnorm2 = nn.LayerNorm(c.emb_size)
        self.ffn = FeedForward(c)

    # x: (B, L, emb)
    def forward(self, x):
        y = self.mha(x) + x
        y = self.lnorm1(y)
        y = self.ffn(y) + y
        y = self.lnorm2(y)
        return y


class Transformer(nn.Module):

    def __init__(self, c: Config):
        super().__init__()
        self.config = c
        self.embed = nn.Embedding(c.vocab_size, c.emb_size)
        self.dropout = nn.Dropout(c.dropout)
        self.blocks = nn.Sequential(
            *[Block(c) for _ in range(c.layer_num)]
        )
        self.proj = nn.Linear(c.emb_size, c.vocab_size)

    # return (L, C)
    def pos_encoding(self, x):
        B, L, C = x.shape
        pos = torch.arange(0, L).view(-1, 1)  # (L, 1)
        div = 2 * torch.arange(0, C) / C  # (C)
        div = torch.pow(10000, div)  # (C)
        e = pos / div
        pe = torch.zeros(L, C)
        pe[:, 0::2] = torch.sin(e[:, 0::2])
        pe[:, 1::2] = torch.cos(e[:, 1::2])

        pe = pe.to(x.device)
        return pe

    # (B, L) -> (B, L, C)
    def forward(self, x):
        y = self.embed(x)  # (B, L, emb)
        y = y + self.pos_encoding(y)  # (B, L, emb)
        y = self.dropout(y)
        y = self.blocks(y)  # (B, L, emb)
        y = self.proj(y)  # (B, L, vocab)

        return y


In [135]:
cfg = Config(
    vocab_size = vocab_size,
    block_size=block_size,
    emb_size=32,
    head_num=4,
    head_size=8,
    layer_num=4,
    ctoi=ctoi,
    dropout=0
)

model = Transformer(cfg)
optim = torch.optim.AdamW(model.parameters())
sum([p.numel() for p in model.parameters()])

34574

In [136]:
-torch.tensor(1/vocab_size).log().item()

2.6390573978424072

In [137]:
@torch.no_grad()
def eval_valid(model):
    model.eval()
    x, y = random_batch(128, 1000) # (N, L)
    logits = model(x)             # (N, L, vocab)
    
    # N, L = y
    y = y[:,1:].reshape(-1)
    logits = logits[:,:-1,:].reshape(-1, vocab_size)
    loss = F.cross_entropy(logits, y, ignore_index=0)
    return loss.item()

In [139]:
model.train()

for i in range(5000):
    optim.zero_grad()
    
    x, y = random_batch(32) # (N, L)
    logits = model(x)      # (N, L, vocab)
    
    # N, L = y
    y = y[:,1:].reshape(-1)
    logits = logits[:,:-1,:].reshape(-1, vocab_size)
    
    loss = F.cross_entropy(logits, y, ignore_index=0)
    loss.backward()
    optim.step()
    
    if i % 500 == 0:
        val_loss = eval_valid(model)
        print(loss.item(), val_loss)

2.6357204914093018 2.586203098297119
1.29684579372406 3.768287181854248
1.3459900617599487 4.3845672607421875
1.158843994140625 3.6213533878326416
0.8411932587623596 4.423673152923584
0.7436173558235168 5.530391693115234
0.36972683668136597 6.554106712341309
0.1334693729877472 9.214505195617676
0.1944161206483841 10.833005905151367
0.027401503175497055 11.598226547241211


In [None]:
def generate_prob():
    a = torch.randint(0, 100, (1,)).item()
    b = torch.randint(0, 100, (1,)).item()
    c = a + b
    cs = f"{c}"[::-1]
    return f"{a}+{b}=", c, cs

In [None]:
generate_prob()

In [None]:
model.eval()

for _ in range(10):
    prob, solu, rsolu = generate_prob()
    print(prob + str(solu))
    
    new_token = []
    tokens = encode(prob)
    for _ in range(100):
        x = torch.tensor(tokens).view(1, -1) # (B, L)
        logits = model(x)
        last = logits[0,-1]
        prob = F.softmax(last, dim=0)
        ix = torch.multinomial(prob, 1).item()
        
        if ix == stop_token:
            break
            
        new_token.append(ix)
        tokens.append(ix)

    s = "".join([itoc[t] for t in new_token])
    s = s[::-1]
    print(s)
    print("\n")