# Constants and Setup

In [1]:
path = "./"
root = "../"

SEED = 23

LR = 1e-3
BATCH_SIZE = 16
SEQ_LEN = 128
MAX_ITERS = 50000  # max num batches to train
PRINT_ITERS = 50  # frequency to print train loss
EVAL_ITERS = 500  # frequency to evaluate val loss and generate text from model
EVAL_ITER_COUNT = 100  # number of batches to estimate val loss with
# given a 10% val split, we have 111540 char, so 100 batches * batch size 16 * seq len 128 = roughly 2x num of chars chosen
# EVAL_ITER_COUNT * BATCH_SIZE
SAVE_ITERS = 1000  # frequency to save model and losses
N_EMBD = 128
N_FF = N_EMBD * 4
N_HEAD = 4
N_LAYER = 4

# automatic mixed precision (will be disabled if CPU, not available)
USE_AMP = True

In [2]:
## Switch-specific hyperparameters
CAPACITY_FACTOR = 1.25
N_EXPERT = 2
AUX_LOSS_COEF = 0.01

MODEL_NAME = (
    f"switch_{N_LAYER}_LAYERs_{N_HEAD}_HEAD_{N_EMBD}_EMBD_DIM_{SEQ_LEN}_SEQ_LEN"
)
print("Model Name:", MODEL_NAME)

Model Name: switch_4_LAYERs_4_HEAD_128_EMBD_DIM_128_SEQ_LEN


# Imports

In [3]:
import json
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchinfo import summary
from tqdm import tqdm

In [4]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

root = root if drive is None else "/content/drive/MyDrive/moe-kit"
path = path if drive is None else "/content/drive/MyDrive/moe-kit/switch_transformer"

# cannot train in mixed precision on CPU (GradScaler needs cuda)
USE_AMP = USE_AMP if device.type == "cuda" else False

In [6]:
import sys

sys.path.append(root)

from utils import set_seed
from models.transformer import MLP, Transformer

In [7]:
set_seed(SEED)

# Data

In [8]:
with open(f"{root}/data/tiny-shakespeare.txt", "r") as f:
    text = f.read()

chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)
print(f"Vocab: {chars}")
print(f"Vocab size: {VOCAB_SIZE}")

Vocab: ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Vocab size: 65


In [9]:
# Prepare mappings / tokenizer
# create a mapping from characters to integers
txt2idx = {ch: i for i, ch in enumerate(chars)}
idx2txt = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [txt2idx[c] for c in s]
decode = lambda l: "".join([idx2txt[i] for i in l])

print(encode("tiny-shakespeare is sick"))
print(decode(encode("tiny-shakespeare is sick")))

[58, 47, 52, 63, 7, 57, 46, 39, 49, 43, 57, 54, 43, 39, 56, 43, 1, 47, 57, 1, 57, 47, 41, 49]
tiny-shakespeare is sick


In [10]:
# tokenizer data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))  # 90-10 split
train_data = data[:n]
val_data = data[n:]
print("train_data len:", len(train_data), "val_data len:", len(val_data))


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - SEQ_LEN, (BATCH_SIZE,))
    x = torch.stack([data[i : i + SEQ_LEN] for i in ix])
    y = torch.stack([data[i + 1 : i + SEQ_LEN + 1] for i in ix])
    return x.to(device), y.to(device)

train_data len: 1003854 val_data len: 111540


# Training setup

### TODO: tweak hyperparameters to keep effective parameter count or FLOP count constant?

In [21]:
set_seed(SEED)
model = Transformer(
    VOCAB_SIZE,
    SEQ_LEN,
    N_EMBD,
    N_HEAD,
    N_FF,
    N_LAYER,
    device=device,
    norm_first=True,
    switch=True,
    switch_first=True,
    every_n_switch=2,
    capacity_factor=CAPACITY_FACTOR,
    drop_tokens=True,
    n_experts=N_EXPERT,
    expert=MLP,
    use_amp=USE_AMP,
    mlp_dropout=0.1,
    expert_dropout=0.4,
)

