In [1]:
import pickle
from tokenizers import BPETokenizer
import torch
from transformer import Transformer

In [2]:
VOCAB_SIZE = 1000

TOKENIZER_PATH = f"tokenizers/bpe_{VOCAB_SIZE}_tokenizer.pkl"
DATA_PATH = "../data/paul_graham_essay.txt"
TRAIN_DATA_PATH = f"bpe_{VOCAB_SIZE}_train.pt"
TEST_DATA_PATH = f"bpe_{VOCAB_SIZE}_test.pt"

DEVICE = "cuda"
EMBED_SIZE = 384
NUM_HEADS = 8
CONTEXT_LENGTH = 16
NUM_LAYERS = 6
BATCH_SIZE = 32

In [3]:
try:
    with open(TOKENIZER_PATH, "rb") as f:
        tokenizer = pickle.load(f)
except:
    tokenizer = BPETokenizer.build_tokenizer(DATA_PATH, VOCAB_SIZE)
    with open(TOKENIZER_PATH, "wb") as f:
        pickle.dump(tokenizer, f)

In [4]:
try:
    train_data = torch.load(TRAIN_DATA_PATH)
    test_data = torch.load(TEST_DATA_PATH)
except:
    with open(DATA_PATH, "r") as f:
        data = f.read()
    tokenized_data = torch.tensor(tokenizer.encode(data), dtype=torch.long)
    n = int(0.9 * len(tokenized_data))
    train_data = tokenized_data[:n]
    test_data = tokenized_data[n:]
    torch.save(train_data, TRAIN_DATA_PATH)
    torch.save(test_data, TEST_DATA_PATH)

In [5]:
m = Transformer(vocab_size=VOCAB_SIZE, embed_size=EMBED_SIZE, num_heads=NUM_HEADS, context_length=CONTEXT_LENGTH, num_layers=NUM_LAYERS)
m.to(DEVICE)

