In [2]:
import torch
import torch.nn as nn
import numpy as np  # reserved for later use
import einops

In [3]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

In [4]:
class Config():
    def __init__(self):
        self.vocab_size = 50257
        self.embedding_dim = 1024
        self.mlp_dim = 4 * self.embedding_dim
        self.num_blocks = 4
        self.num_heads = 8
        self.context_len = 1024
        self.attention_dim = self.embedding_dim // self.num_heads
        # mps (Apple Silicon) support, reserved for further training
        self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

config = Config()

In [5]:
class DataLoader:
    def __init__(self, B, T):
        self.batch_size = B
        self.seq_len = T
    
        with open("tiny_shakespeare_dataset.txt", "r") as f:
            text = f.read()
        
        encoding = tokenizer.encode(text)
        self.tokens = torch.tensor(tokens)

        self.current_pos = 0 # maintain the index of the current data sample

        print(f"loaded {len(self.tokens)} tokens")
        print(f"each epoch has {len(self.tokens) / len(self.batch_size * self.seq_len)} ")

    def next_batch(self):
        B, T = self.batch_size, self.seq_len
        x = self.tokens[self.current_position:self.current_position+BT]
        y = self.tokens[self.current_position+1:self.current_position+BT+1]
        x = x.view(B, T)
        y = y.view(B, T)
        self.current_position += B * T

        if (len(self.tokens) - self.current_position + 1) < B * T:
            self.current_position = 0

        return x, y

In [6]:
class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_E = nn.Embedding(config.vocab_size, config.embedding_dim)
        self.W_pos = nn.Embedding(config.context_len, config.embedding_dim)

    def forward(self, tokens):
        tokens = torch.tensor(tokens)
        embeddings = self.W_E(tokens)

        # print(f"---------------------------{tokens.shape}")
        positions = self.W_pos(torch.arange(tokens.shape[1]))

        return embeddings + positions

class DeEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.W_D = nn.Linear(config.embedding_dim, config.vocab_size)

    def forward(self, x):
        embeddings = self.W_D(x)

        return embeddings

In [7]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.W_Q = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        self.W_K = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        self.W_V = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
        
        self.W_out = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)

        self.num_heads = config.num_heads
        self.attention_dim = config.attention_dim
    
    def forward(self, x):
        Q = einops.rearrange(self.W_Q(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)
        K = einops.rearrange(self.W_K(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)
        V = einops.rearrange(self.W_V(x), 'batch seq (head dim) -> batch head seq dim', head=self.num_heads)

        QK = torch.softmax((Q @ K.transpose(-2, -1))/(torch.sqrt(torch.tensor(self.attention_dim))), dim = -1)

        QKV = einops.rearrange(QK @ V, 'batch head seq dim -> batch seq (head dim)', head=config.num_heads)
        QKV_Out = self.W_out(QKV)

        return QKV_Out

In [8]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer1 = nn.Linear(config.embedding_dim, config.mlp_dim)
        self.gelu = nn.GELU()
        self.layer2 = nn.Linear(config.mlp_dim, config.embedding_dim)
    
    def forward(self, x):
        x = self.layer2(self.gelu(self.layer1(x)))

        return x

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.MLP_Layers = MLP(config)
        self.Attention_Layers = Attention(config)

    def forward(self, x):
        x += self.Attention_Layers(x)
        return x + self.MLP_Layers(x)

In [10]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.embed = Embedding(config)

        self.blocks = nn.ModuleList([
            TransformerBlock(config) for i in range(config.num_blocks)
        ])

        self.deembed = DeEmbedding(config)
        # self.deembed.W_D.weight = self.embed.W_E.weight  # tie weights

    def forward(self, x):

        x = self.embed(x)
        # print(f"after embed: {x}")

        for block in self.blocks:
            x = block(x)
        
        x = self.deembed(x)
        # print(f"after deembed: {x}")
        x = torch.softmax(x, dim=-1)

        return x

In [None]:
def main():
    text = "The quick brown fox"
    tokens = tokenizer.encode(text)

    x = torch.tensor(tokens)
    x = x.unsqueeze(0)

    # print(f"---------------{x.shape}")

    model = Transformer(config)

    out = model(x)
    pred_tokens = out.argmax(dim=-1)
    print(f"predicted tokens: {pred_tokens}")

    # Only take the prediction from the last position (next token after "fox")
    next_token = pred_tokens[0, -1].item()
    predicted_word = tokenizer.decode([next_token])
    print(f"predicted word: {predicted_word}")

    print(f"full sentence: {text}{predicted_word}")

    print("sanity check: all predicted tokens")
    for token in pred_tokens.flatten():
        decoded = tokenizer.decode([token.item()])
        print(f"Token {token} -> '{decoded}'")
if __name__ == "__main__":
    main()

predicted tokens: tensor([[39806, 44034, 26845,  8633]])
predicted word:  dict
full sentence: The quick brown fox dict
sanity check: all predicted tokens
Token 39806 -> 'blast'
Token 44034 -> 'avement'
Token 26845 -> 'articles'
Token 8633 -> ' dict'


  tokens = torch.tensor(tokens)


In [None]:
# %% [markdown]
# ## TODOs
# 
# - [ ] Add support for displaying calculated # of learnable and non-learnable params
# - [ ] Attention sink?
