<a href="https://colab.research.google.com/github/sanggusti/30-days-kaggle/blob/main/Gusti_Winata_ScholarsTakehome.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cohere Labs Scholars 2️⃣0️⃣2️⃣6️⃣ : TakeHome Assignment

# **Background**

Welcome to the Cohere Labs Scholars Program Take-Home Challenge! This exercise is designed to allow you to showcase your engineering and problem solving skills. The Assessment consists of different challenges including:

*   Identifying bugs, and getting the code working. This is designed to test your ability to grapple with real world engineering challenges.
*   Testing your ability to generate code for a specified problem.
*   An opportunity for you to attempt an optional challenge question that extends the original problem set.

These tasks were chosen as a setting to see how you think about problems, even if they are not in your own research field of interest. The tasks and dataset are not meant to be indicative of the research goals of the Scholar Program. We purposefully have selected a simple toy problem so the focus is on how you think, and does not require significant machine learning resources (can be run in this colab).

Good luck! 🍀

**How to Use and Submit this Document?**

*   Make a copy of this document and rename it **Firstname_Lastname_ScholarsTakehome**
* Once you have completed all tasks:
  * Save and pin your revisions
  * Download the colab as a .ipynb file
  * Submit the assignment via the submission link you received via email (subject line: "Cohere Labs: Research Scholar Program - Next Steps") by **September 16 by 11pm PDT**.


### This Coding Challenge(🚨 25 points) consists of 4 parts :

1. **Debugging custom SmolMoELM** 🔍🐛[*10 points*]
2. **Upcycling a Dense Model into an MoE 🔄 🚴** [*3 points*]
3. **Continued Pretraining 📚💪** [*7 points*]
4. **Exploring The Unknown 🧙 ✨** [*5 points*]

Each of these build on top of each other so you are encouraged to work through them in order.

**NOTE**: Part 4 can also be attempted independently(*if you don't wish to build on the previous section*)

## **Coding Challenge Part 1: Debugging custom Smol`MoE`LM 🔍🐛 [🚨 10 points]**

**Mixture of Experts (MoE)** are all the rage in 2025, powering some of the most advanced large-scale AI systems. In this coding challenge, you are required to dive into the core idea behind MoE and fix a bare-bones implementation.

We have **🚨 10 bugs** in the following implementation.
There is a section `3.Test` for your convenience to verify you have correctly identified all the bugs(Both `Check #1` and `Check #2` will help you confirm this).

**Rules**:
1. **Bug Definition:**
  - There are **🚨 10 bugs** to be fixed.
  - A bug is *defined as **{incorrect, missing, unnecessary}** lines of code*.
  - You earn 1 point for each correctly identified and fixed bug.
2. **Fix Guidelines:**
  - You are encouraged to make the smallest possible fix, wherever possible (e.g. edit a line instead of replacing it entirely).
  - Do not optimize the code in any way (combine functions, change variable names, etc) ; **only fix the bugs**. The implementation is *intentionally* non-optimized but valid.
  - **Note:** Some bugs may require more than one line of correction/addition.

3. **Documentation:** Document each fix by adding a comment on the line above the fix: : `### BUG FIX ###`.
4. **Sections:** *1. Setup [Helper Functions]* and *3. Test* don't contain bugs and shouldn't be changed.
5. **Multiple Bug Fixes:** Do not worry about possibly solving multiple bugs with a single fix. Should that rare case arise, you will still be awarded with the correct number of points as long as the fixes are the only changes made.
6. **Rewriting the Implementation:** Rewriting the implementation to get around the bugs will not count towards any points. You are to strictly work within the implementation extending/modifying it only as much as required(indicated by *Step 3*)
7. **Submission:** Your final submission should be the exact same notebook except with your proposed fixes in the cells and the respective comments as per Rule #3.

In [None]:
# Example of a bug fix

def _calc_square_root(x):
    ### BUG FIX ###
    # ans = x*2
    ans = x**(1/2)
    return ans

### 1. Setup [Helper Functions]

In [2]:
######################################################################################################################
############################################## DO NOT CHANGE[START] ##################################################
######################################################################################################################

# # Download the weights from HF
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="dsouzadaniel/C4AI_SmolMoELM",
                       filename="trial_weights.pt",
                      local_dir=".",)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


trial_weights.pt:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

In [3]:
# Libraries
import time
import math
import torch
import numpy as np
import pandas as pd
from torch import nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer


def timed(fn):
    '''Simple Timing Decorator'''
    def wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        out = fn(*args, **kwargs)
        total_time = time.perf_counter() - start_time
        print(f"time={total_time:.3f}s")
        return out
    return wrapper

def labelthis(label):
    '''Simple Label Assigner'''
    def deco(fn):
        fn.label = label
        return fn
    return deco

def pretty_dt(s: float) -> str:
    '''Print Time Taken(but pretty :) )'''
    if s < 1e-6: return f"{s*1e9:.0f} ns"
    if s < 1e-3: return f"{s*1e6:.0f} µs"
    if s < 1:    return f"{s*1e3:.0f} ms"
    if s < 60:   return f"{s:.3f} s"
    h, s = divmod(s, 3600); m, s = divmod(s, 60)
    return (f"{int(m)}m {int(s)}s" if h < 1 else f"{int(h)}h {int(m)}m {int(s)}s")

@timed
def __generate(model, tokenizer, inputs, num_tokens):
    '''Helper function. Recommended to use via `generation_compare`'''
    collect = []
    for _ in range(num_tokens):
        output = model(**inputs)
        output_id = torch.argmax(output['logits'][0,-1]).item()
        collect.append(output_id)
        if output_id==tokenizer.eos_token_id:
            break
        inputs['input_ids'] = torch.unsqueeze(torch.cat([inputs['input_ids'][0],torch.tensor([output_id])]),dim=0)
        inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
    return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(collect))

