# Part 2: Optimizing the Model

So far, we've initialized the GPT-2 model as described in the paper. Now we would want to optimize the model such that we speed up training and potentially, get better performance. So, we'll start with the code we had by the end of last part and then build on that. 

Be aware though, that these optimizations are not going to work on the CPU but rather a GPU.

In [2]:
# imports
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
import math

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

device

'cpu'

In [3]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

class CausalSelfAttention(nn.Module):
    def __init__(self, config:GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1 # Just to identify the layer which we want to scale down by 1 / sqrt(N)

        self.n_head = config.n_head
        self.n_embd = config.n_embd

        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size)) 

    def forward(self, x):
        B, T, C = x.size() # Batch Size, Sequence Length, Embedding Dim

        qkv = self.c_attn(x)

        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.c_proj(y)

        return y

class MLP(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1 # Just to identify the layer which we want to scale down by 1/sqrt(N)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()

        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):

    def __init__(self, config:GPTConfig):
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd), 
            wpe = nn.Embedding(config.block_size, config.n_embd), 
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 
            ln_final = nn.LayerNorm(config.n_embd)
        )) 

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights) # apply func iterates over all submodules, and calls _init_weights on it

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (self.config.n_layer * 2) ** -0.5 # Scale down the weights by 1/sqrt(N) 

            torch.nn.init.normal_(module.weight, mean=0, std=std) # initialize linear layer with zero mean and 0.02 stdev
        
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias) # Initialize the bias to zero if it exists
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0, std=0.02) # Embedding layer also init with zero mean and 0.02 stdev

        # We're not initializing the LayerNorm because the PyTorch default is what GPT-2 has also used

    def forward(self, idx, targets=None):
        B, T = idx.size()

        assert T <= self.config.block_size

        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = pos_emb + tok_emb

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_final(x)

        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

In [4]:
# Dataloader
class DataLoaderLite:
    def __init__(self, B, T):
        self.B = B
        self.T = T

        with open('input.txt', 'r') as f:
            text = f.read()

        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)

        print(f"1 Epoch = {len(self.tokens) // (B * T)} Batches") # In one epoch, we're going to see these many batches, and then start again

        self.current_position = 0
    
    def next_batch(self):
        B, T = self.B, self.T

        buf = self.tokens[self.current_position: self.current_position + (B*T+1)]
        x = (buf[:-1]).view(B, T)
        y = (buf[1:]).view(B, T)

        self.current_position += B * T

        # Start again if you reach the end of the dataset
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        
        return x, y

In [21]:
# training loop
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

