diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index c1e69a47d3d..7933fe9ed3c 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1226,7 +1226,7 @@ def _load_llama_model( EagerModelFactory.create_model( module_name, model_class_name, - llm_config=llm_config, + model_args={"llm_config": llm_config}, ) ) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index d11cf06352b..7a3e5b80e79 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -36,19 +36,18 @@ def convert_to_llama_checkpoint(**kwargs): class Llama2Model(EagerModelBase): - def __init__(self, **kwargs): + def __init__(self, llm_config): resource_dir = get_default_model_resource_dir(__file__) + self.llm_config = llm_config + # Use single checkpoint file. - checkpoint_path = kwargs.get("checkpoint", None) + checkpoint_path = self.llm_config.base.checkpoint # Check if checkpoint_dir was provided for a sharded checkpoint. - checkpoint_dir = kwargs.get("checkpoint_dir", None) + checkpoint_dir = self.llm_config.base.checkpoint_dir # Params file. - params_path = kwargs.get("params", None) - - self.llm_config = kwargs.get("llm_config") - assert self.llm_config is not None, "llm_config must be provided" + params_path = self.llm_config.base.params self.use_kv_cache = self.llm_config.model.use_kv_cache self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache