# Transformer Architechture

## Setup

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# read it in to inspect it
with open('tinyshakespeare.txt', 'r', encoding='utf-8') as f:
  text = f.read()

In [4]:
print("length of dataset in characters:", len(text))

length of dataset in characters: 1115394


In [5]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [6]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print("Vocab size:", vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocab size: 65


In [7]:
# create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode([46, 47, 47, 1, 58, 46, 43, 56, 43]))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


**Note:** This tokenize (encoder/decoder) converts each character into a unique number, however in practice people usually use **sub-word level encoders**. 
  - They have **larger vocabulary sizes**, but use **fewer numbers** to represent a given text.
  - There is a **tradeoff** between vocabulary size and representation size. 
  - Popular **sub-word tokenizers** are [sentencepiece](https://github.com/google/sentencepiece) and [tiktoken](https://github.com/openai/tiktoken).

In [8]:
data = torch.tensor(encode(text), dtype=torch.int64)
print(f"{data.shape=}")
data[:100]

data.shape=torch.Size([1115394])


tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])

In [9]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [10]:
block_size = 8
print("example sequence:", train_data[:block_size+1], "\n")
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f"when input is {context} the target: {target}")

example sequence: tensor([18, 47, 56, 57, 58,  1, 15, 47, 58]) 

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


- The transformer learns to predict the target given context of any length up to `block_size`. 
- This is not just for efficiency, and is infact necessery for inference since the transformer will need to be able to generate text from scratch (with no context, then context of length 1, and so on...).

In [11]:

torch.manual_seed(1337)
batch_size = 4 # number of independant sequences we will process in paralell
block_size = 8 # maximum context length for prediction

def get_batch(data):
  ix = torch.randint(len(data)-block_size, size=(batch_size,))
  x = torch.stack([data[i : i+block_size] for i in ix])
  y = torch.stack([data[i+1 : i+block_size+1] for i in ix])
  x, y = x.to(device), y.to(device)
  return x, y

xb, yb = get_batch(train_data)
print(f'inputs:')
print(xb)
print('targets:')
print(yb)

inputs:
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]], device='cuda:0')
targets:
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]], device='cuda:0')


Each batch contains $4 \times 8$ training examples.
- given `[24]` predict `43`
- given `[24, 43]` predict `58`
- etc...
- given `[25, 17, 27, 10,  0, 21,  1, 54]` predict `39`

Each batch has 2 dimensions. The first dimension is the *batch* dimension of size `4`, and the second dimension is the *time* dimension of size `8`.

## "Neural Network" Bigram Model

In [12]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
  
  def __init__(self, vocab_size:int) -> None:
    super().__init__()
    # the embedding for each token is simply the logits for the next token
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

  def forward(self, idx:torch.Tensor, targets:torch.Tensor|None = None) -> tuple[torch.Tensor, torch.Tensor|None]:
    # idx and targets are both (B,T) tensors of integers
    logits = self.token_embedding_table(idx) # (B, T, C)
    
    if targets is None:
      loss = None
    else:
      B, T, C = logits.size()
      # merge the batch and time dimension since F.cross_entropy expects only one "batch" dimension
      loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
    
    return logits, loss
  
  def generate(self, idx:torch.Tensor, max_new_tokens:int) -> torch.Tensor:
    # idx is (B,T) tensor of integers of current context
    for _ in range(max_new_tokens):
      logits, _loss = self(idx) # (B, T, C)
      # we only care about the predicted token for the last time-step
      logits = logits[:, -1, :] # (B, C)
      probs = F.softmax(logits, dim=1)
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

Currently the generator function is wildly inefficient for a simple bigram model since it remakes the predictions everytime for the full context when it just needs the previous word to predict the next word. However this is on purpose since in the next models we will need to use the context and this way we won't need to write a new generator function.

In [13]:
torch.manual_seed(1337)

m = BigramLanguageModel(vocab_size).to(device)

logits, loss = m(xb, yb)
print(f"{logits.shape=}")
print(f"{loss=}")

context_idx = torch.zeros((1, 1), dtype=torch.int64).to(device)
print(decode(m.generate(context_idx, max_new_tokens=100)[0].tolist()))

logits.shape=torch.Size([4, 8, 65])
loss=tensor(4.8786, device='cuda:0', grad_fn=<NllLossBackward0>)

pYCXxfRkRZd
wc'wfNfT;OLlTEeC K
jxqPToTb?bXAUG:C-SGJO-33SM:C?YI3a
hs:LVXJFhXeNuwqhObxZ.tSVrddXlaSZaNe