model = GPT(GPTConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
train_loader = DataLoaderLite(B=4, T=32)

# Iterate for some epochs and optimize
for i in range(50):
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    logits, loss = model(x, y)
    loss.backward()
    optimizer.step()
print(f"After 50 steps, Loss is: {loss.item()}")

1 Epoch = 2640 Batches
After 50 steps, Loss is: 6.799213886260986


When optimizing, you should always start with:

1. What hardware do you have?
2. What does it offer?
3. Are you fully utilizing it?

In PyTorch, by default, all tensors are of dtype float32. That is, each tensor is occupying 32 bits of memory- including parameters, activations, etc. Empirically, for deep learning, 32 bit float representation is too much. You can lower this precision for deep learning but still get good enough models. This can speed up quite a bit because you have much lower memory to move around. Because memory bandwidth is the bottleneck for GPU workloads. That is, most of the time, the tensor cores just sit idle because you're moving memory around to feed them. If you're getting 60% of hardware utilization, you're doing quite well.

For training, you still want floats (i.e. not integers), but precision can be lower. However, during inference time, you can use integer precision and still get decent results.

Inspect your hardware and GPU. How many TFLOPS does it offer theoretically? 1 TFLOPS = 1 Trillion Floating Point Operations. If you go down the precision, the FLOPS increases quite a bit.

### Tensor Cores

What are tensor cores? Tensor cores are basically simple instructions for the GPU. These do a $4 \times 4$ matrix multiplication. That is, when you pass a big matrix multiplication to GPU, it breaks down the matrix multiplication into these $4 \times 4$ units, and does this small matmul in parallel. And deep learning, is mostly matrix multiplication!

For reference, look at the white paper on the GPU architecture that you are using.

### Floating Point Precisions

1. FP32 is 32 bits
2. TF32 is again 32 bits. Exactly same as FP32 but the mantissa is truncated by 13 bits- exponent stays same. Thus, 19 Bits are used. This is all handled internally in hardware- so nothing has to change in your code, which means in your code the dtype is still float32. Empirically, this is almost the same as FP32.
3. FP16: This truncates the exponent as well as the mantissa. This is bad! You can lose a lot of important information during training with this. Range of numbers itself is changed. Historically, this came first and thus, it's there. This required gradient scalers.
4. BF16: Exponent is the same. But mantissa is cropped even further. Only 7 mantissa bits. Range of numbers is the same, but the precision is low. This doesn't require any gradient scalers, etc. We lose some precision but it runs faster so you can train for longer to make up for that precision. `autocast` in PyTorch does this. Not all tensors are converted to bfloat16. Normalization, loss calculations, etc. are not converted. PyTorch has some internal rules. 

### GPU and CPU

CPU just orders work on the GPU. That is, it queues up a lot of work on the GPU. So, if you want to wait for the GPU to finish it's execution, you need to use: `torch.cuda.synchronize()`.

Also, you don't want to let the space on GPU go free. So keep increasing the batch size by powers of two till you can squeeze your model training on the GPU (i.e. basically, right before you get out of memory errors).

When measuring speed, tokens processed per second is a more objective measure than time.

### Timing: Caveat

1. If you're timing your work, be careful to use `torch.cuda.synchronize()`.
2. Don't rely too much on the first time. Because PyTorch internally may be doing a lot of initialization, etc., which may slow the process. 

In [7]:
# This sets precision to TF32. Include at the top of your script
torch.set_float32_matmul_precision('high')

import time

train_loader = DataLoaderLite(B=1, T=1024)
model = GPT(GPTConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for i in range(50):
    t0 = time.time()
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    optimizer.zero_grad()
    logits, loss = model(x, y)
    loss.backward()

    optimizer.step()
    if device == 'cuda':
        torch.cuda.synchronize()
    t1 = time.time()

    dt = (t1-t0) * 1000 # Time diff in milliseconds

    tokens_per_sec = (train_loader.B * train_loader.T) / (t1- t0)

    print(f"Step {i} Loss: {loss.item()} Tok / Sec: {tokens_per_sec}")

1 Epoch = 330 Batches
Step 0 Loss: 10.953959465026855 Tok / Sec: 111.19398488516616
Step 1 Loss: 9.50391674041748 Tok / Sec: 118.7039687585358
Step 2 Loss: 9.123509407043457 Tok / Sec: 117.93089306876011
Step 3 Loss: 8.86156177520752 Tok / Sec: 118.89022221651372
Step 4 Loss: 8.8906831741333 Tok / Sec: 116.95544750490332
Step 5 Loss: 8.379107475280762 Tok / Sec: 120.66237026694736
Step 6 Loss: 8.229143142700195 Tok / Sec: 117.54321952043479
Step 7 Loss: 8.006213188171387 Tok / Sec: 116.16962268089722
Step 8 Loss: 7.927321910858154 Tok / Sec: 119.23626511306877
Step 9 Loss: 7.698894500732422 Tok / Sec: 117.02392237857198
Step 10 Loss: 7.291344165802002 Tok / Sec: 117.91372372453168
Step 11 Loss: 7.626772403717041 Tok / Sec: 120.76976891744894
Step 12 Loss: 7.092843532562256 Tok / Sec: 119.29011028572712
Step 13 Loss: 7.251816749572754 Tok / Sec: 120.72310333425415
Step 14 Loss: 6.97932767868042 Tok / Sec: 121.88466098665633
Step 15 Loss: 6.898041725158691 Tok / Sec: 117.43754670194578
S

### torch.compile

This compiles your PyTorch code- speeds up the code that is slower due to python overhead and GPU read-write ( finds out optimizations in your code and applies them )! So makes your code go much faster. Compilation may take some time but the running of the model will be much faster.

Kernel Fusion:

Basically, you want to avoid moving data from HBM to GPU again and again. So you want to look at multiple operations on the same data that you can do without moving the memory again. That is, move it once, do a bunch of operations, and move it back. ( As opposed to move it once then do one operation and move it back, and repeat. )

Almost always use `torch.compile` unless debugging.

I've also included the bfloat16 code in here.

In [None]:
import time

train_loader = DataLoaderLite(B=1, T=1024)
model = GPT(GPTConfig())
model.to(device)
model = torch.compile(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for i in range(5):
    t0 = time.time()
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)

    optimizer.zero_grad()

    # Only loss calc and forward pass needs to be wrapped in here. Nothing else. Again, will work with GPUs
    with torch.autocast(device_type=device, dtype=torch.bfloat16):
        logits, loss = model(x, y)
    loss.backward()

    optimizer.step()
    if device == 'cuda':
        torch.cuda.synchronize()
    t1 = time.time()

    dt = (t1-t0) * 1000 # Time diff in milliseconds

    tokens_per_sec = (train_loader.B * train_loader.T) / (t1- t0)

    print(f"Step {i} Loss: {loss.item()} Tok / Sec: {tokens_per_sec}")

### Flash Attention

Not all operations are found by `torch.compile`. For example, it does not find flash attention. Flash attention is an exact, full self-attention but is a architecture aware. Fast attention is again a kernel-fusion operation. So faster. To add flash attention, you have to change a bit of code in the `CausalSelfAttention` class. Flash attention fuses $softmax(\frac{QK^T}{\sqrt{d}})V$ into one fused operation. It's a different algorithm, and thus `torch.compile()` doesn't find it. Further, flash attention does more FLOPS than self-attention but since it is a kernel fusion operation, it doesn't have a lot of read-writes to the HBM.

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config:GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1 # Just to identify the layer which we want to scale down by 1 / sqrt(N)

        self.n_head = config.n_head
        self.n_embd = config.n_embd

        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size)) 

    def forward(self, x):
        B, T, C = x.size() # Batch Size, Sequence Length, Embedding Dim

        qkv = self.c_attn(x)

        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # The attention multiplication: Replace it with F.scaled_dot_product_attention
        # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        # att = F.softmax(att, dim=-1)
        # y = att @ v

        # Flash attention
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.c_proj(y)

        return y

