In [1]:
import torch
import torch.nn as nn
import math
torch.manual_seed(69)

<torch._C.Generator at 0x7f5650bb94d0>

In [2]:
vocab_size = 14
chars = "0123456789+=. " # '.':EOS, ' ':NULL
stoi = {c:i for i, c in enumerate(chars)}
encode = lambda str: [stoi[c] for c in str]
decode = lambda idxs: ''.join([chars[i] for i in idxs])

In [4]:
max_len = 10
rhs_max_len = 4
lhs_max = 1000
B = 32
T = max_len + 2 # including '+' and '='


def get_batch():
    batch = []
    for _ in range(B):
        a = torch.randint(0, lhs_max, (2,))
        sum = a.sum()
        padding = " " * (max_len - int(math.log10(sum))-1 - int(math.log10(a[0]))-1 - int(math.log10(a[1]))-1)
        s = f"{a[0]}+{a[1]}={padding}.{sum}"
        batch.append(s)
    return get_examples(batch)

# given batch of strings, encodes and converts to training examples
def get_examples(batch):
    Xb, Yb = [], []
    for x in batch:
        i = x.index('=')
        xb = x[:i+1]+x[:i+1:-1]
        yb = x[:T-(rhs_max_len+1):-1]
        Xb.append(encode(xb))
        Yb.append(encode(yb))
    return torch.tensor(Xb), torch.tensor(Yb)

get_batch()

(tensor([[ 4,  7,  0, 10,  4,  3,  5, 11,  5,  0,  9, 12],
         [ 3,  6,  9, 10,  2,  0,  3, 11,  2,  7,  5, 12],
         [ 1,  6,  2, 10,  6,  1,  5, 11,  7,  7,  7, 12],
         [ 9,  3,  2, 10,  5,  7,  7, 11,  9,  0,  5,  1],
         [ 5,  1,  0, 10,  2,  2,  5, 11,  5,  3,  7, 12],
         [ 2,  8,  0, 10,  4,  9, 11,  9,  2,  3, 12, 13],
         [ 1,  2,  6, 10,  8,  0, 11,  6,  0,  2, 12, 13],
         [ 9,  3,  6, 10,  7,  6, 11,  2,  1,  0,  1, 12],
         [ 7,  4,  1, 10,  6,  9,  0, 11,  1,  3,  4,  1],
         [ 5,  7,  2, 10,  9,  1,  5, 11,  7,  8,  4,  1],
         [ 9,  2,  5, 10,  8,  1,  9, 11,  4,  4,  7,  1],
         [ 3,  2,  8, 10,  5,  3,  5, 11,  3,  6,  8, 12],
         [ 8,  1,  6, 10,  2,  1,  8, 11,  4,  3,  0,  1],
         [ 7,  4,  7, 10,  4,  3,  1, 11,  8,  7,  1,  1],
         [ 7,  4,  4, 10,  7,  2,  0, 11,  4,  6,  4,  1],
         [ 9,  8,  3, 10,  7,  2,  9, 11,  2,  1,  7,  1],
         [ 4,  8,  2, 10,  6,  0,  3, 11,  5,  8,  0,  1

In [5]:
head_size = 8
n_embd = 8
embedding_table = nn.Embedding(vocab_size, n_embd)
l1 = nn.Linear(n_embd, head_size, bias=False)
l2 = nn.Linear(n_embd, head_size, bias=False)
l3 = nn.Linear(n_embd, head_size, bias=False)

Xb, Yb = get_batch() # (B, T)

x = embedding_table(Xb) # (B, T, C)

queries = l1(x[:, T-(rhs_max_len+1):, :]) # only predict answer
keys = l2(x)
wei = queries @ keys.transpose(1, 2) / x.shape[2]**-0.5 # (B, rhs_max_len+1, H) @ (B, H, T) = (B, rhs_max_len+1, T)

tril = torch.tril(torch.ones(rhs_max_len+1, T), diagonal=T-(rhs_max_len+1))
wei = wei.masked_fill(tril==0, -torch.inf)
wei = torch.softmax(wei, dim=-1)

values = l3(x)
out = wei @ values # (B, rhs_max_len+1, T) @ (B, T, H) = (B, rhs_max_len+1, H)
print(out.shape)


torch.Size([32, 5, 8])


In [11]:
print(Yb[0])

tensor([ 5,  0,  6,  1, 12])


tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