In [14]:
model = BigramLanguageModel(vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) # you can get away with larger lr for smaller networks

In [15]:
batch_size = 32
for step in range(10_000):
  
  xb, yb = get_batch(train_data)
  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  
  if step%1000==0:
    print(loss.item())

4.598329067230225
3.657761812210083
3.120805263519287
2.6083383560180664
2.5169241428375244
2.612285614013672
2.517780303955078
2.4636025428771973
2.619422197341919
2.6189043521881104


In [16]:
context_idx = torch.zeros((1, 1), dtype=torch.int64).to(device)
print(decode(model.generate(context_idx, max_new_tokens=500)[0].tolist()))


Wawice my.

HDEdarom orou waowh$Frtof isth ble mil ndill, ath iree sengmin lat Heriliovets, and Win nghirileranousel lind me l.
HAshe ce hiry:
Supr aisspllw y.
Hurindu n Boopetelaves
MP:

Pl, d mothakleo Windo whthCoribyo the m dourive we higend t so mower; te

AN ad nterupt f s ar igr t m:

Thiny aleronth,
Mad
Whed my o myr f-NLIERor,
SS&Y:

Sadsal thes ghesthidin cour ay aney Iry ts I fr y ce.
Jken pand, bemary.
Yor 'Wour menm sora anghy t-e nomes twe men.
Wand thot sulin s th llety ome.
I muc


The bigram model only takes the previous token as context. What we want is a model where all the previous tokes "talk" with each other and figure out the context that can be used to predict the next token. Enter transformers.

## Transformer Model

### The mathematical trick in self-attention

**Bag of words:** simply averaging all the previous words into one "bag of words" which is used for prediction at that time-step. The term "bag of words" almost always refers to some sort of **averaging**.

In [17]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

We want to create another tensor `xbow` with entries such that $\text{xbow}[b, t] = \text{mean}_{i \leq t} \left( \text{x}[b, i] \right)$.


In [18]:
# version 1
xbow = torch.zeros(B, T, C)
for b in range(B):
  for t in range(T):
    xbow[b, t] = x[b, :t+1].mean(dim=0)

In [19]:
# version 2
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(dim=1, keepdim=True) # weights used to "average" just like above
# print(f"{weights=}")
xbow2 = weights @ x # extra B in weights from broadcasting: (B, T, T) @ (B, T, C) ---> (B, T, C)

In [20]:
# version 3
affinities = torch.zeros((T, T)) # affinities between tokens (0 results in all tokens being equally weighted)
tril = torch.tril(torch.ones(T, T))
weights = affinities.masked_fill(tril==0, float('-inf')) # so that tokens from the future cannot communicate
weights = F.softmax(weights, dim=1) # this gives us the same weight matrix as before
xbow3 = weights @ x

In [21]:
torch.allclose(xbow, xbow2), torch.allclose(xbow2, xbow3)

(True, True)

- This illustrates a mathematical trick where we can "average" using matrix multiplication instead of using for loops. 
- This can also be extended to weighted averages very easily. Here the weights are normalized to sum to one and not look into the future of the time dimension.
- *Version 3* is especially useful since it startes with "affinities" between different tokens, and they can be data dependant instead of just being set to zero.

### Crux of self attention

- Every token will emit three vectors - **query**, **key**, and **value**
- The key tells other tokens what it *has*, and the query tells other tokens what it's *looking for*.
- The value contains the *information* of the token that is relevant for other tokens.
- The **dot product** between a key and query for a token pair become the *affinities* between tokens i.e. how interested is one token with another.
- For example, a pronoun could be looking for a noun, so the dot product between the query of the pronoun and the key of the noun could be high.
- The model then computes the weight matrix by performing softmax on the masked affinities.
- This weight matrix is then multiplied with the values of all the tokens to produce the output of the self-attention layer.

In [22]:
# version 4: self-attention!
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # C could be n_embd for example
x = torch.randn(B, T, C)

# let's see a single head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False) # just a simple matrix multiply layers
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
# each embedded token has produced a key and query independantly and in parllel - no communication has happened yet between tokens
affinities = q @ k.transpose(-2, -1) # basically a batch dot product for every token-pair: (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
weights = affinities.masked_fill(tril==0, float('-inf')) # so that tokens from the future cannot communicate
weights = F.softmax(weights, dim=1)
  
v = value(x) # (B, T, 16)
out = weights @ v # (B, T, T) @ (B, T, 16) ---> (B, T, 16)

out.shape
print(weights[0])
print(weights[0].sum(dim=1, keepdim=True))

tensor([[0.0248, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0052, 0.0091, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0521, 0.0135, 0.2482, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3171, 0.0214, 0.1642, 0.1188, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0412, 0.0487, 0.1046, 0.0742, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1060, 0.5347, 0.2059, 0.1030, 0.7402, 0.0192, 0.0000, 0.0000],
        [0.4298, 0.3409, 0.1769, 0.2027, 0.0480, 0.8472, 0.2329, 0.0000],
        [0.0238, 0.0316, 0.1002, 0.5013, 0.0117, 0.1336, 0.7671, 1.0000]],
       grad_fn=<SelectBackward0>)
tensor([[0.0248],
        [0.0143],
        [0.3138],
        [0.6216],
        [0.4687],
        [1.7090],
        [2.2784],
        [2.5694]], grad_fn=<SumBackward1>)


### Notes:

- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additionally divides the `weights` by 1/sqrt(head_size). This makes it so when input `k`, `q` are unit variance, `weights` will be unit variance too and Softmax will stay diffuse and not saturate too much. This prevents Softmax from converging to one-hot vectors.

### More notes:
- **Multi-head attention**
  - Often, the tokens often "have a lot to talk about". This communication can be facillitated by multiple attention heads between the same tokens.
  - For example one of the heads could be looking for consonants, while the other could be looking for vowels, etc.
  - The final output is a concatination between all the attention heads.
- **Communication and computation**
  - It is often useful to let the network **"think over"** the outputs of the attention blocks. This can be achieved by adding more **computation** steps using feed-forward layers. 
  - Infact a common architechure is to **intersperse** the attention blocks, that allow communication between tokens, with feed-forward blocks, that allow tokens to individually process the results of this communication.
- **Skip/residual connections**
  - Deep neural networks often suffer from optimization issues due to vanishing gradients. Skip connections (a.k.a. residual connections) help with this.
  - The main concept is that you "branch" the data off to transform the data in some way, but also add the non-transformed version back to the "main branch". 
  - Since addition routes gradients, this creates a "gradient super-highway" that allows gradients to backpropagate more effectively to the earlier layers. 
  - The "branches" off the highway are where the computation happens, and they are initialized in such a way that they barely contribute any gradients, but eventually they become more relevant as training goes on.
- **LayerNorm**
  - This is another innovation that helps with numerical stability of the gradients during optimization of deep neural networks.
  - The idea is the same as BatchNorm, we want every neuron to have unit gaussian outputs. 
  - However instead of normalizing over the batch dimension, we normalize over the time dimension. 
  - This is better than BatchNorm since batches are no longer coupled together and we don't need to keep track of training vs testing.
- **Dropout**
  - This is a regularization technique that helps prevent overfitting.
  - Every forward pass, a random subset of neurons are disabled, so only a subset of the network is trained in each optimization step.
  - This has the effect of training an ensemble of networks that are merged together at test time.

The transformer architechture has stood the test of time and has remained relatively unchanged for quite some time since it's introduction in 2017. One change that was made was the pre-norm formulation that adds a LayerNorm before the transformer block, and another one before the feed-forward block.

- When we want the generated text to be conditioned on some previous text, we can use an encoder transformer blocks (without the traingular mask), to generate the keys and values, and a decoder transformer block to generate the queries.
- These can be combined to form an encoder-decoder transformer. This is also known as cross-attention! 
- See the "attention is all you need" paper for more details

## How to train ChatGPT
### Stage 1: Pretraining
- The first stage is to train a decoder transformer langauge model on all the data from the internet that simply learns to babble on
- Here we trained a character level model with 10 million parameters on a dataset with 1 million character level tokens tokens which would be about 300,000 tokens in the OpenAI vocabulary.
- The biggest OpenAI GPT 3 transformer has 175 billion parameters on 300 billion tokens. These numbers are not even large compared to current standards. See OpenAI's [GPT-3 paper](https://arxiv.org/abs/2005.14165) for more details.
- Due to scale, this becomes a massive infrastructure challenge.
### Stage 2: Fine-tuning
- So far the pre-trained model will just autocomplete and babble pretty much anything. It could complete the sentence, it could write a news article, it could ask more questions, pretty much anything that it sees on the internet.
- The next stage is to "align it" to be an assistant. See this [blog post](https://openai.com/blog/chatgpt) by OpenAI for more information.
  - Step 1 fine-tunes this model based on a small dataset of good example "question-answer" pairs. This works because pre-trained LLMs are very sample efficient.
  - Step 2 trains a reward model by asking humans to rank different responses from bad to good. This reward model is a way to score any response.
  - Step 3 uses the PPO a reinforcement learning algorithm to further fine-tune the model based on the reward model.