# Benchmarking

This notebook provides official time and memory benchmarking experiments on Switch vs. vanilla decoder transformers and various architectural and training hyperparameters.

In [1]:
import gc
import itertools
import json
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.profiler import profile, record_function, ProfilerActivity
from torchinfo import summary
from tqdm import tqdm

# Models and data setup

In [None]:
path = "./"
root = "../"

SEED = 23

LR = 1e-3
BATCH_SIZE = 16
SEQ_LEN = 128
MAX_ITERS = 50000  # max num batches to train
PRINT_ITERS = 50  # frequency to print train loss
EVAL_ITERS = 500  # frequency to evaluate val loss and generate text from model
EVAL_ITER_COUNT = 100  # number of batches to estimate val loss with
# given a 10% val split, we have 111540 char, so 100 batches * batch size 16 * seq len 128 = roughly 2x num of chars chosen
# EVAL_ITER_COUNT * BATCH_SIZE
SAVE_ITERS = 1000  # frequency to save model and losses
N_EMBD = 128
N_FF = N_EMBD * 4
N_HEAD = 4
N_KV_HEAD = 2  # GQA
N_LAYER = 4

# automatic mixed precision (will be disabled if CPU, not available)
USE_AMP = True

# RoPE
ROPE_SCALE = 0.5

## Switch-specific hyperparameters
CAPACITY_FACTOR = 1.25
N_EXPERT = 2
AUX_LOSS_COEF = 0.01

In [None]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

root = root if drive is None else "/content/drive/MyDrive/moe-kit"
path = path if drive is None else "/content/drive/MyDrive/moe-kit/switch_transformer"

# cannot train in mixed precision on CPU (GradScaler needs cuda)
USE_AMP = USE_AMP if device.type == "cuda" else False
# Tesla T4 does not support bfloat16, CPU does not support float16
AMP_DTYPE = torch.float16 if device.type == "cuda" else torch.bfloat16

In [None]:
import sys

sys.path.append(root)

from utils import set_seed
from models.transformer import MLP, Transformer

In [None]:
set_seed(SEED)

In [None]:
with open(f"{root}/data/tiny-shakespeare.txt", "r") as f:
    text = f.read()

chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)

# Prepare mappings / tokenizer
# create a mapping from characters to integers
txt2idx = {ch: i for i, ch in enumerate(chars)}
idx2txt = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [txt2idx[c] for c in s]
decode = lambda l: "".join([idx2txt[i] for i in l])

# tokenizer data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))  # 90-10 split
train_data = data[:n]
val_data = data[n:]


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - SEQ_LEN, (BATCH_SIZE,))
    x = torch.stack([data[i : i + SEQ_LEN] for i in ix])
    y = torch.stack([data[i + 1 : i + SEQ_LEN + 1] for i in ix])
    return x.to(device), y.to(device)

### Initialize models

In [None]:
set_seed(SEED)
switch_model = Transformer(
    VOCAB_SIZE,
    SEQ_LEN,
    N_EMBD,
    N_HEAD,
    N_FF,
    N_LAYER,
    device=device,
    n_kv_head=N_KV_HEAD,
    norm_first=True,
    use_rotary_embd=True,
    softmax_off_by_one=False,
    switch=True,
    switch_first=True,
    every_n_switch=2,
    capacity_factor=CAPACITY_FACTOR,
    drop_tokens=True,
    n_experts=N_EXPERT,
    expert=MLP,
    use_amp=USE_AMP,
    amp_dtype=AMP_DTYPE,
    activation="GELU",
    mlp_dropout=0.1,
    expert_dropout=0.4,
    scale=ROPE_SCALE,
)

# Gradient scaling for mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

optimizer = torch.optim.AdamW(switch_model.parameters(), lr=LR)

In [None]:
summary(switch_model)

In [None]:
# model size in bytes
SWITCH_MODEL_SIZE = sum(
    [
        p.numel() * p.dtype.itemsize
        for p in itertools.chain(switch_model.parameters(), switch_model.buffers())
    ]
)
SWITCH_MODEL_SIZE

