In [2]:
import torch
import torch.nn as nn
import math
torch.manual_seed(69)

<torch._C.Generator at 0x7f1fa590d7f0>

In [3]:
vocab_size = 14
chars = "0123456789+=. " # '.':EOS, ' ':NULL
stoi = {c:i for i, c in enumerate(chars)}
encode = lambda str: [stoi[c] for c in str]
decode = lambda idxs: ''.join([chars[i] for i in idxs])

In [4]:
max_len = 10
rhs_max_len = 4
lhs_max = 1000
B = 32
T = max_len + 2 # including '+' and '='


def get_batch():
    batch = []
    for _ in range(B):
        a = torch.randint(0, lhs_max, (2,))
        sum = a.sum()
        padding = " " * (max_len - int(math.log10(sum))-1 - int(math.log10(a[0]))-1 - int(math.log10(a[1]))-1)
        s = f"{a[0]}+{a[1]}={padding}.{sum}"
        batch.append(s)
    return get_examples(batch)

# given batch of strings, encodes and converts to training examples
def get_examples(batch):
    Xb, Yb = [], []
    for x in batch:
        i = x.index('=')
        xb = x[:i+1]+x[:i+1:-1]
        yb = x[:T-(rhs_max_len+1):-1]
        Xb.append(encode(xb))
        Yb.append(encode(yb))
    return torch.tensor(Xb), torch.tensor(Yb)

In [5]:
head_size = 8
n_embd = 8
embedding_table = nn.Embedding(vocab_size, n_embd)
l1 = nn.Linear(n_embd, head_size, bias=False)
l2 = nn.Linear(n_embd, head_size, bias=False)
l3 = nn.Linear(n_embd, head_size, bias=False)

Xb, Yb = get_batch() # (B, T)

x = embedding_table(Xb) # (B, T, C)

queries = l1(x[:, T-(rhs_max_len+1):, :]) # only predict answer
keys = l2(x)
wei = queries @ keys.transpose(1, 2) / x.shape[2]**-0.5 # (B, rhs_max_len+1, H) @ (B, H, T) = (B, rhs_max_len+1, T)

tril = torch.tril(torch.ones(rhs_max_len+1, T), diagonal=T-(rhs_max_len+1))
wei = wei.masked_fill(tril==0, -torch.inf)
wei = torch.softmax(wei, dim=-1)

values = l3(x)
out = wei @ values # (B, rhs_max_len+1, T) @ (B, T, H) = (B, rhs_max_len+1, H)
print(out.shape)


torch.Size([32, 5, 8])


In [11]:
print(Yb[0])

tensor([ 5,  0,  6,  1, 12])


tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
