## create a tokenizer

In [None]:
# read data
path = "../data/input.txt"
with open(path, 'r', encoding='utf-8') as f:
    data = f.read()

print(len(data))

In [None]:
# print first 1000 characters
print(data[:1000])

In [None]:
unique_chars = sorted(list(set(data)))
vocabulary_size = len(unique_chars)
print(''.join(unique_chars))
print(vocabulary_size)

In [None]:
# tokenization of characters
encoder_func = {ch:i for i,ch in enumerate(unique_chars)}
decoder_func = {i:ch for i,ch in enumerate(unique_chars)}

encoder = lambda s: [encoder_func[c] for c in s]
decoder = lambda c: ''.join([decoder_func[i] for i in c])

In [None]:
print(encoder("hii there"))
print(decoder(encoder("hii there")))

In [None]:
import torch
tensor_data = torch.LongTensor(encoder(data))
print(tensor_data.size())
print(tensor_data.dtype)
print(tensor_data[:1000])

In [None]:
train_upper_index = int(0.9*len(tensor_data))
train_data, test_data = tensor_data[:train_upper_index], tensor_data[train_upper_index:]

In [None]:
context_length = 8
train_data[:context_length+1]

In [None]:
# this way we train the transformer to predict on context from size of 1 up until context_size

x = train_data[:context_length]
y = train_data[1:context_length+1]
for t in range(context_length):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context}, target is {target}")

In [None]:
torch.manual_seed(1337)
batch_size = 4
context_length = 8

def get_batch(split):
    data = train_data if split == 'train' else test_data
    idx = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in idx])
    y = torch.stack([data[i+1:i+context_length+1] for i in idx])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)
    
print('****************')

for b in range(batch_size):
    for t in range(context_length):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"when input is {context.tolist()}, target is {target}")

In [None]:
# implement a simple language model

import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from the lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        
        # idx and targets are both (b, t) tensor of type int
        logits = self.token_embedding_table(idx) # (batch, time, channels)
        if targets is None:
            loss = None
        else:
            b, t, c = logits.shape
            logits = logits.view(b*t, c)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets) 
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (b, t) array of indices in current context
        for _ in range(max_new_tokens):
            # get predictions
            logits, loss = self(idx)
            # take only the last time step prediction
            logits = logits[:, -1, :]
            # calculate the probabilities
            probs = F.softmax(logits, dim=1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1) # (b, t+1)
            
        return idx
    
m = BigramModel(vocab_size=vocabulary_size)
out, loss = m(xb, yb)
print(out.shape)
print(loss)

idx = torch.zeros((1,1), dtype=torch.long)
print(decoder(m.generate(idx, max_new_tokens=100)[0].tolist()))

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)


In [None]:
batch_size = 32
for steps in range(10000):
    
    # sample a batch
    xb, yb = get_batch('train')
    
    # forward
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

In [None]:
# still not shakespeare, but we're making progress 
print(decoder(m.generate(idx, max_new_tokens=500)[0].tolist()))