# Module 2 Project 2: T5

Implement a [T5 model](https://blog.research.google/2020/02/exploring-transfer-learning-with-t5.html) using everything that has been covered in previous modules so far
(Tokenization, Attention, Decoder layers)

## STEP 1: IMPORTS AND HYPERPARAMETERS
- Outside of the usual `torch` imports, we need math for `sqrt` and requests to get out data
- For our hyperparameters, we use a dropout of 0.1 and `batch_size` of 16
- Our `chunk_size` here is used to split our batches into 'chunks' of size 16

In [None]:
import requests

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F

import math

dropout = 0.1
chunk_size = 128
batch_size = 32
num_layers = 4
embed_dim = 64
num_heads = 4
num_epochs = 500
max_length = 1000

eval_interval = 20

## STEP 2: LOAD THE DATA
- We will use Project Gutenberg to load our data
- I chose [Treasure Island](https://www.gutenberg.org/cache/epub/120) for this project
- After splitting out the text of the book, print the first 1000 characters

In [None]:
resp = requests.get("https://www.gutenberg.org/cache/epub/120/pg120.txt")

start = "*** START OF THE PROJECT GUTENBERG EBOOK TREASURE ISLAND ***"
end = "*** END OF THE PROJECT GUTENBERG EBOOK TREASURE ISLAND ***"

result = resp.text[resp.text.find(start):resp.text.find(end)]

print(result[:1000])

## STEP 3: TOKENIZATION
- We will be re-implenting our Tokenizer from [last module's projects](https://github.com/samherring99/NightwingCurriculum/blob/main/module_1_nlp_basics/module_1_project_1.ipynb)
- Starts with a base vocab of 256 tokens, performs BPE a set # of times (30) to create vocabulary

In [None]:
class Tokenizer:

    def __init__(self):
        self.merges = {}
        self.vocab = {}

    def get_pair_counts(self, token_ids):
        counts = {}
        for pair in zip(token_ids, token_ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts

    def new_token(self, token_ids, pair, index):
        new_ids = []
        i = 0
        while i < len(token_ids):
            if i < len(token_ids) - 1 and token_ids[i] == pair[0] and token_ids[i+1] == pair[1]:
                new_ids.append(index)
                i += 2
            else:
                new_ids.append(token_ids[i])
                i += 1
        return new_ids

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        text_bytes = text.encode("utf-8")
        token_ids = list(text_bytes)

        merges = {}
        vocab = {index: bytes([index]) for index in range(256)}

        for i in range(num_merges):
            pair_counts = self.get_pair_counts(token_ids)
            pair = max(pair_counts, key=pair_counts.get)
            index = 256 + i
            token_ids = self.new_token(token_ids, pair, index)
            merges[pair] = index
            vocab[index] = vocab[pair[0]] + vocab[pair[1]]
        
        self.merges = merges
        self.vocab = vocab

    def encode(self, text):
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            stats = self.get_pair_counts(tokens)
            pair = min(stats, key= lambda x: self.merges.get(x, float("inf")))
            if pair not in self.merges:
                break
            index = self.merges[pair]
            tokens = self.new_token(tokens, pair, index)
        return tokens

    def decode(self, token_ids):
        tokens = b"".join(self.vocab[index] for index in token_ids)
        text = tokens.decode("utf-8", errors='replace')
        return text


token = Tokenizer()
token.train(result, 486) # We want to do 230 (256 is base vocab size) merges, as an example

vocab = token.vocab

print(token.merges)
print(token.decode(token.encode("Hello! this is a text string!")))


## STEP 4: SELF-ATTENTION
- We will be re-implementing our MultiHeadedAttention class from the last project in this module
- The MHA implementation remains the same as when we did it last, nothing has changed
- See [Module 2 Project 1]() for an in depth explanation on attention and multi-headed attention

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(embed_dim, head_size, bias=False)
        self.query = nn.Linear(embed_dim, head_size, bias=False)
        self.value = nn.Linear(embed_dim, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(chunk_size, chunk_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        weights = q @ k.transpose(-2, -1) * C**-0.5
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=1)
        weights = self.dropout(weights)

        v = self.value(x)
        result = weights @ v

        return result

class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(num_heads)])
        self.projection = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        result = torch.cat([h(x) for h in self.heads], dim=-1)
        result = self.dropout(self.projection(result))
        return result

## STEP 5: CROSS-ATTENTION
- Here, we need to implement cross attention. The difference being that we 'attend' to the outputs from our Encoder layer in T5, we are no longer 'self-attending' across the same input for Q, K, and V
- Now, our Q remains the same, but our K and V values come from the outputs of the Encoder layer (more on this below)
- Attention across these inputs is calculated in a similar way, but we multiply our inputs for our Linear layers by our `num_heads` to 'cross-attend' over all of our attention heads

In [None]:
class MultiHeadedCrossAttention(nn.Module):
    def __init__(self, num_heads, embed_dim):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = embed_dim // num_heads
        
        self.query = nn.Linear(4* embed_dim, embed_dim, bias=False)
        self.key = nn.Linear(4*embed_dim, embed_dim, bias=False)
        self.value = nn.Linear(4*embed_dim, embed_dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(embed_dim, 4*embed_dim)
        
    def forward(self, query, key, value, mask=None):
        B, T_q, _ = query.size()
        _, T_k, _ = key.size()

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        
        query = query.view(B, T_q, self.num_heads, self.head_size).transpose(1, 2)
        key = key.view(B, T_k, self.num_heads, self.head_size).transpose(1, 2)
        value = value.view(B, T_k, self.num_heads, self.head_size).transpose(1, 2)
        
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context = torch.matmul(attn_weights, value)
        context = context.transpose(1, 2).contiguous().view(B, T_q, -1)
        output = self.out_projection(context)

        return output

## STEP 6: FEED FORWARD
- Here, our Feed Forward network is implemented in much the same way as the last project on Transformers
- A linear layer projects our embedding to 4x, performs ReLU activation, and reduces the embedding dimension back to `embed_dim`

In [None]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, 4*embed_dim),
            nn.ReLU(),
            nn.Linear(4*embed_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x) 

## STEP 7: ENCODER LAYER
- Our encoder layer is written much the same as our `BuildingBlocks` in the Transformers project previously
- We have our MHA implementation, our Feed Forward network, and 2 layer normalizations, each one performed before MHA and Feed Forward respectively.

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        head_size = embed_dim // num_heads
        self.attention = MultiHeadedAttention(num_heads, head_size)
        self.feed_forward = FeedForward(embed_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x, mask=None):
        x = x + self.attention(self.layer_norm1(x))
        x = x + self.feed_forward(self.layer_norm2(x))
        return x  

## STEP 8: DECODER LAYER
- Here, our decoder layer needs to 'cross-attend' to our encoder layer's outputs
- It's written out very much the same way, but with `CrossAttention` instead of our usual MHA
- The `x + ` operator below performes our residual connection to help deal with the vanishing gradient problem
- Here, layer normalization is performed after cross-attention and our feed forward layer

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        head_size = embed_dim // num_heads
        self.attention = MultiHeadedCrossAttention(num_heads, head_size)
        self.feed_forward = FeedForward(embed_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x, encoder_output, mask=None):
        cross_attention_output = self.attention(self.layer_norm1(x), encoder_output, encoder_output, mask)
        x = x + cross_attention_output
        ff_output = self.feed_forward(self.layer_norm2(x))
        x = x + ff_output

        return x

## STEP 9: T5 MODEL
- Now we can put everything together
- Using `num_layers` we build out our Encoder and Decoder blocks
- We cap these off with a linear projection to our `vocab_size` to get logits for the next token
- Using the `forward` call in our decoder layer with both `input_embed` and our `target_embed`, we get cross-attention results

In [None]:
class T5(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim)  
        self.layer_norm_f = nn.LayerNorm(embed_dim)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(embed_dim, num_heads) for _ in range(num_layers)
        ])

        self.decoder_layers = nn.ModuleList([
            DecoderLayer(embed_dim, num_heads) for _ in range(num_layers)
        ])

        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, input_ids, target_ids):
        self.input_embed = self.embedding(input_ids)
        self.target_embed = self.embedding(target_ids)

        for layer in self.encoder_layers:
            self.input_embed = layer(self.input_embed)
        
        for layer in self.decoder_layers:
            self.target_embed = layer(self.target_embed, self.input_embed)

        norm_result = self.layer_norm_f(self.target_embed)

        output = self.fc(norm_result)

        return output