def generation_compare(prompt, num_tokens, tokenizer, model_A, model_B=None):
    '''Compares generations of two models. Passing just one model provides simple generation utility'''
    print()
    print(f"{'>'*20}\n\tPrompt\n{'<'*20}\n{prompt}\n\n")
    model_inputs = tokenizer(prompt, return_tensors='pt')
    print(f"{'>'*30}\n\tModel_A Generation\n{'<'*30}\n{__generate(model_A,  tokenizer, model_inputs, num_tokens)}")
    print("\n\n")
    if model_B:
        model_inputs = tokenizer(prompt, return_tensors='pt')
        print(f"{'>'*30}\n\tModel_B Generation\n{'<'*30}\n{__generate(model_B,  tokenizer, model_inputs, num_tokens)}")

def detach_metrics(metrics: dict):
    '''helper for metrics'''
    def to_cpu(x):
        if isinstance(x, torch.Tensor):
            # If scalar, return float; if vector/matrix, return list
            return x.detach().cpu().item() if x.dim() == 0 else x.detach().cpu().tolist()
        elif isinstance(x, list):
            return [to_cpu(y) for y in x]
        elif isinstance(x, dict):
            return {k: to_cpu(v) for k, v in x.items()}
        return x

    return {k: to_cpu(v) for k, v in metrics.items()}

def plot_metrics(metrics: dict, x_vals=None, suptitle="Training Metrics"):
    '''For grid plotting a collection of metrics'''
    metrics = detach_metrics(metrics)

    keys = list(metrics.keys())
    n = len(keys)
    length = len(next(iter(metrics.values())))
    if not x_vals:
        x_vals = list(range(1,length + 1))

    fig, axes = plt.subplots(1, n, figsize=(4*n, 3), constrained_layout=True)
    if n == 1:
        axes = [axes]

    palette = plt.cm.tab10.colors

    for i, (ax, key_str) in enumerate(zip(axes, keys)):
        y_vals = metrics[key_str]
        ax.plot(x_vals, y_vals, marker="o", color=palette[i % len(palette)])
        ax.set_title(key_str)
        ax.grid(True, alpha=0.3)

    fig.suptitle(suptitle)
    fig.supxlabel("Steps")
    plt.show()


class smolMoEConfig:
    vocab_size=49152
    hidden_size=576
    intermediate_size=1536
    num_hidden_layers = 30
    num_heads=9
    kv_heads=3
    num_experts = 3
    num_experts_per_tok = 1

config = smolMoEConfig

TEST_PROMPT = "Where is the Great Wall?"

######################################################################################################################
############################################### DO NOT CHANGE[END] ###################################################
######################################################################################################################

### 2. Custom Smol`MoE`LM (for BugFixes)

Bug found:

1. RotaryEmbedder -> Wrong type.
Changes:
```
-- self.freq = 1/(base ** (torch.arange(0, dim, 2, dtype=torch.int64).float()/dim))
++ self.freq = 1.0/(base ** (torch.arange(0, dim, 2, dtype=torch.float32) /dim))
```
2. MoE initialization -> The model has gate, not route.
Changes:
```
-- self.router = nn.Linear(self.D, self.k, bias=False, dtype=dtype)
++ self.gate = nn.Linear(self.D, self.E, bias=False, dtype=dtype)
```
3. MoE Expert Utilization Function -> Load Balancer Loss formula.
Changes
```
-- load = torch.mean(selected.float(), dim=(0,1))
-- self._aux_lb = self.E * load.sum()
-- self._expert_utilization = selected

++ probs = F.softmax(logits, dim=-1)
++ load = torch.mean(selected.float(), dim=(0,1))
++ density_proxy = torch.mean(probs, dim=(0,1))
++ self._aux_lb = (self.E**2) * torch.mean(density_proxy * load)
++ self._expert_utilization = selected
```
4. MoE Forward method -> not router, but gate.
Changes
```
-- logits = self.router(x)
++ logits = self.gate(x)
```
5. MoE Forward method -> Swish need to be `F.silu(a) * u` and einsum for y is `"bteh,ehd->bted"`. Changes
```
-- h = F.silu(u)
-- y = torch.einsum("bteh,ehd->bteh", h, self.down_bank)
++ h = F.silu(a) * u
++ y = torch.einsum("bteh,ehd->bted", h, self.down_bank)
```
6. RopeAttention -> `rotary_emb` redundant calling dim since it's already defined as `head_dim`. Changes
```
-- self.rotary_emb = RotaryEmbedder(base=self.rope_theta,
                                    dim=config.hidden_size//self.num_heads)
++ self.rotary_emb = RotaryEmbedder(base=self.rope_theta,
                                    dim=self.head_dim)
```
7. RopeAttention forward method -> the one rotates should be `q_states`, not `v_states`. Changes
```
-- cos, sin = self.rotary_emb(v_states)
++ cos, sin = self.rotary_emb(q_states)
```
8. smolMoeModel forward -> it should return `hidden_states` variables, not in an array list. Changes
```
-- return [hidden_states]
++ return hidden_states
```
9. smolMoeLM forward -> `hidden_states` should be taking whole outputs, not only the first index. Changes
```
-- hidden_states = outputs[0]
++ hidden_states = outputs
```
10. Missing implementation of `smolMoELM.get_expert_utilization`. Changes
```
-- lb_loss, expert_utilization_per_layer = 0, 0
-- return expert_utilization_per_layer, lb_loss

++ expert_utils = []
++ lb_losses = []
++ for layer in self.model.layers:
++     moe = getattr(layer, "moe", None)
++     if moe is not None and hasattr(moe, "_expert_utilization") and hasattr(moe, "_aux_lb"):
++        # (_expert_utilization) shape: (B,T,E) one-hot per token
++        util = moe._expert_utilization.float().mean(dim=(0,1))  # (E,)
++        expert_utils.append(util)
++        lb_losses.append(moe._aux_lb)
++ if expert_utils:
++     expert_utilization_per_layer = torch.stack(expert_utils, dim=0)  # (L,E)
++     lb_loss = torch.stack(lb_losses).mean()
++ else:
++     expert_utilization_per_layer = torch.tensor(0.)
++     lb_loss = torch.tensor(0.)
++ return expert_utilization_per_layer, lb_loss
```

