In [199]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
import math
import pandas as pd
import random
from livelossplot import PlotLosses

In [115]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
EMBEDDING_DIM = 256
BATCH_SIZE = 8
CONTEXT_LEN = 128
TEXT_FILE = "./datasets/tinyshakespeare.txt"

In [3]:
tokenizer = tiktoken.encoding_for_model("gpt2")

In [5]:
with open(TEXT_FILE, "r") as f:
    data = f.read() 

In [6]:
tokenized_data = torch.tensor(tokenizer.encode(data))

In [7]:
# partitioning scheme from 
# https://d2l.ai/chapter_recurrent-neural-networks/language-model.html#partitioning-sequences
d = torch.randint(CONTEXT_LEN, size = (1,))[0].item()

# discard last item which may be of diff size
token_partitions_X = torch.stack(torch.split(tokenized_data[d:], CONTEXT_LEN)[:-1])
token_partitions_Y = torch.stack(torch.split(tokenized_data[d+1:], CONTEXT_LEN)[:-1])

In [141]:
class Attention(nn.Module):

    def __init__(self, ctx_len, dropout = 0.5):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        # TODO autoregressive mask; generalize this
        mask = torch.tril(torch.ones(ctx_len, ctx_len))
        self.register_buffer('mask', mask)

    def _masked_softmax(self, logits, mask):
        logits.masked_fill_(mask == 0, float('-inf'))
        return self.softmax(logits)

    def forward(self, queries, keys, values):
        assert queries.shape == keys.shape
        assert queries.shape[1] == values.shape[1]

        qk_dim = queries.shape[2]

        # transpose (ctx, embedding) dims
        scaled_dot_prod = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(qk_dim)

        mask = self.get_buffer('mask')
        attention_weights = self._masked_softmax(scaled_dot_prod, mask)

        # The book uses dropout for weights but doesn't explain why that specifically?
        # https://d2l.ai/chapter_attention-mechanisms-and-transformers/attention-scoring-functions.html#scaled-dot-product-attention
        return torch.bmm(self.dropout(attention_weights), values)


In [383]:
class MultiHeadAttention1(nn.Module):

    def __init__(self, num_heads, qkv_dim, ctx_len, dropout = 0.5):
        super().__init__()

        assert qkv_dim % num_heads == 0

        self.qkv_dim = qkv_dim
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)

        self.attention_heads = nn.ModuleList([
            Attention(ctx_len, dropout = dropout)
            for _ in range(num_heads)
        ])

        self.q_linear = nn.LazyLinear(qkv_dim)
        self.k_linear = nn.LazyLinear(qkv_dim)
        self.v_linear = nn.LazyLinear(qkv_dim)
        self.output_layer = nn.LazyLinear(qkv_dim)

        # TODO autoregressive mask; generalize this
        mask = torch.tril(torch.ones(ctx_len, ctx_len))
        self.register_buffer('mask', mask)

        self.softmax = nn.Softmax(dim=-1)

    def _masked_softmax(self, logits, mask):
        logits.masked_fill_(mask == 0, float('-inf'))
        return self.softmax(logits)

    def forward(self, queries, keys, values):
        assert queries.shape == keys.shape
        assert queries.shape == values.shape
        assert queries.shape[-1] == self.qkv_dim

        head_dim = self.qkv_dim // self.num_heads

        # TODO these are just truncated in the book?!
        # https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html#implementation
        # why not use a linear layer to learn a lower dim value?
        pq = torch.split(self.q_linear(queries), head_dim, dim=-1)
        pk = torch.split(self.k_linear(keys), head_dim, dim=-1)
        pv = torch.split(self.v_linear(values), head_dim, dim=-1)

        # TODO use transpose trick to parallelize across heads
        head_vals = [
            self.attention_heads[i](pq[i], pk[i], pv[i])
            for i in range(self.num_heads)
        ]

        cat_head_vals = torch.cat(head_vals, dim=-1)
        return self.dropout(self.output_layer(cat_head_vals))


