Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def update(

@register_attention("mha")
class AttentionMHA(Attention):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
def __init__(self, args: ModelArgs, layer_id: int, rope: Optional[Rope]):
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
Expand Down Expand Up @@ -412,7 +412,8 @@ def forward(
k = self.k_norm_fn(k)

# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
if self.rope:
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
Expand Down
4 changes: 3 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"qwen3_4b",
"phi_4_mini",
"smollm2",
"smollm3",
]
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
HUGGING_FACE_REPO_IDS = {
Expand All @@ -118,6 +119,7 @@
"qwen3_0_6b": "Qwen/Qwen3-0.6B",
"qwen3_1_7b": "Qwen/Qwen3-1.7B",
"qwen3_4b": "Qwen/Qwen3-4B",
"smollm3": "HuggingFaceTB/SmolLM3-3B",
}


Expand Down Expand Up @@ -605,7 +607,7 @@ def export_llama(
from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21]
convert_weights,
)
elif model_name == "smollm2":
elif model_name == "smollm2" or model_name == "smollm3":
from executorch.examples.models.smollm2 import ( # pyre-ignore[21]
convert_weights,
)
Expand Down
5 changes: 4 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
layers = torch.nn.ModuleList()
cls = ATTENTION_REGISTRY[model_args.attention_type]
for layer_id in range(model_args.n_layers):
attention = cls(model_args, layer_id, rope)
if model_args.no_rope_layer_interval and (layer_id + 1) % model_args.no_rope_layer_interval == 0:
attention = cls(model_args, layer_id, None)
else:
attention = cls(model_args, layer_id, rope)
transformer_block = TransformerBlock(model_args, attention)
layers.append(transformer_block)

Expand Down
1 change: 1 addition & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class ModelArgs:
use_qk_norm: bool = False # apply normalization to q and k in the attention
qk_norm_before_rope: bool = False # when to apply qk norm
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
no_rope_layer_interval: Optional[int] = None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795).
partial_rotary_factor: float = 1.0
rope_theta: Optional[float] = (
None # The official name to override self.rope_freq_base.
Expand Down
84 changes: 56 additions & 28 deletions examples/models/smollm2/convert_weights.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import argparse
import json
import os
from typing import Dict

from safetensors.torch import load_file
import torch

from torchtune.models.convert_weights import get_mapped_key

from torchtune.training import FullModelHFCheckpointer

# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
_SMOLLM_FROM_META = {
"tok_embeddings.weight": "tok_embeddings.weight",
"norm.weight": "norm.scale",
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",

_SMOLLM_TO_META = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.norm.weight": "norm.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
}


def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def smollm_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
Expand All @@ -36,9 +37,8 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _SMOLLM_FROM_META.items()}
for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
new_key = get_mapped_key(key, _SMOLLM_TO_META)
converted_state_dict[new_key] = value
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
Expand All @@ -47,19 +47,47 @@ def smollm_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
return converted_state_dict


def convert_weights(input_dir: str, output_file: str) -> None:
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=input_dir,
checkpoint_files=["model.safetensors"],
output_dir=".",
model_type="LLAMA3",
)
def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()
sd = load_checkpoint(input_dir)
print("Converting checkpoint...")
sd = smollm_tune_to_meta(sd["model"])
breakpoint()
sd = smollm_to_meta(sd)
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")
Expand Down
15 changes: 15 additions & 0 deletions examples/models/smollm3/config/params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"dim": 2048,
"ffn_dim_multiplier": 1,
"hidden_dim": 11008,
"n_heads": 16,
"n_kv_heads": 4,
"n_layers": 36,
"norm_eps": 1e-06,
"rope_theta": 5000000.0,
"use_scaled_rope": false,
"vocab_size": 128256,
"use_hf_rope": true,
"no_rope_layer_interval": 4,
"attention_qkv_bias": false
}
16 changes: 16 additions & 0 deletions examples/models/smollm3/config/smollm3_xnnpack_q8da4w.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
base:
model_class: smollm3
metadata: '{"get_eos_ids":[128012]}'

model:
use_kv_cache: True
use_sdpa_with_kv_cache: True
dtype_override: fp32

quantization:
qmode: 8da4w

backend:
xnnpack:
enabled: True
extended_ops: True
21 changes: 21 additions & 0 deletions examples/models/smollm3/config/smollm3_xnnpack_q8da4w_qe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
base:
model_class: smollm3
metadata: '{"get_eos_ids":[128012]}'

export:
output_name: smollm3_xnnpack_q8da4w_64_qe8_0_norope.pte

model:
use_kv_cache: True
use_sdpa_with_kv_cache: True
dtype_override: fp32

quantization:
qmode: 8da4w
group_size: 64
embedding_quantize: 8,0

backend:
xnnpack:
enabled: True
extended_ops: True
1 change: 1 addition & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ModelType(str, Enum):
qwen3_4b = "qwen3_4b"
phi_4_mini = "phi_4_mini"
smollm2 = "smollm2"
smollm3 = "smollm3"


class PreqMode(str, Enum):
Expand Down
Loading