With these important fixes, the model are able to generate coherent text.

```
>>>>>>>>>>>>>>>>>>>>
	Prompt
<<<<<<<<<<<<<<<<<<<<
Where is the Great Wall?


time=9.400s
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
	Model_A Generation
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


The Great Wall of China is a 13,000-mile-long wall that spans 13,000 miles from the Yellow River in Shaanxi province in China to the border with Mongolia. It is


```

and the other check of `lb_loss` also passed

```
(Expected) Load Balance Loss => 1.00
(Actual) Load Balance Loss => 1.00
```

In [32]:
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    if cos.device != q.device:
        cos = cos.to(q.device)
        sin = sin.to(q.device)
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def repeat_kv(hidden_states, n_rep):
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

class RotaryEmbedder(nn.Module):
    def __init__(self, dim, base):
        super().__init__()
        ### BUG FIX 1 ###
        self.freq = 1.0/(base ** (torch.arange(0, dim, 2, dtype=torch.float32) /dim))

    @torch.no_grad()
    def forward(self,x):
        # x : (B,H,T,D)
        pos = torch.arange(x.shape[-2], dtype=torch.long)
        angles = torch.einsum('p,f->pf', pos.float(), self.freq).unsqueeze(dim=0)
        # (B,T,dim)
        emb = torch.cat((angles, angles), dim=-1)
        return emb.cos(), emb.sin()

class MoE(nn.Module):
    """
    An MoE layer with MLP block with swiglue activation function.
    """

    def __init__(self, num_experts_per_tok: int, num_experts: int, emb_dim: int, moe_dim: int, dtype=torch.float32):
        super().__init__()
        self.k = int(num_experts_per_tok)
        self.E = int(num_experts)
        self.D = int(emb_dim)
        self.H = int(moe_dim)

        ### BUG FIX 2 ###
        # 2 wrong router output size, logits (B,T,E) but here is (B,T,k) and it was gate, not router, and bias is not available in the weights
        self.gate = nn.Linear(self.D, self.E, bias=False, dtype=dtype)
        self.gate_bank = nn.Parameter(torch.empty(self.E, self.D, self.H, dtype=dtype))
        self.up_bank   = nn.Parameter(torch.empty(self.E, self.D, self.H, dtype=dtype))
        self.down_bank = nn.Parameter(torch.empty(self.E, self.H, self.D, dtype=dtype))

    def expert_utilization(self, logits):
        """
        This function compute expert utilization per token and also compute load balancer loss.
        Details of this load balancer can be found in https://arxiv.org/abs/2101.03961
        """
        selected = torch.argmax(logits, dim=-1)
        selected = F.one_hot(selected, num_classes=self.E)

        ### BUG FIX 3 ###
        # Based on paper, intended loss in Switch-Transformer uses both the actual fraction routed and
        # probability mass assigned by the router and computes a dot product term scaled by num_experts**2
        probs = F.softmax(logits, dim=-1)

        load = torch.mean(selected.float(), dim=(0,1))
        density_proxy = torch.mean(probs, dim=(0,1))
        self._aux_lb = (self.E ** 2) * torch.mean(density_proxy * load)

        self._expert_utilization = selected

    def forward(self, x):
        B, T, D = x.shape
        assert D == self.D, f"Expected emb_dim={self.D}, got {D}"

        ### BUG FIX 4 ###
        logits = self.gate(x)

        if self.training:
            logits = logits + torch.randn_like(logits) * 1e-1

        selected = torch.argmax(logits, dim=-1)
        a = torch.einsum("btd,edh->bteh", x, self.gate_bank)
        u = torch.einsum("btd,edh->bteh", x, self.up_bank)
        ### BUG FIX 5 ###
        h = F.silu(a) * u
        ### BUG FIX 6 ###
        y = torch.einsum("bteh,ehd->bted", h, self.down_bank)

        gather_idx = selected.view(B,T,1,1).expand(-1, -1, -1, D)
        y = torch.gather(y, dim=2, index=gather_idx).squeeze(-2)

        self.expert_utilization(logits)
        return y


class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states


class RopeAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = config.hidden_size // self.num_heads
        self.kv_heads = config.kv_heads
        self.rope_theta = 10000.0

        self.W_query = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.W_key = nn.Linear(config.hidden_size, self.kv_heads * self.head_dim, bias=False)
        self.W_value = nn.Linear(config.hidden_size, self.kv_heads * self.head_dim, bias=False)
        self.W_output = nn.Linear(config.hidden_size, config.hidden_size, bias=False)

        ### BUG FIX 7 ###
        self.rotary_emb = RotaryEmbedder(base=self.rope_theta,
                                         dim=self.head_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask= None,
    ):
        b, q, _ = hidden_states.size()

        q_states = self.W_query(hidden_states)
        k_states = self.W_key(hidden_states)
        v_states = self.W_value(hidden_states)

        q_states = q_states.view(b, q, self.num_heads, self.head_dim).transpose(1, 2)
        k_states = k_states.view(b, q, self.kv_heads, self.head_dim).transpose(1, 2)
        v_states = v_states.view(b, q, self.kv_heads, self.head_dim).transpose(1, 2)

        ### BUG FIX 8 ###
        cos, sin = self.rotary_emb(q_states)
        q_states, k_states = apply_rotary_pos_emb(q_states, k_states, cos, sin)

        __kv_groups = self.num_heads // self.kv_heads

        k_states = repeat_kv(k_states, __kv_groups)
        v_states = repeat_kv(v_states, __kv_groups)

        attn_weights = torch.matmul(q_states, k_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        ### BUG FIX Attempt ###
        if attention_mask is not None:
            # Ensure mask is (B, 1, T, T) for broadcasting
            if attention_mask.dim() == 2:  # (B, T) -> (B, 1, T, T)
                attention_mask = attention_mask[:, None, :, None].expand(-1, 1, -1, attention_mask.size(-1))
            attn_weights = attn_weights + attention_mask
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_weights = nn.functional.dropout(attn_weights,p=0)


        attn_output = torch.matmul(attn_weights, v_states)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(b, q, -1)

        attn_output = self.W_output(attn_output)

        return attn_output

class LlamaDecoder(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.self_attn = RopeAttention(config)
        self.moe = MoE(num_experts=config.num_experts,
                       num_experts_per_tok=config.num_experts_per_tok,
                       emb_dim=config.hidden_size,
                       moe_dim=config.intermediate_size)
        self.pre_attn_rmsnorm = RMSNorm(config.hidden_size, eps=1e-05)
        self.pre_moe_rmsnorm = RMSNorm(config.hidden_size, eps=1e-05)

    def forward(self,hidden_states, attention_mask):
        residual = hidden_states
        hidden_states = self.pre_attn_rmsnorm(hidden_states)
        ### BUG FIX Attempt ###
        seq_len = attention_mask.shape[-1]
        attention_mask = torch.triu(torch.full((seq_len, seq_len), fill_value=float('-inf'), device = hidden_states.device),diagonal=1)
        attention_mask = attention_mask.unsqueeze(0).unsqueeze(1)

        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
        )

        hidden_states += residual
        residual = hidden_states

        hidden_states = self.pre_moe_rmsnorm(hidden_states)

        # MLP block
        hidden_states = self.moe(hidden_states)
        hidden_states += residual

        outputs = (hidden_states,)

        return outputs

class smolMoEModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_tokens = nn.Embedding(num_embeddings=config.vocab_size,
                                         embedding_dim=config.hidden_size)
        self.layers = nn.ModuleList([
            LlamaDecoder(config) for _ in range(config.num_hidden_layers)
            ])
        self.norm = RMSNorm(config.hidden_size, eps=1e-05)

    def forward(
        self,
        input_ids= None,
        attention_mask= None,
    ):
        inputs_embeds = self.embed_tokens(input_ids)
        hidden_states = inputs_embeds
        for decoder_layer in self.layers:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.norm(hidden_states)

        ### BUG FIX ###
        return hidden_states

class smolMoELM(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.model = smolMoEModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.tie_weights()

    def tie_weights(self):
        self.lm_head.weight = self.model.embed_tokens.weight

    def forward(self,input_ids,attention_mask):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        hidden_states = outputs

        logits = self.lm_head(hidden_states)
        logits = logits.float()
        return {'logits':logits}

    def get_expert_utilization(self):
        ### BUG FIX ###
        # 9 Previously returns 0,0 all the time, need to aggregate expert utilization & load balance loss across layers
        expert_utils = []
        lb_losses = []

        for layer in self.model.layers:
            moe = getattr(layer, "moe", None)
            if moe is not None and hasattr(moe, "_expert_utilization") and hasattr(moe, "_aux_lb"):
                # (_expert_utilization) shape: (B,T,E) one-hot per token
                util = moe._expert_utilization.float().mean(dim=(0,1))  # (E,)
                expert_utils.append(util)
                lb_losses.append(moe._aux_lb)

        if expert_utils:
            expert_utilization_per_layer = torch.stack(expert_utils, dim=0)  # (L,E)
            lb_loss = torch.stack(lb_losses).mean()
        else:
            expert_utilization_per_layer = torch.tensor(0.)
            lb_loss = torch.tensor(0.)

        return expert_utilization_per_layer, lb_loss

    def reset_weights_and_metrics(self):
        with torch.no_grad():
            modules = list(self.modules())[1:]
            for m in modules:
                fn = getattr(m, "reset_parameters_", None) or getattr(m, "reset_parameters", None)
                if callable(fn):
                    fn()

            for m in modules:
                if hasattr(m, "reset_parameters") or hasattr(m, "reset_parameters_"):
                    continue
                any_param = False
                for name, p in m.named_parameters(recurse=False):
                    any_param = True
                    if p.dim() == 1:
                        if name == "bias":
                            p.zero_()
                        else:
                            p.fill_(1.0)
                    else:
                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))

In [19]:
# personal testing cell

checkpoint="HuggingFaceTB/SmolLM-135M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

config = smolMoEConfig

model = smolMoELM(config)
model.load_state_dict(torch.load('trial_weights.pt'), strict=True)
model.eval()

input_ids = torch.tensor([[1,2,3,4]])
attention_mask = torch.ones(1,4)

with torch.no_grad():
  outputs = model(input_ids, attention_mask)


In [33]:
model.state_dict().keys()

