Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions test/quantization/pt2e/test_x86inductor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,22 @@ def forward(self, input):


class FP8QDQConv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
super().__init__()
self.qtype = torch.float8_e4m3fn
self.weight = torch.randn((out_channels, in_channels // groups, *kernel_size)).to(self.qtype)
self.weight = torch.randn(
(out_channels, in_channels // groups, *kernel_size)
).to(self.qtype)
self.weight_scale = 2.0
self.scale = 2.0
self.bias = None
Expand Down Expand Up @@ -170,7 +182,16 @@ def forward(self, input):
output_dtype=torch.float,
)

return torch.nn.functional.conv2d(dq_input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(
dq_input,
weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)


def qdq(input, scale):
dtype = input.dtype
Expand Down Expand Up @@ -205,9 +226,7 @@ def create_mod_info_recursion(parent):
parent_child_mod_dict = generate_model_info(model)
for name, mod in model.named_modules():
mod_type_str = mod.__class__.__name__
if mod_type_str not in [
"Linear", "Conv2d"
]:
if mod_type_str not in ["Linear", "Conv2d"]:
continue
param = mod.weight
xmax = torch.max(param)
Expand All @@ -225,7 +244,16 @@ def create_mod_info_recursion(parent):
patched_mod.weight_scale = weight_scale.item()
patched_mod.weight.data = q_param
elif mod_type_str in ["Conv2d"]:
patched_mod = FP8QDQConv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, False)
patched_mod = FP8QDQConv2d(
mod.in_channels,
mod.out_channels,
mod.kernel_size,
mod.stride,
mod.padding,
mod.dilation,
mod.groups,
False,
)
patched_mod.bias = mod.bias
patched_mod.weight_scale = weight_scale.item()
patched_mod.weight.data = q_param
Expand Down Expand Up @@ -610,7 +638,9 @@ def test_qconv2d_relu6_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->ReLU6 pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True)
self._qconv2d_unary_test_helper(
device="cpu", unary_op=torch.nn.ReLU6(), is_fp8=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
Expand All @@ -627,7 +657,9 @@ def test_qconv2d_hardtanh_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->Hardtanh pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True)
self._qconv2d_unary_test_helper(
device="cpu", unary_op=torch.nn.Hardtanh(), is_fp8=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
Expand Down Expand Up @@ -678,7 +710,9 @@ def test_qconv2d_hardswish_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->Hardswish pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True)
self._qconv2d_unary_test_helper(
device="cpu", unary_op=torch.nn.Hardswish(), is_fp8=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
Expand Down Expand Up @@ -731,7 +765,9 @@ def test_qconv2d_silu_fp8_cpu(self):
r"""
This testcase will quantize Conv2d->SiLU pattern.
"""
self._qconv2d_unary_test_helper(device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True)
self._qconv2d_unary_test_helper(
device="cpu", unary_op=torch.nn.SiLU(), is_fp8=True
)

@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
Expand Down Expand Up @@ -911,9 +947,7 @@ def forward(self, x, x2, x3):
add_fn_list = quantization_add_fn_list
if not is_fp8:
add_fn_list = add_fn_list + quantization_inplace_add_fn_list
for add_fn, swap_inputs in itertools.product(
add_fn_list, [False, True]
):
for add_fn, swap_inputs in itertools.product(add_fn_list, [False, True]):
mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device)
x = torch.randn(
(1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device
Expand Down
60 changes: 44 additions & 16 deletions torchao/quantization/pt2e/inductor_passes/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,21 +174,26 @@ def get_dequantize_per_tensor_activation_pattern(
output_dtype=KeywordArg("w_dtype"),
)


def get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern):
return _may_generate_pattern_with_dtype_convert(
dequant_wgt_pattern,
KeywordArg("autocast_wgt_dtype"),
)


def get_dequantize_clone_weight_pattern(dequant_wgt_pattern):
return CallFunction(
aten.clone.default,
dequant_wgt_pattern,
memory_format=KeywordArg("memory_format"),
)


def get_dequantize_to_bf16_clone_weight_pattern(dequant_wgt_pattern):
return get_dequantize_clone_weight_pattern(get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern))
return get_dequantize_clone_weight_pattern(
get_dequantize_to_bf16_weight_pattern(dequant_wgt_pattern)
)


def get_qconv_pt2e_pattern(x_scale_zp_are_tensors=False, users=1):
Expand Down Expand Up @@ -450,14 +455,18 @@ def fn(match):
break
assert extra_input_of_binary_node is not None
# Extra input of binary node comes from dequant pattern
if not is_fp8 and extra_input_from_dequant and (
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
or (
extra_input_of_binary_node.target
not in [
quantized_decomposed.dequantize_per_tensor.default,
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
]
if (
not is_fp8
and extra_input_from_dequant
and (
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
or (
extra_input_of_binary_node.target
not in [
quantized_decomposed.dequantize_per_tensor.default,
torch.ops.torchao.dequantize_affine_float8_non_decomposed.default,
]
)
)
):
return False
Expand Down Expand Up @@ -692,7 +701,9 @@ def _inner(match):
return _inner


def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32, is_fp8=False):
def _register_qconv_weight_prepack_pass(
pattern, pass_number, dtype=torch.float32, is_fp8=False
):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_dequant_conv_pattern(dtype),
Expand Down Expand Up @@ -776,7 +787,10 @@ def qconv_weight_prepack(match: Match, *args, **kwargs):
if is_fp8:
# For float8, we assume the scales are from aten.full.default instead of
# a constant buffer to avoid constant folding of q/dq before fusion passes.
assert w_scale.target is torch.ops.aten.full.default and x_scale.target is torch.ops.aten.full.default
assert (
w_scale.target is torch.ops.aten.full.default
and x_scale.target is torch.ops.aten.full.default
)
with torch.utils._python_dispatch._disable_current_modes():
w_scale_tensor = torch.tensor([w_scale.args[1]])
match.graph.owning_module.register_buffer("w_scale", w_scale_tensor)
Expand Down Expand Up @@ -1446,8 +1460,12 @@ def _register_dequant_promotion():


def _register_qconv_weight_prepack():
for dtype, is_fp8 in itertools.product([torch.float32, torch.bfloat16], [True, False]):
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype, is_fp8=is_fp8)
for dtype, is_fp8 in itertools.product(
[torch.float32, torch.bfloat16], [True, False]
):
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(
dtype, is_fp8=is_fp8
)
for weight_prepack_pattern in weight_prepack_patterns:
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
_register_qconv_weight_prepack_pass(
Expand Down Expand Up @@ -2050,7 +2068,13 @@ def qconv(match: Match, *args, **kwargs):
kwargs["groups"],
)
output_dtype = _get_pattern_output_dtype(match)
assert output_dtype in [torch.int8, torch.uint8, torch.float8_e4m3fn, torch.float32, torch.bfloat16]
assert output_dtype in [
torch.int8,
torch.uint8,
torch.float8_e4m3fn,
torch.float32,
torch.bfloat16,
]
# Output QParams
if output_dtype == torch.float8_e4m3fn:
# For float8, we assume the scale is from aten.full.default instead of
Expand Down Expand Up @@ -2297,7 +2321,9 @@ def _register_qconv_unary_fusion():


def _register_qconv_binary_fusion():
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product([False, True], [False, True]):
for int8_mixed_bf16_with_inplace_add, x_scale_zp_are_tensors in itertools.product(
[False, True], [False, True]
):
qconv_binary_op = (
torch.ops.onednn.qconv2d_pointwise.binary_tensor
if x_scale_zp_are_tensors
Expand All @@ -2306,7 +2332,9 @@ def _register_qconv_binary_fusion():
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
swap_binary_inputs_list = [False, True]
binary_replace_patterns = {}
for swap_inputs, is_fp8 in itertools.product(swap_binary_inputs_list, [False, True]):
for swap_inputs, is_fp8 in itertools.product(
swap_binary_inputs_list, [False, True]
):
binary_replace_patterns.update(
{
PostOpAttr(
Expand Down