# todo: try actually running.

#### tovalidate: whether early stopping of wandb.log actually works. maybe change config to only LR 1 or smth outrageous to try and trigger the condition 
#### remark: changed print and eval iters to functions of batch size since BS changes in the sweep

In [1]:
path = "./"
root = "../"

In [2]:
import wandb
from wandb import Api
import yaml

In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mterru3[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
api = Api()

In [5]:
with open(f"{root}/switch_sweep.yaml", "r") as file:
    sweep_config = yaml.safe_load(file)

In [6]:
sweep_config

{'method': 'random',
 'metric': {'goal': 'minimize', 'name': 'val_loss'},
 'parameters': {'LR': {'values': [0.0005, 0.001, 0.003, 0.005]},
  'batch_size': {'values': [16, 32, 64, 128, 256, 512]},
  'optimizer': {'values': ['adamw', 'sgd']},
  'activation': {'values': ['GELU', 'GEGLU', 'SwiGLU']},
  'n_experts': {'values': [2, 4, 8, 16, 32, 64]},
  'capacity_factor': {'values': [0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]},
  'aux_loss_coef': {'values': [0.005, 0.01, 0.05, 0.1, 0.15]},
  'norm_first': {'values': [True, False]},
  'switch_first': {'values': [True, False]},
  'every_n_switch': {'values': [1, 2, 3, 4]},
  'mlp_dropout': {'values': [0.1, 0.2, 0.3, 0.4]},
  'expert_dropout': {'values': [0.1, 0.2, 0.3, 0.4]},
  'rope_scale': {'values': [0.25, 0.5, 0.75, 1]}},
 'early_terminate': {'type': 'hyperband', 's': 2, 'eta': 3, 'max_iter': 27}}

In [7]:
sweep_id = wandb.sweep(sweep_config, project="switch_moe")

Create sweep with ID: qw2sczpt
Sweep URL: https://wandb.ai/terru3/switch_moe/sweeps/qw2sczpt


# Constants and Setup

In [8]:
path = "./"
root = "../"

SEED = 23

# LR = 1e-3
# BATCH_SIZE = 16
SEQ_LEN = 128
# MAX_ITERS = 50000  # max num batches to train
# PRINT_ITERS = 50  # frequency to print train loss
# EVAL_ITERS = 500  # frequency to evaluate val loss and generate text from model
# EVAL_ITER_COUNT = 100  # number of batches to estimate val loss with
# given a 10% val split, we have 111540 char, so 100 batches * batch size 16 * seq len 128 = roughly 2x num of chars chosen
# EVAL_ITER_COUNT * BATCH_SIZE
# SAVE_ITERS = 1000  # frequency to save model and losses
N_EMBD = 128
N_FF = N_EMBD * 4
N_HEAD = 4
N_KV_HEAD = 2  # GQA
N_LAYER = 4

# automatic mixed precision (will be disabled if CPU, not available)
USE_AMP = True

# RoPE
# ROPE_SCALE = 0.5

In [9]:
## Switch-specific hyperparameters
# CAPACITY_FACTOR = 1.25
# N_EXPERT = 2
# AUX_LOSS_COEF = 0.01

MODEL_NAME = (
    f"switch_wandb_{N_LAYER}_LAYERs_{N_HEAD}_HEAD_{N_EMBD}_EMBD_DIM_{SEQ_LEN}_SEQ_LEN"
)
print("Model Name:", MODEL_NAME)

Model Name: switch_wandb_4_LAYERs_4_HEAD_128_EMBD_DIM_128_SEQ_LEN


# Imports

In [10]:
import itertools
import json
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchinfo import summary
from tqdm import tqdm

In [11]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [12]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

root = root if drive is None else "/content/drive/MyDrive/moe-kit"
path = path if drive is None else "/content/drive/MyDrive/moe-kit/switch_transformer"

# cannot train in mixed precision on CPU (GradScaler needs cuda)
USE_AMP = USE_AMP if device.type == "cuda" else False
# Tesla T4 does not support bfloat16, CPU does not support float16
AMP_DTYPE = torch.float16 if device.type == "cuda" else torch.bfloat16

In [13]:
import sys

sys.path.append(root)

from utils import set_seed
from models.transformer import MLP, Transformer

In [14]:
set_seed(SEED)

# Data

In [15]:
with open(f"{root}/data/tiny-shakespeare.txt", "r") as f:
    text = f.read()

chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)
print(f"Vocab: {chars}")
print(f"Vocab size: {VOCAB_SIZE}")

Vocab: ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Vocab size: 65


In [16]:
# Prepare mappings / tokenizer
# create a mapping from characters to integers
txt2idx = {ch: i for i, ch in enumerate(chars)}
idx2txt = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [txt2idx[c] for c in s]
decode = lambda l: "".join([idx2txt[i] for i in l])