Transformer(
  (token_embedding): Embedding(1000, 384)
  (pos_embedding): Embedding(16, 384)
  (attn_blocks): ModuleList(
    (0-5): 6 x AttentionBlock(
      (attn_heads): ModuleList(
        (0-7): 8 x CausalSelfAttention(
          (Q): Linear(in_features=384, out_features=48, bias=False)
          (K): Linear(in_features=384, out_features=48, bias=False)
          (V): Linear(in_features=384, out_features=48, bias=False)
        )
      )
      (mlp): MLP(
        (fcn): Linear(in_features=384, out_features=1536, bias=True)
        (activation): ReLU()
        (proj): Linear(in_features=1536, out_features=384, bias=True)
      )
      (layer_norm_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
  )
  (lm_head): Linear(in_features=384, out_features=1000, bias=False)
)

In [6]:
unif = torch.ones(train_data.shape[0] - CONTEXT_LENGTH, device=DEVICE)
def get_batch():
    batch_indices = unif.multinomial(BATCH_SIZE, replacement=False)
    inputs = torch.stack([train_data[i:i+CONTEXT_LENGTH] for i in batch_indices])
    targets = torch.stack([train_data[i+1:i+1+CONTEXT_LENGTH] for i in batch_indices])
    inputs = inputs.to(DEVICE)
    targets = targets.to(DEVICE)
    return inputs, targets

logits, loss = m(*get_batch())

In [7]:
logits.shape

torch.Size([32, 16, 1000])

In [8]:
loss.item()

7.0593485832214355

In [9]:
start_idx = get_batch()[0]
print(start_idx)
m.generate(start_idx, 1)

tensor([[411, 659, 269, 101, 321,  79, 486, 274, 835, 314, 376, 377, 690, 527,
         658, 422],
        [669, 280, 707, 858, 304, 257, 529, 327, 364, 790, 271, 321,  79, 336,
          32, 116],
        [257, 564, 413, 286, 268, 103, 696, 122, 256, 533, 270, 694, 288, 278,
         421, 286],
        [315, 991, 110, 267, 294, 286, 116, 309, 647, 466, 597, 552, 421, 560,
         354, 624],
        [674, 705, 345, 529, 280, 102, 394, 332, 116, 272, 271, 298, 101, 102,
         495, 766],
        [449, 630, 683, 464, 345, 283, 269,  99, 353, 574, 112, 553, 381, 589,
         816, 287],
        [574, 707, 112, 553, 287, 296, 900, 363, 670, 104, 117, 559, 489, 416,
         116, 631],
        [794, 280, 118, 379, 591, 393, 795, 476, 278, 885, 424, 739, 451, 680,
         103, 389],
        [109, 788, 299, 343, 114,  32, 641, 294, 102, 326, 116, 322, 116, 311,
         360, 714],
        [330,  49,  48,  37,  41,  32, 288, 476, 694, 571, 670,  77,  73,  84,
          32, 738],
        [9

tensor([[411, 659, 269, 101, 321,  79, 486, 274, 835, 314, 376, 377, 690, 527,
         658, 422, 127],
        [669, 280, 707, 858, 304, 257, 529, 327, 364, 790, 271, 321,  79, 336,
          32, 116, 850],
        [257, 564, 413, 286, 268, 103, 696, 122, 256, 533, 270, 694, 288, 278,
         421, 286, 141],
        [315, 991, 110, 267, 294, 286, 116, 309, 647, 466, 597, 552, 421, 560,
         354, 624, 556],
        [674, 705, 345, 529, 280, 102, 394, 332, 116, 272, 271, 298, 101, 102,
         495, 766, 324],
        [449, 630, 683, 464, 345, 283, 269,  99, 353, 574, 112, 553, 381, 589,
         816, 287, 241],
        [574, 707, 112, 553, 287, 296, 900, 363, 670, 104, 117, 559, 489, 416,
         116, 631, 701],
        [794, 280, 118, 379, 591, 393, 795, 476, 278, 885, 424, 739, 451, 680,
         103, 389, 325],
        [109, 788, 299, 343, 114,  32, 641, 294, 102, 326, 116, 322, 116, 311,
         360, 714,   7],
        [330,  49,  48,  37,  41,  32, 288, 476, 694, 571, 670, 

In [10]:
with torch.no_grad():
    print(tokenizer.decode(m.generate(torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 100)[0].tolist()))

I was aViabecaed �used ��I could rowyou could isiclearust somit was timenext reader omthing�I'd ore generdid trying to : writ�used stordays used pept who[U&Lisp ffirmake gowhat �ig magut onan when Fconfhow turdone publiubli�I at summa lot of �ment are becomwas perap�typvery ��imag�pept Xen't ��generunderSstuamge: good 


In [11]:
optim = torch.optim.AdamW(params=m.parameters(), lr=1e-4)

In [12]:
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(tokenizer.decode(m.generate(torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 100)[0].tolist()))

7.06289529800415
6.192881107330322
5.641086101531982
4.955998420715332
4.46861457824707
4.0264787673950195
3.936904191970825
3.50254225730896
3.1594178676605225
2.8628921508789062
I was a. I kept went paint them placantagented partly ul to se igious 1write 222, we formout of could ss could , because it was bworked in N that didn't strimmeded to do it started to have been the day in publishing till t, 2010, in phBe stomatural out of capary to buhknew how when art discofor so stack of bigiousser, to be ? 


In [13]:
#Let's do a 1000 more steps and see
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(tokenizer.decode(m.generate(torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 100)[0].tolist()))

2.498555898666382
2.3501780033111572
2.036022424697876
1.6520980596542358
1.5975638628005981
1.3270169496536255
1.0278438329696655
0.9670493602752686
0.7849155068397522
0.7880560159683228
I was avertes for twas out? I had no Bel to work talk. But axphotographilally infortun3 involnot, but a po, and partly because it would be a language photits centation and adds who wanted one of online ight land liviv, from revil, and of it.

The good per for programs were ting on the flong why to use and inition. 


In [14]:
#1000 more?
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(tokenizer.decode(m.generate(torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 100)[0].tolist()))

0.6662266254425049
0.6020451188087463
0.49546894431114197
0.48122438788414
0.5078184604644775
0.5251424312591553
0.4856676459312439
0.455766886472702
0.45577189326286316
0.4544914662837982
I was aight go back to RISD, but fund a bunch of startups all onced at the prospect of having to stand up in front of a group of people and tell them something that won't waste their time is a greal skey rengage in to make this work. By means of an egregious collection of hacks I managed to ved at Corne


In [15]:
with torch.no_grad():
    print(tokenizer.decode(m.generate(torch.tensor([tokenizer.encode("I wanted not just to build things")], device=DEVICE), 100)[0].tolist()))

I wanted not just to build things. I had plenty of respect for theory — indeed, a snea-d of Microsoft or Goldman Sachs.

The deal for startups was based on a combination and adds it or YC GDLU, partly because if you underst10% sure it's even a good way to paint. But it seemed a good enough bet to be worth trying.

Our 
