# Transformer Architecture

## Problem 1: Root Mean Square Layer Normalization

**Deliverable**: Implement RMSNorm as a `torch.nn.Module`. To test your implementation against our provided test, you will first need to implement the test adapter at [adapters.run_rmsnorm]. Then, run `pytest -k test_rmsnorm` to test your implementation.

In [None]:
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self,d_model,epsilon=1e-5):
        super(RMSNorm, self).__init__()
        self.epsilon = epsilon
        self.d_model = d_model
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x):
        return x / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.epsilon) * self.weight

## Problem 2: Implememnt the position-wise feed-forward network 

- (a) Deliverable: Implement the GELU activation function. To test your implementation against our provided tests, you will need to implement the test adapter at [adapters.run_gelu]. Then, run `pytest -k test_gelu` to test your implementation.

- (b) Deliverable: Implement the position-wise feed-forward network. To test your implememntation, implemement the test adapter at [adpaters.run_poisitonwise_feedforward]. Then, run `pytest -k test_positionwise_feedforward` to test your implementation.

In [None]:
import torch.nn as nn 

class GELU(nn.Module):
    def __init__(self):
        super(GELU, self).__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x**3)))

In [None]:
import torch.nn as nn 

class FFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FFN, self).__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        # self.dropout = nn.Dropout(dropout)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.activation = GELU()
        
    def forward(self, x):
        x = self.activation(self.w1(x))
        # x = self.dropout(x)
        x = self.w2(x)
        return x

## Problem 3: Implement softmax

Deliverable: Write a function to apply the softmax operation on a tensor. Your function should take two parameters: a tensor and a dimension i, and apply softmax to the i-th dimension of the input tensor. The output tensor should have the same shape as the input tensor, but its i-th dimension will now have a normalized probability distribution. Use the same trick as your cross-entropy loss calculation to avoid numerical stability issues.

In [None]:
def softmax(x: torch.Tensor, dim: int):
    e_x = torch.exp(x - torch.max(x, dim=dim, keepdim=True)[0])
    return e_x / e_x.sum(dim=dim, keepdim=True)

## Problem 4: Implement scaled dot-product attention
Deliverable: Implement the scaled dot-product attention function. Your implementation should handle keys and queries of shape `(batch_size, ... , seq_len, d_k)` and values of shape `(batch_size, . . . , seq_len, d_v)`, where . . . represents any number of other batch-like dimensions (if provided). The implementation should return an output with the shape `(batch_size, . . . , d_v)`. See section 3.3 for a discussion on batch-like dimensions.

Your implementation should also support an optional user-provided boolean mask of shape `(seq_len, seq_len)`. The attention probabilities of the masked positions should be zero, and the relative probabilities on the non-masked positions should remain the same.

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attention = softmax(scores, -1)
    return torch.matmul(attention, value), attention

## Problem 5: Implement causal multi-head self-atttention

Implement casual multi-head self-attention as a `torch.nn.Module`. You implementation should accept (at least) the following parameters:

- `d_model`: `int` Dimensionality of the Transformer block inputs.
- `num_heads`: `int` Number of heads to use in multi-head self-attention.
- `attn_pdrop`: `float | None = None` Dropout rate for softmax-normalized attention probabilities. 

Following Vaswani et al, set $d_k = d_v = d_{model} /h$. To test your implementationn against our provided tests, implement the test adapter at [adapters.run_multihead_self_attention]. Then, run `pytest -k test_multihead_self_attention` to test your implementation. 

