In [1]:
from accelerate import init_empty_weights
from transformers import AutoConfig

from fusion_bench.models.modeling_s2_moe_llama import (
    S2MoELlamaConfig,
    S2MoELlamaForCausalLM,
)

MODEL_PATH = "meta-llama/Llama-3.2-3B-Instruct"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = AutoConfig.from_pretrained(MODEL_PATH)
config = S2MoELlamaConfig(
    num_experts_per_tok=1,
    rank_of_router=16,
    num_local_experts=5,
    use_sparse_expert=True,  # TODO
    sparsity_ratio=0.8,  # ?
    **config.to_dict(),
)

In [6]:
with init_empty_weights():
    model = S2MoELlamaForCausalLM(config)

model.to_empty(device="cpu")

S2MoELlamaForCausalLM(
  (model): S2MoELlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x S2MoELlamaDecoderLayer(
        (self_attn): S2MoELlamaAttention(
          (q_proj): SingularMoELinear(in_features=3072, out_features=3072, num_experts=5, top_k=1, rank_of_router=16, )
          (k_proj): SingularMoELinear(in_features=3072, out_features=1024, num_experts=5, top_k=1, rank_of_router=16, )
          (v_proj): SingularMoELinear(in_features=3072, out_features=1024, num_experts=5, top_k=1, rank_of_router=16, )
          (o_proj): SingularMoELinear(in_features=3072, out_features=3072, num_experts=5, top_k=1, rank_of_router=16, )
        )
        (mlp): S2MoELlamaMLP(
          (gate_proj): SingularMoELinear(in_features=3072, out_features=8192, num_experts=5, top_k=1, rank_of_router=16, )
          (up_proj): SingularMoELinear(in_features=3072, out_features=8192, num_experts=5, top_k=1, rank_of_router=16, )
          (down_proj): Singula

In [7]:
model.save_pretrained("outputs/s2_moe_llama")

In [8]:
model.lm_head.weight

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)

In [9]:
model = S2MoELlamaForCausalLM.from_pretrained("outputs/s2_moe_llama")

Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 41.26it/s]


In [10]:
model.lm_head.weight

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)