In [80]:
import torch
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import nn
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import Dataset, DataLoader, random_split

import torch.nn.functional as F

%matplotlib inline

In [5]:
# with open(r'../data/input.txt', 'r', encoding='utf-8') as file:
    # data = file.read()

In [3]:
# vocab = sorted(list(set(data)))
# print(''.join(vocab))

In [4]:
# We do the tokenization here: simple digil encoding here
# ctoi = { c: i for i, c in enumerate(vocab) }
# itoc = { i: c for c, i in ctoi.items() }
# encode = lambda s: [ctoi[c] for c in s] 
# decode = lambda l: ''.join([itoc[i] for i in l])
# tokenized_data = [ ctoi[c] for c in data ]
# tokenized_data = torch.tensor(tokenized_data, dtype=torch.long)
# print(tokenized_data[:30])

In [53]:
class TextDS(Dataset):
    def __init__(self, location, window_size):
        self.location = location
        self.window_size = window_size
        print(self.location, self.window_size)
        print('Load Dataset')
        
        with open(self.location, 'r', encoding='utf-8') as file:
            self.data = file.read()

        print('Loading done')        
        self.vocab = sorted(list(set(data)))
        self.vocab_size = len(self.vocab)
        self.ctoi = { c: i for i, c in enumerate(self.vocab) }
        self.itoc = { i: c for c, i in self.ctoi.items() }

        self.data = [ self.ctoi[c] for c in self.data ]
        self.data = torch.tensor(self.data, dtype=torch.long)
        print('Init Dataset Done')
            
    def encode(self, s):
        return [self.ctoi[c] for c in s]

    def decode(self, l):
        return ''.join([self.itoc[i] for i in l])

    def __len__(self):
        return len(self.data) - self.window_size - 1

    def __getitem__(self, idx):
        x = self.data[idx:idx+self.window_size]
        y = self.data[idx+1:idx+self.window_size+1]

        return x, y
    

In [54]:
# # We prepare the dataset here:
# pct_split = 0.9
# idx_split = int(pct_split * len(data))
# train_data = tokenized_data[:idx_split]
# test_data = tokenized_data[idx_split:]

In [55]:
# batch_size = 4
# context_window = 8

# torch.manual_seed(42)

# def batch(set_type, batch_size, context_window, device):
#     ds = train_data if set_type == 'train' else test_data
#     samples = torch.randint(0, len(ds) - context_window, (batch_size, ))
#     x = torch.stack([ds[idx:idx+context_window] for idx in samples])
#     y = torch.stack([ds[idx+1:idx+context_window+1] for idx in samples])
#     return x.to(device), y.to(device)

# x, y = batch('train')

In [56]:
class AttentionHead(nn.Module):
    def __init__(self, n_embedding, head_size, context_window):
        super().__init__()
        self.key = nn.Linear(n_embedding, head_size, bias=False)
        self.query = nn.Linear(n_embedding, head_size, bias=False)
        self.value = nn.Linear(n_embedding, head_size, bias=False)
        self.register_buffer('mask', torch.tril(torch.ones(context_window, context_window)))


    def forward(self, x):
        B, T, C = x.shape
        key = self.key(x) 
        query = self.query(x)
        value = self.value(x)
        
        dot_product = query @ key.transpose(-2, -1) * (C ** -.5)
        dot_product = dot_product.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        softmax = F.softmax(dot_product, dim=-1)
        out = softmax @ value
        return out
    

In [57]:
class MultiAttentionHead(nn.Module):
    def __init__(self, n_heads, n_embedding, head_size, context_window):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(n_embedding, head_size, context_window) for _ in range(n_heads)])
        self.linear = nn.Linear(n_heads * head_size, n_heads * head_size)  

    def forward(self, x):
        out = torch.concat([head(x) for head in self.heads], dim=-1)
        out = self.linear(out)
        return out

In [58]:
class FeedForward(nn.Module):
    def __init__(self, s_input, s_intermediate):
        super().__init__()
        self.linear1 = nn.Linear(s_input, s_intermediate, bias=True)
        self.linear2 = nn.Linear(s_intermediate, s_input, bias=True)

    def forward(self, x):
        out = self.linear1(x)
        out = F.relu(out)
        out = self.linear2(out)
        return out