odict_keys(['model.embed_tokens.weight', 'model.layers.0.self_attn.W_query.weight', 'model.layers.0.self_attn.W_key.weight', 'model.layers.0.self_attn.W_value.weight', 'model.layers.0.self_attn.W_output.weight', 'model.layers.0.moe.gate_bank', 'model.layers.0.moe.up_bank', 'model.layers.0.moe.down_bank', 'model.layers.0.moe.gate.weight', 'model.layers.0.pre_attn_rmsnorm.weight', 'model.layers.0.pre_moe_rmsnorm.weight', 'model.layers.1.self_attn.W_query.weight', 'model.layers.1.self_attn.W_key.weight', 'model.layers.1.self_attn.W_value.weight', 'model.layers.1.self_attn.W_output.weight', 'model.layers.1.moe.gate_bank', 'model.layers.1.moe.up_bank', 'model.layers.1.moe.down_bank', 'model.layers.1.moe.gate.weight', 'model.layers.1.pre_attn_rmsnorm.weight', 'model.layers.1.pre_moe_rmsnorm.weight', 'model.layers.2.self_attn.W_query.weight', 'model.layers.2.self_attn.W_key.weight', 'model.layers.2.self_attn.W_value.weight', 'model.layers.2.self_attn.W_output.weight', 'model.layers.2.moe.gate

In [34]:
util, lb_loss = model.get_expert_utilization()
print(f"Expert utilization per layer: {util.mean(dim=0)}")

Expert utilization per layer: tensor([0.2583, 0.6667, 0.0750])


In [35]:
print(tokenizer.decode(torch.argmax(outputs['logits'][0], dim=-1).tolist()))
print(f"Logits mean={outputs['logits'].mean().item()}, std={outputs['logits'].std().item()}")

ation's.ation
Logits mean=6.71024751663208, std=3.2853877544403076


In [36]:
ref_model = AutoModelForCausalLM.from_pretrained(checkpoint)
ref_model.eval()
input_ids = torch.tensor([[1, 2, 3, 4]])
attention_mask = torch.ones(1, 4)

with torch.no_grad():
    try:
        ref_outputs = ref_model(input_ids, attention_mask=attention_mask)
        print(f"Reference logits shape={ref_outputs.logits.shape}, mean={ref_outputs.logits.mean().item()}")
        # Access final hidden states
        ref_hidden_states = ref_model.model(input_ids, attention_mask=attention_mask)[0]
        ref_hidden_states = ref_model.model.norm(ref_hidden_states)
        print(f"Reference final hidden_states shape={ref_hidden_states.shape}, mean={ref_hidden_states.mean().item()}")
    except Exception as e:
        print(f"Error in reference model: {e}")
        # Fallback: Run only embedding and first layer
        ref_hidden_states = ref_model.model.embed_tokens(input_ids)
        print(f"Reference embedding mean={ref_hidden_states.mean().item()}")
        ref_hidden_states = ref_model.model.layers[0](ref_hidden_states, attention_mask=attention_mask)[0]
        print(f"Reference layer 0 mean={ref_hidden_states.mean().item()}")

Reference logits shape=torch.Size([1, 4, 49152]), mean=4.597760200500488
Reference final hidden_states shape=torch.Size([1, 4, 576]), mean=0.0934964045882225


In [37]:
from transformers import AutoConfig
ref_model_config = AutoConfig.from_pretrained(checkpoint)
print(ref_model_config)

LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "dtype": "bfloat16",
  "eos_token_id": 0,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 576,
  "initializer_range": 0.02,
  "intermediate_size": 1536,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 9,
  "num_hidden_layers": 30,
  "num_key_value_heads": 3,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": true,
  "transformers_version": "4.56.1",
  "use_cache": true,
  "vocab_size": 49152
}



In [38]:
print(ref_model.modules)

<bound method Module.modules of LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)

In [39]:
print(tokenizer.decode(300))
print(tokenizer.decode(torch.argmax(outputs['logits'][0], dim=-1).tolist()))

om
ation's.ation


In [40]:
util, lb_loss = model.get_expert_utilization()
print(f"Expert utilization per layer: {util.mean(dim=0)}")  # Should show balanced use across experts

Expert utilization per layer: tensor([0.2583, 0.6667, 0.0750])


In [41]:
from transformers import AutoModelForCausalLM
ref_model = AutoModelForCausalLM.from_pretrained(checkpoint)
ref_model.eval()
with torch.no_grad():
    ref_outputs = ref_model(input_ids, attention_mask=attention_mask)
print(ref_outputs.logits.shape, ref_outputs.logits.mean().item())

torch.Size([1, 4, 49152]) 4.597760200500488


In [42]:
print(tokenizer.decode(10))

<issue_closed>


### 3. Test

In [43]:
######################################################################################################################
############################################## DO NOT CHANGE[START] ##################################################
######################################################################################################################

# Load the Tokenizer
checkpoint="HuggingFaceTB/SmolLM-135M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

########################
#### SANITY CHECK ######
########################

# # Instantiate the model
# __test_model = smolMoELM(config)

# #💡 You expect a nonsensical/garbled output here since the weights are random
# generation_compare(
#     prompt=TEST_PROMPT,
#     tokenizer=tokenizer,
#     num_tokens=50,
#     model_A= __test_model,
# )

In [44]:
########################
###### CHECK #1 ########
########################

# Instantiate the model
__test_model = smolMoELM(config)

# Load the weights into your "fixed" implementation
__test_model.load_state_dict(torch.load('trial_weights.pt'), strict=True)


#💡 If you fixed all bugs, you will see a sensible generation here :)
generation_compare(
    prompt=TEST_PROMPT,
    tokenizer=tokenizer,
    num_tokens=50,
    model_A= __test_model,
    model_B=None
)


>>>>>>>>>>>>>>>>>>>>
	Prompt
<<<<<<<<<<<<<<<<<<<<
Where is the Great Wall?


time=9.840s
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
	Model_A Generation
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


The Great Wall of China is a 2,000-mile-long wall that spans 13,000 miles from the southern border of the Yellow River in Shandong province to the northern border of the ancestral





In [45]:
########################
###### CHECK #2 ########
########################


