From 304308d79a77efa212d77331434c0129ee88b771 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Thu, 14 Nov 2024 18:31:01 -0800 Subject: [PATCH 1/2] Bug fix: Enable fast to override quantize json --- torchchat/cli/cli.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 3a7c85937..09f5b3338 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -533,15 +533,16 @@ def arg_init(args): # Localized import to minimize expensive imports from torchchat.utils.build_utils import get_device_str - if args.device is None or args.device == "fast": + if args.device is None: args.device = get_device_str( args.quantize.get("executor", {}).get("accelerator", default_device) ) else: + args.device = get_device_str(args.device) executor_handler = args.quantize.get("executor", None) if executor_handler: if executor_handler["accelerator"] != args.device: - print('overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}') + print(f'overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}') executor_handler["accelerator"] = args.device if "mps" in args.device: From 8db4b728279a1acbbbd0e10d0dc8b0f7f22ed411 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 19 Nov 2024 21:34:34 -0500 Subject: [PATCH 2/2] collapse conditional --- torchchat/cli/cli.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 09f5b3338..a7f7bbba2 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -540,10 +540,9 @@ def arg_init(args): else: args.device = get_device_str(args.device) executor_handler = args.quantize.get("executor", None) - if executor_handler: - if executor_handler["accelerator"] != args.device: - print(f'overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}') - executor_handler["accelerator"] = args.device + if executor_handler and executor_handler["accelerator"] != args.device: + print(f'overriding json-specified device {executor_handler["accelerator"]} with cli device {args.device}') + executor_handler["accelerator"] = args.device if "mps" in args.device: if getattr(args, "compile", False) or getattr(args, "compile_prefill", False):