Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_per_channel_dtype(
if quant_params.dtype == torch.int32:
return XNNDatatype.xnn_datatype_qcint32
elif quant_params.dtype == torch.int8:
if quant_params.is_per_channel_group:
if quant_params.per_channel_group:
# 4-bit per channel group quantized weights
# No 8-bit support yet
assert (
Expand Down Expand Up @@ -282,7 +282,7 @@ def get_quant_params(
buffer_idx = len(xnn_graph.constant_data)
num_scales = scale.numel()

if quant_params.is_per_channel_group:
if quant_params.per_channel_group:
scale = scale.to(torch.bfloat16)

num_bytes = scale.untyped_storage().nbytes()
Expand All @@ -300,7 +300,7 @@ def get_quant_params(
scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT
)

if quant_params.is_per_channel_group:
if quant_params.per_channel_group:
return PerChannelGroupQuant(
scale=[],
channel_dim=quant_params.axis,
Expand Down Expand Up @@ -335,7 +335,7 @@ def _check_per_channel_group_params(
) -> None:
# Make sure things are lining up for per_channel_group quantization case
# Has to be done this late because we don't have clean access to the actual tensor
assert quant_params.is_per_channel_group, "Not per_channel_group quantization"
assert quant_params.per_channel_group, "Not per_channel_group quantization"
# linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0
num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
assert (
Expand Down
63 changes: 58 additions & 5 deletions backends/xnnpack/operators/quant_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,39 @@ def __init__(
# Groupwise quantization for weight
self.per_channel_group = False
self.group_size = group_size

tensor = q_input.meta["val"]

if self.group_size > 0:
assert (
self.per_channel is True
), "Only per channel quantization supports groupwise quantization"
assert (
cast(torch.Tensor, scale).ndim == 2
), "Scale must be 2D for per channel groupwise quant"
self.per_channel_group = True
assert group_size > 0, "Group size must be greater than 0"
self.is_per_channel_group = self.per_channel and self.group_size > 0
# Assumed scale shape - [out_channels, in_channels/group_size]
input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size
# 2d weight tensor shape - [out_channels, in_channels]
assert (
tensor.shape[1] == input_channels
), "Invalid input channels for groupwise quant"
# Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only
# int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack.
self.per_channel_group = (
self.group_size <= input_channels
if self.is_qc4w
else self.group_size < input_channels
)

if not self.per_channel_group:
if cast(torch.Tensor, scale).ndim == 2:
# TODO: don't reshape scale for per_channel cases
assert (
cast(torch.Tensor, scale).shape[1] == 1
), "Invalid scale shape for per channel quantization"
scale = cast(torch.Tensor, scale).squeeze(1)

if per_channel and not self.is_per_channel_group:
tensor = q_input.meta["val"]
if per_channel and not self.per_channel_group:
assert (
tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0]
), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}"
Expand All @@ -110,6 +130,39 @@ def __init__(
tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0]
), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}"

def __str__(self) -> str:
"""String representation of QuantParams for debugging and logging."""
assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor)
scale_str = (
f"{self.scale}"
if isinstance(self.scale, float)
else f"tensor{tuple(self.scale.shape)}"
)
assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor)
zp_str = (
f"{self.zp}"
if isinstance(self.zp, float)
else f"tensor{tuple(self.zp.shape)}"
)

return (
f"QuantParams("
f"per_channel={self.per_channel}, "
f"per_channel_group={self.per_channel_group}, "
f"scale={scale_str}, "
f"zp={zp_str}, "
f"axis={self.axis}, "
f"dtype={self.dtype}, "
f"qmin={self.qmin}, "
f"qmax={self.qmax}, "
f"is_dynamic={self.is_dynamic}, "
f"is_input={self.is_input}, "
f"is_output={self.is_output}, "
f"group_size={self.group_size}, "
f"is_qc4w={self.is_qc4w}"
f")"
)

def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
# Do nothing if already quantized by the Quantizer
if tensor.dtype == self.dtype:
Expand Down
Loading