#💡 If you fixed all the bugs and completed the missing implementation you will match the load balancer loss that we precomputed
correct_lb_loss = torch.tensor(1.0)
_, lb_loss = __test_model.get_expert_utilization()
print(f"(Expected) Load Balance Loss => {correct_lb_loss:0.2f}")
print(f"(Actual) Load Balance Loss => {lb_loss:0.2f}")
assert torch.isclose(lb_loss, correct_lb_loss, atol=1e-2), "Load Balance Check don't match!"

######################################################################################################################
############################################### DO NOT CHANGE[END] ###################################################
######################################################################################################################

(Expected) Load Balance Loss => 1.00
(Actual) Load Balance Loss => 1.00


# **Coding Challenge Part 2: Upcycling a Dense Model into an MoE 🔄 🚴 [🚨 3 points]**


Now that we have worked through an implementation of the MoE architecture, lets look at a procedure called "Upcycling" wherein you convert a dense model into an MoE.

**Guidelines** :

You will upcycle the dense model loaded below into our MoE implementation from Part 1. No changes are required of the MoE implementation for this part.


**🚨 Reference paper:** [Sparse Upcycling: Training Mixture-of-Experts from Dense Checkpoints](https://arxiv.org/abs/2212.05055)

> *This paper introduces a method to transform pre-trained dense models into Mixture-of-Experts (MoE) models, leveraging existing weights instead of training from scratch. This "upcycling" approach selectively sparsifies the model into expert modules, enabling more efficient scaling and training while reducing computational costs. Experiments show that these upcycled MoEs can outperform both standard dense models and traditionally trained MoEs, demonstrating that dense checkpoints contain useful knowledge that can be repurposed for sparse architectures.*



### 1. Setup

In [46]:
######################################################################################################################
############################################## DO NOT CHANGE[START] ##################################################
######################################################################################################################


# Loading the Dense Model
dense_model = AutoModelForCausalLM.from_pretrained(checkpoint)

# Resetting the weights for a clean upcycle!
__test_model.reset_weights_and_metrics()

#💡 This is expected to be garbled due to resetting weights before upcycling.
generation_compare(
    prompt=TEST_PROMPT,
    tokenizer=tokenizer,
    num_tokens=50,
    model_A= dense_model,
    model_B=__test_model
)

######################################################################################################################
############################################### DO NOT CHANGE[END] ###################################################
######################################################################################################################


>>>>>>>>>>>>>>>>>>>>
	Prompt
<<<<<<<<<<<<<<<<<<<<
Where is the Great Wall?


time=5.626s
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
	Model_A Generation
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

The Great Wall of China is the longest wall in the world. It stretches over 13,000 miles and is 13,000 feet high. It is located in the northern part of China, in the country



time=9.599s
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
	Model_B Generation
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
amonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamonamon


### 2. Upcycling (for Implementation)

In [48]:
model.state_dict().keys()

odict_keys(['model.embed_tokens.weight', 'model.layers.0.self_attn.W_query.weight', 'model.layers.0.self_attn.W_key.weight', 'model.layers.0.self_attn.W_value.weight', 'model.layers.0.self_attn.W_output.weight', 'model.layers.0.moe.gate_bank', 'model.layers.0.moe.up_bank', 'model.layers.0.moe.down_bank', 'model.layers.0.moe.gate.weight', 'model.layers.0.pre_attn_rmsnorm.weight', 'model.layers.0.pre_moe_rmsnorm.weight', 'model.layers.1.self_attn.W_query.weight', 'model.layers.1.self_attn.W_key.weight', 'model.layers.1.self_attn.W_value.weight', 'model.layers.1.self_attn.W_output.weight', 'model.layers.1.moe.gate_bank', 'model.layers.1.moe.up_bank', 'model.layers.1.moe.down_bank', 'model.layers.1.moe.gate.weight', 'model.layers.1.pre_attn_rmsnorm.weight', 'model.layers.1.pre_moe_rmsnorm.weight', 'model.layers.2.self_attn.W_query.weight', 'model.layers.2.self_attn.W_key.weight', 'model.layers.2.self_attn.W_value.weight', 'model.layers.2.self_attn.W_output.weight', 'model.layers.2.moe.gate

In [50]:
######################################################################
##################### Write "Upcycling" code here ####################
######################################################################


dense_sd = dense_model.state_dict()
moe_sd = __test_model.state_dict()

print("Starting the upcycling process...")

# 1. Copy shared weights (Embeddings and Final Norm)
# These names are usually consistent
moe_sd['model.embed_tokens.weight'] = dense_sd['model.embed_tokens.weight']
moe_sd['model.norm.weight'] = dense_sd['model.norm.weight']
moe_sd['lm_head.weight'] = dense_sd['lm_head.weight']


# 2. Define the name mappings for attention and MLP layers
attention_mapping = {
    "q_proj": "W_query",
    "k_proj": "W_key",
    "v_proj": "W_value",
    "o_proj": "W_output",
}

# 3. Iterate through each layer to copy weights
for i in range(config.num_hidden_layers):
    print(f"Upcycling layer {i+1}/{config.num_hidden_layers}...")

    # --- Copy Self-Attention weights ---
    for dense_name, moe_name in attention_mapping.items():
        dense_key = f'model.layers.{i}.self_attn.{dense_name}.weight'
        moe_key = f'model.layers.{i}.self_attn.{moe_name}.weight'
        moe_sd[moe_key] = dense_sd[dense_key]

    # --- Copy LayerNorm weights ---
    moe_sd[f'model.layers.{i}.pre_attn_rmsnorm.weight'] = dense_sd[f'model.layers.{i}.input_layernorm.weight']
    moe_sd[f'model.layers.{i}.pre_moe_rmsnorm.weight'] = dense_sd[f'model.layers.{i}.post_attention_layernorm.weight']

    # --- Upcycle the dense MLP to the MoE expert banks ---

    # Correctly map and TRANSPOSE the MLP weights
    gate_weight = dense_sd[f'model.layers.{i}.mlp.gate_proj.weight'].transpose(-2, -1)
    up_weight = dense_sd[f'model.layers.{i}.mlp.up_proj.weight'].transpose(-2, -1)
    down_weight = dense_sd[f'model.layers.{i}.mlp.down_proj.weight'].transpose(-2, -1)

    moe_sd[f'model.layers.{i}.moe.gate_bank'] = gate_weight.unsqueeze(0).expand(config.num_experts, -1, -1)
    moe_sd[f'model.layers.{i}.moe.up_bank'] = up_weight.unsqueeze(0).expand(config.num_experts, -1, -1)
    moe_sd[f'model.layers.{i}.moe.down_bank'] = down_weight.unsqueeze(0).expand(config.num_experts, -1, -1)

    # --- Initialize the MoE router gate weights to zero ---
    nn.init.zeros_(moe_sd[f'model.layers.{i}.moe.gate.weight'])


# Load the newly created state dictionary into our MoE model
__test_model.load_state_dict(moe_sd)

print("\nUpcycling complete!")




Starting the upcycling process...
Upcycling layer 1/30...
Upcycling layer 2/30...
Upcycling layer 3/30...
Upcycling layer 4/30...
Upcycling layer 5/30...
Upcycling layer 6/30...
Upcycling layer 7/30...
Upcycling layer 8/30...
Upcycling layer 9/30...
Upcycling layer 10/30...
Upcycling layer 11/30...
Upcycling layer 12/30...
Upcycling layer 13/30...
Upcycling layer 14/30...
Upcycling layer 15/30...
Upcycling layer 16/30...
Upcycling layer 17/30...
Upcycling layer 18/30...
Upcycling layer 19/30...
Upcycling layer 20/30...
Upcycling layer 21/30...
Upcycling layer 22/30...
Upcycling layer 23/30...
Upcycling layer 24/30...
Upcycling layer 25/30...
Upcycling layer 26/30...
Upcycling layer 27/30...
Upcycling layer 28/30...
Upcycling layer 29/30...
Upcycling layer 30/30...

Upcycling complete!


### 3. Test

In [51]:
######################################################################################################################
############################################## DO NOT CHANGE[START] ##################################################
######################################################################################################################


#💡 If you upcycled correctly, you will output the exact same generation as the dense model!
generation_compare(
    prompt=TEST_PROMPT,
    tokenizer=tokenizer,
    num_tokens=50,
    model_A= dense_model,
    model_B=__test_model
)


######################################################################################################################
############################################### DO NOT CHANGE[END] ###################################################
######################################################################################################################


>>>>>>>>>>>>>>>>>>>>
	Prompt
<<<<<<<<<<<<<<<<<<<<
Where is the Great Wall?


time=5.754s
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
	Model_A Generation
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

The Great Wall of China is the longest wall in the world. It stretches over 13,000 miles and is 13,000 feet high. It is located in the northern part of China, in the country



time=9.340s
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
	Model_B Generation
<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

The Great Wall of China is the longest wall in the world. It stretches over 13,000 miles and is 13,000 feet high. It is located in the northern part of China, in the country


# **Coding Challenge Part 3: Continued Pretraining 📚💪 [🚨 7 points]**

**Note** :
*   For this section, make sure that the model you are using is still the same `__test_model` you upcycled in the previous section.
*   We recommend using a GPU for this section. We have provided the below settings and ensure that they run on the free T4 GPUs on Colab. Make sure you manage your free GPU usage wisely :)


Now that we have an upcycled MoE, lets continue pretraining on a small subset of data to train the expert router.

You will be required to :
* 1. Write a simple training loop (*and implement functions related to this*)
* 2. **Propose a MoE-specific metric** to track whether the MoE is actually learning as expected, implement it and provide a 2 line description of your metric in the space provided.

### 1. Setup

In [None]:
######################################################################################################################
############################################## DO NOT CHANGE[START] ##################################################
######################################################################################################################

STEPS = 100
REPORT_AFTER_N_STEPS = 10
BATCH_SIZE = 4
BF16 = True


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
__test_model.to(device) ### Note: This should be the upcycled model as a result of completing Part 2.
print(f"Using Device : {device}")

scaler_enabled = (device=="cuda" and BF16)
autocast_dtype = torch.bfloat16 if scaler_enabled else None

def build_dataset(
    dataset_id,
    subset,
    split,
    tokenizer,
    block_size,
    max_samples=1000,
    text_column="text",
    val_fraction=None,
    seed=42,
):
    ds = load_dataset(dataset_id, subset, split=split) if subset else load_dataset(dataset_id, split=split)
    ds = ds.select(range(max_samples))

    EOS = tokenizer.eos_token_id
    def tok(batch):
        out = tokenizer(batch[text_column],
                        add_special_tokens=False,
                        return_attention_mask=True)
        out["input_ids"]      = [ids + [EOS] for ids in out["input_ids"]]
        out["attention_mask"] = [m   + [1]   for m   in out["attention_mask"]]
        return {"input_ids": out["input_ids"], "attention_mask": out["attention_mask"]}

    ds = ds.map(tok, batched=True,remove_columns=[c for c in ds.column_names if c not in ("input_ids", "attention_mask")])

    def group_per_doc(batch):
        out_ids = []
        for ids in batch["input_ids"]:
            L = len(ids)
            n = (L // block_size) * block_size
            for i in range(0, n, block_size):
                out_ids.append(ids[i:i+block_size])
        return {"input_ids": out_ids, "attention_mask": [[1]*len(o) for o in out_ids]}

    ds = ds.map(group_per_doc, batched=True)

    if val_fraction and 0.0 < val_fraction < 1.0:
        ds = ds.train_test_split(test_size=val_fraction, seed=seed, shuffle=True)
        train_ds, val_ds = ds["train"], ds["test"]
        train_ds.set_format(type="torch", columns=["input_ids","attention_mask"])
        val_ds.set_format(type="torch", columns=["input_ids","attention_mask"])
        return train_ds, val_ds

    ds.set_format(type="torch", columns=["input_ids","attention_mask"])
    return ds


train_ds, val_ds = build_dataset(dataset_id="HuggingFaceTB/cosmopedia-100k",
                                 subset=None,
                                 split="train",
                                 tokenizer=tokenizer,
                                 block_size=256, # This is intentionally small number DO NOT change this number.
                                 val_fraction=0.2,   # 20% as validation
                                 max_samples=1000, # This only picks first 1000 examples from the dataset. Do NOT change this number.
                                 seed=789)

train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=(device=="cuda")
)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, pin_memory=(device=="cuda")
)


print(f"Train Dataset Batches : {len(train_loader)}")
print(f"Validation Dataset Batches : {len(val_loader)}")

######################################################################################################################
############################################### DO NOT CHANGE[END] ###################################################
######################################################################################################################

### 2. Continued Pretraining (for Implementation)

In [None]:
############################################
############# Training Settings ############
############################################
LEARNING_RATE = None


############################################
############## OPTIMIZER ###################
############################################

opt = None
scheduler = None

#### Helper Functions(🚨 4 points) :
* *causal_lm_loss*(🚨 1 point)
*  *eval_loss* (🚨 1 point)
*  *custom_moe_metric* (🚨 2 points)

In [None]:
############################################
############## HELPER FNS ##################
############################################

def causal_lm_loss(...):
    # Implement this

@torch.no_grad()
def evaluate_loss(...):
    # Implement this

@labelthis('Name This Metric')
@torch.no_grad()
def custom_moe_metric(...):
    # Implement this


####  Training loop (🚨 2 points)

In [None]:
######################################################################
############### Write "Continued Pretraining" code here ##############
######################################################################

moe_metric = custom_moe_metric(__test_model)
print(f"[Before Training : Sanity Check] {custom_moe_metric.label}: {moe_metric:.1f}%\n")

t0 = time.time()

loss = None

training_metrics = {'Train Loss': [], 'Eval Loss': [], 'Load Balancing Loss': []}
moe_metrics = {custom_moe_metric.label: []}

for step in range(1, STEPS+1):

   #########################################################
   ############## Eval/Reporting Section ###################
   #########################################################
    if step % REPORT_AFTER_N_STEPS == 0:
        val_loss = ...

        training_metrics['Train Loss'].append()
        training_metrics['Eval Loss'].append()
        training_metrics['Load Balancing Loss'].append()

        moe_metric = custom_moe_metric(__test_model)
        metrics[custom_moe_metric.label].append(moe_metric)

        time_taken= (time.time()-t0)
        # KEEP THE SAME FORMATTING
        print(f"Step {step}/{STEPS} | Train Loss: {...:.3f} | Eval Loss: {...:.3f} | LB Loss: {...:.3f} | Time Taken: {pretty_dt(time_taken)}")
        print("***"*30)
        t0 = time.time()

   ###################################################
   ############## Training Section ###################
   ###################################################



### 3. Test

In [None]:
######################################################################################################################
############################################## DO NOT CHANGE[START] ##################################################
######################################################################################################################

# Verify plots
x_vals = [REPORT_AFTER_N_STEPS * i for i in range(1, len(training_metrics['Train Loss'])+1)]
plot_metrics(training_metrics, x_vals=x_vals, suptitle="Training Metrics")

#### Plot MoE Metric with Explanation(🚨 1 points)

In [None]:
######################################################################################################################
############################################## PLOT YOUR CUSTOM MOE METRICS ##################################################
######################################################################################################################

def plot_custom_metric(metrics: dict, suptitle=None):
    fig.suptitle(suptitle)
    plt.show()

plot_custom_metric(moe_metrics, suptitle="SOME TITLE HERE")

### Why I chose `CUSTOM MOE METRIC`
....

In [None]:
#💡 Verify that the model didn't collapse and can still generate coherent text.
#   You dont expect this to be the same as the dense model, but should still be coherent
__test_model.to('cpu')
__test_model.eval()

generation_compare(
    prompt=TEST_PROMPT,
    tokenizer=tokenizer,
    num_tokens=50,
    model_A= dense_model,
    model_B=__test_model
)

######################################################################################################################
############################################### DO NOT CHANGE[END] ###################################################
######################################################################################################################

# **Coding Challenge Part 4:  Exploring The Unknown 🧙 ✨ [🚨 5 points]**

In this part, you can choose any one of the provided questions below.

Both questions are open-ended, and there is no one single solution -- you can follow any paper you find related to the question you picked and also you can be fully creative.

We want to see how you will approach the problem and how you will show that your approach is working.    

1. **Make training more efficient with dataset intervention:** Now you can process the whole dataset ([cosmopedia-100k](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia-100k), but you can only sample the same number of examples (1000). How would you modify/filter the original dataset for making the training more efficient?

2. **Explore methods to increase expert specilization for given datasets:** You are given these 3 datasets inside [Nemotron-Post-Training-Dataset (SFT partition)](https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset): [chat, math and code subsets](https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset#filtering-the-data); develop training methods/pipelines that increase expert specialization for each data. (Each expert will focus on one of these datasets rathen than distributing uniformly.)

**NOTE:** If your MoE implementation does not work, you can pick the 1 question and show the effectiness of your method on dense model training.
