From f6e99d3c8c5970fb22252dd562bd7ffd39f65ece Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 21 Dec 2023 10:10:54 -0800 Subject: [PATCH] [quant][pt2e] Relax constraints on dtype and qscheme to allow for customizations Summary: att Test Plan: CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torch/ao/quantization/quantizer/quantizer.py | 26 -------------------- 1 file changed, 26 deletions(-) diff --git a/torch/ao/quantization/quantizer/quantizer.py b/torch/ao/quantization/quantizer/quantizer.py index da9a66337cbe..21b03c294405 100644 --- a/torch/ao/quantization/quantizer/quantizer.py +++ b/torch/ao/quantization/quantizer/quantizer.py @@ -19,24 +19,6 @@ "QuantizationAnnotation", ] -# TODO: maybe remove torch.float32 -SUPPORTED_DTYPES = [ - torch.uint8, - torch.int8, - torch.int16, - torch.int32, - torch.float16, - torch.float32, -] -SUPPORTED_QSCHEMES = [ - torch.per_tensor_affine, - torch.per_tensor_symmetric, - torch.per_channel_affine, - torch.per_channel_symmetric, - torch.per_channel_affine_float_qparams, -] - - class QuantizationSpecBase(ABC): # noqa: B024 """Base class for different types of quantization specs that allows users to specify how to quantize a Tensor (input/output of a Node) in the model @@ -64,10 +46,6 @@ class QuantizationSpec(QuantizationSpecBase): is_dynamic: bool = False def __post_init__(self): - # check dtype is one of the supported types - if self.dtype not in SUPPORTED_DTYPES: - raise TypeError(f"Unsupported dtype {self.dtype}.") - # quant_min must be less than quant_max if ( self.quant_min is not None @@ -78,10 +56,6 @@ def __post_init__(self): f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." ) - # check qscheme is on of the supported ones - if self.qscheme is not None and self.qscheme not in SUPPORTED_QSCHEMES: - raise ValueError(f"Unsupported qscheme {self.qscheme}.") - # ch_axis must be less than the number of channels # but no way to check here. Just check that it is not < 0. if self.ch_axis is not None and self.ch_axis < 0: