# Constants and Setup

In [1]:
path = "./"
root = "../"

SEED = 23

LR = 1e-3
BATCH_SIZE = 16
SEQ_LEN = 128
MAX_ITERS = 50000  # max num batches to train
PRINT_ITERS = 50  # frequency to print train loss
EVAL_ITERS = 500  # frequency to evaluate val loss and generate text from model
EVAL_ITER_COUNT = 100  # number of batches to estimate val loss with
# given a 10% val split, we have 111540 char, so 100 batches * batch size 16 * seq len 128 = roughly 2x num of chars chosen
# EVAL_ITER_COUNT * BATCH_SIZE
SAVE_ITERS = 1000  # frequency to save model and losses
N_EMBD = 128
N_FF = N_EMBD * 4
N_HEAD = 4
N_LAYER = 4

# automatic mixed precision (will be disabled if CPU, not available)
USE_AMP = True

In [2]:
## SMoE-specific hyperparameters TODO
# CAPACITY_FACTOR = 1.25
# N_EXPERT = 2
# AUX_LOSS_COEF = 0.01

MODEL_NAME = f"smoe_{N_LAYER}_LAYERs_{N_HEAD}_HEAD_{N_EMBD}_EMBD_DIM_{SEQ_LEN}_SEQ_LEN"
print("Model Name:", MODEL_NAME)

Model Name: switch_4_LAYERs_4_HEAD_128_EMBD_DIM_128_SEQ_LEN


# Imports

In [3]:
import json
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchinfo import summary
from tqdm import tqdm

In [4]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

root = root if drive is None else "/content/drive/MyDrive/moe-kit"
path = path if drive is None else "/content/drive/MyDrive/moe-kit/switch_transformer"

# cannot train in mixed precision on CPU (GradScaler needs cuda)
USE_AMP = USE_AMP if device.type == "cuda" else False

In [6]:
import sys

sys.path.append(root)

from utils import set_seed

In [7]:
set_seed(SEED)

# Model (TODO——implement SMoE)

# Action items:

## implement noisy top-k gating (gate weight Wg, softplus of x•W_noise + ε~N(0, 1)), set non-topk to -∞, take softmax
## compute coefficient of varation^2 for importance loss 
## router class to dispatch inputs, collect expert outputs, combine and return
## maybe add ST-MoE's router z loss for fun (logsumexp)

In [None]:
###### below is switch code unmodified yet

In [None]:
## Imports
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange

## Model
class MLP(nn.Module):
    def __init__(self, n_embd, n_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_ff),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(n_ff, n_embd),
        )

    def forward(self, x):
        return self.net(x)


class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head, dropout=0.1):
        super().__init__()

        self.n_embd = n_embd
        self.n_head = n_head
        self.head_dim = (
            n_embd // n_head
        )  # Dimension of each head's key, query, and value

        self.drop = nn.Dropout(p=dropout)

        self.query = nn.Linear(n_embd, n_embd, bias=False)
        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        self.out = nn.Linear(n_embd, n_embd, bias=False)

    def split_heads(self, x):
        B, S, D = x.size()
        # split dimension into n_head * head_dim, then transpose the sequence length w/ n_head
        # output: [B, n_head, S, head_dim]
        return x.view(B, S, self.n_head, self.head_dim).transpose(1, 2)

    def combine_heads(self, x):
        B, _, S, head_dim = x.size()  # _ is n_head which we will merge
        # output: [B, S, n_embd]
        return x.transpose(1, 2).contiguous().view(B, S, self.n_embd)

    def scaled_dot_product(self, q, k, v, dropout, mask=None):
        # q,k,v are [B, n_head, S, head_dim]
        # wei = [B, n_head, S, S]
        wei = q @ k.transpose(-2, -1) / np.sqrt(self.head_dim)
        # mask is [B, 1, S, S]
        if mask is not None:
            wei = wei.masked_fill(mask, float("-inf"))
        wei = dropout(F.softmax(wei, dim=-1))
        out = wei @ v
        return out

    def forward(self, x, mask=None):
        # x: (B, S, n_embd)
        # Step 1 and 2: Project full query, key, value, then split via reshaping
        q = self.split_heads(self.query(x))
        k = self.split_heads(self.key(x))
        v = self.split_heads(self.value(x))

        # Step 3: Compute scaled dot-product attention with causal mask
        attn = self.scaled_dot_product(q, k, v, self.drop, mask)

        # Step 4 and 5: Concatenate attention scores, return projected output matrix
        out = self.out(self.combine_heads(attn))  # (B, S, n_embd)
        return out


class SwitchFeedForward(nn.Module):
    """
    Switch FeedForward Layer.
    TODO
    Inputs:
        -
    Returns: Tuple of length 4
        -Layer output
        -Token count per expert (for auxiliary loss)
        -Sum of token probs per expert (for auxiliary loss)
        -Token count dropped (for logging)
    """

    def __init__(
        self,
        d_model,
        n_ff,
        use_amp,
        capacity_factor,
        drop_tokens: bool,
        n_experts: int,
        expert: MLP,
        noise=0.1,
        dropout=0.1,
    ):
        super().__init__()

        self.use_amp = use_amp
        self.capacity_factor = capacity_factor
        self.n_experts = n_experts
        self.drop_tokens = drop_tokens
        self.noise = noise

        self.experts = nn.ModuleList(
            [copy.deepcopy(expert(d_model, n_ff, dropout)) for _ in range(n_experts)]
        )

        # Routing layer
        self.switch = nn.Linear(d_model, n_experts)

    def forward(self, x):

        x = x.float()  # cast to float32 for stability
        B, S, n_embd = x.shape

        # apply multiplicative jitter
        if self.noise > 0:
            x *= torch.empty_like(x).uniform_(1.0 - self.noise, 1.0 + self.noise)

        x = rearrange(x, "b s d -> (b s) d")
        probs = F.softmax(self.switch(x), dim=-1)  # (b*s) x n_experts

        # convert to half precision
        if self.use_amp:
            probs = probs.half()
            x = x.half()

        max_prob, route_idx = torch.max(probs, dim=-1)

        # compute expert capacity
        # (num tokens * CF) / n_experts
        capacity = int(x.shape[0] * self.capacity_factor / self.n_experts)

        # obtain token idx for each expert
        # list of len (n_expert) of tensors indicating token idx going to that expert
        token_indices = [
            torch.eq(route_idx, i).nonzero() for i in range(self.n_experts)
        ]

        # num tokens of each expert
        # new_tensor ensures same dtype and device
        expert_token_counts = x.new_tensor(
            [len(token_indices[i]) for i in range(self.n_experts)]
        )

        # check capacity and drop tokens
        dropped = []
        if self.drop_tokens:
            for i in range(self.n_experts):
                if expert_token_counts[i] > capacity:
                    # no shuffle——drop earlier tokens
                    dropped.append(token_indices[i][capacity:])
                    token_indices[i] = token_indices[i][:capacity]

        # feed tokens to relevant experts
        out = torch.zeros_like(x)
        expert_out = [
            self.experts[i](x[token_indices[i], :]) for i in range(self.n_experts)
        ]

        for i in range(self.n_experts):
            out[token_indices[i], :] = expert_out[i]
        if dropped:
            # concat dropped tokens, skip experts
            dropped = torch.cat(dropped)
            out[dropped, :] = x[dropped, :]

        # scale values by gating probabilities
        # unsqueeze max_prob for broadcasting
        out * rearrange(max_prob, "num_tokens -> num_tokens ()")

        # separate batch_size and seq_len
        # do not use SEQ_LEN or BATCH_SIZE. if inference, may have batch_size=1 and/or smaller seq len, for example
        out = rearrange(out, "(b s) d -> b s d", s=S)

        return out, expert_token_counts, probs.sum(0), len(dropped)