### Ugly Numbers, Nice Numbers

**Nice numbers:** The numbers that have many powers of 2 in them. That is, you can divide those numbers by two a lot of times. 64 is awesome. 24 is also decent.

**Ugly Numbers:** The opposite of nice numbers- the numbers that cannot be divided by 2 a lot of times. Prime numbers are the worst, odd numbers are also bad. For example, 13 and 25 are ugly numbers. 

The reason we're discussing this is that a lot of kernels in CUDA are implemented in terms of powers of two (*block tiles*). But if your input is not a nice number, then CUDA kernels have boundary kernels written for the remaining part and these boundary kernels can be a lot slower.

The simple fix for this: Scan your code for nice and ugly numbers. If you find ugly numbers, replace them with the next power of two. Yes you will be increasing the number of parameters in your model, and the number of FLOPs will increase but still your code will run much faster.

In our code, `vocab_size` = 50,257, which is a very ugly number. So we increase it to 50,304 which has a lot of powers of two in it. (Note: you can't randomly change the sizes and hope that all works well. You may have to check if it is breaking anything. In this case, it doesn't so we can simply replace vocab_size with the new one). We're almost adding fake tokens and the model will drive their probability to zero but because we made it a nice number, we will do more FLOPs but get faster time due to nice numbers. 

`torch.compile()` doesn't find this also.

Depending PyTorch version, the impact of these numbers can be huge. 

In [None]:
train_loader = DataLoaderLite(B=1, T=1024)
model = GPT(GPTConfig(vocab_size=50304)) # Overwrite vocab_size with 50304
model.to(device)
model = torch.compile(model)

## Rest of the code can stay the same