From fc07cc20b4cd835b9a70bb083c48b4b06e82b999 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:58:17 -0800 Subject: [PATCH 01/19] Add qwen 2.5 --- examples/models/llama/attention.py | 7 ++- examples/models/llama/model.py | 3 +- examples/models/llama/rope.py | 53 ++++++++++++++-- examples/models/qwen2_5/convert_weights.py | 73 ++++++++++++++++++++++ 4 files changed, 128 insertions(+), 8 deletions(-) create mode 100644 examples/models/qwen2_5/convert_weights.py diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 91168a388d3..3d127e47f3d 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -175,9 +175,10 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.max_batch_size = args.max_batch_size self.max_context_len = args.max_context_len self.dim = args.dim - self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + # TODO: parametrize bias for attention and feedforward. + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=True) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=True) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=True) self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 90582af4856..ac9d30c7e1b 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -150,6 +150,7 @@ def __init__(self, **kwargs): input_prune_map=input_prune_map, output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, + use_hf_rope=True, **params, ) @@ -170,7 +171,7 @@ def __init__(self, **kwargs): # Within the device="meta" context, tensors that are created do not carry data. # They possess all other metadata a tensor carries such as size, stride, requires_grad. - with torch.device("meta"): + with torch.device("cpu"): self.model_ = Transformer(model_args) if "int8" in str(checkpoint_path): diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 01352f404df..caca5907a0c 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -114,6 +114,7 @@ def apply_rotary_emb_to_k( return xk_out.type_as(xk) +# Wrap apply_rotary_emb in a module to enable it to be module swapped out. class RotaryEmbedding(torch.nn.Module): def __init__(self): super().__init__() @@ -209,18 +210,66 @@ def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1): return k_embed +# ======================= Qwen2 Implementation ======================== + + +def qwen_precompute_freqs_cis(dim: int, end: int, theta: float = 1_000_000.0): + """ + Precompute frequency tensor for Qwen2-style RoPE. + """ + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) + ) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin + + +def qwen_apply_rotary_emb( + q: torch.Tensor, k: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply Qwen2-style RoPE to query and key tensors. + """ + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + # Reshape cos and sin for broadcasting + cos = freqs_cos.unsqueeze(1) # [seq_len, 1, head_dim] + sin = freqs_sin.unsqueeze(1) # [seq_len, 1, head_dim] + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params + + # Choose the appropriate RoPE implementation if self.params.use_hf_rope: self.precompute_freqs_cis = hf_precompute_freqs_cis + self.apply_rotary_emb = hf_apply_rotary_emb + # elif self.params.use_qwen_rope: + # self.precompute_freqs_cis = qwen_precompute_freqs_cis + # self.apply_rotary_emb = qwen_apply_rotary_emb else: self.precompute_freqs_cis = partial( precompute_freqs_cis, use_scaled=self.params.use_scaled_rope, scale_factor=self.params.rope_scale_factor, ) + self.apply_rotary_emb = RotaryEmbedding() + + # Precompute frequencies freqs_cos, freqs_sin = self.precompute_freqs_cis( self.params.head_dim, ( @@ -232,10 +281,6 @@ def __init__(self, params: ModelArgs): ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) - if self.params.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = RotaryEmbedding() def forward( self, diff --git a/examples/models/qwen2_5/convert_weights.py b/examples/models/qwen2_5/convert_weights.py new file mode 100644 index 00000000000..4341d260831 --- /dev/null +++ b/examples/models/qwen2_5/convert_weights.py @@ -0,0 +1,73 @@ +from typing import Dict + +from torchtune.training import FullModelHFCheckpointer +# from torchtune.models import convert_weights +from torchtune.models.convert_weights import get_mapped_key +import torch + +# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. +_QWEN_2_FROM_META = { + "tok_embeddings.weight": "tok_embeddings.weight", + "norm.weight": "norm.scale", + "output.weight": "output.weight", + "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", + "layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias", + "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", + "layers.{}.attention.wq.bias": "layers.{}.attn.q_proj.bias", + "layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight", + "layers.{}.attention.wv.bias": "layers.{}.attn.v_proj.bias", + "layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight", + "layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale", + "layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale", + "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", + "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", + "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", + +} + +def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from torchtune's format to Meta's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + + Returns: + Dict[str, torch.Tensor]: State dict in Meta's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _QWEN_2_FROM_META.items()} + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + converted_state_dict[new_key] = value + + return converted_state_dict + +# TODO: no need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. +checkpointer = FullModelHFCheckpointer( + checkpoint_dir='/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/', + checkpoint_files=['model.safetensors'], + output_dir='.' , + model_type='QWEN2' +) + +print("Loading checkpoint") +sd = checkpointer.load_checkpoint() + +print("HF weights:") +for weight in sd["model"].keys(): + print(weight) +print() + +# Convert from TorchTune to Meta (PyTorch native) +sd = qwen_2_tune_to_meta(sd['model']) + +print("Meta weights:") +for weight in sd.keys(): + print(weight) + +print("Saving checkpoint") +torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth") From 110abd0117f824f591004719ab084d640df51f99 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:08:15 -0800 Subject: [PATCH 02/19] Fix output embedding --- examples/models/qwen2_5/convert_weights.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/models/qwen2_5/convert_weights.py b/examples/models/qwen2_5/convert_weights.py index 4341d260831..6478d4e8b2c 100644 --- a/examples/models/qwen2_5/convert_weights.py +++ b/examples/models/qwen2_5/convert_weights.py @@ -9,7 +9,6 @@ _QWEN_2_FROM_META = { "tok_embeddings.weight": "tok_embeddings.weight", "norm.weight": "norm.scale", - "output.weight": "output.weight", "layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight", "layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias", "layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight", @@ -22,7 +21,6 @@ "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight", "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight", "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", - } def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: @@ -44,6 +42,9 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. new_key = get_mapped_key(key, inverted_mapping_dict) converted_state_dict[new_key] = value + # 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733. + converted_state_dict["output.weight"] = converted_state_dict["tok_embeddings.weight"] + return converted_state_dict # TODO: no need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. From 42fdb0da1f9c7f4af09705ed84a9cf9691303bd8 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:10:33 -0800 Subject: [PATCH 03/19] Comment / lint --- examples/models/llama/rope.py | 1 + examples/models/qwen2_5/convert_weights.py | 36 ++++++++++------------ 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index caca5907a0c..cfe74c172fa 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -233,6 +233,7 @@ def qwen_apply_rotary_emb( """ Apply Qwen2-style RoPE to query and key tensors. """ + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] diff --git a/examples/models/qwen2_5/convert_weights.py b/examples/models/qwen2_5/convert_weights.py index 6478d4e8b2c..0481799e974 100644 --- a/examples/models/qwen2_5/convert_weights.py +++ b/examples/models/qwen2_5/convert_weights.py @@ -1,10 +1,11 @@ from typing import Dict -from torchtune.training import FullModelHFCheckpointer -# from torchtune.models import convert_weights -from torchtune.models.convert_weights import get_mapped_key import torch +from torchtune.models.convert_weights import get_mapped_key + +from torchtune.training import FullModelHFCheckpointer + # Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. _QWEN_2_FROM_META = { "tok_embeddings.weight": "tok_embeddings.weight", @@ -23,6 +24,7 @@ "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight", } + def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Convert a state dict from torchtune's format to Meta's format. This function @@ -43,32 +45,26 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. converted_state_dict[new_key] = value # 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733. - converted_state_dict["output.weight"] = converted_state_dict["tok_embeddings.weight"] + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] return converted_state_dict + # TODO: no need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. checkpointer = FullModelHFCheckpointer( - checkpoint_dir='/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/', - checkpoint_files=['model.safetensors'], - output_dir='.' , - model_type='QWEN2' + checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/", + checkpoint_files=["model.safetensors"], + output_dir=".", + model_type="QWEN2", ) print("Loading checkpoint") sd = checkpointer.load_checkpoint() -print("HF weights:") -for weight in sd["model"].keys(): - print(weight) -print() - -# Convert from TorchTune to Meta (PyTorch native) -sd = qwen_2_tune_to_meta(sd['model']) - -print("Meta weights:") -for weight in sd.keys(): - print(weight) +# Convert from TorchTune to Meta (PyTorch native). +sd = qwen_2_tune_to_meta(sd["model"]) print("Saving checkpoint") -torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth") +torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth") From 3ab0bd994971f9994f61b969d1f69af12dff32ec Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 16:41:11 -0800 Subject: [PATCH 04/19] Add 1.5 config --- examples/models/qwen2_5/1_5b_config.json | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 examples/models/qwen2_5/1_5b_config.json diff --git a/examples/models/qwen2_5/1_5b_config.json b/examples/models/qwen2_5/1_5b_config.json new file mode 100644 index 00000000000..6ef6f3cc27e --- /dev/null +++ b/examples/models/qwen2_5/1_5b_config.json @@ -0,0 +1,12 @@ +{ + "dim": 1536, + "ffn_dim_multiplier": 1, + "hidden_dim": 8960, + "n_heads": 12, + "n_kv_heads": 2, + "n_layers": 28, + "norm_eps": 1e-06, + "rope_theta": 1000000.0, + "use_scaled_rope": false, + "vocab_size": 151936 +} From 0a17e3b30fa8878a7612b88110f9f2547b804cda Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 17:15:41 -0800 Subject: [PATCH 05/19] Comment --- examples/models/qwen2_5/convert_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/qwen2_5/convert_weights.py b/examples/models/qwen2_5/convert_weights.py index 0481799e974..ce914539cf3 100644 --- a/examples/models/qwen2_5/convert_weights.py +++ b/examples/models/qwen2_5/convert_weights.py @@ -52,7 +52,7 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. return converted_state_dict -# TODO: no need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. +# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. checkpointer = FullModelHFCheckpointer( checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/", checkpoint_files=["model.safetensors"], From a27ed67e5ce681bbd2a7b4b376bd9fa227183eef Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 12 Feb 2025 17:45:09 -0800 Subject: [PATCH 06/19] Remove qwen rope, use hf rope instead --- examples/models/llama/rope.py | 43 ----------------------------------- 1 file changed, 43 deletions(-) diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index cfe74c172fa..e081c442032 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -210,46 +210,6 @@ def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1): return k_embed -# ======================= Qwen2 Implementation ======================== - - -def qwen_precompute_freqs_cis(dim: int, end: int, theta: float = 1_000_000.0): - """ - Precompute frequency tensor for Qwen2-style RoPE. - """ - freqs = 1.0 / ( - theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) - ) - t = torch.arange(end, device=freqs.device) - freqs = torch.outer(t, freqs).float() - freqs_cos = torch.cos(freqs) - freqs_sin = torch.sin(freqs) - return freqs_cos, freqs_sin - - -def qwen_apply_rotary_emb( - q: torch.Tensor, k: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply Qwen2-style RoPE to query and key tensors. - """ - - def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - # Reshape cos and sin for broadcasting - cos = freqs_cos.unsqueeze(1) # [seq_len, 1, head_dim] - sin = freqs_sin.unsqueeze(1) # [seq_len, 1, head_dim] - - # Apply rotation - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() @@ -259,9 +219,6 @@ def __init__(self, params: ModelArgs): if self.params.use_hf_rope: self.precompute_freqs_cis = hf_precompute_freqs_cis self.apply_rotary_emb = hf_apply_rotary_emb - # elif self.params.use_qwen_rope: - # self.precompute_freqs_cis = qwen_precompute_freqs_cis - # self.apply_rotary_emb = qwen_apply_rotary_emb else: self.precompute_freqs_cis = partial( precompute_freqs_cis, From 8aadf4510450a1aa98ce05897b14f9b136c961a4 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 13 Feb 2025 11:02:49 -0800 Subject: [PATCH 07/19] Back to meta --- examples/models/llama/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index ac9d30c7e1b..f239952be79 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -171,7 +171,7 @@ def __init__(self, **kwargs): # Within the device="meta" context, tensors that are created do not carry data. # They possess all other metadata a tensor carries such as size, stride, requires_grad. - with torch.device("cpu"): + with torch.device("meta"): self.model_ = Transformer(model_args) if "int8" in str(checkpoint_path): From 8b0b9f9eef516fa7c7a5b8c19a591711fb7816df Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 13 Feb 2025 11:36:36 -0800 Subject: [PATCH 08/19] Parametrize qkv bias --- examples/models/llama/attention.py | 14 ++++++++++---- examples/models/llama/model_args.py | 1 + examples/models/qwen2_5/1_5b_config.json | 3 ++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 3d127e47f3d..66eeb10989f 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -175,10 +175,16 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.max_batch_size = args.max_batch_size self.max_context_len = args.max_context_len self.dim = args.dim - # TODO: parametrize bias for attention and feedforward. - self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=True) - self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=True) - self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=True) + self.attention_qkv_bias = args.attention_qkv_bias + self.wq = nn.Linear( + self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wk = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) + self.wv = nn.Linear( + self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias + ) self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) self.layer_id = layer_id diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index e1c4edb8e93..28804839815 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -21,6 +21,7 @@ class ModelArgs: num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate attention_type: str = "mha" # Attention type, registered in attention.py + attention_qkv_bias: bool = False use_kv_cache: bool = False # Use key/value cache use_sdpa_with_kv_cache_op: bool = ( False # Use custom sdpa op that updates kv cache in-place diff --git a/examples/models/qwen2_5/1_5b_config.json b/examples/models/qwen2_5/1_5b_config.json index 6ef6f3cc27e..95de1f62dfc 100644 --- a/examples/models/qwen2_5/1_5b_config.json +++ b/examples/models/qwen2_5/1_5b_config.json @@ -8,5 +8,6 @@ "norm_eps": 1e-06, "rope_theta": 1000000.0, "use_scaled_rope": false, - "vocab_size": 151936 + "vocab_size": 151936, + "attention_qkv_bias": true } From 52d7a1178f5181aa07b645d5a940214274652463 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 13 Feb 2025 12:35:05 -0800 Subject: [PATCH 09/19] Parametrize use hf rope --- examples/models/llama/model.py | 1 - examples/models/qwen2_5/1_5b_config.json | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index f239952be79..90582af4856 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -150,7 +150,6 @@ def __init__(self, **kwargs): input_prune_map=input_prune_map, output_prune_map=output_prune_map, enable_dynamic_shape=self.enable_dynamic_shape, - use_hf_rope=True, **params, ) diff --git a/examples/models/qwen2_5/1_5b_config.json b/examples/models/qwen2_5/1_5b_config.json index 95de1f62dfc..64daca5a7cd 100644 --- a/examples/models/qwen2_5/1_5b_config.json +++ b/examples/models/qwen2_5/1_5b_config.json @@ -9,5 +9,6 @@ "rope_theta": 1000000.0, "use_scaled_rope": false, "vocab_size": 151936, + "use_hf_rope": true, "attention_qkv_bias": true } From 9258a682db1180cafd26ba32e59f27665dfc283d Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 18 Feb 2025 10:20:21 -0800 Subject: [PATCH 10/19] Add ci tests --- .ci/scripts/test_model.sh | 9 +++++++++ examples/models/__init__.py | 1 + examples/models/qwen2_5/__init__.py | 14 ++++++++++++++ 3 files changed, 24 insertions(+) create mode 100644 examples/models/qwen2_5/__init__.py diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index 157449c0717..f33474a35b7 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -91,6 +91,15 @@ test_model() { # Install requirements for llama vision. bash examples/models/llama3_2_vision/install_requirements.sh fi + if [[ "${MODEL_NAME}" == "qwen2_5" ]]; then + # Install requirements for export_llama + bash examples/models/llama/install_requirements.sh + # Test export_llama script: python3 -m examples.models.llama.export_llama. + # Use Llama random checkpoint with Qwen 2.5 1.5b model configuration. + "${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/qwen2_5/1_5b_config.json + run_portable_executor_runner + rm "./${MODEL_NAME}.pte" + fi # python3 -m examples.portable.scripts.export --model_name="llama2" should works too "${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}" run_portable_executor_runner diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 822d55fc09d..55f5c449ca2 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -34,6 +34,7 @@ "resnet50": ("resnet", "ResNet50Model"), "llava": ("llava", "LlavaModel"), "efficient_sam": ("efficient_sam", "EfficientSAM"), + "qwen2_5": ("qwen2_5", "Qwen2_5Model"), } __all__ = [ diff --git a/examples/models/qwen2_5/__init__.py b/examples/models/qwen2_5/__init__.py new file mode 100644 index 00000000000..d86a97a114d --- /dev/null +++ b/examples/models/qwen2_5/__init__.py @@ -0,0 +1,14 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.example.models.llama.model import Llama2Model + + +class Qwen2_5Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +__all__ = [ + "Qwen2_5Model", +] From 1b7de2fbb29146315f82102022689f58c898ffef Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 18 Feb 2025 11:16:25 -0800 Subject: [PATCH 11/19] Test ci pull --- .ci/scripts/gather_test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/scripts/gather_test_models.py b/.ci/scripts/gather_test_models.py index d02213b9faf..515bc97cca3 100755 --- a/.ci/scripts/gather_test_models.py +++ b/.ci/scripts/gather_test_models.py @@ -90,7 +90,7 @@ def model_should_run_on_event(model: str, event: str) -> bool: We put higher priority and fast models to pull request and rest to push. """ if event == "pull_request": - return model in ["mv3", "vit"] + return model in ["mv3", "vit", "qwen2_5"] # TODO: remove, just to test the ci elif event == "push": # These are super slow. Only run it periodically return model not in ["dl3", "edsr", "emformer_predict"] From 5422420055bbbf7e4f9a49b8f40bf85d308e0abe Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 18 Feb 2025 12:32:08 -0800 Subject: [PATCH 12/19] Add qwen to export_llama --models --- examples/models/llama/export_llama_lib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 4ad92903534..6d9ba750431 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -84,6 +84,7 @@ verbosity_setting = None +# All models that leverage the transformer architecture defined in llama_transformer.py. EXECUTORCH_DEFINED_MODELS = [ "stories110m", "llama2", @@ -91,6 +92,7 @@ "llama3_1", "llama3_2", "static_llama", + "qwen2_5", ] TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] From 12d4073959f23f3a7c88414292ab6d7a34af0970 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 18 Feb 2025 13:14:47 -0800 Subject: [PATCH 13/19] Leave weights uninitialized for checkopint load fail --- examples/models/llama/model.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 90582af4856..489952682e4 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -236,14 +236,21 @@ def __init__(self, **kwargs): eviction_batch_size=eviction_batch_size, ) - # assign=True: load params/buffers by assignment instead of performing an in-place copy. - # Because we are using device="meta", tensors do not have memory associated with them - # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. - missing, unexpected = self.model_.load_state_dict( - checkpoint, - strict=False, - assign=True, - ) # self.model_ = Transformer(gptconf) + missing, unexpected = None, None + try: + # assign=True: load params/buffers by assignment instead of performing an in-place copy. + # Because we are using device="meta", tensors do not have memory associated with them + # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. + missing, unexpected = self.model_.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) # self.model_ = Transformer(gptconf) + except RuntimeError as e: + print( + "Could not load checkpoint into mode, defaulting to random uninitialized weights." + ) + print(f"Error: {e}") if missing: missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")] From 955b991101b3216132fec88930638876f28260db Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 18 Feb 2025 14:26:20 -0800 Subject: [PATCH 14/19] Meta -> cpu for uninitialized weights --- examples/models/llama/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 489952682e4..bc4fd6ccb11 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -251,6 +251,8 @@ def __init__(self, **kwargs): "Could not load checkpoint into mode, defaulting to random uninitialized weights." ) print(f"Error: {e}") + # Need to provide concrete (empty) values for meta-initialized tensors for quantization. + self.model_.to_empty(device="cpu") if missing: missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")] From 9b5516baf1b2208b4d1b0010ac85106479fdc0db Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 18 Feb 2025 21:06:47 -0800 Subject: [PATCH 15/19] Skip executor runner for qwen2 test --- .ci/scripts/test_model.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index f33474a35b7..054ac02bc07 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -97,10 +97,11 @@ test_model() { # Test export_llama script: python3 -m examples.models.llama.export_llama. # Use Llama random checkpoint with Qwen 2.5 1.5b model configuration. "${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/qwen2_5/1_5b_config.json - run_portable_executor_runner rm "./${MODEL_NAME}.pte" + return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears. fi - # python3 -m examples.portable.scripts.export --model_name="llama2" should works too + + # Export a basic .pte and run the model. "${PYTHON_EXECUTABLE}" -m examples.portable.scripts.export --model_name="${MODEL_NAME}" "${STRICT}" run_portable_executor_runner } From 347c6fbc6c79e92443fa71e9360abd481d5ec564 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 19 Feb 2025 14:44:19 -0800 Subject: [PATCH 16/19] Clean up convert_weights --- examples/models/qwen2_5/convert_weights.py | 46 ++++++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/examples/models/qwen2_5/convert_weights.py b/examples/models/qwen2_5/convert_weights.py index ce914539cf3..6b6c0bbdfe2 100644 --- a/examples/models/qwen2_5/convert_weights.py +++ b/examples/models/qwen2_5/convert_weights.py @@ -1,3 +1,4 @@ +import argparse from typing import Dict import torch @@ -52,19 +53,38 @@ def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. return converted_state_dict -# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. -checkpointer = FullModelHFCheckpointer( - checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/", - checkpoint_files=["model.safetensors"], - output_dir=".", - model_type="QWEN2", -) +def main(): + parser = argparse.ArgumentParser( + description="Convert Qwen2 weights to Meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing checkpoint files", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") -print("Loading checkpoint") -sd = checkpointer.load_checkpoint() + args = parser.parse_args() -# Convert from TorchTune to Meta (PyTorch native). -sd = qwen_2_tune_to_meta(sd["model"]) + # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. + checkpointer = FullModelHFCheckpointer( + # checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/", + checkpoint_dir=args.input_dir, + checkpoint_files=["model.safetensors"], + output_dir=".", + model_type="QWEN2", + ) -print("Saving checkpoint") -torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth") + print("Loading checkpoint...") + sd = checkpointer.load_checkpoint() + + print("Converting checkpoint...") + sd = qwen_2_tune_to_meta(sd["model"]) + # torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth") + + torch.save(sd, args.output) + print(f"Checkpoint saved to {args.output}") + + +if __name__ == "__main__": + main() From 44aa34d37b3d59ce05974840d73d6f39561de94e Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 19 Feb 2025 14:44:30 -0800 Subject: [PATCH 17/19] Add README.md --- examples/models/qwen2_5/README.md | 66 +++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 examples/models/qwen2_5/README.md diff --git a/examples/models/qwen2_5/README.md b/examples/models/qwen2_5/README.md new file mode 100644 index 00000000000..aede40ec4dc --- /dev/null +++ b/examples/models/qwen2_5/README.md @@ -0,0 +1,66 @@ +## Summary +Qwen 2.5 is the latest iteration of the Qwen series of large language models (LLMs) developed by Alibaba. At the moment, 1.5b is currently supporting, with plans in the future for adding the 0.5b and 3b versions. + +## Instructions + +Qwen 2.5 uses the same example code as Llama, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details. + +All commands for exporting and running Llama on various backends should also be applicable to Qwen 2.5, by swapping the following args: +``` +--model qwen2_5 +--params examples/models/qwen2_5/1_5b_config.json +--checkpoint +``` + +### Generate the Checkpoint +The original checkpoint can be obtained from HuggingFace: +``` +huggingface-cli download Qwen/Qwen2.5-1.5B +``` + +We then convert it to Meta's checkpoint format: +``` +python examples/models/qwen2_5/convert_weights.py +``` + +### Example export and run +Here is an basic example for exporting and running Qwen 2.5, although please refer to [Llama README page](../llama/README.md) for more advanced usage. + +Export to XNNPack, no quantization: +``` +# No quantization +# Set these paths to point to the downloaded files +QWEN_CHECKPOINT=path/to/checkpoint.pth + +python -m examples.models.llama.export_llama \ + --model "qwen2_5" \ + --checkpoint "${QWEN_CHECKPOINT:?}" \ + --params examples/models/qwen2_5/1_5b_config.json \ + -kv \ + --use_sdpa_with_kv_cache \ + -d fp32 \ + -X \ + --metadata '{"get_bos_id":151643, "get_eos_ids":[151643]}' \ + --output_name="qwen2_5-1_5b.pte" + --verbose +``` + +Run using the executor runner: +``` +# Currently a work in progress, just need to enable HuggingFace json tokenizer in C++. +# In the meantime, can run with an example Python runner with pybindings: + +python -m examples.models.llama.runner.native + --model qwen2_5 + --pte + -kv + --tokenizer /tokenizer.json + --tokenizer_config /tokenizer_config.json + --prompt "Who is the founder of Meta?" + --params examples/models/qwen2_5/1_5b_config.json + --max_len 64 + --temperature 0 +``` + + + From 93064d24827193c1e80348ff06e04c9d51a7aa07 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 21 Feb 2025 08:08:36 -0800 Subject: [PATCH 18/19] Bias for static attention --- examples/models/llama/static_attention.py | 7 ++++--- examples/models/qwen2_5/README.md | 3 --- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 8b341a3aafd..72ed4e1dfff 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -145,22 +145,23 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): self.dim = config.dim self.head_dim = config.head_dim self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5) + self.attention_qkv_bias = config.attention_qkv_bias self.wqs = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) for _ in range(self.n_heads) ] ) self.wks = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) for _ in range(self.n_kv_heads) ] ) self.wvs = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=False) + nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) for _ in range(self.n_kv_heads) ] ) diff --git a/examples/models/qwen2_5/README.md b/examples/models/qwen2_5/README.md index aede40ec4dc..9bf791a35ed 100644 --- a/examples/models/qwen2_5/README.md +++ b/examples/models/qwen2_5/README.md @@ -61,6 +61,3 @@ python -m examples.models.llama.runner.native --max_len 64 --temperature 0 ``` - - - From 8b519597655ccd9ee288335e927718acf9832f0c Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 24 Feb 2025 16:28:21 -0800 Subject: [PATCH 19/19] Remove comments --- examples/models/qwen2_5/convert_weights.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/models/qwen2_5/convert_weights.py b/examples/models/qwen2_5/convert_weights.py index 6b6c0bbdfe2..9aada5b3e90 100644 --- a/examples/models/qwen2_5/convert_weights.py +++ b/examples/models/qwen2_5/convert_weights.py @@ -68,7 +68,6 @@ def main(): # Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves. checkpointer = FullModelHFCheckpointer( - # checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/", checkpoint_dir=args.input_dir, checkpoint_files=["model.safetensors"], output_dir=".", @@ -80,7 +79,6 @@ def main(): print("Converting checkpoint...") sd = qwen_2_tune_to_meta(sd["model"]) - # torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth") torch.save(sd, args.output) print(f"Checkpoint saved to {args.output}")