In [75]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
import time
import random
import math

inf = torch.inf
context_length = 256 # No of tokens
model_dim = 64 # dimension of the model -> residual stream
n_layers = 2 # no of layers
n_heads = 0 # No of attention heads for layer # TODO
head_dim = 16
vocab_size = 65
learning_rate = 0.001
max_iters = 0
eval_iters = 10
batch_size = 64 #Takes 27k iters

lower_triangular_matrix = torch.tensor([[1 if i<=j else -torch.inf for i in range(context_length)] for j in range(context_length)]).float()

def tokenise(str: str):
    return torch.tensor([char_map[i] for i in str])

def decode(tokens: list[str]):
    return ''.join([reverse_char_map[i] for i in tokens])

file = open("tiny_shakesphere.txt", "r")
full_data = file.read()

vocab = list(sorted((set(full_data))))
char_map = {vocab[i]: i for i in range(len(vocab))}
reverse_char_map = {char_map[i] : i for i in char_map}
full_data = tokenise(full_data).tolist()
random.shuffle(full_data)
full_data = torch.tensor(full_data)

total_datapoints  = full_data.shape[0]

training_data : list[int] = full_data[:int(total_datapoints*0.9)]
validation_data = full_data[int(total_datapoints*0.9):total_datapoints]


def sample_data(split: str = "train"):
    data = training_data if split == 'train' else validation_data
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = sample_data(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
        print(losses)
    model.train()
    
    return out["train"], out['val']


class Layer(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        

class AttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.key = nn.Linear(model_dim, head_dim)
        self.query = nn.Linear(model_dim, head_dim)
        self.value = nn.Linear(model_dim, head_dim)
        self.proj = nn.Linear(head_dim, model_dim)
    
    def forward(self, idx):
        key = self.key(idx) # (batch, context_length, head_dim)
        query = self.query(idx)
        value = self.value(idx) # (batch, context_length, head_dim)

        attention = (query@torch.transpose(key,1,2))/(math.sqrt(head_dim)) # (batch, context_length, context_length)

        attention = torch.tril(attention)

        attention = attention.masked_fill(attention == 0, -inf)

        attention = F.softmax(attention,-1) # probs along context_length sum to 1

        attention_value = attention@value  # (batch, context_length, head_dim)

        return self.proj(attention_value)  # (batch, context_length, model_dim)
    

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(model_dim, 4*model_dim), nn.Linear(4*model_dim, model_dim))
        self.relu = nn.ReLU()
    
    def forward(self, idx):
        logits = self.layers(idx)
        return self.relu(logits)

class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, model_dim)
        self.pos_embedding = nn.Embedding(context_length, model_dim)

        self.attention_layes = nn.ModuleList([AttentionHead() for i in range(n_layers)])
        self.mlp_layers = nn.ModuleList([MLP() for i in range(n_layers)])

        self.unembed_layer = nn.Linear(model_dim,vocab_size)

        self.total_parameters = sum([p.numel() for p in self.parameters()])
        print(f"Model has {self.total_parameters//1000}k params")

        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets = None):
        # idx -> [1,2,0,3..] (batch, context_length)

        input_sequence_length = idx.shape[-1]

        residual_stream = self.token_embedding(idx)  # (batch, context_length, model_dim)
        # residual_stream = residual_stream + self.pos_embedding(torch.tensor([i for i in range(input_sequence_length)])) # Pos embedding will be # (context_length, model_dim)
        # for i in range(n_layers):
        #     residual_stream = residual_stream + self.attention_layes[i](residual_stream)
        #     residual_stream = residual_stream + self.attention_layes[i](residual_stream)

        residual_stream = self.unembed_layer(residual_stream) # (batch, context_length, vocab_size)
        if targets is None:
            return residual_stream
        (x,y,z) = residual_stream.shape
        loss = F.cross_entropy(residual_stream.view(x*y,z), F.one_hot(targets, vocab_size).resize(x*y, z).float())
        return residual_stream, loss



model = Transformer()
train_loss,val_loss = estimate_loss()
print(f"Initial training loss: {train_loss}, val loss: {val_loss}")
#plt.show()



        

Model has 99k params
tensor([4.1742, 4.1742, 4.1742, 4.1742, 4.1742, 4.1742, 4.1742, 4.1742, 4.1743,
        4.1742])
tensor([4.1742, 4.1742, 4.1742, 4.1742, 4.1742, 4.1742, 4.1742, 4.1742, 4.1742,
        4.1742])
Initial training loss: 4.174211025238037, val loss: 4.174208641052246




In [61]:
X,Y = sample_data()
print(X.shape)

torch.Size([1, 256])


In [62]:
a,b = X.shape
F.cross_entropy(F.one_hot(X,vocab_size).view(a*b,vocab_size).float(), F.one_hot(Y,vocab_size).view(a*b,vocab_size).float())

tensor(4.1341)

In [40]:
model()

tensor([46, 16, 60, 50, 15, 58, 60,  1,  6, 50,  0, 58, 50, 61,  1,  1, 44, 56,
        46, 57, 43, 43, 30,  0, 52,  1, 39, 57, 59, 57, 61,  0, 58,  1,  1, 58,
        53, 53, 40, 39, 62, 12, 53, 61, 59, 57,  1, 39, 59, 43, 43, 40,  6, 50,
        44, 61, 47, 39, 53, 43, 47, 56, 43, 11, 47, 58, 58, 50, 52,  1, 53, 43,
        39,  0, 43, 18,  8, 53, 46,  7,  1,  0, 60, 58, 57, 46,  0, 56, 47, 56,
        47, 46,  1, 21, 56, 53, 43,  1, 25, 50, 52, 61, 58, 15, 53, 42, 39,  1,
         1, 31, 57,  1,  1,  1, 51, 39, 61, 10, 39, 43, 43, 46, 56, 63, 58, 52,
        56, 46, 57,  1, 43, 43, 15,  1, 43,  0, 41, 51, 53, 53, 59, 49, 47,  1,
        39, 58, 53,  0, 51, 46,  1, 45, 42,  1, 14, 46, 18, 57, 33,  1, 47,  1,
        57, 47, 59, 32, 53, 61, 43, 21,  1, 43,  1,  1, 59, 58,  0,  8, 53, 63,
         0, 21, 49, 51,  0,  6,  1, 56,  1, 50, 52, 46, 40, 42, 59, 39, 39, 58,
        52, 52, 57, 47, 44, 43, 59, 23, 56, 39, 10,  1, 51, 39, 14, 59, 56, 54,
         6, 58,  1, 43,  1, 46,  1, 43, 

In [19]:
a = F.one_hot(x, vocab_size)

In [20]:
a.shape

torch.Size([8, 65])