# Gradient scaling for mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [25]:
summary(model)

Layer (type:depth-idx)                             Param #
Transformer                                        --
├─Embedding: 1-1                                   8,320
├─PositionalEncoding: 1-2                          --
├─Sequential: 1-3                                  --
│    └─Block: 2-1                                  --
│    │    └─MultiHeadAttention: 3-1                65,536
│    │    └─SwitchFeedForward: 3-2                 263,682
│    │    └─LayerNorm: 3-3                         256
│    │    └─LayerNorm: 3-4                         256
│    │    └─Dropout: 3-5                           --
│    │    └─Dropout: 3-6                           --
│    └─Block: 2-2                                  --
│    │    └─MultiHeadAttention: 3-7                65,536
│    │    └─MLP: 3-8                               131,712
│    │    └─LayerNorm: 3-9                         256
│    │    └─LayerNorm: 3-10                        256
│    │    └─Dropout: 3-11                          -

## Computing activated parameter count

Here, we find any parameters unique to the switch layer and to experts beyond the first and subtract from the total count.

In [26]:
switch_additional_params = 0
switch_layer_names = [f"experts.{i}" for i in range(1, N_EXPERT)] + ["switch"]
for (name, layer) in model.named_parameters():
    # if experts.i (i > 0) or switch in name, subtract from total count
    if any(substr in name for substr in switch_layer_names):
        switch_additional_params += layer.numel()

active_param_count = (
    sum([p.numel() for p in model.parameters()]) - switch_additional_params
)
active_param_count

807745

## FLOP counter methods

In [None]:
# define wrapper that only returns model tensor output w/o other logging metrics
class SwitchWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input):
        return self.model(input)[0]


model_wrap = SwitchWrapper(model)

# define vanilla transformer for comparisons
set_seed(SEED)
vanilla_model = Transformer(
    VOCAB_SIZE,
    SEQ_LEN,
    N_EMBD,
    N_HEAD,
    N_FF,
    N_LAYER,
    device=device,
    norm_first=True,
    use_amp=USE_AMP,
    switch=False,
    mlp_dropout=0.1,
)

### 1: fvcore

### Switch Transformer

In [97]:
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str

inp = get_batch("train")[0]
switch_flops = FlopCountAnalysis(model_wrap, inp)
switch_flops.total(), switch_flops.by_operator()

Unsupported operator aten::lift_fresh encountered 1 time(s)
Unsupported operator aten::embedding encountered 1 time(s)
Unsupported operator aten::add encountered 9 time(s)
Unsupported operator aten::div encountered 4 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::empty_like encountered 2 time(s)
Unsupported operator aten::uniform_ encountered 2 time(s)
Unsupported operator aten::mul_ encountered 4 time(s)
Unsupported operator aten::mul encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
model.blocks.1.expert_drop, model.blocks.3.expert_drop


(1719795712,
 Counter({'linear': 1440874496, 'matmul': 268435456, 'layer_norm': 10485760}))

In [98]:
print(flop_count_table(switch_flops))

| module                   | #parameters or shape   | #flops    |
|:-------------------------|:-----------------------|:----------|
| model                    | 1.072M                 | 1.72G     |
|  token_embedding         |  8.32K                 |  0        |
|   token_embedding.weight |   (65, 128)            |           |
|  blocks                  |  1.055M                |  1.703G   |
|   blocks.0               |   0.33M                |   0.384G  |
|    blocks.0.sa           |    65.536K             |    0.201G |
|    blocks.0.ff           |    0.264M              |    0.18G  |
|    blocks.0.ln1          |    0.256K              |    1.311M |
|    blocks.0.ln2          |    0.256K              |    1.311M |
|   blocks.1               |   0.198M               |   0.472G  |
|    blocks.1.sa           |    65.536K             |    0.201G |
|    blocks.1.ff.net       |    0.132M              |    0.268G |
|    blocks.1.ln1          |    0.256K              |    1.311M |
|    block

