# Transformers from Scratch

We will be implementing a Transformer, specifically a Decoder-only Transformer, from scratch. This will be taken from the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et al. (2017).

The Transformer is a model architecture that has been used in many NLP tasks, such as machine translation, text summarization, and question-answering. It is based on the idea of self-attention, which allows the model to weigh the importance of different words in a sentence when encoding or decoding it.

This has been greatly inspired from Andrej Karpathy's [video on building GPT](https://www.youtube.com/watch?v=kCc8FmEb1nY). While this is a poor imitation at best, I hope to explore some aspects I initially found confusing a bit deeper, to play around with different ways to implement the same thing, and to tinker with some ideas that the video does not go into.

In [1]:
# Import in our libraries
import os
import requests
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F

import tiktoken

In [3]:
x = torch.tensor([1.2, 5.3, -0.8, 0.9, float("-inf"), float("-inf"), float("-inf"), float("-inf")])
torch.softmax(x, dim=0)

tensor([0.0161, 0.9698, 0.0022, 0.0119, 0.0000, 0.0000, 0.0000, 0.0000])

## Bringing in our data

We will be using the Tiny Shakespeare dataset, same as the video.

Let's load in our data in the cell below and work on processing it to feed it into our model.

In [4]:
# Download the tiny shakespeare dataset
input_file_path = 'input.txt'

if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    
    with open(input_file_path, 'w', encoding='utf-8') as f:
        f.write(requests.get(data_url).text)

In [5]:
# Read in the data
with open(input_file_path, 'r', encoding='utf-8') as f:
    text = f.read()

print(text[:250])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.



Now that we've loaded in our data, let's start by tokenizing it.

