Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ python_unittest(
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:lib",
":ref_implementations",
],
)

Expand Down
177 changes: 30 additions & 147 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,33 +65,18 @@ def get_args_and_kwargs_add(
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
X_scale_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[0].args[1]),
{"dtype": torch.float},
)
X_zero_point_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[0].args[2]),
{"dtype": torch.int32},
)
Y_scale_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[1].args[1]),
{"dtype": torch.float},
)
Y_zero_point_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[1].args[2]),
{"dtype": torch.int32},
)
X_scale = dequants_inputs[0].args[1]

X_zero_point = dequants_inputs[0].args[2]
Y_scale = dequants_inputs[1].args[1]
Y_zero_point = dequants_inputs[1].args[2]
args = (
inputs_inputs[0],
X_scale_,
X_zero_point_,
X_scale,
X_zero_point,
inputs_inputs[1],
Y_scale_,
Y_zero_point_,
Y_scale,
Y_zero_point,
quant_node.args[1],
quant_node.args[2],
)
Expand Down Expand Up @@ -129,31 +114,12 @@ def get_args_and_kwargs_linear(
else:
bias = bias_inputs[0]

# Create single element tensors for weight_zero_point, out_multiplier, out_shift.
# Note that the function expects int32_t, when it would default to int64_t, so
# we explicitly require that type.
weight_zero_point_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_weights[0].args[2]),
{"dtype": torch.int32},
)
out_multiplier_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_multiplier[0].item()),
{"dtype": torch.int32},
)
out_shift_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_shift[0].item()),
{"dtype": torch.int32},
)

args = tuple(inputs_inputs + weights_inputs + [bias])
kwargs = {
"src_zero_point": dequants_inputs[0].args[2],
"weight_zero_point": weight_zero_point_,
"out_multiplier": out_multiplier_,
"out_shift": out_shift_,
"weight_zero_point": dequants_weights[0].args[2],
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
"out_zero_point": quant_node.args[2],
"offset": None,
}
Expand All @@ -178,22 +144,8 @@ def get_args_and_kwargs_layer_norm(
), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"

# Make the scale and zero_point tensors
scale_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
dequants_inputs[0].args[1],
),
{"dtype": torch.float32},
)
zero_point_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
dequants_inputs[0].args[2],
),
{"dtype": torch.int32},
)
scale = dequants_inputs[0].args[1]
zero_point = dequants_inputs[0].args[2]

weight = other_inputs[1] if len(other_inputs) > 1 else None

Expand All @@ -220,7 +172,7 @@ def get_args_and_kwargs_layer_norm(
)

# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor])
args = tuple(inputs_inputs + [scale, zero_point])
kwargs = {
"normalized_shape": other_inputs[0],
"weight": weight,
Expand Down Expand Up @@ -308,31 +260,6 @@ def get_args_and_kwargs_conv(

(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)

out_multiplier_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_multiplier[0].item()),
{"dtype": torch.int32},
)
out_shift_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_shift[0].item()),
{"dtype": torch.int32},
)

# Create a single element tensor for the weight zero point
weight_zero_point_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], weight_zero_point),
{"dtype": torch.int32},
)

# Create a single element tensor for the bias scale
bias_scale_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], bias_scale),
{"dtype": torch.float32},
)

# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs + weights_inputs + [bias])
kwargs = {
Expand All @@ -341,12 +268,12 @@ def get_args_and_kwargs_conv(
"dilation": dilation,
"groups": groups,
"input_zero_point": dequants_inputs[0].args[2],
"weight_zero_point": weight_zero_point_tensor,
"bias_scale": bias_scale_tensor,
"weight_zero_point": weight_zero_point,
"bias_scale": bias_scale,
"out_scale": quant_node.args[1],
"out_zero_point": quant_node.args[2],
"out_multiplier": out_multiplier_,
"out_shift": out_shift_,
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
}
return args, kwargs

Expand All @@ -367,27 +294,11 @@ def get_args_and_kwargs_relu(
# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs)

X_zero_point = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[0].args[2]),
{"dtype": torch.int32},
)
out_multiplier_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_multiplier[0].item()),
{"dtype": torch.int32},
)
out_shift_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], out_shift[0].item()),
{"dtype": torch.int32},
)

kwargs = {
"X_zero_point": X_zero_point,
"X_zero_point": dequants_inputs[0].args[2],
"out_zero_point": quant_node.args[2],
"out_multiplier": out_multiplier_,
"out_shift": out_shift_,
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
}
return args, kwargs

Expand Down Expand Up @@ -435,48 +346,20 @@ def get_args_and_kwargs_softmax(
{"dtype": torch.int32},
)
# Make the scale and zero_point tensors
in_scale_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
dequants_inputs[0].args[1],
),
{"dtype": torch.float32},
)
in_zero_point_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
dequants_inputs[0].args[2],
),
{"dtype": torch.int32},
)
out_scale_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
quant_node.args[1],
),
{"dtype": torch.float32},
)
out_zero_point_tensor = graph_module.graph.call_function(
torch.ops.aten.full.default,
(
[1],
quant_node.args[2],
),
{"dtype": torch.int32},
)
in_scale = dequants_inputs[0].args[1]
in_zero_point = dequants_inputs[0].args[2]
out_scale = quant_node.args[1]
out_zero_point = quant_node.args[2]

# Make the args and kwargs for the replacement op
args = (
inputs_inputs[0],
mask_tensor,
op_node.args[1],
in_scale_tensor,
in_zero_point_tensor,
out_scale_tensor,
out_zero_point_tensor,
in_scale,
in_zero_point,
out_scale,
out_zero_point,
)
kwargs = {}

Expand Down
23 changes: 12 additions & 11 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default
return torch.ops.cadence.quantized_linear.per_tensor


class AddPattern(QuantizationPattern):
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_add.default
return torch.ops.cadence.quantized_add.per_tensor


class BmmPattern(QuantizationPattern):
Expand Down Expand Up @@ -265,7 +265,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor


class Conv2dPattern(QuantizationPattern):
Expand Down Expand Up @@ -307,7 +307,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor


class LayerNormPattern(QuantizationPattern):
Expand Down Expand Up @@ -345,7 +345,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_layer_norm.default
return torch.ops.cadence.quantized_layer_norm.per_tensor


class LinearPattern(QuantizationPattern):
Expand Down Expand Up @@ -387,7 +387,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear.default
return torch.ops.cadence.quantized_linear.per_tensor


class MatmulPattern(QuantizationPattern):
Expand All @@ -411,6 +411,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
# TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op
return torch.ops.cadence.quantized_matmul.default


Expand All @@ -437,7 +438,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_relu.default
return torch.ops.cadence.quantized_relu.per_tensor


# Regular relu op
Expand Down Expand Up @@ -496,7 +497,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.default
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor


# Conv1d + regular relu op fusion
Expand Down Expand Up @@ -544,7 +545,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_softmax.default
return torch.ops.cadence.quantized_softmax.per_tensor


class MixedW8A32LinearPattern(QuantizationPattern):
Expand Down Expand Up @@ -598,7 +599,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_linear.default
return torch.ops.cadence.quantized_w8a32_linear.per_tensor


class MixedW8A32ConvPattern(QuantizationPattern):
Expand Down Expand Up @@ -660,4 +661,4 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_w8a32_conv.default
return torch.ops.cadence.quantized_w8a32_conv.per_tensor
Loading
Loading