class Block(nn.Module):
    def __init__(
        self,
        n_embd,
        n_head,
        n_ff,
        norm_first,
        use_amp,
        switch,
        capacity_factor,
        drop_tokens,
        n_experts,
        expert,
        noise=0.1,
        mlp_dropout=0.1,
        expert_dropout=0.4,
    ):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, mlp_dropout)
        if switch:
            self.ff = SwitchFeedForward(
                n_embd,
                n_ff,
                use_amp,
                capacity_factor,
                drop_tokens,
                n_experts,
                expert=MLP,
                noise=noise,
                dropout=mlp_dropout,
            )  # no change to dropout here
        else:
            self.ff = MLP(n_embd, n_ff, dropout=mlp_dropout)

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.norm_first = norm_first
        self.mlp_drop = nn.Dropout(p=mlp_dropout)
        self.expert_drop = nn.Dropout(p=expert_dropout)
        self.switch = switch

    def forward(self, x, mask):
        # residual connection (stream)

        # pre layer norm
        if self.norm_first:
            x = x + self.mlp_drop(self.sa(self.ln1(x), mask))
            if self.switch:
                out, expert_token_counts, prob_sum, n_dropped = self.ff(self.ln2(x))
                x = x + self.expert_drop(out)  # expert dropout
                return x, expert_token_counts, prob_sum, n_dropped
            else:
                x = x + self.mlp_drop(self.ff(self.ln2(x)))
        else:
            x = self.ln1(x + self.mlp_drop(self.sa(x, mask)))
            if self.switch:
                out, expert_token_counts, prob_sum, n_dropped = self.ff(x)
                x = self.ln1(x + self.expert_drop(out))  # expert dropout
                return x, expert_token_counts, prob_sum, n_dropped
            else:
                x = self.ln2(x + self.mlp_drop(self.ff(x)))

        return x


class PositionalEncoding(nn.Module):
    """
    Formula taken from the original Transformer paper:
    PE(pos, 2i (even)) = sin(pos/(10000^{2i/d_model}))
    PE(pos, 2i+1 (odd)) = cos(pos/(10000^{2i/d_model}))

    See reference for more details:
    https://kikaben.com/transformers-positional-encoding/
    """

    def __init__(self, d_model, max_len):
        # just set d_model = n_embd and max_len = seq_len
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)  # [max_len, 1]
        divisor = torch.exp(
            torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)
        )  # [d_model / 2, half for each of sin and cos]
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * divisor)
        pe[:, 1::2] = torch.cos(position * divisor)
        self.register_buffer(
            "pe", pe
        )  # result: self.pe = [max_len, d_model], mapping each token index to a vector of length d_model as desired

    def forward(self, x):
        # x = torch.arange(seq_length) has shape [seq_length], so x.size(0) extracts it, then we index self.pe for the first seq_length mappings
        # note we do not add the positional embeddings to x itself yet, we simply return them
        # output = (seq_length, d_model=n_embd)
        return self.pe[: x.size(0)]


### TODO


class SMoE(nn.Module):
    """
    TODO
    switch (bool): Indicates whether to insert Switch MoE layers
    """

    def __init__(
        self,
        vocab_size,
        seq_length,
        n_embd,
        n_head,
        n_ff,
        n_layer,
        device,
        norm_first=True,
        use_amp=False,
        switch=False,
        capacity_factor=None,
        drop_tokens=None,
        n_experts=None,
        expert=None,
        noise=0.1,
        mlp_dropout=0.1,
        expert_dropout=0.4,
    ):
        super().__init__()

        if switch:
            assert (
                isinstance(capacity_factor, (int, float))
                and isinstance(drop_tokens, bool)
                and isinstance(n_experts, int)
                and expert is not None
            ), "For a switch transformer, you must provide a numeric `capacity_factor`, boolean `drop_tokens`, \
                    integer `n_experts` and a MLP class `expert` to serve as the experts."

        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = PositionalEncoding(n_embd, seq_length)

        ### Alternate blocks with switch = True/False
        switch_args = np.full((n_layer,), False)
        if switch:
            switch_args[::2], switch_args[1::2] = True, False
        self.blocks = nn.Sequential(
            *[
                Block(
                    n_embd,
                    n_head,
                    n_ff,
                    norm_first,
                    use_amp,
                    switch_args[i],
                    capacity_factor,
                    drop_tokens,
                    n_experts,
                    expert,
                    noise,
                    mlp_dropout,
                    expert_dropout,
                )
                for i in range(n_layer)
            ]
        )

        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.drop = nn.Dropout(mlp_dropout)
        self.switch = switch
        self.seq_length = seq_length
        self.device = device
        self.use_amp = use_amp
        self.init_params()

    # weight initialization (Xavier uniform)
    def init_params(self, default_initialization=False):
        if not default_initialization:
            for name, p in self.named_parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)

    # Remark: Xavier normal is not supported at this time.

    def get_causal_mask(self, x):
        """
        Generates causal mask for decoding
        """
        B, S = x.shape  # x = (batch_size x seq_len)
        attn_shape = (B, 1, S, S)
        subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype(
            "uint8"
        )  # k = 1 shifts the diagonal, so that the main diagonal gets 0's
        return (torch.from_numpy(subsequent_mask) == 0).to(self.device)
        # True along main diagonal + below, False elsewhere

    def forward(self, x):

        x = x.to(torch.int64)
        B, S = x.shape

        # get mask
        mask = self.get_causal_mask(x).to(self.device)
        # mask = (B x 1 x S x S)

        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(S))
        x = self.drop(tok_emb + pos_emb)
        # (B, S, n_embd)

        expert_token_counts, prob_sum, n_dropped = [], [], []
        for block in self.blocks:
            if block.switch:
                x, counts_i, prob_sum_i, n_dropped_i = block(x, ~mask)
                expert_token_counts.append(counts_i)
                prob_sum.append(prob_sum_i)
                n_dropped.append(n_dropped_i)
            else:
                x = block(x, ~mask)  # (B, S, n_embd)
        # negate mask to fill originally False values with -inf later
        logits = self.lm_head(x)  # (B, S, vocab_size)

        if self.switch:
            return (
                logits,
                torch.stack(expert_token_counts),
                torch.stack(prob_sum),
                n_dropped,
            )
        return logits

    def generate(
        self,
        input_ids,
        method="multinomial",
        max_new_tokens=1000,
        temp=None,
        num_beams=None,
        p_nucleus=None,
        k=None,
    ):

        # input_ids begins as (B, S)
        self.eval()

        for _ in range(max_new_tokens):
            if method in ["multinomial", "temperature", "greedy", "nucleus", "top-k"]:
                # i) Truncate to the most recent `max length` tokens
                text_cond = input_ids[:, -self.seq_length :]
                # ii) Retrieve predictions
                with torch.no_grad():
                    with torch.autocast(
                        device_type=self.device.type,
                        dtype=torch.bfloat16,
                        enabled=self.use_amp,
                    ):
                        if self.switch:
                            logits, _, _, _ = self(text_cond)
                        else:
                            logits = self(text_cond)
                # model output: (B, S, vocab_size)
                # iii) Find last token logits of each
                logits = logits[:, -1, :]  # (B, vocab_size)

                # if temperature sampling, divide logits by temp before applying softmax
                if method == "temperature":
                    logits = logits / temp

                # iv) Take softmax along each
                probs = F.softmax(logits, dim=-1)

                # v) Sample next token depending on method
                if method == "greedy":
                    next_idx = probs.argmax(dim=-1).unsqueeze(-1)

                elif method in ["multinomial", "temperature", "nucleus", "top-k"]:
                    if method == "nucleus":
                        assert (
                            p_nucleus is not None
                            and (0 < p_nucleus)
                            and (p_nucleus <= 1)
                        )

                        sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
                        prob_cumsum = sorted_probs.cumsum(dim=-1)
                        idx_remove = prob_cumsum > p_nucleus
                        # shift one right to ensure the first token is above the threshold
                        idx_remove[..., 1:] = idx_remove[..., :-1].clone()
                        idx_remove[..., 0] = False
                        # retrieve original indices by reverse-sorting
                        remove_mask = idx_remove.gather(
                            dim=-1, index=sorted_idx.argsort(dim=-1)
                        )
                        # ^ specifically, we do this by first argsorting the indices which were returned from argsort
                        # you can show that this returns indices that when used to subset a sorted array, returns the original array in unsorted order
                        # https://stackoverflow.com/questions/52127723/pytorch-better-way-to-get-back-original-tensor-order-after-torch-sort
                        probs[remove_mask] = 0

                    if method == "top-k":
                        remove_mask = (
                            probs < torch.topk(probs, k).values[..., -1, None]
                        )  # topk returns (B, 1), leaving only the
                        # kth largest probs (i.e. the cutoff value for each). Then mask is same size as probs (B, vocab_size)
                        probs[remove_mask] = 0

                    # Sample probabilistically via scores
                    next_idx = torch.multinomial(probs, num_samples=1)  # (B, 1)

                # vi) Autoregressively append to input_text
                input_ids = torch.cat((input_ids, next_idx), dim=-1)

                # now input_text = (B, S + 1)

        return input_ids

# Data

In [9]:
with open(f"{root}/data/tiny-shakespeare.txt", "r") as f:
    text = f.read()

chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)
print(f"Vocab: {chars}")
print(f"Vocab size: {VOCAB_SIZE}")

Vocab: ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Vocab size: 65


In [10]:
# Prepare mappings / tokenizer
# create a mapping from characters to integers
txt2idx = {ch: i for i, ch in enumerate(chars)}
idx2txt = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [txt2idx[c] for c in s]
decode = lambda l: "".join([idx2txt[i] for i in l])

print(encode("tiny-shakespeare is sick"))
print(decode(encode("tiny-shakespeare is sick")))

[58, 47, 52, 63, 7, 57, 46, 39, 49, 43, 57, 54, 43, 39, 56, 43, 1, 47, 57, 1, 57, 47, 41, 49]
tiny-shakespeare is sick


In [11]:
# tokenizer data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))  # 90-10 split
train_data = data[:n]
val_data = data[n:]
print("train_data len:", len(train_data), "val_data len:", len(val_data))


def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - SEQ_LEN, (BATCH_SIZE,))
    x = torch.stack([data[i : i + SEQ_LEN] for i in ix])
    y = torch.stack([data[i + 1 : i + SEQ_LEN + 1] for i in ix])
    return x.to(device), y.to(device)

train_data len: 1003854 val_data len: 111540


# Training setup

In [12]:
set_seed(SEED)
model = SMoE(
    VOCAB_SIZE,
    SEQ_LEN,
    N_EMBD,
    N_HEAD,
    N_FF,
    N_LAYER,
    device=device,
    norm_first=True,
    switch=True,
    capacity_factor=CAPACITY_FACTOR,
    drop_tokens=True,
    n_experts=N_EXPERT,
    expert=MLP,
    use_amp=USE_AMP,
    mlp_dropout=0.1,
    expert_dropout=0.4,
)

# Gradient scaling for mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)



In [25]:
summary(model)

Layer (type:depth-idx)                             Param #
Transformer                                        --
├─Embedding: 1-1                                   8,320
├─PositionalEncoding: 1-2                          --
├─Sequential: 1-3                                  --
│    └─Block: 2-1                                  --
│    │    └─MultiHeadAttention: 3-1                65,536
│    │    └─SwitchFeedForward: 3-2                 263,682
│    │    └─LayerNorm: 3-3                         256
│    │    └─LayerNorm: 3-4                         256
│    │    └─Dropout: 3-5                           --
│    │    └─Dropout: 3-6                           --
│    └─Block: 2-2                                  --
│    │    └─MultiHeadAttention: 3-7                65,536
│    │    └─MLP: 3-8                               131,712
│    │    └─LayerNorm: 3-9                         256
│    │    └─LayerNorm: 3-10                        256
│    │    └─Dropout: 3-11                          -

## Computing activated parameter count

In [26]:
### TODO

807745

# Training

In [14]:
def calc_ce_loss(logits, targets):
    """
    TODO
    Computes cross-entropy loss.
    Inputs:
        -logits: Model output of shape (B, S, vocab_size)
        -counts:
    """
    B, S, C = logits.shape
    logits = logits.view(B * S, C)
    targets = targets.view(B * S)
    loss = F.cross_entropy(logits, targets)
    return loss

In [15]:
def calc_aux_loss(counts, prob_sum):
    """
    Computes Switch Transformer auxiliary loss.
    Inputs:
        -counts: Number of tokens passed to each expert in each switch layer (num_switch_layers x n_experts)
        Note this is NOT equivalent to n_layer; num_switch_layers = (n_layer//2) + (n_layer % 2)
        -prob_sum: Sum of probs across all tokens for each expert (num_switch_layers x n_experts)
    """

    # total number of tokens routed in that layer
    token_count = counts.sum(dim=-1, keepdims=True)

    # prop of tokens dispatched to each expert
    route_frac = counts / token_count

    # fraction of total probability allocated for each expert
    # recall prob_sum := softmaxed probs, which added to 1 across the experts for each token
    # we divide by num_tokens so that the overall 2D scalar sum of prob_frac is 1
    # intuitively we are forcing the total prob for each layer across the experts to be 1 so we can take proportions,
    # the same way as above
    prob_frac = prob_sum / token_count

    # Auxiliary loss
    # L = N \sum_{i=1}^N f_i • P_i
    aux_loss = N_EXPERT * (route_frac * prob_frac).sum()
    return aux_loss