In [59]:
class Block(nn.Module):
    def __init__(self, n_heads, n_embedding, context_window):
        super().__init__()
        head_size = n_embedding // n_heads
        self.attention_layer = MultiAttentionHead(n_heads, n_embedding, head_size, context_window)
        self.ffw = FeedForward(head_size * n_heads, 4 * n_embedding)
        self.ln1 = nn.LayerNorm(head_size * n_heads)
        self.ln2 = nn.LayerNorm(head_size * n_heads)

    def forward(self, x):
        out = x + self.attention_layer(self.ln1(x))
        out = out + self.ffw(self.ln2(out))
        return out

In [60]:
class Model(nn.Module):
    def __init__(self, n_vocab, n_blocks, context_size, n_embedding, n_heads):
        super().__init__()
        self.context_size = context_size
        self.token_embedding = nn.Embedding(n_vocab, n_embedding)
        self.positional_embedding = nn.Embedding(context_size, n_embedding)
        self.blocks = nn.Sequential(*[Block(n_heads, n_embedding, context_size) for _ in range(n_blocks)])
        self.layer_norm = nn.LayerNorm(n_embedding)
        self.layer = nn.Linear(n_embedding, n_vocab)

    def forward(self, device, x, y=None):
        B, T = x.shape
        
        tok_emb = self.token_embedding(x)
        pos_emb = self.positional_embedding(torch.arange(T, device=device))

        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.layer_norm(x)
        x = self.layer(x)

        if y is None:
            loss = None
        else:
            B, T, C = x.shape
            x = x.view(B*T, C)
            targets = y.view(B*T)
            loss = F.cross_entropy(x, targets)

        return x, loss
        
    def generate(self, device, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.context_size:]
            # get the predictions
            logits, loss = self(device, idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
    

In [70]:
context_window = 32
batch_size = 64

dataset = TextDS("../data/input.txt", context_window)
train_set, test_set = random_split(dataset, [0.98, 0.02])

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True)



../data/input.txt 32
Load Dataset
Loading done
Init Dataset Done


In [72]:
eval_iters = 200

@torch.no_grad()
def estimate_loss(model, device):
    out = {}
    model.eval()
    # for split in ['train', 'val']:
    losses = torch.zeros(len(test_dataloader))

    for i, (x, y) in enumerate(test_dataloader):
        logits, loss = model(device, x, y)
        losses[i] = loss.item()

    model.train()
    return losses.mean()
        
    #     for k in range(eval_iters):
    #         X, Y = batch(split, batch_size, context_window, device)
    #         logits, loss = model(device, X, Y)
    #         losses[k] = loss.item()
    #     out[split] = losses.mean()
    # model.train()
    # return out

In [None]:
torch.manual_seed(42)

context_window = 32
# batch_size = 64

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")



device = torch.device("cpu")  
model = Model(n_vocab=dataset.vocab_size, n_blocks=4, context_size=context_window, n_embedding=64, n_heads=4)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)

max_iters = 5000
max_epoch = 10
eval_interval = 100


for i in range(max_epoch):
    losses = estimate_loss(model, device)
    print(f"step {i}: val loss {losses:.4f}")
    
    for i, (x, y) in enumerate(pbar:=tqdm(train_dataloader)):
        
        
    # if i % eval_interval == 0 or i == max_iters - 1:
    #     losses = estimate_loss(model, batch_size, context_window, device)
    #     print(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    # xb, yb = batch('train', batch_size, context_window, device)

    # evaluate the loss
        logits, loss = model(device, x, y)

        if i % 100 == 0:
            pbar.set_description(f"Training Error {loss:.4f}")
        
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        
    scheduler1.step()

0.209729 M parameters
step 0: val loss 4.4400


Training Error 1.4498: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17079/17079 [10:55<00:00, 26.06it/s]


step 1: val loss 1.4049


Training Error 1.3556: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17079/17079 [12:09<00:00, 23.41it/s]


step 2: val loss 1.3551


Training Error 1.3516: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17079/17079 [10:35<00:00, 26.86it/s]


step 3: val loss 1.3274


Training Error 1.2805: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17079/17079 [11:20<00:00, 25.11it/s]


step 4: val loss 1.3072


Training Error 1.2672:  27%|██████████████████████████████▉                                                                                  | 4678/17079 [03:08<08:29, 24.35it/s]

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(device, context, max_new_tokens=200)[0].tolist()))

In [306]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(device, context, max_new_tokens=200)[0].tolist()))

	arly. They were fut.
<|endoftext|>
Once upon a time there was a fun share. So he meates to explore. 
The miserah went to try for to play with it him.
Mitten try to your boy, showed whiled Lon't loak i
