In [1]:
import pickle
from tokenizers import CharacterTokenizer
import torch
from transformer import Transformer

In [2]:
TOKENIZER_PATH = "tokenizers/character_tokenizer.pkl"
DATA_PATH = "../data/paul_graham_essay.txt"
TRAIN_DATA_PATH = "train.pt"
TEST_DATA_PATH = "test.pt"

DEVICE = "cuda"
EMBED_SIZE = 384
NUM_HEADS = 8
CONTEXT_LENGTH = 16
NUM_LAYERS = 6
BATCH_SIZE = 32

In [3]:
try:
    with open(TOKENIZER_PATH, "rb") as f:
        tokenizer = pickle.load(f)
except:
    tokenizer = CharacterTokenizer.build_tokenizer(DATA_PATH)
    with open(TOKENIZER_PATH, "wb") as f:
        pickle.dump(tokenizer, f)

VOCAB_SIZE = tokenizer.vocab_size

In [4]:
try:
    train_data = torch.load(TRAIN_DATA_PATH)
    test_data = torch.load(TEST_DATA_PATH)
except:
    with open(DATA_PATH, "r") as f:
        data = f.read()
    tokenized_data = torch.tensor(tokenizer.encode(data), dtype=torch.long)
    n = int(0.9 * len(tokenized_data))
    train_data = tokenized_data[:n]
    test_data = tokenized_data[n:]
    torch.save(train_data, TRAIN_DATA_PATH)
    torch.save(test_data, TEST_DATA_PATH)

In [5]:
m = Transformer(
    vocab_size=VOCAB_SIZE,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    context_length=CONTEXT_LENGTH,
    num_layers=NUM_LAYERS,
)
m.to(DEVICE)