In [16]:
def train(
    model,
    optimizer,
    scaler,
    device,
    train_loss_list=None,
    val_loss_list=None,
    train_time_list=None,
    val_aux_loss_list=None,
    dropped_list=None,
):

    train_losses = train_loss_list if train_loss_list is not None else []
    val_losses = val_loss_list if val_loss_list is not None else []
    train_times = train_time_list if train_time_list is not None else []
    val_aux_losses = val_aux_loss_list if val_aux_loss_list is not None else []
    dropped = dropped_list if dropped_list is not None else []

    model.train()
    model.to(device)

    # Set up prompt generation
    generation_file_path = f"{path}/outputs/OUTPUT_{MODEL_NAME}_SEED_{SEED}.txt"
    empty_tokens = torch.zeros((1, 1), dtype=torch.long).to(device)
    cond_prompts = ["KING TERRY: Thou art", "DANIEL: Ay, my dear,"]

    cond_token_list = [encode(prompt) for prompt in cond_prompts]

    for step in range(MAX_ITERS):

        start = time.perf_counter()

        optimizer.zero_grad(set_to_none=True)
        inputs, targets = get_batch("train")
        with torch.autocast(
            device_type=device.type, dtype=torch.bfloat16, enabled=USE_AMP
        ):
            if model.switch:
                logits, counts, prob_sum, n_dropped = model(inputs)
                loss = calc_ce_loss(logits, targets)
                aux_loss = calc_aux_loss(counts, prob_sum)
                loss += AUX_LOSS_COEF * aux_loss
                drop_frac = (np.array(n_dropped) / (BATCH_SIZE * SEQ_LEN)).tolist()
                dropped.append(drop_frac)  # for logging purposes
            else:
                logits = model(inputs)
                loss = calc_ce_loss(logits, targets)

        train_losses.append(loss.item())  # for printing

        scaler.scale(loss).backward()
        # loss.backward()

        # Monitor gradient norm
        scaler.unscale_(optimizer)

        with torch.autocast(
            device_type=device.type, dtype=torch.bfloat16, enabled=USE_AMP
        ):
            grads = [
                param.grad.detach().flatten()
                for param in model.parameters()
                if param.grad is not None
            ]
            norm = torch.cat(grads).norm()

        train_time = time.perf_counter() - start
        tokens_per_sec = (1 / train_time) * BATCH_SIZE * SEQ_LEN
        train_times.append(tokens_per_sec)

        scaler.step(optimizer)
        scaler.update()
        # optimizer.step()

        # print training statistics
        if step % PRINT_ITERS == 0 and step != 0:
            print(
                f"Step {step}/{MAX_ITERS} | Running Avg Train Loss: {np.mean(train_losses):.5f} |",
                f"Grad Norm: {norm:.3f} | Running Avg Tokens/Sec: {np.mean(train_times):.3f} |",
            )

        # estimate val loss, generate text and save
        if step % EVAL_ITERS == 0 and step != 0:
            val_losses, val_aux_losses = estimate_loss(
                model, val_losses, val_aux_losses, device
            )
            generate(model, generation_file_path, empty_tokens, cond_token_list, step)
            model.train()

        # save model, val losses (not train_losses), train times
        if step % SAVE_ITERS == 0 and step != 0:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                f"{path}/checkpoints/{MODEL_NAME}_STEP_{step}_SEED_{SEED}.pt",
            )

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_losses.json", "w"
            ) as f:
                json.dump(val_losses, f)

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_aux_losses.json", "w"
            ) as f2:
                json.dump(val_aux_losses, f2)

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_train_times.json", "w"
            ) as f3:
                json.dump(
                    train_times[EVAL_ITERS::EVAL_ITERS], f3
                )  # match freq of val_losses
                # note this means if you load from checkpoint to continue training you will have a sparser train_times
                # list in computing running avg

            with open(
                f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_dropped.json", "w"
            ) as f4:
                json.dump(dropped[EVAL_ITERS::EVAL_ITERS], f4)  # same here

In [17]:
@torch.no_grad()
def estimate_loss(model, val_losses, val_aux_losses, device):
    model.eval()
    losses = torch.zeros(EVAL_ITER_COUNT)
    aux_losses = torch.zeros(EVAL_ITER_COUNT)
    for k in range(EVAL_ITER_COUNT):
        inputs, targets = get_batch("test")
        with torch.autocast(
            device_type=device.type, dtype=torch.bfloat16, enabled=USE_AMP
        ):
            if model.switch:
                logits, counts, prob_sum, n_dropped = model(inputs)
                losses[k] = calc_ce_loss(logits, targets)
                aux_losses[k] = calc_aux_loss(counts, prob_sum)
                losses[k] += AUX_LOSS_COEF * aux_losses[k]
            else:
                logits = model(inputs)
                losses[k] = calc_ce_loss(logits, targets)
    val_loss, val_aux_loss = losses.mean().item(), aux_losses.mean().item()
    val_losses.append(val_loss)
    val_aux_losses.append(val_aux_loss)  # track separate aux loss for logging
    # keep model in eval, next call is to .generate() anyway
    print(f"Est. Val Loss: {val_loss:.5f} | Est. Aux Loss: {val_aux_loss:.5f}")
    return val_losses, val_aux_losses