print(encode("tiny-shakespeare is sick"))
print(decode(encode("tiny-shakespeare is sick")))

[58, 47, 52, 63, 7, 57, 46, 39, 49, 43, 57, 54, 43, 39, 56, 43, 1, 47, 57, 1, 57, 47, 41, 49]
tiny-shakespeare is sick


In [17]:
# tokenizer data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))  # 90-10 split
train_data = data[:n]
val_data = data[n:]
print("train_data len:", len(train_data), "val_data len:", len(val_data))

train_data len: 1003854 val_data len: 111540


# Training

In [18]:
def get_batch(split, batch_size):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - SEQ_LEN, (batch_size,))
    x = torch.stack([data[i : i + SEQ_LEN] for i in ix])
    y = torch.stack([data[i + 1 : i + SEQ_LEN + 1] for i in ix])
    return x.to(device), y.to(device)

In [19]:
def calc_ce_loss(logits, targets):
    """
    TODO
    Computes cross-entropy loss.
    Inputs:
        -logits: Model output of shape (B, S, vocab_size)
        -counts:
    """
    B, S, C = logits.shape
    logits = logits.view(B * S, C)
    targets = targets.view(B * S)
    loss = F.cross_entropy(logits, targets)
    return loss

In [20]:
def calc_aux_loss(counts, prob_sum, n_expert):
    """
    Computes Switch Transformer auxiliary loss.
    Inputs:
        -counts: Number of tokens passed to each expert in each switch layer (num_switch_layers x n_experts)
        Note this is NOT equivalent to n_layer; num_switch_layers depends on `switch_first` and `every_n_switch`
        -prob_sum: Sum of probs across all tokens for each expert (num_switch_layers x n_experts)
    """

    # total number of tokens routed in that layer
    token_count = counts.sum(dim=-1, keepdims=True)

    # prop of tokens dispatched to each expert
    route_frac = counts / token_count

    # fraction of total probability allocated for each expert
    # recall prob_sum := softmaxed probs, which added to 1 across the experts for each token
    # we divide by num_tokens so that the overall 2D scalar sum of prob_frac is 1
    # intuitively we are forcing the total prob for each layer across the experts to be 1 so we can take proportions,
    # the same way as above
    prob_frac = prob_sum / token_count

    # Auxiliary loss
    # L = N \sum_{i=1}^N f_i • P_i
    aux_loss = n_expert * (route_frac * prob_frac).sum()
    return aux_loss

In [21]:
def build_model(config):
    set_seed(SEED)

    model = Transformer(
        VOCAB_SIZE,
        SEQ_LEN,
        N_EMBD,
        N_HEAD,
        N_FF,
        N_LAYER,
        device=device,
        n_kv_head=N_KV_HEAD,
        norm_first=config.get("norm_first"),
        use_rotary_embd=True,
        softmax_off_by_one=False,
        switch=True,
        switch_first=config.get("switch_first"),
        every_n_switch=config.get("every_n_switch"),
        capacity_factor=config.get("capacity_factor"),
        drop_tokens=True,
        n_experts=config.get("n_experts"),
        expert=MLP,
        use_amp=USE_AMP,
        activation=config.get("activation"),
        mlp_dropout=config.get("mlp_dropout"),
        expert_dropout=config.get("expert_dropout"),
        scale=config.get("rope_scale"),
    )
    return model

In [22]:
def build_optimizer(config, model):
    if config.get("optimizer") == "adamw":
        return torch.optim.AdamW(model.parameters(), lr=config.get("LR"))
    elif config.get("optimizer") == "sgd":
        return torch.optim.SGD(model.parameters(), lr=config.get("LR"))

