Skip to content

Commit

Permalink
Enable oneDNN QLinear FP32/BF16 output (pytorch#112126)
Browse files Browse the repository at this point in the history
**Summary**
- PR 2 for enabling Int8-Mixed-BF16 PT2E PTQ Quantization with Inductor pytorch#111640.
- Enable QLinear (relu) with BFloat16 or Float32 output.

**TestPlan**
```
python -u -m pytest -s -v test_quantized_op.py -k test_qlinear_pt2e
```

Pull Request resolved: pytorch#112126
Approved by: https://github.com/jerryzh168, https://github.com/jgong5
ghstack dependencies: pytorch#112010
  • Loading branch information
leslie-fang-intel authored and xuhancn committed Nov 8, 2023
1 parent 84e4069 commit 9aa0962
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 41 deletions.
23 changes: 13 additions & 10 deletions aten/src/ATen/native/quantized/cpu/qlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ static at::Tensor linear_int8_with_onednn_weight(
c10::optional<at::Tensor> bias, // plain tensor
double output_scale,
int64_t output_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
std::string& post_op_name, // e.g. "none", "relu"
torch::List<c10::optional<at::Scalar>>& post_op_args,
std::string& post_op_algorithm) {
Expand All @@ -924,7 +924,9 @@ static at::Tensor linear_int8_with_onednn_weight(
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
TORCH_CHECK(
weight_scales.scalar_type() == c10::ScalarType::Float, "weight scales should be dtype c10::ScalarType::Float.");
if (fp32_output) {
bool fp32_output = output_dtype.has_value() && (output_dtype.value() == c10::kFloat);
bool bfloat16_output = output_dtype.has_value() && (output_dtype.value() == c10::kBFloat16);
if (fp32_output || bfloat16_output) {
TORCH_CHECK(
output_scale == 1.0f && output_zero_point == 0, "onednn qlinear: expect scale=1 and zero point=0 for fp32 output");
}
Expand All @@ -935,21 +937,22 @@ static at::Tensor linear_int8_with_onednn_weight(
int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1);
c10::optional<ideep::tensor> onednn_bias{c10::nullopt};
bool with_bias = bias.has_value();
at::Tensor bias_val_float;
if (with_bias) {
if (bias.value().dim() == 1) {
auto b_reshape = bias.value().reshape({1, bias.value().size(0)});
bias_val_float = bias.value().to(at::kFloat);
if (bias_val_float.dim() == 1) {
auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)});
onednn_bias = at::native::itensor_view_from_dense(b_reshape);
} else {
onednn_bias = at::native::itensor_view_from_dense(bias.value());
onednn_bias = at::native::itensor_view_from_dense(bias_val_float);
}
}
std::vector<int64_t> src_dims = {M, K};
std::vector<int64_t> dst_dims = {M, N};
auto output_dtype = fp32_output ? c10::kFloat : c10::kByte;
at::Tensor output = at::empty(
dst_dims,
device(c10::kCPU)
.dtype(output_dtype)
.dtype(fp32_output ? c10::kFloat : (bfloat16_output ? c10::kBFloat16 : c10::kByte))
);
if (output.numel() == 0) {
return output;
Expand All @@ -959,7 +962,7 @@ static at::Tensor linear_int8_with_onednn_weight(
// Create onednn primitive
auto src_desc = tensor::desc(src_dims, ideep::data_type::u8, ideep::format_tag::any);
auto weights_desc = packed_weight.get_desc();
auto dst_dtype = fp32_output ? ideep::data_type::f32 : ideep::data_type::u8;
auto dst_dtype = fp32_output ? ideep::data_type::f32 : (bfloat16_output ? ideep::tensor::data_type::bf16 : ideep::data_type::u8);
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
auto bias_desc = with_bias ?
tensor::desc(onednn_bias.value().get_dims(), ideep::data_type::f32, ideep::format_tag::any) :
Expand Down Expand Up @@ -1117,15 +1120,15 @@ class QLinearOnednn final {
c10::optional<Tensor> bias,
double output_scale,
int64_t output_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
std::string post_op_name,
torch::List<c10::optional<at::Scalar>> post_op_args,
std::string post_op_algorithm) {
#if AT_MKLDNN_ENABLED()
return linear_int8_with_onednn_weight(
act, act_scale, act_zero_point,
onednn_weight, weight_scales, weight_zero_points,
bias, output_scale, output_zero_point, fp32_output,
bias, output_scale, output_zero_point, output_dtype,
post_op_name, post_op_args, post_op_algorithm
);
#endif
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,5 @@ TORCH_LIBRARY(onednn, m) {
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_prepack(Tensor weight, int[]? x_shape) -> Tensor"));

// Linear with unary postop
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, bool fp32_output, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));

m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
}
39 changes: 27 additions & 12 deletions test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4161,18 +4161,22 @@ def test_qlinear_pt2e(self):
use_bias_list = [True, False]
supported_post_ops = ['none', 'relu']
weight_quant_per_channel_list = [True, False]
fp32_output_list = [True, False]
output_dtype_list = [None, torch.float32, torch.bfloat16]
x_scale, x_zp = 1.2, 1
w_scale, w_zp = 0.8, 0
y_scale, y_zp = 4.7, 2
post_op_args = []
cases = itertools.product(
in_channels_list, out_channels_list, use_bias_list,
supported_post_ops, weight_quant_per_channel_list, fp32_output_list)
supported_post_ops, weight_quant_per_channel_list, output_dtype_list)
with override_quantized_engine('onednn'):
for ic, oc, use_bias, post_op, weight_quant_per_channel, fp32_out in cases:
if fp32_out:
y_scale, y_zp = 1.0, 0
for ic, oc, use_bias, post_op, weight_quant_per_channel, output_dtype in cases:
used_y_scale = y_scale
used_y_zp = y_zp
fp32_out = output_dtype == torch.float32
bfloat16_out = output_dtype == torch.bfloat16
if fp32_out or bfloat16_out:
used_y_scale, used_y_zp = 1.0, 0
x = torch.rand(batch_size, ic) * 10
w = torch.rand(oc, ic) * 10
qx = torch.quantize_per_tensor(x, x_scale, x_zp, torch.quint8)
Expand All @@ -4194,19 +4198,30 @@ def test_qlinear_pt2e(self):
qw_cpu = qw.int_repr()
qw_packed = qlinear_prepack(qw_cpu, x.shape)
qy_cpu = qlinear(qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, y_scale, y_zp, fp32_out, post_op, post_op_args, post_op_algorithm)
b, 1.0 / used_y_scale, used_y_zp, output_dtype, post_op, post_op_args, post_op_algorithm)

# Reference
qw_packed_ref = qlinear_prepack_ref(qw, b)
qlinear_ref = post_op_to_qlinear_ref_dict[post_op]
qy_ref = qlinear_ref(qx, qw_packed_ref, y_scale, y_zp)
qy_ref = qlinear_ref(qx, qw_packed_ref, used_y_scale, used_y_zp)

# Compare results
if fp32_out:
qy_cpu = torch.quantize_per_tensor(qy_cpu, y_scale, y_zp, dtype=torch.quint8).dequantize()
self.assertEqual(qy_cpu, qy_ref.dequantize(), "Results not equal!")
else:
self.assertEqual(qy_cpu, qy_ref.int_repr(), "Results not equal!")
if fp32_out or bfloat16_out:
qy_cpu = torch.quantize_per_tensor(
qy_cpu.to(torch.float32),
used_y_scale,
used_y_zp, dtype=torch.quint8
).int_repr()

np.testing.assert_array_almost_equal(
qy_ref.int_repr().cpu().numpy(),
qy_cpu.cpu().numpy(),
decimal=0,
err_msg=f"""X: {x}, W: {w}, b: {b},
x_s: {x_scale}, x_zp: {x_zp},
w_s: {w_scale}, w_zp: {w_zp},
y_s: {y_scale}, y_zp: {y_zp}""",
)

@unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
class TestQuantizedEmbeddingOps(TestCase):
Expand Down
12 changes: 6 additions & 6 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
KeywordArg("b"),
KeywordArg("output_scale"),
KeywordArg("output_zero_point"),
KeywordArg("fp32_output"),
KeywordArg("output_dtype"),
KeywordArg("postop_name"),
KeywordArg("postop_args"),
KeywordArg("postop_algorithm"),
Expand Down Expand Up @@ -228,7 +228,7 @@ def _register_quantized_linear_lowering(
pattern,
pass_number,
computation_op,
fp32_output,
output_dtype,
unary_attr,
):
@register_lowering_pattern(pattern, pass_number=pass_number)
Expand All @@ -255,7 +255,7 @@ def qlinear(match: Match, *args, **kwargs):
kwargs["o_zp"],
)
assert (
kwargs["fp32_output"] is True
kwargs["output_dtype"] is torch.float32
) # Expected int8-in fp32-out qlinear in weight prepack phase
assert (
kwargs["postop_name"] == "none"
Expand All @@ -271,7 +271,7 @@ def qlinear(match: Match, *args, **kwargs):
b,
o_inv_scale,
o_zero_point,
fp32_output,
output_dtype,
unary_attr.op_name,
unary_attr.scalars_attr,
unary_attr.algorithm_attr,
Expand Down Expand Up @@ -384,7 +384,7 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None):
patterns,
1 if unary_attr.op_name != "none" else 2, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
False, # fp32_output
None, # output_dtype
unary_attr, # unary_attr
)