In [18]:
def generate(model, generation_file_path, empty_tokens, cond_token_list, step):

    set_seed(42)

    uncond_res1 = decode(
        model.generate(empty_tokens, method="top-k", k=5, max_new_tokens=500)[
            0
        ].tolist()
    )
    uncond_res2 = decode(
        model.generate(
            empty_tokens, method="nucleus", p_nucleus=0.5, max_new_tokens=500
        )[0].tolist()
    )

    cond_res_list = []
    for prompt in cond_token_list:
        cond_res = decode(
            model.generate(
                torch.tensor(prompt).unsqueeze(0).long().to(device),
                method="top-k",
                k=5,
                max_new_tokens=500,
            )[0].tolist()
        )
        cond_res_list.append(cond_res)

    cond_res_list = "\n\n".join(cond_res_list)

    generation_text = f"""{MODEL_NAME} Output, Step {step}
    UNCONDITIONAL GENERATION:

    Top-k (5) (500 max_tokens):
    {uncond_res1}

    Nucleus (0.5) (500 max_tokens):
    {uncond_res2}

    #####################################################
    CONDITIONAL GENERATION (Top-k (5), 500 max_tokens):
    {cond_res_list}
    -----------------------------------------------------
    """
    with open(generation_file_path, "a") as file:
        file.write(generation_text)
    print(generation_text)

In [60]:
## Driver code
train(model, optimizer, scaler, device)

Step 50/50000 | Running Avg Train Loss: 4.22206 | Grad Norm: 1.482 | Running Avg Tokens/Sec: 6810.113 | Running Avg Route Frac: [[0.528 0.472]
 [0.420 0.580]]
Step 100/50000 | Running Avg Train Loss: 3.78290 | Grad Norm: 1.492 | Running Avg Tokens/Sec: 6853.738 | Running Avg Route Frac: [[0.515 0.485]
 [0.454 0.546]]


KeyboardInterrupt: 

In [810]:
## Driver code
train(model, optimizer, device)

Step 50/50000 | Running Avg Train Loss: 4.26320 | Grad Norm: 1.009 | Running Avg Tokens/Sec: 6688.876
Step 100/50000 | Running Avg Train Loss: 3.80411 | Grad Norm: 1.250 | Running Avg Tokens/Sec: 6746.328
Step 150/50000 | Running Avg Train Loss: 3.56538 | Grad Norm: 1.857 | Running Avg Tokens/Sec: 6763.554
Step 200/50000 | Running Avg Train Loss: 3.39562 | Grad Norm: 1.342 | Running Avg Tokens/Sec: 6709.624
Step 250/50000 | Running Avg Train Loss: 3.27404 | Grad Norm: 1.251 | Running Avg Tokens/Sec: 6643.860
Est. Val Loss: 2.62356 | Est. Aux Loss: 2.01001
switch_4_LAYERs_4_HEAD_128_EMBD_DIM_128_SEQ_LEN Output, Step 250
    UNCONDITIONAL GENERATION:

    Top-k (5) (500 max_tokens):
    

Are  coor oooto thelanttsst sond bateses m man m win m bes tounderthe an withilouneselle t thirer toulir seng t terllore bour athes w b wore ssessearate alllllese serol lallel soulland ssss wind t ararathilan s tor thind angor sens tene sthan anerouss arl s astele toung thit wer therer seras we wes than

KeyboardInterrupt: 

# Generation

In [None]:
print(
    decode(
        model.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[
            0
        ].tolist()
    )
)

In [115]:
input_txt = "TERRY: thou art"
ctx = encode(input_txt)
print(
    decode(
        model.generate(torch.tensor(ctx).unsqueeze(0).long(), max_new_tokens=500)[
            0
        ].tolist()
    )
)

TERRY: thou arte a my she
Which have may and of contain.

DUKE VINCENTIO:
Good as I tall no knrow, for shalt agarnt
And mpo; and Kong a m, not outhpile Mesce.

HENRY VI:
When I will thy lookess, oner the pexstrey
The the the hee voagh gresed livioe.

MENCIO:
My her callis his peaced of to that
We where's by shall bore: as shall myselvea
The plender feuls!

PAPELLANT:
In the the into balby me dods to love,
In but the giving of nyou ase. I tall it-me?'e Goveuling
The theer haught art praver count madeng Camen:
T
