<a href="https://colab.research.google.com/github/swaekaa/llm_quant_sense/blob/master/Copy_of_moe_layer_sensitivity_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Layer-wise Quantization Sensitivity in Mixture-of-Experts Models

## Goal
'''Analyze the sensitivity of router (gating) layers versus expert layers in Mixture-of-Experts (MoE) models
under simulated sub-4-bit quantization.

## Hypothesis
Quantization noise in MoE routing networks causes disproportionate performance degradation due to
discrete expert selection instability, unlike dense transformers where degradation is continuous.

## Constraints
- Inference only
- No fine-tuning
- Consumer hardware'''


'Analyze the sensitivity of router (gating) layers versus expert layers in Mixture-of-Experts (MoE) models\nunder simulated sub-4-bit quantization.\n\n## Hypothesis\nQuantization noise in MoE routing networks causes disproportionate performance degradation due to\ndiscrete expert selection instability, unlike dense transformers where degradation is continuous.\n\n## Constraints\n- Inference only\n- No fine-tuning\n- Consumer hardware'

In [None]:
import torch
import numpy as np
import random
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import matplotlib.pyplot as plt
from tqdm import tqdm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = False


Device: cpu


In [None]:
from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

TEXT_SAMPLES = 256
texts = dataset["text"][:TEXT_SAMPLES]
texts = [t for t in texts if len(t.strip()) > 0]

print("Total loaded texts:", len(texts))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Total loaded texts: 157


In [None]:
MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)

model.eval()


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (down_proj): Linear(in_features=1408, out_features=2048, bias=False)
        

In [None]:
texts_small = texts[:64]   # 32–64 is ideal
print("Eval samples:", len(texts_small))


Eval samples: 64


In [None]:
print("texts exists:", "texts" in globals())
print("texts_small exists:", "texts_small" in globals())
print("model exists:", "model" in globals())
print("baseline_nll exists:", "baseline_nll" in globals())


texts exists: True
texts_small exists: True
model exists: True
baseline_nll exists: False


In [None]:
eval_prompts = [
    "The capital of France is",
    "Machine learning models are trained by",
    "In physics, energy is defined as",
    "The movie was absolutely",
    "A good restaurant should",
    "The theory of relativity states that",
    "Artificial intelligence systems can",
    "The book was interesting because",
    "In mathematics, a prime number is",
    "The weather today is",
    "The government announced that",
    "Neural networks learn by adjusting",
    "The experiment failed due to",
    "The purpose of education is",
    "The company reported earnings of",
    "In biology, cells are",
    "The main character decided to",
    "Economic growth depends on",
    "The scientist discovered that",
    "The product was disappointing because"
]


In [None]:
def compute_prompt_nll(model, prompts, max_length=32):
    total_loss = 0.0

    for p in prompts:
        enc = tokenizer(
            p,
            return_tensors="pt",
            truncation=True,
            max_length=max_length
        ).to(next(model.parameters()).device)

        with torch.no_grad():
            out = model(**enc, labels=enc["input_ids"])
            total_loss += out.loss.item()

    return total_loss / len(prompts)


In [None]:
baseline_nll = compute_prompt_nll(model, eval_prompts)
print("Baseline prompt NLL:", baseline_nll)


Baseline prompt NLL: 3.888915538787842


In [None]:
# Inspect model structure to find router/gate layers
for name, module in model.named_modules():
    lname = name.lower()
    if "router" in lname or "gate" in lname:
        print(name, "->", type(module))


model.layers.0.mlp.gate -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.0.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.1.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.2.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.3.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.4.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.5.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.6.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.7.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.8.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.9.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp.experts.10.gate_proj -> <class 'torch.nn.modules.linear.Linear'>
model.layers.0.mlp

In [None]:
router_layers = []
expert_layers = []

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        if name.endswith(".mlp.gate"):
            router_layers.append(name)
        elif ".mlp.experts." in name:
            expert_layers.append(name)

print("Number of router layers:", len(router_layers))
print("Number of expert layers:", len(expert_layers))
print("\nExample router layers:")
print(router_layers[:5])
print("\nExample expert layers:")
print(expert_layers[:5])


Number of router layers: 24
Number of expert layers: 4320

Example router layers:
['model.layers.0.mlp.gate', 'model.layers.1.mlp.gate', 'model.layers.2.mlp.gate', 'model.layers.3.mlp.gate', 'model.layers.4.mlp.gate']

Example expert layers:
['model.layers.0.mlp.experts.0.gate_proj', 'model.layers.0.mlp.experts.0.up_proj', 'model.layers.0.mlp.experts.0.down_proj', 'model.layers.0.mlp.experts.1.gate_proj', 'model.layers.0.mlp.experts.1.up_proj']


