# Benchmarking

This notebook provides official time and memory benchmarking experiments on Switch vs. vanilla decoder transformers and various architectural and training hyperparameters.

In [1]:
import gc
import itertools
import json
import random
import time

import warnings

warnings.simplefilter(action="ignore", category=(FutureWarning, UserWarning))

# !pip -q install einops
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.profiler import profile, record_function, ProfilerActivity
from tqdm import tqdm

# Models and data setup

In [2]:
path = "./"
root = path

SEED = 23

LR = 1e-3
BATCH_SIZE = 16
BENCH_ITERS = 100  # max num batches to benchmark

#### ^ 2000 later

PRINT_ITERS = 100  # frequency to print avg train loss
EVAL_ITERS = 250  # frequency to evaluate val loss
EVAL_ITER_COUNT = 50  # number of batches to estimate val loss with
# given a 10% val split, we have 111540 char, so 50 batches * batch size 16 * if seq len 128
# = roughly equal to all chars chosen

# automatic mixed precision (will be disabled if CPU, not available)
USE_AMP = True

### will not benchmark/modify: switch transformer auxiliary loss coefficient
AUX_LOSS_COEF = 0.01

In [3]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

root = root if drive is None else "/content/drive/MyDrive/self-learn/moe-kit"
path = root

# cannot train in mixed precision on CPU (GradScaler needs cuda)
USE_AMP = USE_AMP if device.type == "cuda" else False
# Tesla T4 does not support bfloat16, CPU does not support float16
AMP_DTYPE = torch.float16 if device.type == "cuda" else torch.bfloat16

In [4]:
import sys

sys.path.append(root)

from utils import set_seed
from models.transformer import MLP, Transformer

In [5]:
set_seed(SEED)

In [6]:
with open(f"{root}/data/tiny-shakespeare.txt", "r") as f:
    text = f.read()

chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Prepare mappings / tokenizer
# create a mapping from characters to integers
txt2idx = {ch: i for i, ch in enumerate(chars)}
idx2txt = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [txt2idx[c] for c in s]
decode = lambda l: "".join([idx2txt[i] for i in l])

# tokenizer data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))  # 90-10 split
train_data = data[:n]
val_data = data[n:]


def get_batch(split, seq_len):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - seq_len, (BATCH_SIZE,))
    x = torch.stack([data[i : i + seq_len] for i in ix])
    y = torch.stack([data[i + 1 : i + seq_len + 1] for i in ix])
    return x.to(device), y.to(device)

In [7]:
def calc_ce_loss(logits, targets):
    """
    TODO
    Computes cross-entropy loss.
    Inputs:
        -logits: Model output of shape (B, S, vocab_size)
        -counts:
    """
    B, S, C = logits.shape
    logits = logits.view(B * S, C)
    targets = targets.view(B * S)
    loss = F.cross_entropy(logits, targets)
    return loss

In [8]:
def calc_aux_loss(counts, prob_sum, n_experts):
    """
    Computes Switch Transformer auxiliary loss.
    Inputs:
        -counts: Number of tokens passed to each expert in each switch layer (num_switch_layers x n_experts)
        Note this is NOT equivalent to n_layer; num_switch_layers depends on `switch_first` and `every_n_switch`
        -prob_sum: Sum of probs across all tokens for each expert (num_switch_layers x n_experts)
    """

    # total number of tokens routed in that layer
    token_count = counts.sum(dim=-1, keepdims=True)

    # prop of tokens dispatched to each expert
    route_frac = counts / token_count

    # fraction of total probability allocated for each expert
    # recall prob_sum := softmaxed probs, which added to 1 across the experts for each token
    # we divide by num_tokens so that the overall 2D scalar sum of prob_frac is 1
    # intuitively we are forcing the total prob for each layer across the experts to be 1 so we can take proportions,
    # the same way as above
    prob_frac = prob_sum / token_count

    # Auxiliary loss
    # L = N \sum_{i=1}^N f_i • P_i
    aux_loss = n_experts * (route_frac * prob_frac).sum()
    return aux_loss

