In [1]:
import os
import random
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm.auto import tqdm

In [2]:
# read it in to inspect it
# data_file = 'sample_scripts.txt'
# data_file = 'dataset/adamw.txt'
with open('../data/sample_scripts.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  2247598


In [4]:
print(text[:100])

from typing import Dict, Union, Iterator

import torch

from allennlp.common.registrable import Regi


In [5]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
''.join(chars)

'\t\n\x1b !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~Ġ▁🤗'

In [6]:
print(vocab_size)

101


# encoding and decoding for chars

In [7]:
# create a mapping from characters to integers
ch_to_idx = { ch:i for i,ch in enumerate(chars) }
idx_to_ch = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [ch_to_idx[ch] for ch in s] # encoder: take a string, output a list of mapping idx
decode = lambda l: ''.join([idx_to_ch[idx] for idx in l]) # decoder: take a list of index, output a string

print(encode("import torch"))
print(decode(encode("import torch")))

[76, 80, 83, 82, 85, 87, 3, 87, 82, 85, 70, 75]
import torch


In [8]:
# encode the entire text dataset and store it into a torch.Tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([2247598]) torch.int64
tensor([73, 85, 82, 80,  3, 87, 92, 83, 76, 81, 74,  3, 76, 80, 83, 82, 85, 87,
         3, 39, 76, 70, 87, 15,  3, 56, 81, 76, 82, 81, 15,  3, 44, 87, 72, 85,
        68, 87, 82, 85,  1,  1, 76, 80, 83, 82, 85, 87,  3, 87, 82, 85, 70, 75,
         1,  1, 73, 85, 82, 80,  3, 68, 79, 79, 72, 81, 81, 79, 83, 17, 70, 82,
        80, 80, 82, 81, 17, 85, 72, 74, 76, 86, 87, 85, 68, 69, 79, 72,  3, 76,
        80, 83, 82, 85, 87,  3, 53, 72, 74, 76])


# train dev split

In [9]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [10]:
context_length = 8
x = train_data[:context_length]
y = train_data[1:context_length+1]
for t in range(context_length):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([73]) the target: 85
when input is tensor([73, 85]) the target: 82
when input is tensor([73, 85, 82]) the target: 80
when input is tensor([73, 85, 82, 80]) the target: 3
when input is tensor([73, 85, 82, 80,  3]) the target: 87
when input is tensor([73, 85, 82, 80,  3, 87]) the target: 92
when input is tensor([73, 85, 82, 80,  3, 87, 92]) the target: 83
when input is tensor([73, 85, 82, 80,  3, 87, 92, 83]) the target: 76


# config

In [11]:
batch_size = 64
context_length = 256
max_iters = 1000
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_interval = 100
eval_iters = 200

num_heads = 6
emb_dim = 64 * num_heads
num_layers = 6
dropout = 0.2

# data loader

In [12]:
torch.manual_seed(111)

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    start_idxs = torch.randint(len(data) - context_length, (batch_size,))
    context_idxs = torch.stack([data[start_idx : start_idx+context_length] for start_idx in start_idxs])
    target_idxs = torch.stack([data[start_idx+1 : start_idx+context_length+1] for start_idx in start_idxs])
    
    context_idxs, target_idxs = context_idxs.to(device), target_idxs.to(device)
    
    return context_idxs, target_idxs

context_idxs, target_idxs = get_batch('train')
print('inputs:')
print(context_idxs.shape)
print(context_idxs)
print('targets:')
print(target_idxs.shape)
print(target_idxs)

print('----')

for b in range(4): # batch dimension
    for step in range(8): # context length dimension
        context = context_idxs[b, :step+1]
        target = target_idxs[b,step]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([64, 256])
tensor([[74, 11,  5,  ..., 82, 71, 72],
        [32,  3, 71,  ..., 68, 80, 83],
        [72, 89, 68,  ..., 49, 82, 81],
        ...,
        [74, 72, 87,  ..., 76, 81, 74],
        [68, 76, 81,  ..., 72, 15,  3],
        [72, 81, 66,  ...,  0, 83, 68]], device='cuda:0')
targets:
torch.Size([64, 256])
tensor([[11,  5, 60,  ..., 71, 72, 79],
        [ 3, 71, 85,  ..., 80, 83, 79],
        [89, 68, 79,  ..., 82, 81, 72],
        ...,
        [72, 87, 66,  ..., 81, 74, 66],
        [76, 81, 66,  ..., 15,  3, 68],
        [81, 66, 80,  ..., 83, 68, 86]], device='cuda:0')
----
when input is [74] the target: 11
when input is [74, 11] the target: 5
when input is [74, 11, 5] the target: 60
when input is [74, 11, 5, 60] the target: 82
when input is [74, 11, 5, 60, 82] the target: 88
when input is [74, 11, 5, 60, 82, 88] the target: 3
when input is [74, 11, 5, 60, 82, 88, 3] the target: 68
when input is [74, 11, 5, 60, 82, 88, 3, 68] the target: 85
when input is [32]

In [13]:
weight_test = torch.tril(torch.ones(10,10))
print(weight_test)
weight_test = weight_test.masked_fill(weight_test == 0, float('-inf'))
print(weight_test)
weight_test = F.softmax(weight_test, dim=-1)
print(weight_test)
v_test = torch.rand((10,3))
print(v_test)
out = weight_test @ v_test
print(out)

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., 1., -inf, -inf],
        [1.