From 175c967c7fd57e59c3b7c3c1f5329092f8d1903d Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 25 Sep 2024 01:15:04 -0700 Subject: [PATCH 1/2] update rope class --- torchchat/cli/builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 38ffa9174..1cea524c1 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -32,6 +32,7 @@ from torchchat.model import Model, ModelArgs, ModelType from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( @@ -402,7 +403,7 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: max_seq_len = decoder_config['max_seq_len'] rope_base = decoder_config['rope_base'] for submodule in model.modules(): - if isinstance(submodule, RotaryPositionalEmbeddings): + if isinstance(submodule, Llama3ScaledRoPE): submodule.__init__(head_dim, max_seq_len, rope_base) state_dict = flamingo_meta_to_tune(checkpoint) model.model.load_state_dict(state_dict, assign=True, strict=False) From 078ac23a2acac0f478e6edd61ff12f5dbf2c4741 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Wed, 25 Sep 2024 01:17:05 -0700 Subject: [PATCH 2/2] remove old rope class --- torchchat/cli/builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 1cea524c1..1049b346f 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -31,7 +31,6 @@ from torchchat.model import Model, ModelArgs, ModelType -from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchchat.model_config.model_config import resolve_model_config