In [None]:
set_seed(SEED)
vanilla_model = Transformer(
    VOCAB_SIZE,
    SEQ_LEN,
    N_EMBD,
    N_HEAD,
    N_FF,
    N_LAYER,
    device=device,
    n_kv_head=N_KV_HEAD,
    norm_first=True,
    use_rotary_embd=True,
    softmax_off_by_one=False,
    use_amp=USE_AMP,
    amp_dtype=AMP_DTYPE,
    activation="GELU",
    switch=False,
    mlp_dropout=0.1,
    scale=ROPE_SCALE,
)

# Gradient scaling for mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

optimizer = torch.optim.AdamW(vanilla_model.parameters(), lr=LR)

In [None]:
summary(switch_model)

In [None]:
# model size in bytes
VANILLA_MODEL_SIZE = sum(
    [
        p.numel() * p.dtype.itemsize
        for p in itertools.chain(vanilla_model.parameters(), vanilla_model.buffers())
    ]
)
VANILLA_MODEL_SIZE

In [None]:
def calc_ce_loss(logits, targets):
    """
    TODO
    Computes cross-entropy loss.
    Inputs:
        -logits: Model output of shape (B, S, vocab_size)
        -counts:
    """
    B, S, C = logits.shape
    logits = logits.view(B * S, C)
    targets = targets.view(B * S)
    loss = F.cross_entropy(logits, targets)
    return loss

In [None]:
def calc_aux_loss(counts, prob_sum):
    """
    Computes Switch Transformer auxiliary loss.
    Inputs:
        -counts: Number of tokens passed to each expert in each switch layer (num_switch_layers x n_experts)
        Note this is NOT equivalent to n_layer; num_switch_layers depends on `switch_first` and `every_n_switch`
        -prob_sum: Sum of probs across all tokens for each expert (num_switch_layers x n_experts)
    """

    # total number of tokens routed in that layer
    token_count = counts.sum(dim=-1, keepdims=True)

    # prop of tokens dispatched to each expert
    route_frac = counts / token_count

    # fraction of total probability allocated for each expert
    # recall prob_sum := softmaxed probs, which added to 1 across the experts for each token
    # we divide by num_tokens so that the overall 2D scalar sum of prob_frac is 1
    # intuitively we are forcing the total prob for each layer across the experts to be 1 so we can take proportions,
    # the same way as above
    prob_frac = prob_sum / token_count

    # Auxiliary loss
    # L = N \sum_{i=1}^N f_i • P_i
    aux_loss = N_EXPERT * (route_frac * prob_frac).sum()
    return aux_loss

## TODO: write bench_train and bench_inference fns
    ## to-test:
        -training: architecture vs. samples/s, loss vs. samples seen (sample efficiency)
        -inference: num_tokens vs. time, decoding method vs. tokens/s, architecture vs tokens/s
    -architecture: switch (1/2/4/8 experts, capacity) vs. vanilla, hidden_dim/n_layer/n_head, total_params, ROPE vs. cos/sin PE, etc.
        -also benchmark memory, e.g. effect of GQA
        -Roadmap: Benchmark specific time/memory spent for MLP vs. switch FF, attention vs. MLP proportions within a block, etc.

In [None]:
## TODO: write params above into config files, write get_model wrappers, etc., to use in these functions
## then you can easily delete model and config after every run


@torch.no_grad()
def bench_train():
    pass


@torch.no_grad()
def bench_inference():
    pass


# start = perf_counter()
# ## pass
# total_time = perf_counter() - start # in ms
# mem_data = {'reserved_memory': torch.cuda.memory_reserved(0),
#             'allocated_memory': torch.cuda.memory_allocated(0)}
# del model
# del config
# torch.cuda.empty_cache()
# gc.collect()
# return total_time, mem_data

In [None]:
## try profiler for fun

# with profile(activities=[
#         ProfilerActivity.CPU, ProfilerActivity.CUDA],
#              profile_memory=True,
#              record_shapes=True) as prof:
#     with record_function("model_inference"):
#         model(inputs)

# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))

In [None]:
### driver code:

### create configs for what you want to try, e.g. a yaml file or just native dicts/lists here
## create result dataframe with configs
## loop over all and run training and inference fns, append results
## plot results from dataframe