In [9]:
@torch.no_grad()
def estimate_loss(model, val_losses, val_aux_losses, device):
    model.eval()
    losses = torch.zeros(EVAL_ITER_COUNT)
    aux_losses = torch.zeros(EVAL_ITER_COUNT)
    for k in range(EVAL_ITER_COUNT):
        inputs, targets = get_batch("test", model.seq_length)
        with torch.autocast(device_type=device.type, dtype=AMP_DTYPE, enabled=USE_AMP):
            if model.switch:
                logits, counts, prob_sum, n_dropped = model(inputs)
                losses[k] = calc_ce_loss(logits, targets)
                aux_losses[k] = calc_aux_loss(counts, prob_sum, model.n_experts)
                losses[k] += AUX_LOSS_COEF * aux_losses[k]
            else:
                logits = model(inputs)
                losses[k] = calc_ce_loss(logits, targets)
    val_loss, val_aux_loss = losses.mean().item(), aux_losses.mean().item()
    val_losses.append(val_loss)
    val_aux_losses.append(val_aux_loss)  # track separate aux loss for logging
    # keep model in eval
    print(f"Est. Val Loss: {val_loss:.3f} | Est. Aux Loss: {val_aux_loss:.3f}")
    return val_losses, val_aux_losses

In [10]:
def generate(model, empty_tokens, num_tokens):
    generation_text = decode(
        model.generate(
            empty_tokens, method="top-k", k=5, max_new_tokens=num_tokens, uncond=True
        )[0].tolist()
    )
    return generation_text

## Benchmarking functions

In [11]:
def get_model(config):
    set_seed(SEED)
    model = Transformer(
        VOCAB_SIZE,
        config["seq_len"],
        config["n_embd"],
        config["n_head"],
        config["n_embd"] * config["n_ff_ratio"],
        config["n_layer"],
        device=device,
        n_kv_head=config["n_kv_head"],
        norm_first=config["norm_first"],
        use_rotary_embd=config["use_rotary_embd"],
        softmax_off_by_one=config["softmax_off_by_one"],
        switch=config["switch"],
        switch_first=config["switch_first"],
        every_n_switch=config["every_n_switch"],
        capacity_factor=config["capacity_factor"],
        drop_tokens=config["drop_tokens"],
        n_experts=config["n_experts"],
        expert=MLP,
        use_amp=USE_AMP,
        amp_dtype=AMP_DTYPE,
        activation=config["activation"],
        mlp_dropout=config["mlp_dropout"],
        expert_dropout=config["expert_dropout"],
        scale=config["scale"],
    )
    return model

In [12]:
def bench_train(config, device):

    # Initialize model and training setup
    model = get_model(config)

    scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

    # num params
    num_params = sum([p.numel() for p in model.parameters()])
    if not model.switch:
        num_active_params = num_params
    else:
        switch_additional_params = 0
        switch_layer_names = [f"experts.{i}" for i in range(1, model.n_experts)] + [
            "switch"
        ]
        for (name, layer) in model.named_parameters():
            # if experts.i (i > 0) or switch in name, subtract from total count
            if any(substr in name for substr in switch_layer_names):
                switch_additional_params += layer.numel()
        num_active_params = num_params - switch_additional_params

    # model size in bytes
    model_size = sum(
        [
            p.numel() * p.dtype.itemsize
            for p in itertools.chain(model.parameters(), model.buffers())
        ]
    )

    # GQA
    gqa = model.is_gqa

    # Train fn
    train_losses = []
    val_losses = []
    train_times = []
    # below two are only relevant if benchmarking switch
    val_aux_losses = []
    dropped = []

    model.train()
    model.to(device)

    # text generation setup
    empty_tokens = torch.zeros((1, 1), dtype=torch.long).to(device)

    for step in range(BENCH_ITERS):

        start = time.perf_counter()

        optimizer.zero_grad(set_to_none=True)
        inputs, targets = get_batch("train", model.seq_length)

        with torch.autocast(device_type=device.type, dtype=AMP_DTYPE, enabled=USE_AMP):
            if model.switch:
                logits, counts, prob_sum, n_dropped = model(inputs)
                loss = calc_ce_loss(logits, targets)
                aux_loss = calc_aux_loss(counts, prob_sum, model.n_experts)
                loss += AUX_LOSS_COEF * aux_loss
                drop_frac = (
                    np.array(n_dropped) / (BATCH_SIZE * model.seq_length)
                ).tolist()
                dropped.append(drop_frac)
            else:
                logits = model(inputs)
                loss = calc_ce_loss(logits, targets)

        train_time = time.perf_counter() - start
        tokens_per_sec = (1 / train_time) * BATCH_SIZE * model.seq_length
        train_times.append(tokens_per_sec)

        train_losses.append(loss.item())

        scaler.scale(loss).backward()

        # Monitor gradient norm
        scaler.unscale_(optimizer)

        with torch.autocast(device_type=device.type, dtype=AMP_DTYPE, enabled=USE_AMP):
            grads = [
                param.grad.detach().flatten()
                for param in model.parameters()
                if param.grad is not None
            ]
            norm = torch.cat(grads).norm()

        scaler.step(optimizer)
        scaler.update()

        # print training statistics
        if step % PRINT_ITERS == 0 and step != 0:
            print(
                f"Step {step}/{BENCH_ITERS} | Running Avg Train Loss: {np.mean(train_losses):.3f} |",
                f"Grad Norm: {norm:.2f} | Running Avg Tokens/Sec: {np.mean(train_times):.2f} |",
                f"Bandwidth: {model_size * np.mean(train_times) / 1e9:.2f} GB/s",
            )

        # estimate val loss
        if step % EVAL_ITERS == 0 and step != 0:
            val_losses, val_aux_losses = estimate_loss(
                model, val_losses, val_aux_losses, device
            )
            model.train()

    generated_text = generate(model, empty_tokens, num_tokens=500)

    mem_data = {
        "reserved_memory": torch.cuda.memory_reserved(0) / 1e9,
        "allocated_memory": torch.cuda.memory_allocated(0) / 1e9,
    }

    del model
    torch.cuda.empty_cache()
    gc.collect()
    return (
        train_losses,
        val_losses,
        train_times,
        val_aux_losses,
        dropped,
        generated_text,
        mem_data,
        num_params,
        num_active_params,
        model_size,
        gqa,
    )

In [13]:
@torch.no_grad()
def bench_inference(config, device, num_tokens=500, gen_count=5):

    # Initialize model and training setup
    model = get_model(config)

    model.eval()
    model.to(device)

    empty_tokens = torch.zeros((1, 1), dtype=torch.long).to(device)

    gen_times = []

    # model generation params, fixed
    p_nucleus = None
    k = None
    if config["decode_method"] == "nucleus":
        p_nucleus = 0.5
    elif config["decode_method"] == "top-k":
        k = 5

    for i in range(gen_count):
        start = time.perf_counter()
        res = model.generate(
            empty_tokens,
            max_new_tokens=num_tokens,
            method=config["decode_method"],
            p_nucleus=p_nucleus,
            k=k,
        )
        gen_times.append(time.perf_counter() - start)

    mem_data = {
        "reserved_memory": torch.cuda.memory_reserved(0) / 1e9,
        "allocated_memory": torch.cuda.memory_allocated(0) / 1e9,
    }

    del model
    torch.cuda.empty_cache()
    gc.collect()
    return gen_times, mem_data

#### ––––

## Driver code

## TODO: 
        -Training: architecture vs. samples/s, loss vs. samples seen (sample efficiency)
        -Inference: Decoding method vs. tokens/s, architecture vs tokens/s
        
        -Architectures to-vary: switch (1/2/4/8 experts, capacity) vs. vanilla, hidden_dim/n_layer/n_head, total_params, ROPE vs. cos/sin PE, etc.
        -Also, benchmark memory, e.g. for effect of GQA
        -Roadmap: Benchmark specific time/memory spent for MLP vs. switch FF, attention vs. MLP proportions within a block, etc.

## todo: be more selective w/ architecture params lol

In [156]:
from itertools import product

metric_keys = [
    "Train Loss",
    "Val Loss",
    "Train Tokens/s",
    "Bandwidth (GB/s)",
    "Train Memory (GB)",
    "Inference Tokens/s",
    "Inference Memory (GB)",
]
hp_keys = [
    "seq_len",
    "n_embd",
    "n_ff_ratio",
    "n_head",
    "n_kv_head",
    "n_layer",
    "norm_first",
    "use_rms_norm",
    "switch_first",
    "use_rotary_embd",
    "softmax_off_by_one",
    "switch",
    "every_n_switch",
    "capacity_factor",
    "drop_tokens",
    "n_experts",
    "activation",
    "mlp_dropout",
    "expert_dropout",
    "scale",
    "decode_method",
]
param_keys = ["num_params", "num_active_params"]
other_keys = ["gqa", "top5-generated_text"]

seq_len = np.array([128])
n_embd = np.array([128, 256, 384])
n_ff_ratio = np.array([2, 4])
n_head = np.array([4, 8])
n_kv_head = np.array([1, 2, 4, 8])
n_layer = np.array([4, 6])
norm_first = np.array([True])
use_rms_norm = np.array([True, False])
switch_first = np.array([True])
use_rotary_embd = np.array([True, False])
softmax_off_by_one = np.array([True, False])
switch = np.array([True, False])
every_n_switch = np.array([2])
capacity_factor = np.array([1, 1.25])
drop_tokens = np.array([True])
n_experts = np.array([2, 4])
activation = np.array(["GELU"])
mlp_dropout = np.array([0.1])
expert_dropout = np.array(
    [0.2]
)  #### HMMMMMM. want to check this first. do in TODO manual benchmarking ipynb!
scale = np.array([0.5])  # RoPE
decode_method = np.array(["multinomial", "nucleus", "top-k"])

configs = [
    dict(zip(hp_keys, combo))
    for combo in product(
        seq_len,
        n_embd,
        n_ff_ratio,
        n_head,
        n_kv_head,
        n_layer,
        norm_first,
        use_rms_norm,
        switch_first,
        use_rotary_embd,
        softmax_off_by_one,
        switch,
        every_n_switch,
        capacity_factor,
        drop_tokens,
        n_experts,
        activation,
        mlp_dropout,
        expert_dropout,
        scale,
        decode_method,
    )
]

df = pd.DataFrame(columns=metric_keys + param_keys + hp_keys + other_keys)

In [157]:
def filter_configs(configs):
    """
    Filter illogical config combinations.
    """
    res = []
    for config in configs:
        too_many_kv_heads = config["n_kv_head"] > config["n_head"]
        no_switch_ever = config["every_n_switch"] > config["n_layer"]
        switch_first_but_also_no = config["switch_first"] and not config["switch"]
        scale_but_no_rotary = config["scale"] != 1 and not config["use_rotary_embd"]
        conds = [
            too_many_kv_heads,
            no_switch_ever,
            switch_first_but_also_no,
            scale_but_no_rotary,
        ]
        if any(conds):
            continue
        res.append(config)
    return res

In [158]:
configs = filter_configs(configs)
set_seed(SEED)
random.shuffle(configs)
len(configs)

672

In [160]:
df

Unnamed: 0,Train Loss,Train Tokens/s,Bandwidth (GB/s),Train Memory,Inference Tokens/s,Inference Memory,num_params,num_active_params,seq_len,n_embd,...,every_n_switch,capacity_factor,drop_tokens,n_experts,activation,mlp_dropout,expert_dropout,scale,gqa,generated_text