### Vanilla Transformer

In [91]:
from fvcore.nn import FlopCountAnalysis

inp = get_batch("train")[0]
vanilla_flops = FlopCountAnalysis(vanilla_model, inp)
vanilla_flops.total(), vanilla_flops.by_operator()

Unsupported operator aten::lift_fresh encountered 1 time(s)
Unsupported operator aten::embedding encountered 1 time(s)
Unsupported operator aten::add encountered 9 time(s)
Unsupported operator aten::div encountered 4 time(s)
Unsupported operator aten::softmax encountered 4 time(s)
Unsupported operator aten::gelu encountered 4 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
blocks.0.expert_drop, blocks.1.expert_drop, blocks.2.expert_drop, blocks.3.expert_drop


(1906573312,
 Counter({'linear': 1627652096, 'matmul': 268435456, 'layer_norm': 10485760}))

In [92]:
print(flop_count_table(vanilla_flops))

| module                   | #parameters or shape   | #flops    |
|:-------------------------|:-----------------------|:----------|
| model                    | 0.808M                 | 1.907G    |
|  token_embedding         |  8.32K                 |  0        |
|   token_embedding.weight |   (65, 128)            |           |
|  blocks                  |  0.791M                |  1.89G    |
|   blocks.0               |   0.198M               |   0.472G  |
|    blocks.0.sa           |    65.536K             |    0.201G |
|    blocks.0.ff.net       |    0.132M              |    0.268G |
|    blocks.0.ln1          |    0.256K              |    1.311M |
|    blocks.0.ln2          |    0.256K              |    1.311M |
|   blocks.1               |   0.198M               |   0.472G  |
|    blocks.1.sa           |    65.536K             |    0.201G |
|    blocks.1.ff.net       |    0.132M              |    0.268G |
|    blocks.1.ln1          |    0.256K              |    1.311M |
|    block

We see that **the switch transformer has 200k more parameters** and yet takes computationally **fewer FLOPs** in its forward pass.

## 2: TorchTNT

We indeed get the same results.

In [169]:
# TNT is a library for PyTorch training tools and utilities.
from torchtnt.utils.module_summary import get_module_summary

# formatting fn: https://forums.fast.ai/t/lesson-18-official-topic/102750/15
import pandas as pd


def markdown_to_pandas(table_string):

    rows = table_string.strip().split("\n")

    header = rows[0].split("|")[1:]
    header = [x.strip() for x in header]
    # row[1] is separator -----
    data = [row.split("|")[1:] for row in rows[2:]]
    data = [[x.strip() for x in row] for row in data]
    return pd.DataFrame(data[:-1], columns=header)

### Switch Transformer

In [174]:
module_summary = get_module_summary(model_wrap, [inp])
summary_df = markdown_to_pandas(f"{module_summary}")[1:]  # first is our SwitchWrapper
summary_df
# note these parameters are not numeric and we cannot sort / filter with them as is

Unnamed: 0,Type,# Parameters,# Trainable Parameters,Size (bytes),Contains Uninitialized Parameters?,Forward FLOPs,Backward FLOPs,In size,Out size,Forward Elapsed Times (ms)
1,Transformer,1.1 M,1.1 M,4.4 M,No,1.7 G,3.4 G,"[16, 128]","[[16, 128, 65], [2, 2], [2, 2], ['?', '?']]",0.0001923215
2,Embedding,8.3 K,8.3 K,33.3 K,No,0,0,"[16, 128]","[16, 128, 128]",0.0000006968
3,PositionalEncoding,0,0,65.5 K,No,0,0,[128],"[128, 128]",0.0000002529
4,Sequential,1.1 M,1.1 M,4.2 M,No,0,0,?,?,?
5,Block,329 K,329 K,1.3 M,No,381 M,763 M,"[[16, 128, 128], [16, 1, 128, 128]]","[[16, 128, 128], [2], [2], '?']",0.0000640843
...,...,...,...,...,...,...,...,...,...,...
88,LayerNorm,256,256,1.0 K,No,0,0,"[16, 128, 128]","[16, 128, 128]",0.0000004883
89,Dropout,0,0,0,No,0,0,"[16, 128, 128]","[16, 128, 128]",0.0000002825
90,Dropout,0,0,0,No,0,0,?,?,?
91,Linear,8.4 K,8.4 K,33.5 K,No,17.0 M,34.1 M,"[16, 128, 128]","[16, 128, 65]",0.0000012268