Transformer(
  (token_embedding): Embedding(81, 384)
  (pos_embedding): Embedding(16, 384)
  (attn_blocks): ModuleList(
    (0-5): 6 x AttentionBlock(
      (attn_heads): ModuleList(
        (0-7): 8 x CausalSelfAttention(
          (Q): Linear(in_features=384, out_features=48, bias=False)
          (K): Linear(in_features=384, out_features=48, bias=False)
          (V): Linear(in_features=384, out_features=48, bias=False)
        )
      )
      (mlp): MLP(
        (fcn): Linear(in_features=384, out_features=1536, bias=True)
        (activation): ReLU()
        (proj): Linear(in_features=1536, out_features=384, bias=True)
      )
      (layer_norm_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
  )
  (lm_head): Linear(in_features=384, out_features=81, bias=False)
)

In [6]:
unif = torch.ones(train_data.shape[0] - CONTEXT_LENGTH, device=DEVICE)


def get_batch():
    batch_indices = unif.multinomial(BATCH_SIZE, replacement=False)
    inputs = torch.stack([train_data[i : i + CONTEXT_LENGTH] for i in batch_indices])
    targets = torch.stack(
        [train_data[i + 1 : i + 1 + CONTEXT_LENGTH] for i in batch_indices]
    )
    inputs = inputs.to(DEVICE)
    targets = targets.to(DEVICE)
    return inputs, targets


logits, loss = m(*get_batch())

In [7]:
logits.shape

torch.Size([32, 16, 81])

In [8]:
loss.item()

4.586947917938232

In [9]:
start_idx = get_batch()[0]
print(start_idx)
m.generate(start_idx, 1)

tensor([[31, 43, 12,  5, 25,  6, 74, 20, 67, 43,  5, 20, 43, 28, 25, 62],
        [ 5, 17, 64, 31, 43, 31,  5, 43, 44, 20, 31, 43, 44, 62, 73, 17],
        [64, 74, 52, 72, 78, 31, 43, 44, 52, 43, 44, 43, 57,  5,  5,  6],
        [64, 78, 43, 37, 43, 12, 44, 73, 62, 72, 78, 31, 43, 73, 72, 78],
        [44, 25, 44, 20, 73, 78, 78, 43,  0,  5, 17, 27, 25, 78, 43,  5],
        [44, 64, 43, 37, 20, 73, 78, 64, 64, 74, 67, 78, 20, 62, 78, 76],
        [ 5, 17, 64, 31, 43, 57, 44, 25, 78, 64,  0, 43, 17, 20, 31, 78],
        [43, 44, 43, 53, 52, 62, 72, 73, 74, 62,  6, 53, 29, 43, 52,  5],
        [43, 51, 17, 52, 73, 43, 52, 74, 73, 73, 74, 20, 67, 43, 74, 20],
        [78, 43, 78, 20, 31, 43,  5, 63, 43, 73, 72, 78, 43,  0, 78, 44],
        [72,  5, 12, 43, 73, 72, 78, 43,  4, 44, 74, 20, 73, 74, 20, 67],
        [72, 44, 73, 43, 74, 73, 43, 64, 78, 44, 31, 52, 43, 73,  5, 43],
        [74, 78, 31, 43, 12, 78, 43, 12, 78, 25, 78, 43, 64, 44, 73, 78],
        [43, 67, 44, 46, 78, 43, 73, 7

tensor([[31, 43, 12,  5, 25,  6, 74, 20, 67, 43,  5, 20, 43, 28, 25, 62, 19],
        [ 5, 17, 64, 31, 43, 31,  5, 43, 44, 20, 31, 43, 44, 62, 73, 17, 65],
        [64, 74, 52, 72, 78, 31, 43, 44, 52, 43, 44, 43, 57,  5,  5,  6, 67],
        [64, 78, 43, 37, 43, 12, 44, 73, 62, 72, 78, 31, 43, 73, 72, 78, 42],
        [44, 25, 44, 20, 73, 78, 78, 43,  0,  5, 17, 27, 25, 78, 43,  5, 44],
        [44, 64, 43, 37, 20, 73, 78, 64, 64, 74, 67, 78, 20, 62, 78, 76, 48],
        [ 5, 17, 64, 31, 43, 57, 44, 25, 78, 64,  0, 43, 17, 20, 31, 78, 37],
        [43, 44, 43, 53, 52, 62, 72, 73, 74, 62,  6, 53, 29, 43, 52,  5, 70],
        [43, 51, 17, 52, 73, 43, 52, 74, 73, 73, 74, 20, 67, 43, 74, 20, 11],
        [78, 43, 78, 20, 31, 43,  5, 63, 43, 73, 72, 78, 43,  0, 78, 44, 63],
        [72,  5, 12, 43, 73, 72, 78, 43,  4, 44, 74, 20, 73, 74, 20, 67, 64],
        [72, 44, 73, 43, 74, 73, 43, 64, 78, 44, 31, 52, 43, 73,  5, 43, 34],
        [74, 78, 31, 43, 12, 78, 43, 12, 78, 25, 78, 43, 64, 44,

In [10]:
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(
                torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 1000
            )[0].tolist()
        )
    )

I was aKmX+8HV3KWcwx,EXjFdcw0L[IIX4W'+R8A%VC$'s5yi—/-0j:z+—+,CCs[",;Cm/FCCVHy6;Wln)lr;'iEWsq-F7j:Wi6pXNWB,IC(z3jwK+j,ga%t,Yses/:sC/kIwH Ob5m)—j[!!KeApNTF,iIaboBYDSWhlyK8
KC/+VP?$/j
?IsJrzs.r,q
M4by9Gf)Az'5.2+6!&7'eN3]RVJgp-wepXc6hf
WcGIbok[glSX6WnPo5/J3J+EjzHVG—]JEgzA-'q+K%ffg4pjsYrdapLm
XK&HE"gq%3qIJ6G'wy zO3rzrGtEPTVT eRfoCc1cB(Sl1,,vCv0GUfyL(SK09p;6ars8[—Awaeci/F!BMfDirI%P+
sXzfayP/xN]
[rVRCpwYK—j-,c—'%4F'Va—%a[,/mI!I0IoAI-qw!n3rDo,MC"Ar.y/fU'KavX
1CG0kfA'1DpfW0tO,u-Wx/E
$!w—d[R2Lx(KC24V!CD$K9N5HVPO!!okI(RWxga
Bj!rmU3if
+[CJ[6H9'M:$6+'p/Kp-C9dyHj2n[D8vr(DG :/0[),5/K!Eg3w(8p-vhDb$F2a"aYz9Gib3[qp-;G)]ulrs—i)iWmW21;%HYtoAK8.MAI1cJ9s, zl
!GxU%dqr
.lrVdI+uACd7!ig5WJ7g8j+cF?bG,6b5(vsKeXL ,S-q-nE6z.'9NaOXEG$[W)I[j-DJq]H—w!E:JP8J,w4/ytw[r!(O-(D)Ke,tC;&M—Kdl-n[y$;k+'v&6;X1Vn/(aj)'.--Kgi!L4'G+I)[iIEegwbf—&IyGl!X-c2lIE+,f+VYxgCfD3q4i/mfgI$%KFE,r+gct!Ln9ypCr0MkE-M&zN2h,pxkO3WtA/JG'LblkURf8vS!+h-nLoACP'jKjyWne0hprEm-zcC'$P'Jd7—4VU(dxyjxWi,W9lp;F/cdq:MNa"-E;IYB&r$PLHhKl-py1L:B'8K'e[jBcD—Rs[jR[JBl

In [11]:
optim = torch.optim.AdamW(params=m.parameters(), lr=1e-4)

In [12]:
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(
                torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 1000
            )[0].tolist()
        )
    )

4.612915992736816
2.387751817703247
2.195810317993164
1.8805665969848633
1.8831758499145508
1.939691185951233
1.7789957523345947
1.7594327926635742
1.6662935018539429
1.607583999633789
I was alrealy comprety, ideal think, bett[ rearsiou for of 199E, as were ginal where jecambe there frou not mely founds. There as lot fill arthund. At to kneows blow money sexpecialy foundsyrate Oneel of called for means I firunders lanking appayinting MBC5 thery definers refounders. So be foundat's fiust feew better for was wored ond what 4 scatives never kind, a but paint, because "I mome. HOn't we retard a  company mare the (obN. I" his side thinkings of whty anter ofied of have moinxting I sold by to fick trunsess I coloasi were all because for wrling fic&iple thinked from exceprespece would rally pristific, that be Now $3 sees I was occomianifed fer I, toceps you'd bu know more ould was that it nais interstive. I was ot: kefthought was d— whoul, was that be summers that sturtups Ollieir things to mi

In [13]:
# Let's do a 1000 more steps and see
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(
                torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 1000
            )[0].tolist()
        )
    )

1.5965429544448853
1.5324620008468628
1.5667592287063599
1.4346281290054321
1.4636309146881104
1.5177146196365356
1.456603765487671
1.4044984579086304
1.4710373878479004
1.2809756994247437
I was actually fund at the first, which. I'l phile idea of he elt be an intemora+s at the long I'd bever but it was notually latergap if wall. Now I alked on on means about to print going the gatt the wall. A(I realized then in the for the day make stracked on at lett leave thear it. I had let tel; the talk for the pastratter. So you there's for why was stricture staten ence was me? ANSIWX48 Y Comman its as the were a PUW2 the Hists where ISD), was it fast was then was site the way plisting were. My sents of 130% could, I did to not— and funding in to make was stuff fu was smet become users you wouldn't know grad school, on then about the fit for procas though New Yor6 a'd are thing on other web, and was now it's bank. Ex atfectually prodenty printer really and me. The poppy working on the was better

In [14]:
# 1000 more?
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(
                torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 1000
            )[0].tolist()
        )
    )

1.337056040763855
1.288700819015503
1.215787649154663
1.180588722229004
1.3454545736312866
1.256841778755188
1.235177755355835
1.2567518949508667
1.2188429832458496
1.027768850326538
I was a big caller Live I was going to New York 7x years I wanted, or ganderies exanted me-inded mY Computer for hi wanted out, but know but I noticed un, and it was batch these originally of completely users the Robert gandled for the toge except than the would last focurious had to be far, it was not the last long shirt, because how else's work only building from an audience. I'd write to write the soft the contranslated, just a blike the other of a complete, whose worn in, it the eimners later seed for things really for he processes of cornercal become I needed more exciting YC more to color ques.

The work, I was a writing es out of rent fagain, but I was doing a for the inticuoted it it could be oxce so.

At For the Some bright there was what point then I livinked around when I got kenounded start pro