Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 7 additions & 48 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,31 +306,6 @@ def get_args_and_kwargs_conv(

(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)

out_multiplier_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_multiplier[0].item()),
{"dtype": torch.int32},
)
out_shift_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_shift[0].item()),
{"dtype": torch.int32},
)

# Create a single element tensor for the weight zero point
weight_zero_point_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], weight_zero_point),
{"dtype": torch.int32},
)

# Create a single element tensor for the bias scale
bias_scale_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], bias_scale),
{"dtype": torch.float32},
)

# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs + weights_inputs + [bias])
kwargs = {
Expand All @@ -339,12 +314,12 @@ def get_args_and_kwargs_conv(
"dilation": dilation,
"groups": groups,
"input_zero_point": dequants_inputs[0].args[2],
"weight_zero_point": weight_zero_point_tensor,
"bias_scale": bias_scale_tensor,
"weight_zero_point": weight_zero_point,
"bias_scale": bias_scale,
"out_scale": quant_node.args[1],
"out_zero_point": quant_node.args[2],
"out_multiplier": out_multiplier_,
"out_shift": out_shift_,
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
}
return args, kwargs

Expand All @@ -365,27 +340,11 @@ def get_args_and_kwargs_relu(
# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs)

X_zero_point = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[0].args[2]),
{"dtype": torch.int32},
)
out_multiplier_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_multiplier[0].item()),
{"dtype": torch.int32},
)
out_shift_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_shift[0].item()),
{"dtype": torch.int32},
)

kwargs = {
"X_zero_point": X_zero_point,
"X_zero_point": dequants_inputs[0].args[2],
"out_zero_point": quant_node.args[2],
"out_multiplier": out_multiplier_,
"out_shift": out_shift_,
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
}
return args, kwargs

Expand Down
8 changes: 4 additions & 4 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor


class Conv2dPattern(QuantizationPattern):
Expand Down Expand Up @@ -307,7 +307,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor


class LayerNormPattern(QuantizationPattern):
Expand Down Expand Up @@ -437,7 +437,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_relu.default
return torch.ops.cadence.quantized_relu.per_tensor


# Regular relu op
Expand Down Expand Up @@ -496,7 +496,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor


# Conv1d + regular relu op fusion
Expand Down
53 changes: 12 additions & 41 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ def variant(
if out_shift.numel() != 1:
raise ValueError("out_shift must be a scalar")

if out_shift.dtype != torch.int64:
raise ValueError("out_shift must be an int64")
if out_shift.dtype != torch.int32:
raise ValueError("out_shift must be an int32")

_out_shift = int(out_shift.item())
_out_multiplier = int(out_multiplier[0].item())
Expand Down Expand Up @@ -1125,75 +1125,46 @@ def quantized_relu_common(


def quantized_relu_variant(
per_tensor: bool,
dtype: torch.dtype | None = None,
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
"""Create a quantized relu variant with type checking."""

def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
def variant(
X: torch.Tensor,
X_zero_point: torch.Tensor | int,
X_zero_point: int,
out_zero_point: int,
out_multiplier: torch.Tensor | int,
out_shift: torch.Tensor | int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
if per_tensor:
if dtype and X.dtype != dtype:
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")

assert isinstance(out_shift, int)
assert isinstance(out_multiplier, int)
_out_shift = out_shift
_out_multiplier = out_multiplier
else:
assert isinstance(out_multiplier, torch.Tensor)
if out_multiplier.numel() > 1:
raise ValueError("Only scalar out_multiplier is supported")

assert isinstance(out_shift, torch.Tensor)
if out_shift.numel() > 1:
raise ValueError("Only scalar out_shift is supported")

assert isinstance(X_zero_point, torch.Tensor)
if X_zero_point.shape != X.shape:
raise ValueError(
f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}"
)

_out_multiplier = int(out_multiplier.item())
_out_shift = int(out_shift.item())
if dtype and X.dtype != dtype:
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")

return quantized_relu_common(
X,
X_zero_point,
out_zero_point,
_out_multiplier,
_out_shift,
out_multiplier,
out_shift,
)

return variant

return decorator


@impl(m, "quantized_relu")
@quantized_relu_variant(False)
def quantized_relu() -> torch.Tensor: ...


@impl(m, "quantized_relu.per_tensor")
@quantized_relu_variant(True)
@quantized_relu_variant()
def quantized_relu_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu_asym8s_asym8s.per_tensor")
@quantized_relu_variant(True, torch.int8)
@quantized_relu_variant(torch.int8)
def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu_asym8u_asym8u.per_tensor")
@quantized_relu_variant(True, torch.uint8)
@quantized_relu_variant(torch.uint8)
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...


Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,7 +1644,7 @@ def call_operator(self, op, args, kwargs, meta):
ceil_mode = args[4] if len(args) >= 5 else False
count_include_pad = args[5] if len(args) >= 6 else True
divisor_override = args[6] if len(args) >= 7 else None
zero_point = torch.tensor(0, dtype=torch.int32)
zero_point = args[7] if len(args) >= 8 else None

# If the op is avg_pool1d, then we need to reshape the 3d input to a 4d
# tensor.
Expand Down
16 changes: 8 additions & 8 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_quantized_add(
torch.tensor(
[1073741824], dtype=torch.int32
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
torch.tensor([0], dtype=torch.int32), # out_shift
0, # out_zero_point
torch.tensor([[0]], dtype=dtype), # expected_output
per_tensor,
Expand All @@ -197,7 +197,7 @@ def test_quantized_add(
torch.tensor(
[1073741824], dtype=torch.int32
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
torch.tensor([0], dtype=torch.int32), # out_shift
0, # out_zero_point
torch.tensor([[-2, -8]], dtype=dtype), # expected_output
per_tensor,
Expand All @@ -220,7 +220,7 @@ def test_quantized_add(
torch.tensor(
[1073741824], dtype=torch.int32
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
torch.tensor([0], dtype=torch.int32), # out_shift
0, # out_zero_point
torch.tensor([[0, 0]], dtype=dtype), # expected_output
per_tensor,
Expand All @@ -244,7 +244,7 @@ def test_quantized_add(
torch.tensor(
[1073741824], dtype=torch.int32
), # out_multiplier (0.5 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
torch.tensor([0], dtype=torch.int32), # out_shift
0, # out_zero_point
torch.tensor(
[[[0, -2, -4], [-2, -7, -12]]], dtype=dtype
Expand All @@ -270,7 +270,7 @@ def test_quantized_add(
torch.tensor(
[268435456], dtype=torch.int32
), # out_multiplier (1.0 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
torch.tensor([0], dtype=torch.int32), # out_shift
1, # out_zero_point
torch.tensor([[1, 1]], dtype=dtype), # expected_output
per_tensor,
Expand All @@ -295,7 +295,7 @@ def test_quantized_add(
torch.tensor(
[268435456], dtype=torch.int32
), # out_multiplier (1.0 * 2^31)
torch.tensor([0], dtype=torch.int64), # out_shift
torch.tensor([0], dtype=torch.int32), # out_shift
1, # out_zero_point
torch.tensor([[1, 1]], dtype=dtype), # expected_output
False,
Expand All @@ -317,7 +317,7 @@ def test_quantized_add(
[268435456], dtype=torch.int32
), # out_multiplier (0.125 * 2^31)
torch.tensor(
[1], dtype=torch.int64
[1], dtype=torch.int32
), # out_shift (shift=1, doubles the scale)
1, # out_zero_point
torch.tensor([[1, 2]], dtype=dtype), # expected_output
Expand All @@ -339,7 +339,7 @@ def test_quantized_add(
[268435456], dtype=torch.int32
), # out_multiplier (0.125 * 2^31)
torch.tensor(
[1], dtype=torch.int64
[1], dtype=torch.int32
), # out_shift (shift=1, doubles the scale)
1, # out_zero_point
torch.tensor([[1, 2]], dtype=dtype), # expected_output
Expand Down
Loading