diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index a49d33d0..b9a902ab 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -407,7 +407,9 @@ def get_observer(self) -> str: def round_to_quantized_type_dtype( - tensor: torch.Tensor, dtype: torch.dtype + tensor: torch.Tensor, + dtype: torch.dtype, + cast_to_original_dtype: Optional[bool] = True, ) -> torch.Tensor: """ Rounds an input tensor to the nearest quantized representation given a dtype. @@ -415,6 +417,8 @@ def round_to_quantized_type_dtype( :param tensor: tensor to round :param dtype: dtype to use for rounding + :param cast_to_original_dtype: whether or not we cast the rounded tensor to + the original dtype :return: rounded tensor """ original_dtype = tensor.dtype @@ -425,7 +429,9 @@ def round_to_quantized_type_dtype( iinfo = torch.iinfo(dtype) rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max)) - return rounded.to(original_dtype) + if cast_to_original_dtype: + return rounded.to(original_dtype) + return rounded def round_to_quantized_type_args( @@ -433,6 +439,7 @@ def round_to_quantized_type_args( args: QuantizationArgs, min: torch.Tensor, max: torch.Tensor, + cast_to_original_dtype: Optional[bool] = True, ) -> torch.Tensor: """ Rounds an input tensor to the nearest quantized representation given @@ -442,6 +449,8 @@ def round_to_quantized_type_args( :param args: quantization args to use for rounding :param min: min value to use for clamping :param max: max value to use for clamping + :param cast_to_original_dtype: whether or not we cast the rounded tensor to + the original dtype :return: rounded tensor """ @@ -459,4 +468,6 @@ def round_to_quantized_type_args( else: raise ValueError(f"Invalid quantization type {args.type}") - return rounded.to(original_dtype) + if cast_to_original_dtype: + return rounded.to(original_dtype) + return rounded diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5e728dd5..45a4ef83 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -127,7 +127,7 @@ def calculate_qparams( # 5. Round the zp to zp_dtype zero_points = round_to_quantized_type_dtype( - zero_points, dtype=quantization_args.zp_dtype + zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False ) if scales.ndim == 0: diff --git a/tests/test_quantization/lifecycle/test_enabled.py b/tests/test_quantization/lifecycle/test_enabled.py index 24be64bf..87e25d55 100644 --- a/tests/test_quantization/lifecycle/test_enabled.py +++ b/tests/test_quantization/lifecycle/test_enabled.py @@ -26,8 +26,8 @@ def test_quantization_enabled_disabled(): - inp = torch.randn(16, dtype=torch.bfloat16) - model = Linear(16, 16, dtype=torch.bfloat16) + inp = torch.randn(16) + model = Linear(16, 16) quantized_model = deepcopy(model) apply_quantization_config( model=quantized_model,