In [2]:
import pickle
from tokenizers import BPETokenizer
import torch
from transformer import Transformer

In [3]:
VOCAB_SIZE = 1000

TOKENIZER_PATH = f"tokenizers/bpe_{VOCAB_SIZE}_tokenizer.pkl"
DATA_PATH = "../data/paul_graham_essay.txt"
TRAIN_DATA_PATH = f"bpe_{VOCAB_SIZE}_train.pt"
TEST_DATA_PATH = f"bpe_{VOCAB_SIZE}_test.pt"

DEVICE = "cuda"
EMBED_SIZE = 384
NUM_HEADS = 8
CONTEXT_LENGTH = 64
NUM_LAYERS = 12
BATCH_SIZE = 32

In [4]:
try:
    with open(TOKENIZER_PATH, "rb") as f:
        tokenizer = pickle.load(f)
except:
    tokenizer = BPETokenizer.build_tokenizer(DATA_PATH, VOCAB_SIZE)
    with open(TOKENIZER_PATH, "wb") as f:
        pickle.dump(tokenizer, f)

In [5]:
try:
    train_data = torch.load(TRAIN_DATA_PATH)
    test_data = torch.load(TEST_DATA_PATH)
except:
    with open(DATA_PATH, "r") as f:
        data = f.read()
    tokenized_data = torch.tensor(tokenizer.encode(data), dtype=torch.long)
    n = int(0.9 * len(tokenized_data))
    train_data = tokenized_data[:n]
    test_data = tokenized_data[n:]
    torch.save(train_data, TRAIN_DATA_PATH)
    torch.save(test_data, TEST_DATA_PATH)

In [6]:
m = Transformer(
    vocab_size=VOCAB_SIZE,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    context_length=CONTEXT_LENGTH,
    num_layers=NUM_LAYERS,
)
m.to(DEVICE)

Transformer(
  (token_embedding): Embedding(1000, 384)
  (pos_embedding): Embedding(64, 384)
  (attn_blocks): ModuleList(
    (0-11): 12 x AttentionBlock(
      (attn_heads): ModuleList(
        (0-7): 8 x CausalSelfAttention(
          (Q): Linear(in_features=384, out_features=48, bias=False)
          (K): Linear(in_features=384, out_features=48, bias=False)
          (V): Linear(in_features=384, out_features=48, bias=False)
        )
      )
      (mlp): MLP(
        (fcn): Linear(in_features=384, out_features=1536, bias=True)
        (activation): ReLU()
        (proj): Linear(in_features=1536, out_features=384, bias=True)
      )
      (layer_norm_1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
  )
  (lm_head): Linear(in_features=384, out_features=1000, bias=False)
)

In [7]:
unif = torch.ones(train_data.shape[0] - CONTEXT_LENGTH, device=DEVICE)


def get_batch():
    batch_indices = unif.multinomial(BATCH_SIZE, replacement=False)
    inputs = torch.stack([train_data[i : i + CONTEXT_LENGTH] for i in batch_indices])
    targets = torch.stack(
        [train_data[i + 1 : i + 1 + CONTEXT_LENGTH] for i in batch_indices]
    )
    inputs = inputs.to(DEVICE)
    targets = targets.to(DEVICE)
    return inputs, targets


logits, loss = m(*get_batch())

In [8]:
logits.shape

torch.Size([32, 64, 1000])

In [9]:
loss.item()

7.084176540374756

In [10]:
start_idx = get_batch()[0]
print(start_idx)
m.generate(start_idx, 1)

tensor([[286, 292, 273,  ..., 791, 310, 914],
        [296, 119, 491,  ..., 110, 321, 511],
        [789, 318, 766,  ..., 984, 555, 285],
        ...,
        [647, 466, 379,  ..., 569, 296,  74],
        [260, 424, 280,  ..., 505, 339, 359],
        [386, 336, 267,  ..., 406, 860, 412]], device='cuda:0')


tensor([[286, 292, 273,  ..., 310, 914, 395],
        [296, 119, 491,  ..., 321, 511, 926],
        [789, 318, 766,  ..., 555, 285, 741],
        ...,
        [647, 466, 379,  ..., 296,  74,  44],
        [260, 424, 280,  ..., 339, 359, 208],
        [386, 336, 267,  ..., 860, 412, 692]], device='cuda:0')

In [11]:
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 100)[
                0
            ].tolist()
        )
    )