In [23]:
def train(
    config=None,
    train_loss_list=None,
    val_loss_list=None,
    train_time_list=None,
    val_aux_loss_list=None,
    dropped_list=None,
):

    train_losses = train_loss_list if train_loss_list is not None else []
    val_losses = val_loss_list if val_loss_list is not None else []
    train_times = train_time_list if train_time_list is not None else []
    val_aux_losses = val_aux_loss_list if val_aux_loss_list is not None else []
    dropped = dropped_list if dropped_list is not None else []

    #### NEW:
    # Initialize a new wandb run
    with wandb.init(config=config) as wandb_r:
        # also has `resume` arg,

        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb_r.config
        print(f"Run ID: {wandb_r.id}, Name: {wandb_r.name}")

        # training setup
        model = build_model(config)
        model.train()
        model.to(device)

        optimizer = build_optimizer(config, model)
        scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

        # model size in bytes
        MODEL_SIZE = sum(
            [
                p.numel() * p.dtype.itemsize
                for p in itertools.chain(model.parameters(), model.buffers())
            ]
        )

        # configure print / logging freq to be fn of changing batch size
        # same for eval iter count for estimating val loss in eval fn.
        B = config.get("batch_size")

        MAX_ITERS = int(8e5 / B)
        PRINT_ITERS = MAX_ITERS // 1000
        EVAL_ITERS = MAX_ITERS // 100
        SAVE_ITERS = MAX_ITERS // 500  ########## TEMP. should be 50

        # begin training
        for step in range(MAX_ITERS):

            start = time.perf_counter()

            optimizer.zero_grad(set_to_none=True)
            inputs, targets = get_batch("train", B)
            with torch.autocast(
                device_type=device.type, dtype=AMP_DTYPE, enabled=USE_AMP
            ):
                if model.switch:
                    logits, counts, prob_sum, n_dropped = model(inputs)
                    loss = calc_ce_loss(logits, targets)
                    aux_loss = calc_aux_loss(counts, prob_sum, config.get("n_experts"))
                    loss += config.get("aux_loss_coef") * aux_loss
                    drop_frac = (np.array(n_dropped) / (B * SEQ_LEN)).tolist()
                    dropped.append(drop_frac)  # for logging purposes
                else:
                    logits = model(inputs)
                    loss = calc_ce_loss(logits, targets)

            train_losses.append(loss.item())  # for printing

            scaler.scale(loss).backward()
            # loss.backward()

            # Monitor gradient norm
            scaler.unscale_(optimizer)

            with torch.autocast(
                device_type=device.type, dtype=AMP_DTYPE, enabled=USE_AMP
            ):
                grads = [
                    param.grad.detach().flatten()
                    for param in model.parameters()
                    if param.grad is not None
                ]
                norm = torch.cat(grads).norm()

            train_time = time.perf_counter() - start
            tokens_per_sec = (1 / train_time) * B * SEQ_LEN
            train_times.append(tokens_per_sec)

            scaler.step(optimizer)
            scaler.update()
            # optimizer.step()

            # print training statistics
            if step % PRINT_ITERS == 0 and step != 0:
                print(
                    f"Step {step}/{MAX_ITERS} | Running Avg Train Loss: {np.mean(train_losses):.3f} |",
                    f"Grad Norm: {norm:.2f} | Running Avg Tokens/Sec: {np.mean(train_times):.2f} |",
                    f"Bandwidth: {MODEL_SIZE * np.mean(train_times) / 1e9:.2f} GB/s",
                )

                # wandb
                wandb.log(
                    {
                        "train_loss": loss.item(),
                        "samples_seen": (step + 1) * B,
                    }
                )

            # estimate val loss, generate text and save
            if step % EVAL_ITERS == 0 and step != 0:
                val_losses, val_aux_losses = estimate_loss(
                    config, model, val_losses, val_aux_losses, device
                )

                # for wandb
                wandb.log(
                    {
                        "val_loss": val_losses[-1],
                        "samples_seen": (step + 1) * B,
                    }
                )

            # no idea if this works or not. also didn't add api run id into name to avoid overlap, should I??
            if step % SAVE_ITERS == 0 and step != 0:
                model_artifact = wandb.Artifact(
                    MODEL_NAME, type="model", metadata=dict(config)
                )

                torch.save(
                    model.state_dict(), f"{path}/wandb_artifacts/{MODEL_NAME}.pt"
                )
                model_artifact.add_file(f"{path}/wandb_artifacts/{MODEL_NAME}.pt")
                wandb_r.log_artifact(
                    model_artifact
                )  # log artifact version e.g. "MODEL_NAME:v0"
                # caution: calls to log_artifact are performed asynchronously for performant uploads

In [24]:
@torch.no_grad()
def estimate_loss(config, model, val_losses, val_aux_losses, device):
    model.eval()
    B = config.get("batch_size")
    EVAL_ITER_COUNT = int(1600 / B)
    losses = torch.zeros(EVAL_ITER_COUNT)
    aux_losses = torch.zeros(EVAL_ITER_COUNT)
    for k in range(EVAL_ITER_COUNT):
        inputs, targets = get_batch("test", B)
        with torch.autocast(device_type=device.type, dtype=AMP_DTYPE, enabled=USE_AMP):
            if model.switch:
                logits, counts, prob_sum, n_dropped = model(inputs)
                losses[k] = calc_ce_loss(logits, targets)
                aux_losses[k] = calc_aux_loss(counts, prob_sum, config.get("n_experts"))
                losses[k] += config.get("aux_loss_coef") * aux_losses[k]
            else:
                logits = model(inputs)
                losses[k] = calc_ce_loss(logits, targets)
    val_loss, val_aux_loss = losses.mean().item(), aux_losses.mean().item()
    val_losses.append(val_loss)
    val_aux_losses.append(val_aux_loss)  # track separate aux loss for logging
    # keep model in eval, next call is to .generate() anyway
    print(f"Est. Val Loss: {val_loss:.3f} | Est. Aux Loss: {val_aux_loss:.3f}")
    return val_losses, val_aux_losses