In [None]:
import torch.nn as nn

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads,attn_pdrop=None):
        super(MultiHeadSelfAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.attention = None
        if attn_pdrop is not None:
            self.attn_dropout = nn.Dropout(attn_pdrop)
        else:
            self.attn_dropout = None
        
    def forward(self, x, mask=None):
        batch_size = x.size(0)
        query = self.query(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        key = self.key(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        value = self.value(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        x, attention = scaled_dot_product_attention(query, key, value, mask)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        self.attention = attention
        x = self.proj(x)
        x = self.attn_dropout(x) if self.attn_dropout is not None else x
        return x

## Problem 6: Implement the Transformer block
Implement the pre-norm Transformer block as described in §3.4 and illustrated in Figure 2. Your Transformer block should accept (at least) the following parameters.

- `d_model: int` Dimensionality of the Transformer block inputs.
- `num_heads: int` Number of heads to use in multi-head self-attention.
- `d_ff: int` Dimensionality of the position-wise feed-forward inner layer.
- `attn_pdrop: float | None = None` Dropout rate for softmax-normalized attention probabilities.
- `residual_pdrop: float | None = None` Dropout rate for embeddings and Transformer block sublayer outputs.

To test your implementation, implement the adapter [adapters.run_transformer_block]. Then run `python -m pytest -k test_transformer_block` to test your implementation.

In [None]:
import torch.nn as nn 

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, attn_pdrop=None, residual_pdrop=None):
        super(TransformerBlock, self).__init__()
        self.attn = MultiHeadSelfAttention(d_model, num_heads, attn_pdrop)
        self.ffn = FFN(d_model, d_ff)
        self.rms_norm = RMSNorm(d_model)
        self.residual_pdrop = residual_pdrop
        self.residual_dropout = nn.Dropout(residual_pdrop)

    def foward(self, x):
        x = x + self.residual_dropout(self.attn(self.rms_norm(x)))
        return x


## Problem 7: Implement the Transformer LM
Time to put it all together! Implement the Transformer language model as described in §3.1 and illustrated in Figure 1. At minimum, your implementation should accept all of the aforementioned construction parameters for the Transformer block, as well as these additional parameters:

- `vocab_size: int` The size of the vocabulary, necessary for determining the dimensionality of the token embedding matrix.
- `context_length: int` The maximum context length, necessary for determining the dimensionality of the position embedding matrix.
- `num_layers: int` The number of Transformer blocks to use.

To test your implementation against our provided tests, you will first need to implement the test adapter at [adapters.run_transformer_lm]. Then, run `pytest -k test_transformer_lm` to test your implementation.

In [None]:
class TransformerLM(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, vocab_size, context_length, attn_pdrop=None, residual_pdrop=None):
        super(TransformerLM, self).__init__()
        self.d_model = d_model
        self.context_length = context_length
        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        self.position_embeddings = nn.Embedding(context_length, d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, num_heads, d_ff, attn_pdrop, residual_pdrop) for _ in range(num_layers)])
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.drop = nn.Dropout(residual_pdrop)
        self.ln_final = RMSNorm(d_model)

    def forward(self, x):
        B,T = x.size()
        positions = torch.arange(T, device=x.device, dtype=x.dtype).expand(B, T)
        x = self.token_embeddings(x) + self.position_embeddings(positions)
        x = self.drop(x)

        for block in self.layers:
            x = block(x)
        
        x = self.ln_final(x)
        x = self.lm_head(x)
        return x

## Problem 8: Transformer LM resource accounting 

(a) Consider GPT-2 XL, which has the following configuration:
```
vocab_size : 50,257  
context_length : 1,024  
num_layers : 48  
d_model : 1,600  
num_heads : 25  
d_ff : 6,400
```
Suppose we constructed our model using this configuration. How many trainable parameters would our model have? Assuming each parameter is represented using single-precision floating point, how much memory is required to just load this model?  

Identify the matrix multiplies required to complete a forward pass of our GPT-2 XL-shaped model. How many FLOPs do these matrix multiplies require in total? Assume that our input sequence has context_length tokens.  

(b) Based on your analysis above, which parts of the model require the most FLOPs?

(c) Repeat your analysis with GPT-2 small (12 layers, 768 d_model, 12 heads), GPT-2 medium (24 layers, 1024 d_model, 16 heads), and GPT-2 large (36 layers, 1280 d_model, 20 heads). As the model size increases, which parts of the Transformer LM take up proportionally more or less of the total FLOPs?

(d) Take GPT-2 XL and increase the context length to 16,384. How does the total FLOPs for one forward pass change? How do the relative contribution of FLOPs of the model components change?

In [22]:
from collections import OrderedDict
from tabulate import tabulate
import pandas as pd

def count_params(
    num_decoder_layer: int = 12,
    context_length: int = 1024,
    n_embd: int = 768,
    ffw_size: int = 3072,
    vocab_size: int = 50257,
) -> OrderedDict[str, int]:
    """estimates the number of parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out["embedding/position"] = n_embd * context_length
    out["embedding/token"] = n_embd * vocab_size
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    # attention blocks
    out["attention/ln"] = n_embd  # note, bias=False in our LN
    out["attention/kqv"] = n_embd * 3 * n_embd
    out["attention/proj"] = n_embd**2
    out["attention"] = out["attention/ln"] + out["attention/kqv"] + out["attention/proj"]

    # MLP blocks
    assert ffw_size == 4 * n_embd, "ffw_size must be 4 * n_embd"
    out["mlp/ln"] = n_embd
    out["mlp/ffw"] = n_embd * ffw_size
    out["mlp/proj"] = ffw_size * n_embd
    out["mlp"] = out["mlp/ln"] + out["mlp/ffw"] + out["mlp/proj"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = num_decoder_layer * out["block"]
    out["ln_f"] = n_embd  # final layernorm
    out["read_out"] = n_embd * vocab_size
    out["dense"] = 0  # 0 because of parameter sharing. This layer uses the weights from the embedding layer

    # total
    out["total"] = out["embedding"] + out["transformer"] + out["ln_f"] + out["dense"] + out["read_out"]

    return out


params = count_params(
    num_decoder_layer=48,
    context_length=1024,
    n_embd=1600,
    ffw_size=6400,
    vocab_size=50257,
)


print(f"Total memory:{params['total']*4/1024/1024/1024}GB")
params = OrderedDict(sorted(params.items(), key=lambda x: x[1]))
data = {
    "Name": params.keys(),
    "Parameters": params.values(),
    "Ratio": [value/params["total"] for value in params.values()],
}
df = pd.DataFrame(data)
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))


Total memory:6.098955869674683GB
+--------------------+------------+-----------------------+
|        Name        | Parameters |         Ratio         |
+--------------------+------------+-----------------------+
|       dense        |     0      |          0.0          |
|    attention/ln    |    1600    | 9.772926062927871e-07 |
|       mlp/ln       |    1600    | 9.772926062927871e-07 |
|        ln_f        |    1600    | 9.772926062927871e-07 |
| embedding/position |  1638400   | 0.001000747628843814  |
|   attention/proj   |  2560000   | 0.0015636681700684594 |
|   attention/kqv    |  7680000   | 0.004691004510205378  |
|      mlp/ffw       |  10240000  | 0.0062546726802738374 |
|      mlp/proj      |  10240000  | 0.0062546726802738374 |
|     attention      |  10241600  |  0.00625564997288013  |
|        mlp         |  20481600  | 0.012510322653153967  |
|       block        |  30723200  |  0.0187659726260341   |
|  embedding/token   |  80411200  |  0.0491157945144566   |
|      

In [23]:
def count_flops(
    num_decoder_blocks: int = 12,
    context_length: int = 1024,
    n_embd: int = 768,
    n_head: int = 12,
    ffw_size: int = 3072,
    vocab_size: int = 50257,
) -> OrderedDict[str, int]:
    # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant
    # we count actual FLOPs, not MACs. Hence 2* all over the place
    # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D

    out = OrderedDict()
    head_size = n_embd // n_head

    # attention blocks
    # 1) the projection to key, query, values
    out["attention/kqv"] = 2 * context_length * (n_embd * 3 * n_embd)
    # 2) calculating the attention scores
    out["attention/scores"] = 2 * context_length * context_length * n_embd
    # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    out["attention/reduce"] = 2 * n_head * (context_length * context_length * head_size)
    # 4) the final linear projection
    out["attention/proj"] = 2 * context_length * (n_embd * n_embd)
    out["attention"] = sum(out["attention/" + k] for k in ["kqv", "scores", "reduce", "proj"])

    # MLP blocks
    ffw_size = 4 * n_embd  # feed forward size
    out["mlp/ffw1"] = 2 * context_length * (n_embd * ffw_size)
    out["mlp/ffw2"] = 2 * context_length * (ffw_size * n_embd)
    out["mlp"] = out["mlp/ffw1"] + out["mlp/ffw2"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = num_decoder_blocks * out["block"]
    out["dense"] = 2 * context_length * (n_embd * vocab_size)

    # forward,backward,total
    out["forward_total"] = out["transformer"] + out["dense"]
    # out["backward_total"] = 2 * out["forward_total"]  # use common estimate of bwd = 2*fwd
    # out["total"] = out["forward_total"] + out["backward_total"]

    out = OrderedDict(sorted(out.items(), key=lambda x: x[1]))
    return out

flops = count_flops(
    num_decoder_blocks=48,
    context_length=1024,
    n_embd=1600,
    n_head=25,
    ffw_size=6400,
    vocab_size=50257,
)

flops = OrderedDict(sorted(flops.items(), key=lambda x: x[1]))
data = {
    "Name": flops.keys(),
    "FLOPs": flops.values(),
    "Ratio": [value/flops["forward_total"] for value in flops.values()],
}
df = pd.DataFrame(data)
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))


+------------------+---------------+-----------------------+
|       Name       |     FLOPs     |         Ratio         |
+------------------+---------------+-----------------------+
| attention/scores |  3355443200   | 0.0009568653688557142 |
| attention/reduce |  3355443200   | 0.0009568653688557142 |
|  attention/proj  |  5242880000   | 0.0014951021388370535 |
|  attention/kqv   |  15728640000  |  0.00448530641651116  |
|     mlp/ffw1     |  20971520000  | 0.005980408555348214  |
|     mlp/ffw2     |  20971520000  | 0.005980408555348214  |
|    attention     |  27682406400  | 0.007894139293059642  |
|       mlp        |  41943040000  | 0.011960817110696428  |
|      block       |  69625446400  |  0.01985495640375607  |
|      dense       | 164682137600  |  0.04696209261970862  |
|   transformer    | 3342021427200 |  0.9530379073802914   |
|  forward_total   | 3506703564800 |          1.0          |
+------------------+---------------+-----------------------+


In [25]:
# GPT-2 small 

n_layer = 12
d_model = 768
n_head = 12
d_ff = 3072

vocab_size = 50257
context_length = 1024

# params = count_params(n_layer, context_length, d_model, d_ff, vocab_size)
flops = count_flops(n_layer, context_length, d_model, n_head, d_ff, vocab_size)

data = {
    "Name": flops.keys(),
    "FLOPs": flops.values(),
    "Ratio": [value/flops["forward_total"] for value in flops.values()],
}
df = pd.DataFrame(data)
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))


+------------------+--------------+----------------------+
|       Name       |    FLOPs     |        Ratio         |
+------------------+--------------+----------------------+
|  attention/kqv   |  3623878656  | 0.012425508965889174 |
| attention/scores |  1610612736  | 0.005522448429284077 |
| attention/reduce |  1610612736  | 0.005522448429284077 |
|  attention/proj  |  1207959552  | 0.004141836321963058 |
|    attention     |  8053063680  | 0.027612242146420385 |
|     mlp/ffw1     |  4831838208  | 0.016567345287852232 |
|     mlp/ffw2     |  4831838208  | 0.016567345287852232 |
|       mlp        |  9663676416  | 0.033134690575704465 |
|      block       | 17716740096  | 0.06074693272212485  |
|   transformer    | 212600881152 |  0.7289631926654981  |
|      dense       | 79047426048  |  0.2710368073345018  |
|  forward_total   | 291648307200 |         1.0          |
+------------------+--------------+----------------------+


In [27]:
# GPT-2 medium 

n_layer = 24
d_model = 1024
n_head = 16
d_ff = d_model * 4

vocab_size = 50257
context_length = 1024

# params = count_params(n_layer, context_length, d_model, d_ff, vocab_size)
flops = count_flops(n_layer, context_length, d_model, n_head, d_ff, vocab_size)

data = {
    "Name": flops.keys(),
    "FLOPs": flops.values(),
    "Ratio": [value/flops["forward_total"] for value in flops.values()],
}
df = pd.DataFrame(data)
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))


+------------------+--------------+-----------------------+
|       Name       |    FLOPs     |         Ratio         |
+------------------+--------------+-----------------------+
|  attention/kqv   |  6442450944  | 0.0077906071449402895 |
| attention/scores |  2147483648  |  0.00259686904831343  |
| attention/reduce |  2147483648  |  0.00259686904831343  |
|  attention/proj  |  2147483648  |  0.00259686904831343  |
|    attention     | 12884901888  | 0.015581214289880579  |
|     mlp/ffw1     |  8589934592  |  0.01038747619325372  |
|     mlp/ffw2     |  8589934592  |  0.01038747619325372  |
|       mlp        | 17179869184  |  0.02077495238650744  |
|      block       | 30064771072  |  0.03635616667638802  |
|   transformer    | 721554505728 |  0.8725480002333125   |
|      dense       | 105396568064 |  0.12745199976668756  |
|  forward_total   | 826951073792 |          1.0          |
+------------------+--------------+-----------------------+


In [26]:
# GPT-2 large 

n_layer = 36
d_model = 1280
n_head = 20
d_ff = 1280*4

vocab_size = 50257
context_length = 1024

# params = count_params(n_layer, context_length, d_model, d_ff, vocab_size)
flops = count_flops(n_layer, context_length, d_model, n_head, d_ff, vocab_size)

data = {
    "Name": flops.keys(),
    "FLOPs": flops.values(),
    "Ratio": [value/flops["forward_total"] for value in flops.values()],
}
df = pd.DataFrame(data)
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))


+------------------+---------------+-----------------------+
|       Name       |     FLOPs     |         Ratio         |
+------------------+---------------+-----------------------+
|  attention/kqv   |  10066329600  | 0.0056725435596688065 |
| attention/scores |  2684354560   | 0.0015126782825783482 |
| attention/reduce |  2684354560   | 0.0015126782825783482 |
|  attention/proj  |  3355443200   | 0.0018908478532229354 |
|    attention     |  18790481920  | 0.010588747978048438  |
|     mlp/ffw1     |  13421772800  | 0.007563391412891742  |
|     mlp/ffw2     |  13421772800  | 0.007563391412891742  |
|       mlp        |  26843545600  | 0.015126782825783483  |
|      block       |  45634027520  |  0.02571553080383192  |
|   transformer    | 1642824990720 |  0.9257591089379492   |
|      dense       | 131745710080  |  0.07424089106205083  |
|  forward_total   | 1774570700800 |          1.0          |
+------------------+---------------+-----------------------+


In [28]:
# GPT-2 XL 

n_layer = 48
d_model = 1600
n_head = 48
d_ff = d_model*4

vocab_size = 50257
context_length = 16384

# params = count_params(n_layer, context_length, d_model, d_ff, vocab_size)
flops = count_flops(n_layer, context_length, d_model, n_head, d_ff, vocab_size)

data = {
    "Name": flops.keys(),
    "FLOPs": flops.values(),
    "Ratio": [value/flops["forward_total"] for value in flops.values()],
}
df = pd.DataFrame(data)
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))


+------------------+-----------------+-----------------------+
|       Name       |      FLOPs      |         Ratio         |
+------------------+-----------------+-----------------------+
|  attention/kqv   |  251658240000   | 0.0018921053119957884 |
| attention/scores |  858993459200   | 0.006458386131612291  |
| attention/reduce |  850403524608   | 0.006393802270296168  |
|  attention/proj  |   83886080000   | 0.0006307017706652628 |
|    attention     |  2044941303808  |  0.01537499548456951  |
|     mlp/ffw1     |  335544320000   | 0.0025228070826610514 |
|     mlp/ffw2     |  335544320000   | 0.0025228070826610514 |
|       mlp        |  671088640000   | 0.005045614165322103  |
|      block       |  2716029943808  | 0.020420609649891612  |
|   transformer    | 130369437302784 |  0.9801892631947974   |
|      dense       |  2634914201600  |  0.01981073680520257  |
|  forward_total   | 133004351504384 |          1.0          |
+------------------+-----------------+-----------------