In [367]:
class MultiHeadAttention2(nn.Module):

    def __init__(self, num_heads, qkv_dim, ctx_len, dropout = 0.5):
        super().__init__()

        assert qkv_dim % num_heads == 0

        self.qkv_dim = qkv_dim
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)

        # self.attention_heads = nn.ModuleList([
        #     Attention(ctx_len, dropout = dropout)
        #     for _ in range(num_heads)
        # ])

        self.q_linear = nn.LazyLinear(qkv_dim)
        self.k_linear = nn.LazyLinear(qkv_dim)
        self.v_linear = nn.LazyLinear(qkv_dim)
        self.output_layer = nn.LazyLinear(qkv_dim)

        # TODO autoregressive mask; generalize this
        mask = torch.tril(torch.ones(ctx_len, ctx_len))
        self.register_buffer('mask', mask)

        self.softmax = nn.Softmax(dim=-1)

    def _masked_softmax(self, logits, mask):
        logits.masked_fill_(mask == 0, float('-inf'))
        return self.softmax(logits)

    def forward(self, queries, keys, values):
        assert queries.shape == keys.shape
        assert queries.shape == values.shape
        assert queries.shape[-1] == self.qkv_dim

        # # TODO these are just truncated in the book?!
        # # https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html#implementation
        # # why not use a linear layer to learn a lower dim value?
        # pq = torch.split(self.q_linear(queries), head_dim, dim=-1)
        # pk = torch.split(self.k_linear(keys), head_dim, dim=-1)
        # pv = torch.split(self.v_linear(values), head_dim, dim=-1)

        # # TODO use transpose trick to parallelize across heads
        # head_vals = [
        #     self.attention_heads[i](pq[i], pk[i], pv[i])
        #     for i in range(self.num_heads)
        # ]

        # cat_head_vals = torch.cat(head_vals, dim=-1)
        # return self.dropout(self.output_layer(cat_head_vals))

        batch_size = queries.shape[0]

        d_k = self.qkv_dim // self.num_heads
        q = self.q_linear(queries).view(batch_size, -1, self.num_heads, d_k).transpose(1, 2)
        k = self.k_linear(keys).view(batch_size, -1, self.num_heads, d_k).transpose(1, 2)
        v = self.v_linear(values).view(batch_size, -1, self.num_heads, d_k).transpose(1, 2)

        # Attention mechanism
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        mask = self.get_buffer('mask')
        attn = self._masked_softmax(scores, mask)

        # The book uses dropout for weights but doesn't explain why just that specifically?
        # https://d2l.ai/chapter_attention-mechanisms-and-transformers/attention-scoring-functions.html#scaled-dot-product-attention
        output = (
            torch.matmul(self.dropout(attn), v)
            .transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.qkv_dim)
        )

        return output


In [370]:
class PositionwiseFFN(nn.Sequential):
    def __init__(self, ffn_hidden_dim: int, ffn_output_dim: int):
        super().__init__(
            nn.LazyLinear(ffn_hidden_dim),
            nn.ReLU(),
            nn.LazyLinear(ffn_output_dim)
        )

In [371]:
class AddNorm(nn.Module):

    def __init__(self, norm_shape, dropout = 0.5):
        super().__init__()

        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(norm_shape)

    def forward(self, X, Y):
        return self.ln(X + self.dropout(Y))

In [384]:
class TransformerDecoderBlock(nn.Module):

    def __init__(self, num_heads: int, model_dim: int, ffn_hidden_dim: int, ctx_len: int, dropout = 0.5):
        super().__init__()

        self.attention = MultiHeadAttention1(num_heads, model_dim, ctx_len, dropout)
        self.add_norm1 = AddNorm(model_dim, dropout)
        self.ffn = PositionwiseFFN(ffn_hidden_dim, model_dim)
        self.add_norm2 = AddNorm(model_dim, dropout)

    def forward(self, X):
        attention = self.attention(X, X, X)
        resid_out = self.add_norm1(X, attention)

        ffn_out = self.ffn(resid_out)
        return self.add_norm2(resid_out, ffn_out)

In [385]:
class RalphGPT(nn.Module):

    def __init__(
        self,
        num_blocks: int,
        num_heads: int,
        model_dim: int,
        ffn_hidden_dim: int,
        vocab_size: int,
        context_len: int,
        dropout = 0.5
    ):
        super().__init__()

        self.context_len = context_len
        self.model_dim = model_dim

        # TODO do we need a max_norm? seems like this would be important
        # depending on positional embedding scheme
        self.token_embedder = nn.Embedding(
            num_embeddings = vocab_size,
            embedding_dim = model_dim
        )

        # use learned positional embeddings for simplicity
        # TODO what are the tradeoffs with fixed positional embeddings besides less storage?
        self.positional_embedder = nn.Embedding(
            num_embeddings = context_len,
            embedding_dim = model_dim
        )

        self.decoder = nn.Sequential(*[
            TransformerDecoderBlock(num_heads, model_dim, ffn_hidden_dim, context_len, dropout)
            for _ in range(num_blocks)
        ])

        self.output = nn.LazyLinear(vocab_size)

        context_idx_tensor = torch.tensor(list(range(context_len)))
        self.register_buffer('context_idx_tensor', context_idx_tensor)


    def forward(self, tokens):
        # tokens.shape is (batch size, ctx len)
        assert tokens.shape[-1] == self.context_len

        context_idx_tensor = self.get_buffer('context_idx_tensor')
        positional_embeddings = self.positional_embedder(context_idx_tensor)

        token_embeddings = self.token_embedder(tokens)
        embeddings = token_embeddings + positional_embeddings

        return self.output(self.decoder(embeddings))