In [24]:
## Driver code
wandb.agent(sweep_id, train, count=5)

[34m[1mwandb[0m: Agent Starting Run: sngup0hu with config:
[34m[1mwandb[0m: 	LR: 0.003
[34m[1mwandb[0m: 	activation: GELU
[34m[1mwandb[0m: 	aux_loss_coef: 0.01
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	capacity_factor: 1.25
[34m[1mwandb[0m: 	every_n_switch: 3
[34m[1mwandb[0m: 	expert_dropout: 0.4
[34m[1mwandb[0m: 	mlp_dropout: 0.2
[34m[1mwandb[0m: 	n_experts: 2
[34m[1mwandb[0m: 	norm_first: True
[34m[1mwandb[0m: 	optimizer: adamw
[34m[1mwandb[0m: 	rope_scale: 0.5
[34m[1mwandb[0m: 	switch_first: False


Step 25/25000 | Running Avg Train Loss: 5.039 | Grad Norm: 1.53 | Running Avg Tokens/Sec: 7449.05 | Bandwidth: 28.00 GB/s
Step 50/25000 | Running Avg Train Loss: 4.036 | Grad Norm: 0.75 | Running Avg Tokens/Sec: 7716.43 | Bandwidth: 29.01 GB/s
Step 75/25000 | Running Avg Train Loss: 3.607 | Grad Norm: 0.77 | Running Avg Tokens/Sec: 7498.41 | Bandwidth: 28.19 GB/s
Step 100/25000 | Running Avg Train Loss: 3.356 | Grad Norm: 0.70 | Running Avg Tokens/Sec: 7545.13 | Bandwidth: 28.36 GB/s
Step 125/25000 | Running Avg Train Loss: 3.189 | Grad Norm: 0.78 | Running Avg Tokens/Sec: 7635.85 | Bandwidth: 28.70 GB/s
Step 150/25000 | Running Avg Train Loss: 3.068 | Grad Norm: 0.61 | Running Avg Tokens/Sec: 7757.37 | Bandwidth: 29.16 GB/s
Step 175/25000 | Running Avg Train Loss: 2.976 | Grad Norm: 0.69 | Running Avg Tokens/Sec: 7830.98 | Bandwidth: 29.44 GB/s
Step 200/25000 | Running Avg Train Loss: 2.899 | Grad Norm: 0.68 | Running Avg Tokens/Sec: 7863.53 | Bandwidth: 29.56 GB/s
Step 225/25000 | Ru

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


In [50]:
######### if you want to load saved model artifact for further wandb runs, etc., do this
# model_artifact = api.artifact(f"terru3/switch_moe/{MODEL_NAME}:latest")
# model_dir = model_artifact.download()

In [64]:
# to load in wandb checkpoints (note only model saved, not optimizer)

# create model, Transformer(_____)
model.load_state_dict(torch.load(f"{path}/wandb_artifacts/{MODEL_NAME}.pt"))

<All keys matched successfully>

In [None]:
## TODO: try accessing best run
## most ppl do sweep = api.sweep but I'm not using API for sweep, so how to access??? plus
## even when you use API that's inside train() no?

# sweep.best_run()
# sweep.best_run().config

# Run this when done

In [25]:
wandb.finish()

0,1
loss,█▄▂▁
samples_seen,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
train_loss,█▆▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁

0,1
loss,1.8559
samples_seen,32832.0
train_loss,1.78517


Step 1050/25000 | Running Avg Train Loss: 2.143 |

Traceback (most recent call last):


# –––––––––––––––––––––––––––––––––––––––––––––––-

# Generation

In [None]:
print(
    decode(
        model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[
            0
        ].tolist()
    )
)

In [115]:
input_txt = "TERRY: thou art"
ctx = encode(input_txt)
print(
    decode(
        model.generate(torch.tensor(ctx).unsqueeze(0).long(), max_new_tokens=500)[
            0
        ].tolist()
    )
)

TERRY: thou arte a my she
Which have may and of contain.

DUKE VINCENTIO:
Good as I tall no knrow, for shalt agarnt
And mpo; and Kong a m, not outhpile Mesce.

HENRY VI:
When I will thy lookess, oner the pexstrey
The the the hee voagh gresed livioe.

MENCIO:
My her callis his peaced of to that
We where's by shall bore: as shall myselvea
The plender feuls!

PAPELLANT:
In the the into balby me dods to love,
In but the giving of nyou ase. I tall it-me?'e Goveuling
The theer haught art praver count madeng Camen:
T
