In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from BLM import BLModel

In [3]:
with open('input.txt', 'r', encoding='utf-8') as fp:
    text = fp.read()

    
print(len(text))

1115394


In [4]:
max([len(line) for line in text.splitlines()])

63

In [5]:
chars = sorted(list(set(text)))
vocab_len = len(chars)
vocab_len

65

In [6]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s:[stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])
print(encode("hello"))
print(decode([46, 43, 50, 50, 53]))

[46, 43, 50, 50, 53]
hello


In [7]:
encoded_text = encode(text)
data = torch.tensor(encoded_text, dtype=torch.long)
data.shape

torch.Size([1115394])

In [72]:
n = int(0.9*len(data))
train = data[:n]
val = data[n:]

In [73]:
block_size = 8

In [87]:
torch.manual_seed(1337)
batch_size = 4

def get_batch(split, batch_size):
    data = train if split=='train' else val
    ix = torch.randint(0, len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [75]:
x_b, y_b = get_batch('train')

In [77]:
a = x_b.view(-1)
a.shape

torch.Size([32])

In [78]:
y_b.view(-1)

tensor([43, 58,  5, 57,  1, 46, 43, 39, 53, 56,  1, 58, 46, 39, 58,  1, 58,  1,
        58, 46, 39, 58,  1, 46, 17, 27, 10,  0, 21,  1, 54, 39])

In [80]:
model = BLModel(vocab_len)
logits, loss = model(x_b, y_b)
print(loss)

tensor(4.7288, grad_fn=<NllLossBackward0>)


In [85]:
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(model.generate(idx, max_tokens=100)[0].tolist()))


LM,VPBBZslB3q;X-?XYjnnoS,KPxOhFXFVAB:$Ssy'KiuBH:iztcqsjOV? ypX.CV?E,IW!B3RrkQ3slxSJWyp
RxruO ?OmnCV?


In [90]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [102]:
batch_size = 32
epochs = 10000

for i in range(epochs):
    x, y = get_batch('train', batch_size)
    logits, loss = model(x, y)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print(loss.item())

2.54422664642334


In [103]:
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(model.generate(idx, max_tokens=300)[0].tolist()))


DUSearspl woalsorthathay st sly ch hentrse,

'sur titest'es anaswh; s'drvisthe ANTI tshef mif thithonik y
ARDYer mit, t velags g I beffuror, tou IONTanoredie mabongrelash, buste sicod.

Medosewnof the offt, atingey y hear w, tontr hee ct meho myoulorey;
Antlondy wek thal,
Cony hath, win, wl d r's n 


In [39]:
a = torch.arange(8)
x = nn.Embedding(8, 32)
c = x(a)
print(a)
print(list(x(a)))


tensor([0, 1, 2, 3, 4, 5, 6, 7])
[tensor([-0.1360, -0.4326,  0.7615, -0.8889, -1.2581,  0.7896, -0.0308,  0.4394,
        -0.7868,  0.8417,  0.6901, -0.1989,  2.2285,  0.0775,  0.3692,  1.1617,
         1.0096,  0.0471, -1.6527,  0.4753,  1.3813, -0.4455,  0.0853,  1.5816,
        -1.7010,  0.3399,  0.2559,  0.8622, -0.6977,  0.2110,  1.0457,  0.7995],
       grad_fn=<UnbindBackward0>), tensor([-0.4958, -0.5321, -0.0093, -0.1969,  0.3079, -0.0489, -1.8682, -1.4695,
        -0.0491, -0.7463, -0.7415,  1.9469,  0.3440, -0.2296, -0.5075, -1.4917,
        -0.6732, -0.3399, -0.4867,  0.7179, -1.9880,  0.5867,  0.4187, -0.5697,
        -0.2930,  0.5169, -0.1086,  0.9301, -1.5397, -1.4829, -0.3558,  1.5564],
       grad_fn=<UnbindBackward0>), tensor([-0.2727,  0.8650,  2.2066,  0.6425, -0.2533,  0.2507,  0.5127, -0.5012,
         0.9489,  0.4927, -1.9828,  1.8919,  0.6860, -0.7374,  0.7541,  0.1292,
         0.9509,  0.7362,  2.4035,  1.4060,  0.3669, -0.8453, -1.6132, -0.2096,
        -0.405

In [50]:
b, t, c = 4, 8, 32
x = torch.randn(b, t, c)

tril = torch.tril(torch.ones(t, t))
w = torch.zeros((t, t))
w = w.masked_fill(tril==0, float('-inf'))
w = F.softmax(w, dim=1)
out = w @ x
out.shape

torch.Size([4, 8, 32])

In [55]:
head_size = 16
key = nn.Linear(c, head_size, bias=False)
query = nn.Linear(c, head_size, bias=False)
value = nn.Linear(c, head_size, bias=False)

k = key(x)
q = query(x)
w = q @ k.transpose(-2, -1)
w = w.masked_fill(tril==0, float('-inf'))
w = F.softmax(w, dim=-1)

v = value(x)
out = w @ v
out

tensor([[[ 0.4678,  0.4361,  0.1105, -0.0578, -0.2667, -0.0610, -0.2685,
           0.2343, -0.0932,  0.7875,  0.1396,  0.7671,  0.6959,  0.1273,
          -0.1844, -0.1012],
         [ 0.6956,  0.2382, -0.1232,  0.0624,  0.4234, -0.6754, -0.0030,
           0.0559, -0.1732,  1.0472,  0.0074,  0.9908,  0.5818,  0.0974,
          -0.3931, -0.7754],
         [ 0.5226,  0.3481, -0.0137, -0.0127,  0.0504, -0.2954, -0.1160,
           0.1338, -0.1811,  0.8669,  0.0688,  0.7908,  0.5945,  0.1079,
          -0.2718, -0.3781],
         [ 0.1089,  0.2564, -0.1149,  0.7547,  0.3245, -0.5050,  0.0051,
           0.0688, -0.0524,  0.3941,  0.2570,  0.2161,  0.6449,  0.4455,
           0.1674, -0.7853],
         [-0.1991,  0.3894, -0.1258,  0.1582,  0.0591,  0.2512,  0.1702,
           0.0273, -0.6009,  0.2235,  0.0855, -0.1765,  0.2069,  0.1726,
           0.0516, -0.1026],
         [ 0.1748,  0.0503,  0.0056, -0.2091,  0.0794,  0.1152,  0.2774,
           0.0910, -0.7538,  0.4539,  0.0404, -0.004