This can be very easily handled with the `tiktoken` library (read up [here](https://github.com/openai/tiktoken)), which is also what the GPT family uses :p

A tokenizer will convert a token (a piece of text - could be a word or part of a word) into an integer, which is what we need to feed into our model. We will use the same tokenizer as GPT-2 for this task, which has a vocabulary size of 50257 - this means that we have 50257 unique tokens in our vocabulary, i.e. 50257 unique integers that represent different words or parts of words.

In [13]:
tokenizer = tiktoken.get_encoding("gpt2")

# Small example
sample_text = "Hello, world!"
encoded = tokenizer.encode(sample_text)
decoded = tokenizer.decode(encoded)

print(f"Original text: \"{sample_text}\" --> Encoded: {encoded}")
print(f"List of tokens: {[tokenizer.decode([token]) for token in encoded]}")

Original text: "Hello, world!" --> Encoded: [15496, 11, 995, 0]
List of tokens: ['Hello', ',', ' world', '!']


In [31]:
# Let's tokenize our entire dataset
encoded_text = tokenizer.encode(text)
# Convert to torch tensor of int64 (important for later on)
encoded_text = torch.tensor(encoded_text).long()

print(f"First 10 tokens: {encoded_text[:10]}")
print(f"Length of original dataset: {len(text)}")
print(f"Length of tokenized dataset: {len(encoded_text)}")

First 10 tokens: tensor([ 5962, 22307,    25,   198,  8421,   356,  5120,   597,  2252,    11])
Length of original dataset: 1115394
Length of tokenized dataset: 338025


Now that we have our textual data in numeric form, we can think about how to process it in a way to feed to our model.

We will start training our model to predict the next token in a sequence of tokens - the essence of language modeling.

We will do this by creating *windows* in our dataset, where each window is a sequence of tokens of a fixed length. We will then train our model to predict the next token in the sequence given the previous tokens.

This means that the input to our model will be a sequence of tokens, and the desired output will simply be that sequence shifted by one token.

In [32]:
# Define the size of the window
ctx_len = 8

# Create a window of size 8 to feed to our model
x = encoded_text[:ctx_len]
y = encoded_text[1: ctx_len+1]

print(f"Context: {x}")
print(f"Target: {y}")
print('-'*80)

for t in range(ctx_len):
    context = x[:t+1]
    target = y[t]
    print(f"Context: {context} --> Target: {target}")

Context: tensor([ 5962, 22307,    25,   198,  8421,   356,  5120,   597])
Target: tensor([22307,    25,   198,  8421,   356,  5120,   597,  2252])
--------------------------------------------------------------------------------
Context: tensor([5962]) --> Target: 22307
Context: tensor([ 5962, 22307]) --> Target: 25
Context: tensor([ 5962, 22307,    25]) --> Target: 198
Context: tensor([ 5962, 22307,    25,   198]) --> Target: 8421
Context: tensor([ 5962, 22307,    25,   198,  8421]) --> Target: 356
Context: tensor([ 5962, 22307,    25,   198,  8421,   356]) --> Target: 5120
Context: tensor([ 5962, 22307,    25,   198,  8421,   356,  5120]) --> Target: 597
Context: tensor([ 5962, 22307,    25,   198,  8421,   356,  5120,   597]) --> Target: 2252


Now we can start thinking of what our model will actually do.

It works off taking a chunk of integers, whose length is at most `ctx_len`, and will predict the next integer in the sequence. Note that the chunks can be any length, we just have to specify the maximum context length for our model.

When we take a chunk of 8 chars, we don't just predict the next character after this sequence of 8 chars - we train our model to predict **at each and every one of these positions**. This means we have $n$ different training examples for each context of length $n$.

The model is being made to predict at contexts with sizes all the way from 1 till `ctx_len`; this means it has the ability to predict the next token and start generating when it's been given just one token of context.

Now we can begin to gather multiple windows at once to create minibatches to feed into our model.

In [35]:
batch_size = 16

def get_batch():
    """
    Returns a batch of data for training
    """
    # Sample batch_size number of starting indices to create our windows
    idxs = torch.randint(0, len(encoded_text) - ctx_len, (batch_size,))

    # Get our inputs and targets
    x = torch.stack([encoded_text[idx:idx+ctx_len] for idx in idxs])
    y = torch.stack([encoded_text[idx+1:idx+ctx_len+1] for idx in idxs])

    return x, y

xb, yb = get_batch()
print(xb.shape, yb.shape)
print('-'*80)

# Print out window's examples - this is a single item in our batch
for b in range(1):
    for t in range(ctx_len):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'Context: {context.tolist()} -> Target: {target.item()}')

torch.Size([16, 8]) torch.Size([16, 8])
--------------------------------------------------------------------------------
Context: [21067] -> Target: 11
Context: [21067, 11] -> Target: 3367
Context: [21067, 11, 3367] -> Target: 11
Context: [21067, 11, 3367, 11] -> Target: 466
Context: [21067, 11, 3367, 11, 466] -> Target: 26
Context: [21067, 11, 3367, 11, 466, 26] -> Target: 329
Context: [21067, 11, 3367, 11, 466, 26, 329] -> Target: 356
Context: [21067, 11, 3367, 11, 466, 26, 329, 356] -> Target: 1276


## Language Model

The output of the cell above shows exactly what our model would see and what it would try to predict for each item in the batch.

Our model would take as input a $(B, T)$ tensor and output a $(B, T, V)$ tensor, where $B$ is the batch size, $T$ is the sequence length, and $V$ is the vocabulary size. 

This is because we are predicting the probability distribution of the next token for all $B \times T$ **positions** in the input. This means we will have exactly $B \times T$ different training examples for each batch.

## Self-Attention

The whole point of Self-Attention is to get the tokens interacting and talking with one another. The idea is to bake context into the raw representations of the tokens themselves.

The way this is done is to take each token, and to extract three pieces of information from it: the Query, the Key, and the Value. These have the following ideas behind them:

* The Query says "What am I looking for/Here's what I'm interested in..."

* The Key says "What do I contain/This is what I have..."

* The Value says "If you find me interesting, here's what I will communicate to you..."

The way we get the *affinities between tokens* now is to simply take a dot product between the Query and the Key. 

As an example: if some token, say at position 8 (whose Query we take), finds that a token at postion 4 produces a Key that generates a high dot product value, then the model would have learned something important about the meaning of that 8th token. It now knows that to better understand the 8th token, it should look at the 4th token as well for context.

The way we extract these pieces is to simply perform a linear transformation on the input embeddings to get the Query, Key, and Value matrices.