### Vanilla model

In [161]:
module_summary = get_module_summary(vanilla_model, [inp])
summary_df = markdown_to_pandas(f"{module_summary}")
summary_df

Unnamed: 0,Type,# Parameters,# Trainable Parameters,Size (bytes),Contains Uninitialized Parameters?,Forward FLOPs,Backward FLOPs,In size,Out size,Forward Elapsed Times (ms)
0,Transformer,807 K,807 K,3.3 M,No,1.9 G,3.8 G,"[16, 128]","[16, 128, 65]",0.0002140681
1,Embedding,8.3 K,8.3 K,33.3 K,No,0,0,"[16, 128]","[16, 128, 128]",0.0000005123
2,PositionalEncoding,0,0,65.5 K,No,0,0,[128],"[128, 128]",0.0000006297
3,Sequential,791 K,791 K,3.2 M,No,0,0,?,?,?
4,Block,197 K,197 K,791 K,No,469 M,939 M,"[[16, 128, 128], [16, 1, 128, 128]]","[16, 128, 128]",0.0000659738
...,...,...,...,...,...,...,...,...,...,...
69,LayerNorm,256,256,1.0 K,No,0,0,"[16, 128, 128]","[16, 128, 128]",0.0000005330
70,Dropout,0,0,0,No,0,0,"[16, 128, 128]","[16, 128, 128]",0.0000007307
71,Dropout,0,0,0,No,0,0,?,?,?
72,Linear,8.4 K,8.4 K,33.5 K,No,17.0 M,34.1 M,"[16, 128, 128]","[16, 128, 65]",0.0000011544


## 3: PyTorch Dev FLOP counter (credits: Horace He)
https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
https://gist.github.com/soumith/5f81c3d40d41bb9d08041431c656b233 

#### This more recent code can likely more accurately calculate backward FLOPs. It also allows us to insert our actual forward and backward pass code and loss calculations.

#### Moreover, the previous calculations may have computed MACs (Multiply-ACCumulate) (i.e. a*b + c) rather than FLOPs, where FLOPs ≈ 2 * MACs, although this is unclear.

In [12]:
from flop_counter import FlopCounterMode

In [114]:
flop_counter = FlopCounterMode(model)
with flop_counter:
    optimizer.zero_grad(set_to_none=True)
    inputs, targets = get_batch("train")
    logits, counts, prob_sum, n_dropped = model(inputs)
    loss = calc_ce_loss(logits, targets)
    aux_loss = calc_aux_loss(counts, prob_sum)
    loss += AUX_LOSS_COEF * aux_loss
    loss.backward()
    optimizer.step()

Total: 10.52 GFlops
Module:  Global
aten.mm: 7.0 GFLOPS
aten.bmm: 1.62 GFLOPS
aten.addmm: 1.9 GFLOPS

Module:  lm_head
aten.addmm: 0.04 GFLOPS
aten.mm: 0.06 GFLOPS



In [118]:
flop_counter = FlopCounterMode(vanilla_model)
with flop_counter:
    optimizer.zero_grad(set_to_none=True)
    inputs, targets = get_batch("train")
    logits = vanilla_model(inputs)
    loss = calc_ce_loss(logits, targets)
    loss.backward()
    optimizer.step()

Total: 11.38 GFlops
Module:  Global
aten.mm: 7.58 GFLOPS
aten.bmm: 1.62 GFLOPS
aten.addmm: 2.18 GFLOPS

