Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable oneDNN QConv FP32/BF16 output #112010

Closed
61 changes: 41 additions & 20 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ static at::Tensor _quantized_convolution_onednn(
c10::optional<at::Tensor> accum, // accum to fused with conv add
double accum_scale,
int64_t accum_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
c10::optional<c10::string_view> binary_attr,
c10::optional<at::Scalar> binary_alpha,
c10::optional<c10::string_view> unary_attr,
Expand All @@ -1402,13 +1402,15 @@ static at::Tensor _quantized_convolution_onednn(
// inv_scale = 1.0 / scale will be folded.
// So, we can only get inv_scale from quant node which is used as
// output_scale of this op.
if (fp32_output) {
// When fp32_output, oneDNN expects op_attr doesn't set_scales and set_zero_points.
bool fp32_output = output_dtype.has_value() && (output_dtype.value() == c10::kFloat);
bool bfloat16_output = output_dtype.has_value() && (output_dtype.value() == c10::kBFloat16);
if (fp32_output || bfloat16_output) {
// When fp32 or bf16 output, oneDNN expects op_attr doesn't set_scales and set_zero_points.
// So, we will use default inv_output_scale as 1.0 and output_zero_point as 0, since
// when inv_output_scale is 1.0, we will skip invoking of op_attr.set_scales in ideep;
// when output_zero_point is 0, we will skip invoking of op_attr.set_zero_points in ideep.
TORCH_CHECK(inv_output_scale == 1.0, " (ONEDNN): fp32 output, inv_output_scale must be 1.0.");
TORCH_CHECK(output_zero_point == 0, " (ONEDNN): fp32 output, output_zero_point must be 0");
TORCH_CHECK(inv_output_scale == 1.0, " (ONEDNN): fp32 or bf16 output, inv_output_scale must be 1.0.");
TORCH_CHECK(output_zero_point == 0, " (ONEDNN): fp32 or bf16 output, output_zero_point must be 0");
}

int kSpatialDim = act.dim() - 2;
Expand All @@ -1417,7 +1419,13 @@ static at::Tensor _quantized_convolution_onednn(
bool has_binary_post_op = binary_attr.has_value() && binary_attr.value() != "none";
bool has_unary_post_op = unary_attr.has_value() && unary_attr.value() != "none";
// has_accum_postop_sum: extra input besides the conv to do conv add fusion with post op sum.
bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "add" && !fp32_output;
bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "add";

if (has_accum_postop_sum && (fp32_output || bfloat16_output)) {
TORCH_CHECK(accum_scale == 1.0, " (ONEDNN): fp32 or bf16 output, accum_scale must be 1.0.");
TORCH_CHECK(accum_zero_point == 0, " (ONEDNN): fp32 or bf16 output, accum_zero_point must be 0");
}

std::string func_name = "quantized::packed_weights_conv";
func_name += std::to_string(kSpatialDim) + "d";
if (has_binary_post_op) {
Expand Down Expand Up @@ -1523,14 +1531,17 @@ static at::Tensor _quantized_convolution_onednn(
ideep::tensor onednn_bias;
const int output_channels = weight.size(0);
bool with_bias = bias.has_value();

at::Tensor bias_val_float;
if (with_bias) {
at::Tensor bias_val = bias.value();
TORCH_CHECK(bias_val.dim() == 1, "bias should be a vector (1D Tensor)");
// For int8-mixed-bf16, we will also use float32 bias
bias_val_float = bias.value().to(at::kFloat);
TORCH_CHECK(bias_val_float.dim() == 1, "bias should be a vector (1D Tensor)");
TORCH_CHECK(
bias_val.size(0) == output_channels,
bias_val_float.size(0) == output_channels,
"bias should have K elements: " + std::to_string(output_channels));
auto bias_desc = ideep::tensor::desc(bias.value().sizes().vec(), dnnl::memory::data_type::f32);
onednn_bias.init(bias_desc, bias.value().data_ptr());
auto bias_desc = ideep::tensor::desc(bias_val_float.sizes().vec(), dnnl::memory::data_type::f32);
onednn_bias.init(bias_desc, bias_val_float.data_ptr());
}

const auto& expected_bias = with_bias ? onednn_bias : ideep::tensor();
Expand All @@ -1556,11 +1567,11 @@ static at::Tensor _quantized_convolution_onednn(
ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()});
// Output is not a quantized tensor but data type is uint8
at::Tensor output;
if (fp32_output) {
if (fp32_output || bfloat16_output) {
output = at::empty(
dst_dims,
device(c10::kCPU)
.dtype(c10::kFloat)
.dtype(fp32_output ? c10::kFloat : c10::kBFloat16)
.memory_format(kSpatialDim == 2 ?
c10::MemoryFormat::ChannelsLast :
c10::MemoryFormat::ChannelsLast3d),
Expand All @@ -1581,16 +1592,26 @@ static at::Tensor _quantized_convolution_onednn(
ideep::tensor dst;
at::Tensor accum_contig;
if (has_accum_postop_sum) {
auto dst_desc = ideep::tensor::desc(dst_dims, src_data_type,
auto dst_desc = ideep::tensor::desc(dst_dims, fp32_output ? ideep::tensor::data_type::f32 : (
bfloat16_output ? ideep::tensor::data_type::bf16 : src_data_type),
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
accum_contig = accum.value().contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d);
if (fp32_output || bfloat16_output) {
TORCH_CHECK((accum_contig.scalar_type() == c10::kFloat) || (accum_contig.scalar_type() == c10::kBFloat16), "The accum_contig tensor should be KFloat or KBFloat.");
TORCH_CHECK((output.scalar_type() == c10::kFloat) || (output.scalar_type() == c10::kBFloat16), "The output tensor should be KFloat or KBFloat.");
leslie-fang-intel marked this conversation as resolved.
Show resolved Hide resolved
if (accum_contig.scalar_type() != output.scalar_type()) {
// accum_contig is KFloat32 and we expect a kBFloat16 output
// or accum_contig is kBFloat16 and we expect a KFloat32 output
accum_contig = accum_contig.to(output.scalar_type());
}
}
TORCH_CHECK(accum_contig.dtype() == output.dtype(), "The output tensor should have same dtype as the accum tensor.");
// When fused with sum, the dst tensor will share the data ptr as the accum tensor.
dst.init(dst_desc, accum_contig.data_ptr());
} else {
if (fp32_output) {
if (fp32_output || bfloat16_output) {
// Conv without add: int8-in, fp32-output
dst = ideep::tensor({dst_dims, ideep::tensor::data_type::f32, {output.strides().cbegin(), output.strides().cend()}},
dst = ideep::tensor({dst_dims, fp32_output ? ideep::tensor::data_type::f32 : ideep::tensor::data_type::bf16, {output.strides().cbegin(), output.strides().cend()}},
output.data_ptr());
} else {
dst = ideep::tensor({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}},
Expand Down Expand Up @@ -1782,7 +1803,7 @@ class QConvoneDNN final {
int64_t groups,
double inv_output_scale, // inv_output_scale is the reciprocal of scale in fake quant
int64_t output_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
c10::string_view attr,
torch::List<c10::optional<at::Scalar>> scalars,
c10::optional<c10::string_view> algorithm) {
Expand Down Expand Up @@ -1810,7 +1831,7 @@ class QConvoneDNN final {
bias, stride, padding, dilation, /*transposed*/false,
groups, inv_output_scale, output_zero_point,
/*accum*/c10::nullopt, /*accum_scale*/0.0, /*accum_zero_point*/0,
/*fp32_output*/fp32_output, /*binary_attr*/c10::nullopt, /*binary_alpha*/c10::nullopt,
/*output_dtype*/output_dtype, /*binary_attr*/c10::nullopt, /*binary_alpha*/c10::nullopt,
/*unary_attr*/attr, /*unary_scalars*/scalars, /*unary_algorithm*/algorithm
);
#else
Expand All @@ -1834,7 +1855,7 @@ class QConvoneDNN final {
int64_t groups,
double inv_output_scale, // inv_output_scale is the reciprocal of scale in fake quant
int64_t output_zero_point,
bool fp32_output,
c10::optional<c10::ScalarType> output_dtype,
c10::string_view binary_attr,
c10::optional<at::Scalar> alpha,
c10::optional<c10::string_view> unary_attr,
Expand Down Expand Up @@ -1862,7 +1883,7 @@ class QConvoneDNN final {
bias, stride, padding, dilation, /*transposed*/false,
groups, inv_output_scale, output_zero_point,
accum, accum_scale, accum_zero_point,
/*fp32_output*/false, binary_attr, alpha,
/*output_dtype*/output_dtype, binary_attr, alpha,
unary_attr, unary_scalars, unary_algorithm
);
#else
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,12 @@ TORCH_LIBRARY(onednn, m) {
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv_prepack(Tensor weight, Tensor w_scales, float x_scale, int x_zp, int[] stride, int[] padding, int[] dilation, int groups, int[]? x_shape=None) -> Tensor"));

// Conv1D/2D/3D with unary postop
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv1d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv3d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv1d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv3d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor"));

// Conv2D with binary postop
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qaccum, float accum_scale, int accum_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qaccum, float accum_scale, int accum_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor"));

// Linear prepack
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_prepack(Tensor weight, int[]? x_shape) -> Tensor"));
Expand Down
Loading
Loading