In [386]:
ralphgpt = RalphGPT(
    num_blocks = 10,
    num_heads = 4,
    model_dim = EMBEDDING_DIM,
    ffn_hidden_dim = EMBEDDING_DIM,
    vocab_size = tokenizer.n_vocab,
    context_len = CONTEXT_LEN,
    dropout = 0.1
)



In [387]:
ralphgpt.to(device)

RalphGPT(
  (token_embedder): Embedding(50257, 256)
  (positional_embedder): Embedding(128, 256)
  (decoder): Sequential(
    (0): TransformerDecoderBlock(
      (attention): MultiHeadAttention1(
        (dropout): Dropout(p=0.1, inplace=False)
        (attention_heads): ModuleList(
          (0-3): 4 x Attention(
            (softmax): Softmax(dim=-1)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (q_linear): LazyLinear(in_features=0, out_features=256, bias=True)
        (k_linear): LazyLinear(in_features=0, out_features=256, bias=True)
        (v_linear): LazyLinear(in_features=0, out_features=256, bias=True)
        (output_layer): LazyLinear(in_features=0, out_features=256, bias=True)
        (softmax): Softmax(dim=-1)
      )
      (add_norm1): AddNorm(
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      )
      (ffn): PositionwiseFFN(
        (0): LazyLinear(in_features=0, o

In [388]:
batches_x = torch.split(token_partitions_X.to(device), BATCH_SIZE, dim = 0)
batches_y = torch.split(token_partitions_Y.to(device), BATCH_SIZE, dim = 0)

In [389]:
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(ralphgpt.parameters(), lr=0.0005)

In [390]:
num_epochs = 100

In [391]:
gradients = []

In [392]:
model = ralphgpt
liveloss = PlotLosses()
for epoch in range(num_epochs):
    batches = list(zip(batches_x, batches_y))
    random.shuffle(batches)

    losses = []
    for batch_i, batch in enumerate(batches):
        batch_x, batch_y = batch

        pred_y = model(batch_x)

        flat_pred_y = torch.flatten(pred_y, 0, 1)
        flat_batch_y = torch.flatten(batch_y, 0, 1)

        loss = loss_fn(flat_pred_y, flat_batch_y)
        loss_val = loss.item()

        losses.append(loss_val)
        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # # Vector to store each gradient's L2 norm squared
        # grads = []

        # for parameter in model.parameters():
        #     if parameter.requires_grad and parameter.grad is not None:
        #         # Add the squared L2 norm of each gradient to the list
        #         grads.append(parameter.grad.data.norm(2).item() ** 2)

        # # Sum up the squared norms and take the square root to get the total L2 norm
        # total_grad_norm = torch.sqrt(torch.tensor(grads).sum())
        # gradients.append(total_grad_norm)

        opt.step()
        opt.zero_grad()

    mean_loss = pd.Series(losses).mean()
    logs = dict(
        loss = mean_loss
    )
    liveloss.update(logs)
    liveloss.send()

    print(f"[epoch: {epoch}] loss", loss_val)
        # val_pred_y = self.model(val_X)
        # val_loss = self.loss_fn(val_pred_y, val_y)

In [286]:
def pad_input(s: str):
    tokens = tokenizer.encode(s)

    pad_len = CONTEXT_LEN - len(tokens)

    if pad_len < 0:
        raise Exception(f"Got {len(tokens)} tokens but context len is {CONTEXT_LEN}")

    padding = pad_len * [0]

    return tokens + padding

In [291]:
def infer(model, s, num_out = 50):

    input_tokens = pad_input(s)
    output_tokens = []

    for _ in range(num_out):
        batch = torch.tensor(input_tokens).unsqueeze(0).to(device)

        logits = model(batch)
        tokens = logits.argmax(-1).squeeze(0).tolist()

        pred_token = tokens[-1]
        output_tokens.append(pred_token)
        input_tokens = input_tokens[:-1]
        input_tokens.append(pred_token)


    return tokenizer.decode(output_tokens)

In [293]:
print(infer(model, "For all the world's a stage"))

 this day the trump day your trump back back back on Thursday ho devil back back, not a word hours am back again are Thursday ho gods th are Thursday am noise we'll be a canopy back, be gods are you say this hard Thursday Henry am
