@@ -123,26 +123,19 @@ def verbose_export():
123123
124124
125125def build_model(
126- modelname : str = "llama3" ,
127- extra_opts : str = "" ,
128- * ,
129- par_local_output: bool = False ,
130- resource_pkg_name: str = __name__ ,
126+ model : str,
127+ checkpoint : str,
128+ params: str ,
129+ output_dir: Optional[str] = "." ,
130+ extra_opts: Optional[ str] = "" ,
131131) -> str:
132- if False: # par_local_output:
133- output_dir_path = "par:."
134- else:
135- output_dir_path = "."
136-
137- argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
132+ argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}"
138133 parser = build_args_parser()
139134 args = parser.parse_args(shlex.split(argString))
140- # pkg_name = resource_pkg_name
141135 return export_llama(args)
142136
143137
144138def build_args_parser() -> argparse.ArgumentParser:
145- ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}"
146139 parser = argparse.ArgumentParser()
147140 parser.add_argument("-o", "--output-dir", default=".", help="output directory")
148141 # parser.add_argument(
@@ -191,8 +184,8 @@ def build_args_parser() -> argparse.ArgumentParser:
191184 parser.add_argument(
192185 "-c",
193186 "--checkpoint",
194- default=f"{ckpt_dir}/params/demo_rand_params.pth" ,
195- help="checkpoint path ",
187+ required=False ,
188+ help="Path to the checkpoint .pth file. When not provided, the model will be initialized with random weights. ",
196189 )
197190
198191 parser.add_argument(
@@ -273,8 +266,8 @@ def build_args_parser() -> argparse.ArgumentParser:
273266 parser.add_argument(
274267 "-p",
275268 "--params",
276- default=f"{ckpt_dir}/params/demo_config.json" ,
277- help="config.json ",
269+ required=False ,
270+ help="Config file for model parameters. When not provided, the model will fallback on default values defined in examples/models/llama/model_args.py. ",
278271 )
279272 parser.add_argument(
280273 "--optimized_rotation_path",
@@ -561,7 +554,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
561554 checkpoint_dir = (
562555 canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
563556 )
564- params_path = canonical_path(args.params)
557+ params_path = canonical_path(args.params) if args.params else None
565558 output_dir_path = canonical_path(args.output_dir, dir=True)
566559 weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
567560
@@ -960,7 +953,7 @@ def _load_llama_model(
960953 *,
961954 checkpoint: Optional[str] = None,
962955 checkpoint_dir: Optional[str] = None,
963- params_path: str,
956+ params_path: Optional[ str] = None ,
964957 use_kv_cache: bool = False,
965958 use_sdpa_with_kv_cache: bool = False,
966959 generate_full_logits: bool = False,
@@ -987,13 +980,6 @@ def _load_llama_model(
987980 An instance of LLMEdgeManager which contains the eager mode model.
988981 """
989982
990- assert (
991- checkpoint or checkpoint_dir
992- ) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
993- logging.info(
994- f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
995- )
996-
997983 if modelname in EXECUTORCH_DEFINED_MODELS:
998984 module_name = "llama"
999985 model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
0 commit comments