Module:  lm_head
aten.addmm: 0.04 GFLOPS
aten.mm: 0.06 GFLOPS



# Training

In [14]:
def calc_ce_loss(logits, targets):
    """
    TODO
    Computes cross-entropy loss.
    Inputs:
        -logits: Model output of shape (B, S, vocab_size)
        -counts:
    """
    B, S, C = logits.shape
    logits = logits.view(B * S, C)
    targets = targets.view(B * S)
    loss = F.cross_entropy(logits, targets)
    return loss

In [15]:
def calc_aux_loss(counts, prob_sum):
    """
    Computes Switch Transformer auxiliary loss.
    Inputs:
        -counts: Number of tokens passed to each expert in each switch layer (num_switch_layers x n_experts)
        Note this is NOT equivalent to n_layer; num_switch_layers depends on `switch_first` and `every_n_switch`
        -prob_sum: Sum of probs across all tokens for each expert (num_switch_layers x n_experts)
    """

    # total number of tokens routed in that layer
    token_count = counts.sum(dim=-1, keepdims=True)

    # prop of tokens dispatched to each expert
    route_frac = counts / token_count

    # fraction of total probability allocated for each expert
    # recall prob_sum := softmaxed probs, which added to 1 across the experts for each token
    # we divide by num_tokens so that the overall 2D scalar sum of prob_frac is 1
    # intuitively we are forcing the total prob for each layer across the experts to be 1 so we can take proportions,
    # the same way as above
    prob_frac = prob_sum / token_count

    # Auxiliary loss
    # L = N \sum_{i=1}^N f_i • P_i
    aux_loss = N_EXPERT * (route_frac * prob_frac).sum()
    return aux_loss

In [16]:
def train(
    model,
    optimizer,
    scaler,
    device,
    train_loss_list=None,
    val_loss_list=None,
    train_time_list=None,
    val_aux_loss_list=None,
    dropped_list=None,
):

    train_losses = train_loss_list if train_loss_list is not None else []
    val_losses = val_loss_list if val_loss_list is not None else []
    train_times = train_time_list if train_time_list is not None else []
    val_aux_losses = val_aux_loss_list if val_aux_loss_list is not None else []
    dropped = dropped_list if dropped_list is not None else []

    model.train()
    model.to(device)

    # Set up prompt generation
    generation_file_path = f"{path}/outputs/OUTPUT_{MODEL_NAME}_SEED_{SEED}.txt"
    empty_tokens = torch.zeros((1, 1), dtype=torch.long).to(device)
    cond_prompts = ["KING TERRY: Thou art", "DANIEL: Ay, my dear,"]

    cond_token_list = [encode(prompt) for prompt in cond_prompts]

    for step in range(MAX_ITERS):

        start = time.perf_counter()

        optimizer.zero_grad(set_to_none=True)
        inputs, targets = get_batch("train")
        # Tesla T4 does not support bfloat16 at this time
        with torch.autocast(
            device_type=device.type, dtype=torch.float16, enabled=USE_AMP
        ):
            if model.switch:
                logits, counts, prob_sum, n_dropped = model(inputs)
                loss = calc_ce_loss(logits, targets)
                aux_loss = calc_aux_loss(counts, prob_sum)
                loss += AUX_LOSS_COEF * aux_loss
                drop_frac = (np.array(n_dropped) / (BATCH_SIZE * SEQ_LEN)).tolist()
                dropped.append(drop_frac)  # for logging purposes
            else:
                logits = model(inputs)
                loss = calc_ce_loss(logits, targets)

        train_losses.append(loss.item())  # for printing

        scaler.scale(loss).backward()
        # loss.backward()

        # Monitor gradient norm
        scaler.unscale_(optimizer)

        with torch.autocast(
            device_type=device.type, dtype=torch.float16, enabled=USE_AMP
        ):
            grads = [
                param.grad.detach().flatten()
                for param in model.parameters()
                if param.grad is not None
            ]
            norm = torch.cat(grads).norm()

        train_time = time.perf_counter() - start
        tokens_per_sec = (1 / train_time) * BATCH_SIZE * SEQ_LEN
        train_times.append(tokens_per_sec)

        scaler.step(optimizer)
        scaler.update()
        # optimizer.step()

        # print training statistics
        if step % PRINT_ITERS == 0 and step != 0:
            print(
                f"Step {step}/{MAX_ITERS} | Running Avg Train Loss: {np.mean(train_losses):.5f} |",
                f"Grad Norm: {norm:.3f} | Running Avg Tokens/Sec: {np.mean(train_times):.3f} |",
            )

        # estimate val loss, generate text and save
        if step % EVAL_ITERS == 0 and step != 0:
            val_losses, val_aux_losses = estimate_loss(
                model, val_losses, val_aux_losses, device
            )
            generate(model, generation_file_path, empty_tokens, cond_token_list, step)
            model.train()

        # save model, val losses (not train_losses), train times
        if step % SAVE_ITERS == 0 and step != 0:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                f"{path}/checkpoints/{MODEL_NAME}_STEP_{step}_SEED_{SEED}.pt",
            )

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_losses.json", "w"
            ) as f:
                json.dump(val_losses, f)

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_aux_losses.json", "w"
            ) as f2:
                json.dump(val_aux_losses, f2)

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_train_times.json", "w"
            ) as f3:
                json.dump(
                    train_times[EVAL_ITERS::EVAL_ITERS], f3
                )  # match freq of val_losses
                # note this means if you load from checkpoint to continue training you will have a sparser train_times
                # list in computing running avg

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_dropped.json", "w"
            ) as f4:
                json.dump(dropped[EVAL_ITERS::EVAL_ITERS], f4)  # same here

In [17]:
@torch.no_grad()
def estimate_loss(model, val_losses, val_aux_losses, device):
    model.eval()
    losses = torch.zeros(EVAL_ITER_COUNT)
    aux_losses = torch.zeros(EVAL_ITER_COUNT)
    for k in range(EVAL_ITER_COUNT):
        inputs, targets = get_batch("test")
        with torch.autocast(
            device_type=device.type, dtype=torch.float16, enabled=USE_AMP
        ):
            if model.switch:
                logits, counts, prob_sum, n_dropped = model(inputs)
                losses[k] = calc_ce_loss(logits, targets)
                aux_losses[k] = calc_aux_loss(counts, prob_sum)
                losses[k] += AUX_LOSS_COEF * aux_losses[k]
            else:
                logits = model(inputs)
                losses[k] = calc_ce_loss(logits, targets)
    val_loss, val_aux_loss = losses.mean().item(), aux_losses.mean().item()
    val_losses.append(val_loss)
    val_aux_losses.append(val_aux_loss)  # track separate aux loss for logging
    # keep model in eval, next call is to .generate() anyway
    print(f"Est. Val Loss: {val_loss:.5f} | Est. Aux Loss: {val_aux_loss:.5f}")
    return val_losses, val_aux_losses

In [18]:
def generate(model, generation_file_path, empty_tokens, cond_token_list, step):

    set_seed(42)

    uncond_res1 = decode(
        model.generate(empty_tokens, method="top-k", k=5, max_new_tokens=500)[
            0
        ].tolist()
    )
    uncond_res2 = decode(
        model.generate(
            empty_tokens, method="nucleus", p_nucleus=0.5, max_new_tokens=500
        )[0].tolist()
    )

    cond_res_list = []
    for prompt in cond_token_list:
        cond_res = decode(
            model.generate(
                torch.tensor(prompt).unsqueeze(0).long().to(device),
                method="top-k",
                k=5,
                max_new_tokens=500,
            )[0].tolist()
        )
        cond_res_list.append(cond_res)

    cond_res_list = "\n\n".join(cond_res_list)

    generation_text = f"""{MODEL_NAME} Output, Step {step}
    UNCONDITIONAL GENERATION:

    Top-k (5) (500 max_tokens):
    {uncond_res1}

    Nucleus (0.5) (500 max_tokens):
    {uncond_res2}

    #####################################################
    CONDITIONAL GENERATION (Top-k (5), 500 max_tokens):
    {cond_res_list}
    -----------------------------------------------------
    """
    with open(generation_file_path, "a") as file:
        file.write(generation_text)
    print(generation_text)

