In [None]:
#| default_exp speedup

In [None]:
#| export
import random, math, torch, numpy as np, matplotlib.pyplot as plt
from tinyai.model import *
from tinyai.learner import *
from tinyai.hooks import *
from tinyai.init import *
import fastcore.all as fc
from functools import partial
import time

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
import tiktoken
import os

enc = tiktoken.get_encoding("gpt2")

def get_tokens(input_file):
    with open(input_file) as f:
        text = f.read()
    tokens = enc.encode(text)
    return tokens

cwd = os.getcwd()
input_file = f"{cwd}/fast-nanogpt/input.txt"
tokens = get_tokens(input_file)[:20000]
train, valid = tokens[:int(len(tokens)*0.8)], tokens[int(len(tokens)*0.8):]

In [None]:
tds = DataSet(torch.tensor(train), T=512)
# vds = DataSet(torch.tensor(valid))
dls = DataLoaders.from_dd([tds, None], batch_size=4)
# dls = DataLoaders.from_dd([tds, vds], batch_size=4)
x, y = next(iter(dls.train))
x.shape, y.shape, len(tds), len(dls.train)

In [None]:
stats = ActivationStats(fc.risinstance(Block))
cbs = [TrainCB(), InitWeightsCB(), DeviceCB(), MetricsCB(), ProgressCB()]
def fit(model, epochs=1, xtra_cbs=None):
    lrn = Learner(model, dls=dls, opt_func=optim.AdamW, cbs=cbs + fc.L(xtra_cbs), lr=3e-4)
    lrn.fit(epochs, valid=False)
    return lrn

In [None]:
??get_model

In [None]:
#| export
import time

class TimeCallback(Callback):
    def before_batch(self, learn):
        self.t0 = time.time()

    def _log(self, d):
        pass

    def after_batch(self, learn):
        t1 = time.time()
        dt = (t1 - self.t0) * 1000
        x, _ = learn.batch
        tokens_per_sec = x.shape[0] * x.shape[1] / (t1 - self.t0)

        print(
            f"step {learn.iter}, loss: {learn.loss.item():.2f}, time: {dt:.2f}msi, tok/sec: {tokens_per_sec:.0f}"
        )

## Baseline

In [None]:
set_seed(1337)
model = get_model()
fit(model, xtra_cbs=[TimeCallback()])

## TODO: what are dtypes
![](https://devblogs.nvidia.com/wp-content/uploads/2020/05/TensorFloat32-TF32.jpg)

In [None]:
# Use TensorFloat32
# Only available for Ampere GPUs
torch.set_float32_matmul_precision('high')

In [None]:
clean_mem()

In [None]:
model = get_model()
fit(model, xtra_cbs=[TimeCallback()])

Enable [auto mixed precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-torch-autocast)

Lower Precision could speed up training and inference time. 
**Precision support matrix**

|             | Ampere                                       | Turing                 | Volta                  |
|-------------|----------------------------------------------|------------------------|------------------------|
| Tensor Core | FP64, TF32, bfloat16, FP16, INT8, INT4, INT1 | FP16, INT8, INT4, INT1 | FP16                   |
| CUDA® Core  | FP64, FP32, FP16, bfloat16, INT8             | FP64, FP32, FP16, INT8 | FP64, FP32, FP16, INT8 |

In [None]:
#| export
torch_dtype_float16 = (
    torch.bfloat16
    if torch.cuda.is_bf16_supported()
    else torch.float16
)

In [None]:
#| export
class MixedPrecisionTrainCB(TrainCB):

    def predict(self, learn):
        with torch.autocast(device_type=default_device, enabled=learn.training, dtype=torch_dtype_float16):
            learn.preds, learn.loss = learn.model(*learn.batch)


In [None]:
torch.cuda.is_bf16_supported()

In [None]:
cbs = [MixedPrecisionTrainCB(), InitWeightsCB(), DeviceCB(), MetricsCB(), ProgressCB()]

In [None]:
fit(model, xtra_cbs=[TimeCallback()])

## TODO: compile
1. gelu example
2. why compile? explain hbm to sm round trip
3. trouble shooting, no speed up for old cards

In [None]:
model = get_model()
model = torch.compile(model)

In [None]:
fit(model, xtra_cbs=[TimeCallback()])

compile primarily helps with memory bandwidth bound workloads by reducing data round trips between HBM and SM, in which case SM is so fast that it keeps waiting for data to arrive from HBM. Older cards may not see much speed up because they are slow on computation.

See https://huggingface.co/docs/transformers/perf_torch_compile for compile speed up benchmarks.

In [None]:
#| export
class CompileCB(Callback):
    def before_fit(self, learn):
        learn.model = torch.compile(learn.model)

## Flash attention

Flash attention is more memory efficient, it never materializes the full attention matrix.


In [None]:
#| export
class FastCausalSelfAttention(CausalSelfAttention):

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, T, C) -> (B, T, nh, hs) -> (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        # attention (materializes the large (T,T) matrix for all the queries and keys)
        # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        # att = F.softmax(att, dim=-1)
        # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

In [None]:
model = GPT(GPTConfig(), proj=ResidualLinear, attn=FastCausalSelfAttention)

In [None]:
fit(model, xtra_cbs=[CompileCB(), TimeCallback()])

## Use kernel friendly numbers

A lot of cuda kernels are written in terms of power of 2, and if the input is not a power of 2, it will spin up a kernel that is a power of 2, and then do some extra work to handle the rest.

So look up the nn code, if a number is power of 2, it is nice number. Otherwise it is a ugly number, and see if you can make increase it to the nearest power of 2.


In [None]:
??GPTConfig

In [None]:
#| export
def get_model():
    return GPT(GPTConfig(vocab_size=50304), proj=ResidualLinear, attn=FastCausalSelfAttention)

In [None]:
model = get_model()
model

In [None]:
fit(model, xtra_cbs=[CompileCB(), TimeCallback()])