In [None]:
def force_module_to_device(module, device):
    for param in module.parameters():
        if param.device != device:
            param.data = param.data.to(device)


In [None]:
def quantize_router_safe(router_module, bits=2):
    force_module_to_device(router_module, torch.device("cuda"))
    backup = router_module.weight.data.clone()
    router_module.weight.data = simulated_quantize_weight(
        router_module.weight.data, bits
    )
    return backup

def restore_router_safe(router_module, backup):
    router_module.weight.data = backup


In [None]:
model.config.output_router_logits = True

def get_router_logits(model, prompt, max_length=24):
    enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
    enc = enc.to(next(model.parameters()).device)

    with torch.no_grad():
        out = model(**enc, output_router_logits=True)

    return out.router_logits


In [None]:
import torch.nn.functional as F

def routing_entropy(logits):
    p = F.softmax(logits, dim=-1)
    return -(p * torch.log(p + 1e-9)).sum(dim=-1).mean().item()


In [None]:
def get_module_by_name(model, name):
    module = model
    for attr in name.split("."):
        module = getattr(module, attr)
    return module


In [None]:
def simulated_quantize_weight(w, bits=2):
    qmin = 0
    qmax = (1 << bits) - 1

    min_w = w.min()
    max_w = w.max()

    if (max_w - min_w) < 1e-8:
        return w.clone()

    scale = (max_w - min_w) / qmax
    q = torch.round((w - min_w) / scale)
    q = torch.clamp(q, qmin, qmax)

    return q * scale + min_w


In [None]:
def quantize_router_safe(router_module, bits=2):
    backup = router_module.weight.data.clone()
    router_module.weight.data = simulated_quantize_weight(
        router_module.weight.data, bits
    )
    return backup


def restore_router_safe(router_module, backup):
    router_module.weight.data = backup


In [None]:
def get_router_logits(model, prompt, max_length=24):
    device = next(model.parameters()).device

    enc = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=max_length
    ).to(device)

    with torch.no_grad():
        out = model(**enc, output_router_logits=True)

    return out.router_logits


In [None]:
import torch.nn.functional as F

def routing_entropy(logits):
    p = F.softmax(logits, dim=-1)
    return -(p * torch.log(p + 1e-9)).sum(dim=-1).mean().item()


In [None]:
model.config.use_cache = False
model.config.output_router_logits = True


In [None]:
def is_meta_tensor(t):
    return getattr(t, "is_meta", False)

In [None]:
router_inputs = {}

def capture_router_inputs(module, inp, out, name):
    # inp[0] is the hidden state entering the router
    router_inputs[name] = inp[0].detach().cpu()

hooks = []

for name in router_layers:
    router = get_module_by_name(model, name)
    hooks.append(
        router.register_forward_hook(
            lambda m, i, o, n=name: capture_router_inputs(m, i, o, n)
        )
    )

# Single forward pass
_ = model(**tokenizer("The meaning of life is", return_tensors="pt").to(model.device))

# Remove hooks immediately
for h in hooks:
    h.remove()

print("Captured router inputs:", len(router_inputs))


Captured router inputs: 24


In [None]:
def simulated_quantize_weight(W, bits=2):
    qmin = -(2 ** (bits - 1))
    qmax = (2 ** (bits - 1)) - 1
    scale = W.abs().max() / qmax + 1e-8
    return (W / scale).round().clamp(qmin, qmax) * scale

def router_logits(router, hidden):
    return hidden @ router.weight.T


In [None]:
results = []

for name in list(router_inputs.keys())[:6]:
    router = get_module_by_name(model, name)
    hidden = router_inputs[name]

    if hidden.is_meta:
        continue

    # Baseline routing
    base_logits = hidden @ router.weight.T
    base_top = base_logits.argmax(dim=-1)
    base_ent = routing_entropy(base_logits)

    # Quantizing router weights
    Wq = simulated_quantize_weight(router.weight.data, bits=2)

    quant_logits = hidden @ Wq.T
    q_top = quant_logits.argmax(dim=-1)
    q_ent = routing_entropy(quant_logits)

    flip_rate = (base_top != q_top).float().mean().item()

    results.append({
        "layer": name,
        "flip_rate": flip_rate,
        "entropy_drop": q_ent - base_ent
    })

    print(
        name,
        "| flip_rate:", round(flip_rate, 3),
        "| Δentropy:", round(q_ent - base_ent, 3)
    )


