diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index dc3d2a6841b..223b068375f 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -229,14 +229,29 @@ def get_default_8bit_qnn_ptq_config( ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) + if act_symmetric: + # If zero_point is 128, htp can do optimizations. + # If we keep quant_min and quant_max none, observer will default use 128 as zero_point. + # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + else: + # PyTorch will remove redundant observers based on attributes such as: + # dtype, quant_min, quant_max, ch_axis, etc. + # Providing values like quant_min and quant_max can help observers compare + # and further reduce the number of observers. + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) weight_quantization_spec = QuantizationSpec( dtype=torch.int8, @@ -409,6 +424,7 @@ def get_ptq_per_channel_quant_config( quant_min=torch.iinfo(act_dtype).min, quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, + ch_axis=0, observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), ) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 06225be2d1c..ae5444023a4 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -348,7 +348,9 @@ def histogram(golden, predict): return (pa, mpa, miou, cls_iou) -def get_imagenet_dataset(dataset_path, data_size, image_shape, crop_size=None): +def get_imagenet_dataset( + dataset_path, data_size, image_shape, crop_size=None, shuffle=True +): from torchvision import datasets, transforms def get_data_loader(): @@ -365,7 +367,7 @@ def get_data_loader(): imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) return torch.utils.data.DataLoader( imagenet_data, - shuffle=True, + shuffle=shuffle, ) # prepare input data