### Train + Inference

In [161]:
def set_config_df(row, config):
    """
    Set config parameters into the specified DataFrame row.
    """
    row = row.copy()
    for k, v in config.items():
        row[k] = v
    return row

In [162]:
for i, config in enumerate(configs):

    if i > 1:
        break

    ## Train
    print(f"Training config {i+1}/{len(configs)}")
    (
        train_losses,
        val_losses,
        train_times,
        val_aux_losses,
        dropped,
        generated_text,
        mem_data,
        num_params,
        num_active_params,
        model_size,
        gqa,
    ) = bench_train(config, device=device)

    ## Inference
    print(f"Performing inference \n")
    NUM_TOKENS = 20  ####
    gen_times, gen_mem_data = bench_inference(
        config, device=device, num_tokens=NUM_TOKENS, gen_count=3
    )

    # Compute stats
    avg_train_loss = np.mean(train_losses)
    avg_val_loss = np.mean(val_losses)
    tokens_per_sec = np.mean(train_times)
    bandwidth = model_size * tokens_per_sec / 1e9  # GB/s

    gen_tokens_per_sec = NUM_TOKENS / np.mean(gen_times)

    ## Update results dataframe
    df.loc[i] = {
        "Train Loss": avg_train_loss,
        "Val Loss": avg_val_loss,
        "Train Tokens/s": tokens_per_sec,
        "Bandwidth (GB/s)": bandwidth,
        "Train Memory (GB)": mem_data,
        "Inference Tokens/s": gen_tokens_per_sec,
        "Inference Memory (GB)": gen_mem_data,
        "num_params": num_params,
        "num_active_params": num_active_params,
        "gqa": gqa,
        "top5-generated_text": generated_text,
    }
    df.loc[i] = set_config_df(df.loc[i], config)
    df.to_csv(f"{path}/benchmark_logs/bench_stats.csv")

Training config 1/672
Performing inference
Training config 2/672
Performing inference


In [163]:
df

Unnamed: 0,Train Loss,Train Tokens/s,Bandwidth (GB/s),Train Memory,Inference Tokens/s,Inference Memory,num_params,num_active_params,seq_len,n_embd,...,every_n_switch,capacity_factor,drop_tokens,n_experts,activation,mlp_dropout,expert_dropout,scale,gqa,generated_text
0,7.455641,6196.759307,125.873062,"{'reserved_memory': 0, 'allocated_memory': 0}",126.404787,"{'reserved_memory': 0, 'allocated_memory': 0}",5078085,3894593,128.0,384.0,...,2.0,1.0,True,2.0,GELU,0.1,0.4,0.5,True,"Rt'IcEShOII spt, bemehithethesthy bCI thy mes ..."
1,6.854943,5442.228753,181.301473,"{'reserved_memory': 0, 'allocated_memory': 0}",118.746279,"{'reserved_memory': 0, 'allocated_memory': 0}",8328265,4779329,128.0,384.0,...,2.0,1.25,True,4.0,GELU,0.1,0.4,0.5,False,AOOOOAMOOLOEOEOOOEOLOEOEOOLOEEEAOEEEEEEEEOEEOo...


In [164]:
df.iloc[1]

Train Loss                                                     6.854943
Train Tokens/s                                              5442.228753
Bandwidth (GB/s)                                             181.301473
Train Memory              {'reserved_memory': 0, 'allocated_memory': 0}
Inference Tokens/s                                           118.746279
Inference Memory          {'reserved_memory': 0, 'allocated_memory': 0}
num_params                                                      8328265
num_active_params                                               4779329
seq_len                                                           128.0
n_embd                                                            384.0
n_ff_ratio                                                          2.0
n_head                                                              4.0
n_kv_head                                                           4.0
n_layer                                                         

In [None]:
print(df.iloc[1]["top5-generated_text"])