In [60]:
## Driver code
train(model, optimizer, scaler, device)

Step 50/50000 | Running Avg Train Loss: 4.22206 | Grad Norm: 1.482 | Running Avg Tokens/Sec: 6810.113 | Running Avg Route Frac: [[0.528 0.472]
 [0.420 0.580]]
Step 100/50000 | Running Avg Train Loss: 3.78290 | Grad Norm: 1.492 | Running Avg Tokens/Sec: 6853.738 | Running Avg Route Frac: [[0.515 0.485]
 [0.454 0.546]]


KeyboardInterrupt: 

In [810]:
## Driver code
train(model, optimizer, device)

Step 50/50000 | Running Avg Train Loss: 4.26320 | Grad Norm: 1.009 | Running Avg Tokens/Sec: 6688.876
Step 100/50000 | Running Avg Train Loss: 3.80411 | Grad Norm: 1.250 | Running Avg Tokens/Sec: 6746.328
Step 150/50000 | Running Avg Train Loss: 3.56538 | Grad Norm: 1.857 | Running Avg Tokens/Sec: 6763.554
Step 200/50000 | Running Avg Train Loss: 3.39562 | Grad Norm: 1.342 | Running Avg Tokens/Sec: 6709.624
Step 250/50000 | Running Avg Train Loss: 3.27404 | Grad Norm: 1.251 | Running Avg Tokens/Sec: 6643.860
Est. Val Loss: 2.62356 | Est. Aux Loss: 2.01001
switch_4_LAYERs_4_HEAD_128_EMBD_DIM_128_SEQ_LEN Output, Step 250
    UNCONDITIONAL GENERATION:

    Top-k (5) (500 max_tokens):
    

Are  coor oooto thelanttsst sond bateses m man m win m bes tounderthe an withilouneselle t thirer toulir seng t terllore bour athes w b wore ssessearate alllllese serol lallel soulland ssss wind t ararathilan s tor thind angor sens tene sthan anerouss arl s astele toung thit wer therer seras we wes than

KeyboardInterrupt: 

## We see similar training speed, but we have 200,000 more parameters to work with.

### it remains to be seen if our model is more sample-efficient. Our loss now includes an auxiliary loss which inflates the numbers in comparison to a vanilla transformer. For efficient comparison we subtract the aux loss * its coefficient (0.01), but unfortunately it seems like the loss is still slightly higher than a vanilla transformer and its speed slightly slower. More investigation is needed.

#### Aside——on a T4 GPU, we see approx 100k samples/sec as opposed to 6.7k

# Generation

In [None]:
print(
    decode(
        model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[
            0
        ].tolist()
    )
)

In [115]:
input_txt = "TERRY: thou art"
ctx = encode(input_txt)
print(
    decode(
        model.generate(torch.tensor(ctx).unsqueeze(0).long(), max_new_tokens=500)[
            0
        ].tolist()
    )
)

TERRY: thou arte a my she
Which have may and of contain.

DUKE VINCENTIO:
Good as I tall no knrow, for shalt agarnt
And mpo; and Kong a m, not outhpile Mesce.

HENRY VI:
When I will thy lookess, oner the pexstrey
The the the hee voagh gresed livioe.

MENCIO:
My her callis his peaced of to that
We where's by shall bore: as shall myselvea
The plender feuls!

PAPELLANT:
In the the into balby me dods to love,
In but the giving of nyou ase. I tall it-me?'e Goveuling
The theer haught art praver count madeng Camen:
T