## STEP 10: TRAINING LOOP
- Here we implement a common training loop using CrossEntropyLoss and Adam for our optimizer
- We set a learning rate of 0.001 arbitrarily
- We set our initialization parameters for our model to be 4 layers, 4 heads in each layer, with an `embed_dim` of 32
- We want to train for 100 epochs and to generate up to 1000 characters as a test output when training completes

In [None]:
vocab_size = len(vocab)

model = T5(vocab_size, embed_dim, num_layers, num_heads)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

data = torch.tensor(token.encode(result), dtype=torch.long)
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    data = train_data
    index = torch.randint(len(data) - chunk_size, (batch_size,))
    input_ids = torch.stack([data[i:i+chunk_size] for i in index])
    target_ids = torch.stack([data[i+1:i+chunk_size+1] for i in index])
    optimizer.zero_grad()
    
    output = model(input_ids, target_ids)
    loss = criterion(output.view(-1, vocab_size), target_ids.view(-1))
    
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    if epoch % eval_interval == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(input_ids)}")
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        data = val_data
        index = torch.randint(len(data) - chunk_size, (batch_size,))
        val_input_ids = torch.stack([data[i:i+chunk_size] for i in index])
        val_target_ids = torch.stack([data[i+1:i+chunk_size+1] for i in index])
        val_output = model(val_input_ids, val_target_ids)
        val_loss += criterion(val_output.view(-1, vocab_size), val_target_ids.view(-1)).item()
    #print(f"Validation Loss: {val_loss / len(val_input_ids)}")

print("Finished!")

## STEP 11: GENERATE TEXT
- Here we evaluate our trained model to generate sample text
- This should be nonsense, just confirming our model works for generating text from token sequences

In [None]:
model.eval()
with torch.no_grad():
    input_ids = torch.tensor(token.encode("The ")).unsqueeze(0)
    
    for _ in range(max_length):
        context = input_ids[:, -chunk_size:]
        logits = model(context, context)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        next_index = torch.multinomial(probs, num_samples=1)

        input_ids = torch.cat((input_ids, next_index), dim=1)

    print(token.decode(input_ids[0].tolist()))