# Exercise: Self-attention mechanism in GPT Transformer

In this exercise you will be implementing the `forward()` function of the `MultiHeadSelfAttention` module in a minified GPT implementation. GPT refers to the "Generative Pre-trained Transformers" paper from OpenAI, originally described in [this paper](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) [1].

A full GPT model consists of a sequence of GPT Blocks (transformer). The self-attention module is invoked in the GPT Block, where the input sequence of tokens are layer-normalized before application and added to itself constituting a residual connection. For extra reading on residual connections and why they are useful to avoid gradient collapse, see [this image](https://miro.medium.com/v2/resize:fit:640/format:webp/1*mxJ5gBvZnYPVo0ISZE5XkA.png) [3] and the [ResNet paper](https://arxiv.org/pdf/1512.03385.pdf) [2]. 

```
def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x
```

## References

[1] Radford, A., Narasimhan, K., Salimans, T., and Sutskever, I. Improving language understanding with unsupervised learning. Technical report, OpenAI (2018).

[2] He, K., Zhang, X., Ren, S., & Sun, J. (2015). Deep Residual Learning for Image Recognition. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).

[3] https://towardsdatascience.com/the-vanishing-gradient-problem-69bf08b15484

## Scaled Multiplicative Attention

Recall this attention formula:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

This is represented in the Python code as follows:

* $Q$: `q`
* $K$: `k`
* $V$: `v`
* $\text{softmax}$: `F.softmax()`
* $K^T$: `k_t`
* $QK^T$ (matrix multiplication): `q @ k_t`
* $\sqrt{}$: `math.sqrt()`
* $d_k$: `d_k`

## Your Task

Within the `MultiHeadSelfAttention` class, fill in the TODO portion of the `forward` method.

In [19]:
import math
import logging

import torch
import torch.nn as nn
from torch.nn import functional as F


class GPTConfig:
    """ base GPT config, params common to all GPT versions """
    embd_pdrop = 0.1
    resid_pdrop = 0.1
    attn_pdrop = 0.1
    cross_attention = False

    def __init__(self, vocab_size, block_size, **kwargs):
        self.vocab_size = vocab_size
        self.block_size = block_size
        for k,v in kwargs.items():
            setattr(self, k, v)
            
class MultiHeadSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # START OF SOLUTION
        # multiply q and k_t matrices, then divide by the square root of d_k
        k_t = k.transpose(-2, -1)
        d_k = k.size(-1)
        att = q @ k_t / math.sqrt(d_k)

        # set the mask fill value to negative infinity
        masked_fill_value = float('-inf')
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, masked_fill_value)

        # apply softmax and regularization
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        # multiply att and v matrices
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        # END OF SOLUTION

        # re-assemble all head outputs side by side
        y = y.transpose(1, 2).contiguous().view(B, T, C) 

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

## GPT model definition

- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
    - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
    - all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a Softmax classifier

Run this cell without changes.

In [17]:
import math

import torch
import torch.nn as nn
from torch.nn import functional as F

class Block(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        if config.cross_attention:
            self.attn = CrossAttention(config)
        else:
            self.attn = MultiHeadSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    """  the full GPT language model, with a context size of block_size """

    def __init__(self, config):
        super().__init__()

        # input embedding stem
        self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
        self.drop = nn.Dropout(config.embd_pdrop)
        # transformer
        self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        # decoder head
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.block_size = config.block_size
        self.apply(self._init_weights)

        print("number of parameters: {}".format(sum(p.numel() for p in self.parameters())))

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def get_block_size(self):
        return self.block_size

    def forward(self, idx, targets=None):
        b, t = idx.size()
        assert t <= self.block_size, "Cannot forward, model block size is exhausted."

        # forward the GPT model
        token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
        position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)

        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0)

        return logits, loss

# Training the GPT model

We will now "pretrain" our GPT transformer on Wikipedia to predict a notable person's birth place, using the **self-supervised pretext objective of next token prediction** to generate one character at a time. By exposing the transformer to world knowledge, this enables it to perform considerably above chance. 

You can read the accompanied `dataset.py` code for details on how the training data is prepared, which follows the span corruption denoising objective whereby it randomly selects spans of text in a document and replaces them with `MASK` tokens, as outlined in the [T5 paper](https://arxiv.org/pdf/1910.10683.pdf) [4]. 

Example pretraining data (x, y) for next token prediction. `⁇` denotes the masked span/text and `□` denotes padding tokens:

```
x: Khatchig Mouradian. Khatchig Mouradian is a journalist, writer and translator bo⁇non .⁇rn in Leba□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y: hatchig Mouradian. Khatchig Mouradian is a journalist, writer and translator bo⁇non .⁇rn in Leba□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
```

### References
[4] Raffel, C., Shazeer, N.M., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W., & Liu, P.J. (2019). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. J. Mach. Learn. Res., 21, 140:1-140:67.

Run this cell without changes.

In [15]:
import dataset
import trainer

import multiprocess as mp
if mp.get_start_method() is None:
    mp.set_start_method("fork")
else:
    print("A multiprocessing context has been set.")

BLOCK_SIZE = 128
PRETRAIN_CORPUS = 'wiki.txt'
SAVE_CKPT_PATH = 'pretrain.params'

text = open(PRETRAIN_CORPUS).read()
pretrain_dataset = dataset.CharCorruptionDataset(text, BLOCK_SIZE)

model_config = GPTConfig(pretrain_dataset.vocab_size, pretrain_dataset.block_size,
    n_layer=4, n_head=8, n_embd=256)

train_config = trainer.TrainerConfig(max_epochs=1, 
                                     batch_size=16, 
                                     learning_rate=6e-4, 
                                     lr_decay=True, 
                                     warmup_tokens=512*20, 
                                     final_tokens=200*len(pretrain_dataset)*BLOCK_SIZE,
                                     num_workers=1, 
                                     ckpt_path=SAVE_CKPT_PATH)

A multiprocessing context has been set.
data has 418352 characters, 256 unique.


In one epoch (a few minutes), your loss should decrease from ~5 to 2.5±0.2 range.

Run this cell without changes. It should take 183 iterations to complete.

In [16]:
model = GPT(model_config)
gpt_trainer = trainer.Trainer(model, pretrain_dataset, None, train_config)
gpt_trainer.train()

number of parameters: 3323392


epoch 1 iter 183: train loss 2.56776. lr 5.999655e-04: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 184/184 [00:22<00:00,  8.17it/s]
