In [57]:
import torch
import torch.nn as nn
import tiktoken
import math

In [31]:
EMBEDDING_DIM = 256
BATCH_SIZE = 8
CONTEXT_LEN = 1024
TEXT_FILE = "./datasets/tinyshakespeare.txt"

In [32]:
tokenizer = tiktoken.encoding_for_model("gpt2")

In [33]:
with open(TEXT_FILE, "r") as f:
    data = f.read() 

In [34]:
tokenized_data = torch.tensor(tokenizer.encode(data))

In [35]:
# partitioning scheme from 
# https://d2l.ai/chapter_recurrent-neural-networks/language-model.html#partitioning-sequences
d = torch.randint(CONTEXT_LEN, size = (1,))[0].item()

# discard last item which may be of diff size
token_partitons = torch.stack(torch.split(tokenized_data[d:], CONTEXT_LEN)[:-1])

In [36]:
sample_batch = token_partitons[:BATCH_SIZE]

embeddings

In [37]:
# TODO do we need a max_norm? seems like this would be important
# depending on positional embedding scheme
embedder = nn.Embedding(
    num_embeddings = tokenizer.n_vocab,
    embedding_dim = EMBEDDING_DIM
)

In [38]:
# use learned positional embeddings for simplicity
# TODO what are the tradeoffs with fixed positional embeddings besides less storage?
positional_embedder = nn.Embedding(
    num_embeddings = CONTEXT_LEN,
    embedding_dim = EMBEDDING_DIM
)

In [39]:
context_idx_tensor = torch.tensor(list(range(CONTEXT_LEN)))

In [40]:
positional_embeddings = positional_embedder(context_idx_tensor)

In [41]:
token_embeddings = embedder(sample_batch)

In [44]:
embeddings = token_embeddings + positional_embeddings

attention

In [47]:
queries, keys, values = embeddings, embeddings, embeddings

In [59]:
assert queries.shape == keys.shape
assert keys.shape == values.shape

In [None]:
"""Generate a mask for subsequent positions (upper triangular mask)."""
subsequent_mask = torch.triu(torch.ones(), diagonal=1).type(torch.bool)
return subsequent_mask

In [61]:
# transpose (ctx, embedding) dims
scaled_dot_prod = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(EMBEDDING_DIM)

In [77]:
ninf_mask = torch.triu(torch.ones(scaled_dot_prod.shape[1:]), diagonal=1) * float('-inf')
ninf_mask[torch.isnan(ninf_mask)] = 0

In [81]:
softmax = nn.Softmax()

In [101]:
attention_weights = softmax(ninf_mask + scaled_dot_prod)
attention_weights[torch.isnan(attention_weights)] = 0

  return self._call_impl(*args, **kwargs)


In [102]:
dropout = nn.Dropout(0.5)

In [106]:
# The book uses dropout for weights but doesn't explain why that specifically?
# https://d2l.ai/chapter_attention-mechanisms-and-transformers/attention-scoring-functions.html#scaled-dot-product-attention
attention = torch.bmm(dropout(attention_weights), values)