In [113]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
import math

In [31]:
EMBEDDING_DIM = 256
BATCH_SIZE = 8
CONTEXT_LEN = 1024
TEXT_FILE = "./datasets/tinyshakespeare.txt"

In [32]:
tokenizer = tiktoken.encoding_for_model("gpt2")

In [33]:
with open(TEXT_FILE, "r") as f:
    data = f.read() 

In [34]:
tokenized_data = torch.tensor(tokenizer.encode(data))

In [35]:
# partitioning scheme from 
# https://d2l.ai/chapter_recurrent-neural-networks/language-model.html#partitioning-sequences
d = torch.randint(CONTEXT_LEN, size = (1,))[0].item()

# discard last item which may be of diff size
token_partitons = torch.stack(torch.split(tokenized_data[d:], CONTEXT_LEN)[:-1])

In [36]:
sample_batch = token_partitons[:BATCH_SIZE]

embeddings

In [37]:
# TODO do we need a max_norm? seems like this would be important
# depending on positional embedding scheme
embedder = nn.Embedding(
    num_embeddings = tokenizer.n_vocab,
    embedding_dim = EMBEDDING_DIM
)

In [38]:
# use learned positional embeddings for simplicity
# TODO what are the tradeoffs with fixed positional embeddings besides less storage?
positional_embedder = nn.Embedding(
    num_embeddings = CONTEXT_LEN,
    embedding_dim = EMBEDDING_DIM
)

In [39]:
context_idx_tensor = torch.tensor(list(range(CONTEXT_LEN)))

In [40]:
positional_embeddings = positional_embedder(context_idx_tensor)

In [41]:
token_embeddings = embedder(sample_batch)

In [44]:
embeddings = token_embeddings + positional_embeddings

attention

In [153]:
class Attention(nn.Module):

    def __init__(self, dropout = 0.5):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

    def _masked_softmax(self, logits, mask):
        logits.masked_fill_(mask == 0, float('-inf'))
        return self.softmax(logits)

    def forward(self, queries, keys, values):
        assert queries.shape == keys.shape
        assert queries.shape[1] == values.shape[1]

        ctx_len = queries.shape[1]
        qk_dim = queries.shape[2]

        # transpose (ctx, embedding) dims
        scaled_dot_prod = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(qk_dim)

        # TODO autoregressive mask; generalize this
        mask = torch.tril(torch.ones(ctx_len, ctx_len))  

        attention_weights = self._masked_softmax(scaled_dot_prod, mask)

        # The book uses dropout for weights but doesn't explain why that specifically?
        # https://d2l.ai/chapter_attention-mechanisms-and-transformers/attention-scoring-functions.html#scaled-dot-product-attention
        return torch.bmm(self.dropout(attention_weights), values)


In [186]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, qkv_dim, dropout = 0.5):
        super().__init__()

        assert qkv_dim % num_heads == 0

        self.qkv_dim = qkv_dim
        self.num_heads = num_heads
        self.dropout = dropout

        self.attention_heads = [
            Attention(dropout = dropout)
            for _ in range(num_heads)
        ]

        self.q_linear = nn.LazyLinear(qkv_dim)
        self.k_linear = nn.LazyLinear(qkv_dim)
        self.v_linear = nn.LazyLinear(qkv_dim)
        self.output_layer = nn.LazyLinear(qkv_dim)

    def forward(self, queries, keys, values):
        assert queries.shape == keys.shape
        assert queries.shape == values.shape
        assert queries.shape[-1] == self.qkv_dim

        head_dim = self.qkv_dim // self.num_heads
        # TODO these are just truncated in the book?!
        # https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html#implementation
        # why not use a linear layer to learn a lower dim value?
        pq = torch.split(self.q_linear(queries), head_dim, dim=-1)
        pk = torch.split(self.k_linear(keys), head_dim, dim=-1)
        pv = torch.split(self.v_linear(values), head_dim, dim=-1)

        # TODO use transpose trick to parallelize across heads
        head_vals = [
            self.attention_heads[i](pq[i], pk[i], pv[i])
            for i in range(self.num_heads)
        ]

        cat_head_vals = torch.cat(head_vals, dim=-1)
        return self.output_layer(cat_head_vals)

In [187]:
mha = MultiHeadAttention(8, EMBEDDING_DIM)

In [188]:
mha(embeddings, embeddings, embeddings)