From 4e2bee82e76a569eb5bfeddec8670b7229c72972 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Sun, 29 Sep 2024 14:39:50 -0700 Subject: [PATCH 1/2] Show a8wxdq load error only when the quant is used --- torchchat/utils/quantize.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 77b03fcba..28567afab 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -51,6 +51,9 @@ ) +# Flag for whether the a8wxdq quantizer is available. +a8wxdq_loaded = False + ######################################################################### ### torchchat quantization API ### @@ -97,6 +100,9 @@ def quantize_model( try: if quantizer == "linear:a8wxdq": + if not a8wxdq_loaded: + raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {e}") + quant_handler = ao_quantizer_class_dict[quantizer]( device=device, precision=precision, @@ -898,5 +904,8 @@ def quantized_model(self) -> nn.Module: print("Failed to torchao ops library with error: ", e) print("Slow fallback kernels will be used.") + # Mark the Quant option as available + a8wxdq_loaded = True + except Exception as e: - print(f"Failed to load torchao experimental a8wxdq quantizer with error: {e}") + pass From 6e1cdd5369bd8de886acde2058582db9ccf1eee2 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Sun, 29 Sep 2024 14:52:31 -0700 Subject: [PATCH 2/2] Update Error check --- torchchat/utils/quantize.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 28567afab..abca48d25 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -52,7 +52,7 @@ # Flag for whether the a8wxdq quantizer is available. -a8wxdq_loaded = False +a8wxdq_load_error: Optional[Exception] = None ######################################################################### ### torchchat quantization API ### @@ -79,6 +79,10 @@ def quantize_model( quantize_options = json.loads(quantize_options) for quantizer, q_kwargs in quantize_options.items(): + # Test if a8wxdq quantizer is available; Surface error if not. + if quantizer == "linear:a8wxdq" and a8wxdq_load_error is not None: + raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}") + if ( quantizer not in quantizer_class_dict and quantizer not in ao_quantizer_class_dict @@ -100,9 +104,6 @@ def quantize_model( try: if quantizer == "linear:a8wxdq": - if not a8wxdq_loaded: - raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {e}") - quant_handler = ao_quantizer_class_dict[quantizer]( device=device, precision=precision, @@ -904,8 +905,5 @@ def quantized_model(self) -> nn.Module: print("Failed to torchao ops library with error: ", e) print("Slow fallback kernels will be used.") - # Mark the Quant option as available - a8wxdq_loaded = True - except Exception as e: - pass + a8wxdq_load_error = e