I was aexcow<ant`a loLisp�to get blemembm ting since working �first inte�who paintfor the intmemb�MQbeen ealso qua loeven s.

Hul epprogrammmukds ce ant it's ig buYCgrableLisp*�see WRobtrstrefr�. Whst�col�firte eeljMcCarInterdoing end en't while studcolle�Cambrid their me. had been decidunme. Youlugood t, I J5 ed that 5 v�rote 


In [12]:
optim = torch.optim.AdamW(params=m.parameters(), lr=1e-4)

In [13]:
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 100)[
                0
            ].tolist()
        )
    )

7.061542987823486
6.245255470275879
5.356468677520752
4.347721099853516
3.8550236225128174
3.4973127841949463
2.962090253829956
2.2358956336975098
1.5125809907913208
0.9945037961006165
I was ain make them. It wasn't stresly binternedes, on, even at the ticularchins, the produess quests for felt like all for ons for to do was learn company to New York all texpend, but like facult.

The friing for expressivelds of ious dows to the right in cessor "Yahoo bought usually ft plann't, 


In [14]:
# Let's do a 1000 more steps and see
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 100)[
                0
            ].tolist()
        )
    )

0.6650863289833069
0.48293110728263855
0.35592928528785706
0.31363072991371155
0.27621620893478394
0.2521527111530304
0.22401702404022217
0.21059758961200714
0.20255036652088165
0.18780313432216644
I was axible both thing.

So most of a bThe good , I scare, respons when I had ding this time e's a geadvice lunnatural hMcme, and completrobrain any moment vilosophor the fway of lack investors part Yahoo bought us. In princience is an arlook ld new likegrousered in was profi


In [15]:
# 1000 more?
for step in range(1000):
    _, loss = m(*get_batch())
    optim.zero_grad()
    loss.backward()
    optim.step()
    if step % 100 == 0:
        print(loss.item())
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(torch.tensor([tokenizer.encode("I was a")], device=DEVICE), 100)[
                0
            ].tolist()
        )
    )

0.17992229759693146
0.16258060932159424
0.1676357239484787
0.16767027974128723
0.14056962728500366
0.16359587013721466
0.1349770575761795
0.13539303839206696
0.12565146386623383
0.13672854006290436
I was awreesmake it so intange YC. I don't think it was reading 

I do next? Rtm's advice hadn't rospecranneted sinkknow how blaround with visits currself people (Flory for shown at interrich model of startups working on batch processing to memured control you'd have a smanging with Y Combinat


In [16]:
with torch.no_grad():
    print(
        tokenizer.decode(
            m.generate(
                torch.tensor(
                    [tokenizer.encode("I wanted not just to build things")],
                    device=DEVICE,
                ),
                100,
            )[0].tolist()
        )
    )

I wanted not just to build thingsatisused to langabout , because this whole se. It was micult . And in retrospect, because he would still be working on it almost miract later on, someone mean like a mall prospect of a book about Lisp hacking."Wowed the bYC gest source of stress in one's work should at least be something close to the core of the work. Whereas it was they like someone languag


In [17]:
len(train_data) + len(test_data)

24465

In [18]:
losses = """
7.061542987823486
6.245255470275879
5.356468677520752
4.347721099853516
3.8550236225128174
3.4973127841949463
2.962090253829956
2.2358956336975098
1.5125809907913208
0.9945037961006165
0.6650863289833069
0.48293110728263855
0.35592928528785706
0.31363072991371155
0.27621620893478394
0.2521527111530304
0.22401702404022217
0.21059758961200714
0.20255036652088165
0.18780313432216644
0.17992229759693146
0.16258060932159424
0.1676357239484787
0.16767027974128723
0.14056962728500366
0.16359587013721466
0.1349770575761795
0.13539303839206696
0.12565146386623383
0.13672854006290436
"""
losses = list(map(float, losses.strip().split("\n")))

In [19]:
losses

[7.061542987823486,
 6.245255470275879,
 5.356468677520752,
 4.347721099853516,
 3.8550236225128174,
 3.4973127841949463,
 2.962090253829956,
 2.2358956336975098,
 1.5125809907913208,
 0.9945037961006165,
 0.6650863289833069,
 0.48293110728263855,
 0.35592928528785706,
 0.31363072991371155,
 0.27621620893478394,
 0.2521527111530304,
 0.22401702404022217,
 0.21059758961200714,
 0.20255036652088165,
 0.18780313432216644,
 0.17992229759693146,
 0.16258060932159424,
 0.1676357239484787,
 0.16767027974128723,
 0.14056962728500366,
 0.16359587013721466,
 0.1349770575761795,
 0.13539303839206696,
 0.12565146386623383,
 0.13672854006290436]