From 898552932f681d39222b23203d53845ddff9bd63 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 30 Jul 2025 12:44:26 -0700 Subject: [PATCH] Add smollm3 to ET --- examples/models/llama/attention.py | 5 +- examples/models/llama/export_llama_lib.py | 4 +- examples/models/llama/llama_transformer.py | 5 +- examples/models/llama/model_args.py | 1 + examples/models/smollm2/convert_weights.py | 84 ++++++++++++------- examples/models/smollm3/config/params.json | 15 ++++ .../config/smollm3_xnnpack_q8da4w.yaml | 16 ++++ .../config/smollm3_xnnpack_q8da4w_qe.yaml | 21 +++++ extension/llm/export/config/llm_config.py | 1 + 9 files changed, 120 insertions(+), 32 deletions(-) create mode 100644 examples/models/smollm3/config/params.json create mode 100644 examples/models/smollm3/config/smollm3_xnnpack_q8da4w.yaml create mode 100644 examples/models/smollm3/config/smollm3_xnnpack_q8da4w_qe.yaml diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index aa53b330837..bef21da7ed5 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -325,7 +325,7 @@ def update( @register_attention("mha") class AttentionMHA(Attention): - def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): + def __init__(self, args: ModelArgs, layer_id: int, rope: Optional[Rope]): super().__init__() self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads @@ -412,7 +412,8 @@ def forward( k = self.k_norm_fn(k) # RoPE relative positional embeddings - q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + if self.rope: + q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 39f5f2ec0cd..8ff4ca3eeaa 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -109,6 +109,7 @@ "qwen3_4b", "phi_4_mini", "smollm2", + "smollm3", ] TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] HUGGING_FACE_REPO_IDS = { @@ -118,6 +119,7 @@ "qwen3_0_6b": "Qwen/Qwen3-0.6B", "qwen3_1_7b": "Qwen/Qwen3-1.7B", "qwen3_4b": "Qwen/Qwen3-4B", + "smollm3": "HuggingFaceTB/SmolLM3-3B", } @@ -605,7 +607,7 @@ def export_llama( from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21] convert_weights, ) - elif model_name == "smollm2": + elif model_name == "smollm2" or model_name == "smollm3": from executorch.examples.models.smollm2 import ( # pyre-ignore[21] convert_weights, ) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index a53e1716375..2ee79cb829b 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -255,7 +255,10 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: layers = torch.nn.ModuleList() cls = ATTENTION_REGISTRY[model_args.attention_type] for layer_id in range(model_args.n_layers): - attention = cls(model_args, layer_id, rope) + if model_args.no_rope_layer_interval and (layer_id + 1) % model_args.no_rope_layer_interval == 0: + attention = cls(model_args, layer_id, None) + else: + attention = cls(model_args, layer_id, rope) transformer_block = TransformerBlock(model_args, attention) layers.append(transformer_block) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 5734cd66ef7..4920b6ee237 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -40,6 +40,7 @@ class ModelArgs: use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm use_hf_rope: bool = False # Use HuggingFace's RoPE implementation + no_rope_layer_interval: Optional[int] = None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795). partial_rotary_factor: float = 1.0 rope_theta: Optional[float] = ( None # The official name to override self.rope_freq_base. diff --git a/examples/models/smollm2/convert_weights.py b/examples/models/smollm2/convert_weights.py index 59b83d3e3a3..2b70c147f36 100644 --- a/examples/models/smollm2/convert_weights.py +++ b/examples/models/smollm2/convert_weights.py @@ -1,29 +1,30 @@ import argparse +import json +import os from typing import Dict +from safetensors.torch import load_file import torch from torchtune.models.convert_weights import get_mapped_key -from torchtune.training import FullModelHFCheckpointer - -# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. -_SMOLLM_FROM_META = { - "tok_embeddings.weight": "tok_embeddings.weight", - "norm.weight": "norm.scale", - "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", - "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", - "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", - "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", - "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", - "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", - "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", - "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", - "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", + +_SMOLLM_TO_META = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", } -def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: +def smollm_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Convert a state dict from torchtune's format to Meta's format. This function doesn't handle any sharding or splitting of state dicts. It follows the @@ -36,9 +37,8 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. Dict[str, torch.Tensor]: State dict in Meta's format. """ converted_state_dict = {} - inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()} for key, value in state_dict.items(): - new_key = get_mapped_key(key, inverted_mapping_dict) + new_key = get_mapped_key(key, _SMOLLM_TO_META) converted_state_dict[new_key] = value converted_state_dict["output.weight"] = converted_state_dict[ "tok_embeddings.weight" @@ -47,19 +47,47 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. return converted_state_dict -def convert_weights(input_dir: str, output_file: str) -> None: - # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. - checkpointer = FullModelHFCheckpointer( - checkpoint_dir=input_dir, - checkpoint_files=["model.safetensors"], - output_dir=".", - model_type="LLAMA3", - ) +def load_checkpoint_from_safetensors(input_dir: str) -> Dict: + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + # Sharded checkpoint. + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + # Load all the shards into memory + shard_to_weights = {} + for shard in checkpoint_shards: + shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) + + # Merge tensors into consolidated state dict. + merged_state_dict = {} + for weight_name, shard in weight_map.items(): + tensor = shard_to_weights[shard][weight_name] + merged_state_dict[weight_name] = tensor + return merged_state_dict + else: + # Single checkpoint. + state_dict = load_file(os.path.join(input_dir, "model.safetensors")) + return state_dict + + +def load_checkpoint(input_dir: str) -> Dict: + pytorch_path = os.path.join(input_dir, "pytorch_model.bin") + if os.path.exists(pytorch_path): + print("Loading checkpoint from PyTorch .bin file") + return torch.load(pytorch_path, map_location="cpu", weights_only=True) + print("Loading checkpoint from safetensors directory") + return load_checkpoint_from_safetensors(input_dir) + +def convert_weights(input_dir: str, output_file: str) -> None: print("Loading checkpoint...") - sd = checkpointer.load_checkpoint() + sd = load_checkpoint(input_dir) print("Converting checkpoint...") - sd = smollm_tune_to_meta(sd["model"]) + breakpoint() + sd = smollm_to_meta(sd) print("Saving checkpoint...") torch.save(sd, output_file) print("Done.") diff --git a/examples/models/smollm3/config/params.json b/examples/models/smollm3/config/params.json new file mode 100644 index 00000000000..fb228b25f1e --- /dev/null +++ b/examples/models/smollm3/config/params.json @@ -0,0 +1,15 @@ +{ + "dim": 2048, + "ffn_dim_multiplier": 1, + "hidden_dim": 11008, + "n_heads": 16, + "n_kv_heads": 4, + "n_layers": 36, + "norm_eps": 1e-06, + "rope_theta": 5000000.0, + "use_scaled_rope": false, + "vocab_size": 128256, + "use_hf_rope": true, + "no_rope_layer_interval": 4, + "attention_qkv_bias": false +} diff --git a/examples/models/smollm3/config/smollm3_xnnpack_q8da4w.yaml b/examples/models/smollm3/config/smollm3_xnnpack_q8da4w.yaml new file mode 100644 index 00000000000..c3d0cbf2646 --- /dev/null +++ b/examples/models/smollm3/config/smollm3_xnnpack_q8da4w.yaml @@ -0,0 +1,16 @@ +base: + model_class: smollm3 + metadata: '{"get_eos_ids":[128012]}' + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp32 + +quantization: + qmode: 8da4w + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/examples/models/smollm3/config/smollm3_xnnpack_q8da4w_qe.yaml b/examples/models/smollm3/config/smollm3_xnnpack_q8da4w_qe.yaml new file mode 100644 index 00000000000..b0d973b2894 --- /dev/null +++ b/examples/models/smollm3/config/smollm3_xnnpack_q8da4w_qe.yaml @@ -0,0 +1,21 @@ +base: + model_class: smollm3 + metadata: '{"get_eos_ids":[128012]}' + +export: + output_name: smollm3_xnnpack_q8da4w_64_qe8_0_norope.pte + +model: + use_kv_cache: True + use_sdpa_with_kv_cache: True + dtype_override: fp32 + +quantization: + qmode: 8da4w + group_size: 64 + embedding_quantize: 8,0 + +backend: + xnnpack: + enabled: True + extended_ops: True diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 94bbb2d8b2e..11aa0e276e2 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -43,6 +43,7 @@ class ModelType(str, Enum): qwen3_4b = "qwen3_4b" phi_4_mini = "phi_4_mini" smollm2 = "smollm2" + smollm3 = "smollm3" class PreqMode(str, Enum):