Expand Down Expand Up @@ -967,7 +967,7 @@ def qlinear_weight_prepack(match: Match, *args, **kwargs):
bias,
1.0, # output_scale
0, # output_zero_point
True, # fp32_output
torch.float32, # output_dtype
"none", # post op name
[], # post op args
"", # post op algorithm
Expand Down
15 changes: 8 additions & 7 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5797,7 +5797,7 @@ def __init__(
c10::optional<at::Tensor> bias,
double inv_output_scale,
int64_t output_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
std::string post_op_name,
torch::List<c10::optional<at::Scalar>> post_op_args,
std::string post_op_algorithm)"""
Expand All @@ -5817,7 +5817,7 @@ def codegen(self, wrapper):
x_zp,
o_inv_scale,
o_zp,
fp32_output,
output_dtype,
unary_attr,
unary_scalars,
unary_algorithm,
Expand All @@ -5833,7 +5833,7 @@ def codegen(self, wrapper):
bias,
o_inv_scale,
o_zp,
fp32_output,
output_dtype,
unary_attr,
unary_scalars,
unary_algorithm,
Expand All @@ -5860,7 +5860,7 @@ def create(
bias: "TensorBox",
o_inv_scale: float,
output_zero_point: int,
fp32_output,
output_dtype,
unary_attr,
unary_scalars,
unary_algorithm,
Expand All @@ -5880,16 +5880,17 @@ def create(
x_zp,
o_inv_scale,
output_zero_point,
fp32_output,
output_dtype,
unary_attr,
may_convert_to_optional(unary_scalars),
unary_algorithm,
]

if fp32_output:
if output_dtype is not None:
assert output_dtype in [torch.float32, torch.bfloat16]
# in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
# if we set fp32_output, the output buf should be dtype float32 instead of uint8.
kernel_layout.dtype = torch.float32
kernel_layout.dtype = output_dtype

return QLinearPointwisePT2E(
layout=kernel_layout,
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,7 @@ def qlinear_unary(
bias: TensorBox,
o_inv_scale,
o_zero_point,
fp32_output,
output_dtype,
attr,
scalars,
algorithm,
Expand All @@ -1551,7 +1551,7 @@ def qlinear_unary(
bias,
o_inv_scale,
o_zero_point,
fp32_output,
output_dtype,
attr,
scalars,
algorithm,
Expand Down
5 changes: 3 additions & 2 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,14 +2155,15 @@ def meta_qlinear_pointwise(
bias,
output_scale,
output_zero_point,
fp32_output,
output_dtype,
post_op_name,
post_op_args,
post_op_algorithm,
):
output_shape = list(x.shape)
output_shape[-1] = w.shape[0]
out = x.new_empty(output_shape, dtype=(torch.float32 if fp32_output else None))
assert output_dtype in [torch.float32, torch.bfloat16]
out = x.new_empty(output_shape, dtype=output_dtype)
return out

_meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library(
Expand Down

0 comments on commit 9aa0962

Please sign in to comment.