|
| 1 | +import torch |
| 2 | +from .quant_primitives import ( |
| 3 | + _get_reduction_params, |
| 4 | + choose_qparams_affine_with_min_max, |
| 5 | + MappingType, |
| 6 | + ZeroPointDomain, |
| 7 | +) |
| 8 | + |
| 9 | +from abc import ABCMeta, abstractmethod |
| 10 | +from dataclasses import dataclass |
| 11 | +from typing import Callable, List, Tuple, Optional, Any |
| 12 | +from functools import partial |
| 13 | +import logging |
| 14 | +logger = logging.getLogger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +@dataclass(frozen=True) |
| 18 | +class GranularityType: |
| 19 | + pass |
| 20 | + |
| 21 | +@dataclass(frozen=True) |
| 22 | +class PerTensor(GranularityType): |
| 23 | + pass |
| 24 | + |
| 25 | +@dataclass(frozen=True) |
| 26 | +class PerAxis(GranularityType): |
| 27 | + axis: int |
| 28 | + |
| 29 | +# borrowed from torch.ao.quantization.observer |
| 30 | +class _PartialWrapper: |
| 31 | + def __init__(self, p): |
| 32 | + self.p = p |
| 33 | + |
| 34 | + def __call__(self, *args, **keywords): |
| 35 | + return self.p(*args, **keywords) |
| 36 | + |
| 37 | + def __repr__(self): |
| 38 | + return self.p.__repr__() |
| 39 | + |
| 40 | + def with_args(self, *args, **kwargs): |
| 41 | + return _with_args(self, *args, **kwargs) |
| 42 | + |
| 43 | +def _with_args(cls_or_self, *args, **kwargs): |
| 44 | + r"""Wrapper that allows creation of class factories. |
| 45 | +
|
| 46 | + This can be useful when there is a need to create classes with the same |
| 47 | + constructor arguments, but different instances. |
| 48 | +
|
| 49 | + Example:: |
| 50 | +
|
| 51 | + >>> # xdoctest: +SKIP("Undefined vars") |
| 52 | + >>> Foo.with_args = classmethod(_with_args) |
| 53 | + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) |
| 54 | + >>> foo_instance1 = foo_builder() |
| 55 | + >>> foo_instance2 = foo_builder() |
| 56 | + >>> id(foo_instance1) == id(foo_instance2) |
| 57 | + False |
| 58 | + """ |
| 59 | + r = _PartialWrapper(partial(cls_or_self, *args, **kwargs)) |
| 60 | + return r |
| 61 | + |
| 62 | +def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]: |
| 63 | + if isinstance(granularity_type, PerTensor): |
| 64 | + return input_shape |
| 65 | + elif isinstance(granularity_type, PerAxis): |
| 66 | + block_size = list(input_shape) |
| 67 | + block_size[granularity_type.axis] = 1 |
| 68 | + return tuple(block_size) |
| 69 | + raise ValueError(f"Unsupported GranularityType: {granularity_type}") |
| 70 | + |
| 71 | +ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: |
| 72 | + |
| 73 | +class AffineQuantizedObserverBase(ABC, torch.nn.Module): |
| 74 | + """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) |
| 75 | +
|
| 76 | + Args: |
| 77 | + `granularity_type` and `block_size`: The granularity of the quantization, |
| 78 | + must specify at least one, if both are specified `block_size` takes precedence |
| 79 | + Current supported granularity type are `PerTensor` and `PerAxis` |
| 80 | + other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` |
| 81 | + """ |
| 82 | + with_args = classmethod(_with_args) |
| 83 | + |
| 84 | + def __init__(self, |
| 85 | + mapping_type: MappingType, |
| 86 | + target_dtype: torch.dtype, |
| 87 | + block_size: Optional[Tuple[int, ...]] = None, |
| 88 | + granularity_type: Optional[GranularityType] = None, |
| 89 | + quant_min: Optional[int] = None, |
| 90 | + quant_max: Optional[int] = None, |
| 91 | + eps: Optional[float] = None, |
| 92 | + scale_dtype: Optional[torch.dtype] = None, |
| 93 | + zero_point_dtype: Optional[torch.dtype] = None, |
| 94 | + preserve_zero: bool = True, |
| 95 | + zero_point_domain = ZeroPointDomain.INT, |
| 96 | + ): |
| 97 | + super().__init__() |
| 98 | + assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type" |
| 99 | + if block_size is not None and granularity_type is not None: |
| 100 | + logger.warning("Both block_size and granularity_type are specified, ignoring granularity_type. block_size: {block_size}, granularity_type: {granularity_type}") |
| 101 | + self.mapping_type = mapping_type |
| 102 | + self.target_dtype = target_dtype |
| 103 | + self.block_size = block_size |
| 104 | + self.granularity_type = granularity_type |
| 105 | + self.quant_min = quant_min |
| 106 | + self.quant_max = quant_max |
| 107 | + self.eps = eps |
| 108 | + self.scale_dtype = scale_dtype |
| 109 | + self.zero_point_dtype = zero_point_dtype |
| 110 | + self.preserve_zero = preserve_zero |
| 111 | + self.zero_point_domain = zero_point_domain |
| 112 | + |
| 113 | + @abstractmethod |
| 114 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 115 | + """ forward function should take the input tensor |
| 116 | + and updates internal stats and return the original input Tensor |
| 117 | + """ |
| 118 | + pass |
| 119 | + |
| 120 | + @abstractmethod |
| 121 | + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| 122 | + """Calculate quantization parameter based on the stats attached to the observer module |
| 123 | + and returns a tuple of scale and zero_point Tensor |
| 124 | + """ |
| 125 | + pass |
| 126 | + |
| 127 | +class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): |
| 128 | + def forward(self, input: torch.Tensor): |
| 129 | + if input.numel() == 0: |
| 130 | + return input |
| 131 | + |
| 132 | + input_detached = input.detach() |
| 133 | + if self.block_size is None: |
| 134 | + self.block_size = get_block_size(input_detached.shape, self.granularity_type) |
| 135 | + |
| 136 | + shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input_detached.size()) |
| 137 | + input_detached = input_detached.view(shape_for_reduction) |
| 138 | + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) |
| 139 | + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) |
| 140 | + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): |
| 141 | + self.min_val = min_val |
| 142 | + self.max_val = max_val |
| 143 | + else: |
| 144 | + min_val = torch.min(self.min_val, min_val) |
| 145 | + max_val = torch.max(self.max_val, max_val) |
| 146 | + self.min_val.copy_(min_val) |
| 147 | + self.max_val.copy_(max_val) |
| 148 | + # returning original input |
| 149 | + return input |
| 150 | + |
| 151 | + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| 152 | + assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" |
| 153 | + return choose_qparams_affine_with_min_max( |
| 154 | + self.min_val, |
| 155 | + self.max_val, |
| 156 | + self.mapping_type, |
| 157 | + self.block_size, |
| 158 | + self.target_dtype, |
| 159 | + self.quant_min, |
| 160 | + self.quant_max, |
| 161 | + self.eps, |
| 162 | + self.scale_dtype, |
| 163 | + self.zero_point_dtype, |
| 164 | + self.preserve_zero, |
| 165 | + self.zero_point_domain |
| 166 | + ) |
0 commit comments