model.layers.0.mlp.gate | flip_rate: 0.6 | Δentropy: 0.014
model.layers.1.mlp.gate | flip_rate: 0.6 | Δentropy: -0.006
model.layers.2.mlp.gate | flip_rate: 0.8 | Δentropy: -0.465
model.layers.3.mlp.gate | flip_rate: 0.8 | Δentropy: 0.109
model.layers.4.mlp.gate | flip_rate: 1.0 | Δentropy: 0.58
model.layers.5.mlp.gate | flip_rate: 1.0 | Δentropy: 0.766


In [None]:
results = []

for name in list(router_inputs.keys())[:6]:
    router = get_module_by_name(model, name)
    hidden = router_inputs[name]

    if hidden.is_meta:
        continue

    # Baseline
    base_logits = hidden @ router.weight.T
    base_top = base_logits.argmax(dim=-1)
    base_ent = routing_entropy(base_logits)

    # for 2-bit
    Wq2 = simulated_quantize_weight(router.weight.data, bits=2)
    q2_logits = hidden @ Wq2.T
    q2_top = q2_logits.argmax(dim=-1)
    ent_2 = routing_entropy(q2_logits)
    flip_2 = (base_top != q2_top).float().mean().item()

    # for 4-bit -
    Wq4 = simulated_quantize_weight(router.weight.data, bits=4)
    q4_logits = hidden @ Wq4.T
    q4_top = q4_logits.argmax(dim=-1)
    ent_4 = routing_entropy(q4_logits)
    flip_4 = (base_top != q4_top).float().mean().item()

    results.append({
        "layer": name,
        "flip_2b": flip_2,
        "flip_4b": flip_4,
        "Δentropy_2b": ent_2 - base_ent,
        "Δentropy_4b": ent_4 - base_ent
    })

    print(
        name,
        "| flip@2b:", round(flip_2, 2),
        "| ΔH@2b:", round(ent_2 - base_ent, 3),
        "| flip@4b:", round(flip_4, 2),
        "| ΔH@4b:", round(ent_4 - base_ent, 3),
    )


model.layers.0.mlp.gate | flip@2b: 0.6 | ΔH@2b: 0.014 | flip@4b: 0.2 | ΔH@4b: 0.01
model.layers.1.mlp.gate | flip@2b: 0.6 | ΔH@2b: -0.006 | flip@4b: 0.0 | ΔH@4b: -0.012
model.layers.2.mlp.gate | flip@2b: 0.8 | ΔH@2b: -0.465 | flip@4b: 0.6 | ΔH@4b: 0.027
model.layers.3.mlp.gate | flip@2b: 0.8 | ΔH@2b: 0.109 | flip@4b: 0.2 | ΔH@4b: -0.023
model.layers.4.mlp.gate | flip@2b: 1.0 | ΔH@2b: 0.58 | flip@4b: 0.4 | ΔH@4b: 0.129
model.layers.5.mlp.gate | flip@2b: 1.0 | ΔH@2b: 0.766 | flip@4b: 0.0 | ΔH@4b: 0.148


In [None]:
import random
random.seed(42)

# 3 random samples with seed 42
sample_experts = random.sample(expert_layers, 3)
print(sample_experts)


['model.layers.5.mlp.experts.4.gate_proj', 'model.layers.1.mlp.experts.8.gate_proj', 'model.layers.12.mlp.experts.31.gate_proj']


In [None]:
expert_results = []

for expert_name in sample_experts:
    expert = get_module_by_name(model, expert_name)

    # skipping meta
    if expert.weight.is_meta:
        print(expert_name, "| skipped (meta)")
        continue

    # baseline routing
    base_logits = router_inputs[list(router_inputs.keys())[0]] @ \
                  get_module_by_name(model, router_layers[0]).weight.T
    base_top = base_logits.argmax(dim=-1)

    # quantizing expert weights
    Wq = simulated_quantize_weight(expert.weight.data, bits=2)
    backup = expert.weight.data.clone()
    expert.weight.data = Wq

    # routing after expert quantization
    quant_logits = router_inputs[list(router_inputs.keys())[0]] @ \
                   get_module_by_name(model, router_layers[0]).weight.T
    quant_top = quant_logits.argmax(dim=-1)

    # restoring expert
    expert.weight.data = backup

    flip_rate = (base_top != quant_top).float().mean().item()

    expert_results.append({
        "expert": expert_name,
        "flip_rate": flip_rate
    })

    print(expert_name, "| routing flip rate:", flip_rate)


model.layers.5.mlp.experts.4.gate_proj | routing flip rate: 0.0
model.layers.1.mlp.experts.8.gate_proj | routing flip rate: 0.0
model.layers.12.mlp.experts.31.gate_proj | skipped (meta)
