From 085558c605f3a61e4c3f3b542964f4b251595f94 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Wed, 31 Jul 2024 16:01:40 -0700 Subject: [PATCH 1/3] CLI: Remove unsafe access of unused args --- build/builder.py | 34 ++++++++------ cli.py | 119 ++++++++++++++++++++++++++--------------------- generate.py | 6 ++- 3 files changed, 89 insertions(+), 70 deletions(-) diff --git a/build/builder.py b/build/builder.py index b69fcaf20..2cc6e3ecc 100644 --- a/build/builder.py +++ b/build/builder.py @@ -103,6 +103,9 @@ def from_args(cls, args): # -> BuilderArgs: model_config.transformer_params_key or model_config.name.split("/")[-1] ) + dso_path = getattr(args, "dso_path", None) + pte_path = getattr(args, "pte_path", None) + is_chat_model = False if args.is_chat_model: is_chat_model = True @@ -110,8 +113,8 @@ def from_args(cls, args): # -> BuilderArgs: for path in [ checkpoint_path, checkpoint_dir, - args.dso_path, - args.pte_path, + dso_path, + pte_path, args.gguf_path, ]: if path is not None: @@ -125,7 +128,10 @@ def from_args(cls, args): # -> BuilderArgs: if "chat" in path_basename or "instruct" in path_basename: is_chat_model = True - if args.output_pte_path and args.dtype.startswith("fast"): + + output_pte_path = getattr(args, "output_pte_path", None) + output_dso_path = getattr(args, "output_dso_path", None) + if output_pte_path and args.dtype.startswith("fast"): if args.dtype == "fast": # As per Kimish, float32 should be faster on ET XNNPACK # (because fp16 is implemented as upcast to fp32 for several @@ -144,11 +150,11 @@ def from_args(cls, args): # -> BuilderArgs: params_table=params_table, gguf_path=args.gguf_path, gguf_kwargs=None, - dso_path=args.dso_path, - pte_path=args.pte_path, + dso_path=dso_path, + pte_path=pte_path, device=args.device, precision=dtype, - setup_caches=(args.output_dso_path or args.output_pte_path), + setup_caches=(output_dso_path or output_pte_path), use_distributed=args.distributed, is_chat_model=is_chat_model, ) @@ -355,27 +361,27 @@ def _maybe_init_distributed( builder_args: BuilderArgs, ) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: """ - Initialize distributed related setups if the user specified + Initialize distributed related setups if the user specified using distributed inference. If not, this is a no-op. Args: builder_args (:class:`BuilderArgs`): Command args for model building. Returns: - Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: - - The first element is an optional DeviceMesh object, + Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: + - The first element is an optional DeviceMesh object, which which describes the mesh topology of devices for the DTensor. - - The second element is an optional ParallelDims object, + - The second element is an optional ParallelDims object, which represents the parallel dimensions configuration. """ if not builder_args.use_distributed: return None, None dist_config = 'llama3_8B.toml' # TODO - integrate with chat cmd line - - world_mesh, parallel_dims = launch_distributed(dist_config) - + + world_mesh, parallel_dims = launch_distributed(dist_config) + assert world_mesh is not None and parallel_dims is not None, f"failed to launch distributed using {dist_config}" - + return world_mesh, parallel_dims diff --git a/cli.py b/cli.py index fad69f837..bd0c5eb6f 100644 --- a/cli.py +++ b/cli.py @@ -29,7 +29,7 @@ INVENTORY_VERBS = ["download", "list", "remove", "where"] # Subcommands related to generating inference output based on user prompts -GENERATION_VERBS = ["browser", "chat", "generate", "server"] +GENERATION_VERBS = ["browser", "chat", "generate", "server"] # List of all supported subcommands in torchchat KNOWN_VERBS = GENERATION_VERBS + ["eval", "export"] + INVENTORY_VERBS @@ -49,9 +49,6 @@ def check_args(args, verb: str) -> None: # Given a arg parser and a subcommand (verb), add the appropriate arguments # for that subcommand. -# -# Note the use of argparse.SUPPRESS to hide arguments from --help due to -# legacy CLI arg parsing. See https://github.com/pytorch/torchchat/issues/932 def add_arguments_for_verb(parser, verb: str) -> None: # Argument closure for inventory related subcommands if verb in INVENTORY_VERBS: @@ -62,16 +59,15 @@ def add_arguments_for_verb(parser, verb: str) -> None: # Add argument groups for model specification (what base model to use) _add_model_specification_args(parser) - # Add argument groups for exported model path IO - _add_exported_input_path_args(parser, verb) - _add_export_output_path_args(parser, verb) - # Add argument groups for model configuration (compilation, quant, etc) _add_model_config_args(parser, verb) # Add thematic argument groups based on the subcommand - if verb in ["browser", "chat", "generate", "server"]: + if verb in GENERATION_VERBS: + _add_exported_input_path_args(parser) _add_generation_args(parser, verb) + if verb == "export": + _add_export_output_path_args(parser) if verb == "eval": _add_evaluation_args(parser) @@ -89,8 +85,13 @@ def add_arguments_for_verb(parser, verb: str) -> None: # Add CLI Args related to model specification (what base model to use) def _add_model_specification_args(parser) -> None: - model_specification_parser = parser.add_argument_group("Model Specification", "(REQUIRED) Specify the base model. Args are mutually exclusive.") - exclusive_parser = model_specification_parser.add_mutually_exclusive_group(required=True) + model_specification_parser = parser.add_argument_group( + "Model Specification", + "(REQUIRED) Specify the base model. Args are mutually exclusive.", + ) + exclusive_parser = model_specification_parser.add_mutually_exclusive_group( + required=True + ) exclusive_parser.add_argument( "model", type=str, @@ -120,20 +121,25 @@ def _add_model_specification_args(parser) -> None: help=argparse.SUPPRESS, ) + # Add CLI Args related to model configuration (compilation, quant, etc) def _add_model_config_args(parser, verb: str) -> None: - is_not_export = verb != "export" - model_config_parser = parser.add_argument_group("Model Configuration", "Specify model configurations") - model_config_parser.add_argument( - "--compile", - action="store_true", - help="Whether to compile the model with torch.compile" if is_not_export else argparse.SUPPRESS, - ) - model_config_parser.add_argument( - "--compile-prefill", - action="store_true", - help="Whether to compile the prefill. Improves prefill perf, but has higher compile times." if is_not_export else argparse.SUPPRESS, + model_config_parser = parser.add_argument_group( + "Model Configuration", "Specify model configurations" ) + + if verb != "export": + model_config_parser.add_argument( + "--compile", + action="store_true", + help="Whether to compile the model with torch.compile", + ) + model_config_parser.add_argument( + "--compile-prefill", + action="store_true", + help="Whether to compile the prefill. Improves prefill perf, but has higher compile times.", + ) + model_config_parser.add_argument( "--dtype", default="fast", @@ -157,54 +163,55 @@ def _add_model_config_args(parser, verb: str) -> None: help="Hardware device to use. Options: cpu, cuda, mps", ) -# Add CLI Args representing output paths of exported model files -def _add_export_output_path_args(parser, verb: str) -> None: - is_export = verb == "export" +# Add CLI Args representing output paths of exported model files +def _add_export_output_path_args(parser) -> None: output_path_parser = parser.add_argument_group( - "Export Output Path" if is_export else None, - "Specify the output path for the exported model files" if is_export else None, + "Export Output Path", + "Specify the output path for the exported model files", ) exclusive_parser = output_path_parser.add_mutually_exclusive_group() exclusive_parser.add_argument( "--output-pte-path", type=str, default=None, - help="Output to the specified ExecuTorch .pte model file" if is_export else argparse.SUPPRESS, + help="Output to the specified ExecuTorch .pte model file", ) exclusive_parser.add_argument( "--output-dso-path", type=str, default=None, - help="Output to the specified AOT Inductor .dso model file" if is_export else argparse.SUPPRESS, + help="Output to the specified AOT Inductor .dso model file", ) # Add CLI Args representing user provided exported model files -def _add_exported_input_path_args(parser, verb: str) -> None: - is_generation_verb = verb in GENERATION_VERBS - +def _add_exported_input_path_args(parser) -> None: exported_model_path_parser = parser.add_argument_group( - "Exported Model Path" if is_generation_verb else None, - "Specify the path of the exported model files to ingest" if is_generation_verb else None, + "Exported Model Path", + "Specify the path of the exported model files to ingest", ) exclusive_parser = exported_model_path_parser.add_mutually_exclusive_group() exclusive_parser.add_argument( "--dso-path", type=Path, default=None, - help="Use the specified AOT Inductor .dso model file" if is_generation_verb else argparse.SUPPRESS, + help="Use the specified AOT Inductor .dso model file", ) exclusive_parser.add_argument( "--pte-path", type=Path, default=None, - help="Use the specified ExecuTorch .pte model file" if is_generation_verb else argparse.SUPPRESS, + help="Use the specified ExecuTorch .pte model file", ) + # Add CLI Args related to JIT downloading of model artifacts def _add_jit_downloading_args(parser) -> None: - jit_downloading_parser = parser.add_argument_group("Model Downloading", "Specify args for model downloading (if model is not downloaded)",) + jit_downloading_parser = parser.add_argument_group( + "Model Downloading", + "Specify args for model downloading (if model is not downloaded)", + ) jit_downloading_parser.add_argument( "--hf-token", type=str, @@ -217,7 +224,8 @@ def _add_jit_downloading_args(parser) -> None: default=default_model_dir, help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}", ) - + + # Add CLI Args that are general to subcommand cli execution def _add_cli_metadata_args(parser) -> None: parser.add_argument( @@ -274,12 +282,21 @@ def _add_generation_args(parser, verb: str) -> None: generator_parser = parser.add_argument_group( "Generation", "Configs for generating output based on provided prompt" ) - generator_parser.add_argument( - "--prompt", - type=str, - default="Hello, my name is", - help="Input prompt for manual output generation" if verb == "generate" else argparse.SUPPRESS, - ) + + if verb == "generate": + generator_parser.add_argument( + "--prompt", + type=str, + default="Hello, my name is", + help="Input prompt for manual output generation", + ) + generator_parser.add_argument( + "--num-samples", + type=int, + default=1, + help="Number of samples", + ) + generator_parser.add_argument( "--chat", action="store_true", @@ -292,12 +309,6 @@ def _add_generation_args(parser, verb: str) -> None: # help="Whether to use a web UI for an interactive chat session", help=argparse.SUPPRESS, ) - generator_parser.add_argument( - "--num-samples", - type=int, - default=1, - help="Number of samples" if verb == "generate" else argparse.SUPPRESS, - ) generator_parser.add_argument( "--max-new-tokens", type=int, @@ -441,7 +452,7 @@ def arg_init(args): # if we specify dtype in quantization recipe, replicate it as args.dtype args.dtype = args.quantize.get("precision", {}).get("dtype", args.dtype) - if args.output_pte_path: + if getattr(args, "output_pte_path", None): if args.device not in ["cpu", "fast"]: raise RuntimeError("Device not supported by ExecuTorch") args.device = "cpu" @@ -451,12 +462,12 @@ def arg_init(args): ) if "mps" in args.device: - if args.compile or args.compile_prefill: + if hasattr(args, "compile") and hasattr(args, "compile_prefill"): print( "Warning: compilation is not available with device MPS, ignoring option to engage compilation" ) - args.compile = False - args.compile_prefill = False + vars(args)["compile"] = False + vars(args)["compile_prefill"] = False if hasattr(args, "seed") and args.seed: torch.manual_seed(args.seed) diff --git a/generate.py b/generate.py index 21d54373c..d38ad80b4 100644 --- a/generate.py +++ b/generate.py @@ -103,8 +103,10 @@ def validate_build( @classmethod def from_args(cls, args): - sequential_prefill = ( - args.sequential_prefill or bool(args.dso_path) or bool(args.pte_path) + dso_path = getattr(args, "dso_path", None) + pte_path = getattr(args, "pte_path", None) + sequential_prefill = args( + args.sequential_prefill or bool(dso_path) or bool(pte_path) ) return cls( From f3c236b0909aad53859868c942a01d4b667b2e1c Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Wed, 31 Jul 2024 17:19:32 -0700 Subject: [PATCH 2/3] Annotate the args conditional on subcommands in functions --- cli.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cli.py b/cli.py index bd0c5eb6f..d2093b9c2 100644 --- a/cli.py +++ b/cli.py @@ -69,6 +69,7 @@ def add_arguments_for_verb(parser, verb: str) -> None: if verb == "export": _add_export_output_path_args(parser) if verb == "eval": + _add_exported_input_path_args(parser) _add_evaluation_args(parser) # Add CLI Args related to downloading of model artifacts (if not already downloaded) @@ -123,6 +124,7 @@ def _add_model_specification_args(parser) -> None: # Add CLI Args related to model configuration (compilation, quant, etc) +# Excludes compile args if subcommand is export def _add_model_config_args(parser, verb: str) -> None: model_config_parser = parser.add_argument_group( "Model Configuration", "Specify model configurations" @@ -278,6 +280,7 @@ def _configure_artifact_inventory_args(parser, verb: str) -> None: # Add CLI Args specific to user prompted generation +# Include prompt and num_sample args when the subcommand is generate def _add_generation_args(parser, verb: str) -> None: generator_parser = parser.add_argument_group( "Generation", "Configs for generating output based on provided prompt" From 1c77ad1d4733f5e142107c2e57d015e1d32ffbad Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Wed, 31 Jul 2024 17:38:06 -0700 Subject: [PATCH 3/3] Typo in generate.py --- generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generate.py b/generate.py index d38ad80b4..69d817723 100644 --- a/generate.py +++ b/generate.py @@ -105,7 +105,7 @@ def validate_build( def from_args(cls, args): dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) - sequential_prefill = args( + sequential_prefill = ( args.sequential_prefill or bool(dso_path) or bool(pte_path) )