Skip to content

Commit

Permalink
fix oneDNN channels_last path issue (#83653)
Browse files Browse the repository at this point in the history
Fix #82060(N>1 will call in OneDNN path) and #80837, those two issues are introduced by the definition of channels last is different between PyTorch FW side with ideep side, this PR will fix this gap which ideep will use the format flag given by FW side.

Pull Request resolved: #83653
Approved by: https://github.com/mingfeima, https://github.com/malfet
  • Loading branch information
XiaobingSuper authored and pytorchmergebot committed Aug 25, 2022
1 parent b6ba419 commit a013597
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 30 deletions.
89 changes: 59 additions & 30 deletions aten/src/ATen/native/mkldnn/Conv.cpp
Expand Up @@ -155,9 +155,17 @@ static void check_shape_forward(const Tensor& input,
// but weight/bias and grad_weight/grad_bias are always CPU tensor.
//

static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {
auto memory_format = at::MemoryFormat::Contiguous;
if (is_channels_last) {
memory_format = dims == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
}
return memory_format;
}

Tensor mkldnn_convolution(
const Tensor& input,
const Tensor& weight,
const Tensor& input_t,
const Tensor& weight_t,
const c10::optional<Tensor>& bias_opt,
IntArrayRef padding,
IntArrayRef stride,
Expand All @@ -167,15 +175,18 @@ Tensor mkldnn_convolution(
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;

if (input.scalar_type() == ScalarType::BFloat16) {
if (input_t.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_convolution: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}

check_shape_forward(input, weight, bias, padding, stride, dilation, groups);
check_shape_forward(input_t, weight_t, bias, padding, stride, dilation, groups);

bool is_channels_last = input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t);
auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last);

auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
auto output_sizes = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation);
auto output = at::empty({0}, input.options());

Expand All @@ -184,12 +195,12 @@ Tensor mkldnn_convolution(

ideep::tensor y;
if (is_channels_last) {
output.resize_(output_sizes, input.suggest_memory_format());
output.resize_(output_sizes, memory_format);
y = itensor_from_tensor(output);
}
if (bias.defined()) {
const ideep::tensor b = itensor_from_tensor(bias);
ideep::convolution_forward::compute(
ideep::convolution_forward::compute_v3(
x,
w,
b,
Expand All @@ -199,9 +210,10 @@ Tensor mkldnn_convolution(
{dilation.begin(), dilation.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
groups);
groups,
is_channels_last);
} else {
ideep::convolution_forward::compute(
ideep::convolution_forward::compute_v3(
x,
w,
{output_sizes.cbegin(), output_sizes.cend()},
Expand All @@ -210,7 +222,8 @@ Tensor mkldnn_convolution(
{dilation.begin(), dilation.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
groups);
groups,
is_channels_last);
}

if (input.is_mkldnn()) {
Expand All @@ -224,21 +237,27 @@ Tensor mkldnn_convolution(
}

Tensor mkldnn_convolution_backward_input(
IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
{
bool is_channels_last = grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
IntArrayRef input_size,
const Tensor& grad_output,
const Tensor& weight,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
bool bias_defined,
bool is_channels_last) {
auto grad_input = at::empty({0}, grad_output.options());

auto grad_y = itensor_from_tensor(grad_output);
auto w = itensor_view_from_dense(weight);

ideep::tensor grad_x;
if (is_channels_last) {
grad_input.resize_(input_size, grad_output.suggest_memory_format());
auto memory_format = mkldnn_convolution_memory_format(grad_output.ndimension(), is_channels_last);
grad_input.resize_(input_size, memory_format);
grad_x = itensor_from_tensor(grad_input);
}
ideep::convolution_backward_data::compute(
ideep::convolution_backward_data::compute_v2(
grad_y,
w,
input_size.vec(),
Expand All @@ -247,7 +266,8 @@ Tensor mkldnn_convolution_backward_input(
dilation.vec(),
padding.vec(),
padding.vec(),
groups);
groups,
is_channels_last);

if (grad_output.is_mkldnn()) {
return MKLDNNTensor(grad_x, grad_output.options());
Expand All @@ -260,17 +280,21 @@ Tensor mkldnn_convolution_backward_input(
}

std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
{
bool is_channels_last = grad_output.suggest_memory_format() == at::MemoryFormat::ChannelsLast;

IntArrayRef weight_size,
const Tensor& grad_output,
const Tensor& input,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
bool bias_defined,
bool is_channels_last) {
const ideep::tensor grad_y = itensor_from_tensor(grad_output);
const ideep::tensor x = itensor_from_tensor(input);

ideep::tensor grad_w, grad_b;
if (bias_defined) {
ideep::convolution_backward_weights::compute(
ideep::convolution_backward_weights::compute_v2(
x,
grad_y,
weight_size.vec(),
Expand All @@ -280,9 +304,10 @@ std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
dilation.vec(),
padding.vec(),
padding.vec(),
groups);
groups,
is_channels_last);
} else {
ideep::convolution_backward_weights::compute(
ideep::convolution_backward_weights::compute_v2(
x,
grad_y,
weight_size.vec(),
Expand All @@ -291,7 +316,8 @@ std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
dilation.vec(),
padding.vec(),
padding.vec(),
groups);
groups,
is_channels_last);
}

if (!is_channels_last) {
Expand All @@ -306,20 +332,23 @@ std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
}

std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
const Tensor& input_t, const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array<bool,3> output_mask)
{
auto memory_format = input.suggest_memory_format();
bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t);
auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last);
Tensor grad_output = grad_output_t.is_mkldnn() ? grad_output_t : grad_output_t.contiguous(memory_format);

Tensor input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format);
Tensor weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
grad_input = mkldnn_convolution_backward_input(
input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2], is_channels_last);
}
if (output_mask[1] || output_mask[2]) {
std::tie(grad_weight, grad_bias) = mkldnn_convolution_backward_weights(
weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2], is_channels_last);
}

return std::make_tuple(grad_input, grad_weight, grad_bias);
Expand Down
40 changes: 40 additions & 0 deletions test/test_nn.py
Expand Up @@ -13644,6 +13644,46 @@ def _make_noncontiguous(inp):
self.assertTrue(gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol))


@onlyCPU
def test_conv_contiguous_for_oneDNN(self):
# See https://github.com/pytorch/pytorch/issues/80837.
for dtype in [torch.float, torch.bfloat16]:
conv = nn.Conv2d(
1,
128,
kernel_size=(5, 2),
stride=(2, 1),
padding=(0, 1),
dilation=(1, 1),
groups=1,
bias=True,
padding_mode='zeros').to(dtype=dtype)

x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype)
x = torch.transpose(x, 1, 4)
x2 = x[..., 0]
inputs = [x2, conv.weight, conv.bias, (2, 1), (0, 1), (1, 1), False, (0, 1), 1]
if torch.backends.mkldnn.is_available():
y = conv(x2)
# Disable MKLDNN explicitly
with torch.backends.mkldnn.flags(enabled=False):
y_ = conv(x2)
self.assertEqual(y, y_)

@onlyCPU
def test_conv_ic1_channels_last_for_oneDNN(self):
# See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path.
for dtype in [torch.float, torch.bfloat16]:
conv = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False)
conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype)
x = torch.rand(2, 1, 100, 100).to(dtype=dtype)
if torch.backends.mkldnn.is_available():
y = conv(x)
# Disable MKLDNN explicitly
with torch.backends.mkldnn.flags(enabled=False):
y_ = conv(x)
self.assertEqual(y, y_)

def test_Dropout(self, device):
input = torch.empty(1000)
self._test_dropout(nn.Dropout, device, input)
Expand Down

0 comments on commit a013597

Please sign in to comment.