Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong output of single-channel channels_last convolution with channels_first input #82060

Closed
emilyfy opened this issue Jul 23, 2022 · 5 comments
Labels
high priority module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: correctness (silent) issue that returns an incorrect result silently module: cpu CPU specific problem (e.g., perf, algorithm) module: memory format Memory format/layout related issues/changes (channels_last, nhwc) module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@emilyfy
Copy link

emilyfy commented Jul 23, 2022

🐛 Describe the bug

Asked about some unexpected behavior on this thread and was asked to raise an issue here.

I'm trying to run convolution with torch.channels_last memory format for the weights but contiguous input, and the output is wrong if the input channels is 1.

import torch

input = torch.randn(1, 1, 100, 100)
conv = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False)

with torch.no_grad():
    out_ref = conv(input)

conv.to(memory_format=torch.channels_last)
with torch.no_grad():
    out = conv(input)

print(torch.mean(torch.abs(out - out_ref)))

this prints some non-zero value.

Note:

  • if input channel is not 1, the output is correct.
  • this is observed in CPU only. If input and conv are both sent to .cuda(), the error will be zero.

Versions

PyTorch version: 1.12.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 11.2.3 (x86_64)
GCC version: Could not collect
Clang version: 12.0.0 (clang-1200.0.32.27)
CMake version: version 3.16.2
Libc version: N/A

Python version: 3.10.4 (main, Mar 31 2022, 03:38:35) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] torch==1.12.0
[pip3] torchaudio==0.12.0
[conda] numpy 1.23.1 pypi_0 pypi
[conda] torch 1.12.0 pypi_0 pypi
[conda] torchaudio 0.12.0 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @VitalyFedyunin @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jamesr66a

@ezyang ezyang added high priority module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: correctness (silent) issue that returns an incorrect result silently labels Jul 24, 2022
@ezyang ezyang added module: cpu CPU specific problem (e.g., perf, algorithm) module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration module: memory format Memory format/layout related issues/changes (channels_last, nhwc) labels Jul 24, 2022
@jingxu10
Copy link
Collaborator

Thanks for reporting this issue. We will look into it.

@jbschlosser
Copy link
Contributor

Is this a mkldnn issue? I can repro this even using with torch.backends.mkldnn.flags(enabled=False):

@jingxu10
Copy link
Collaborator

jingxu10 commented Jul 27, 2022

I tested the following script with the latest commit 24d702d. It worked correctly on Ubuntu.
Could you verify this in your environment?

import torch

input = torch.randn(1, 1, 100, 100)
conv = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False)

with torch.no_grad():
    out_ref = conv(input)

input = input.to(memory_format=torch.channels_last) # Without this line, prints non-zero value.
conv.to(memory_format=torch.channels_last)
with torch.no_grad():
    out = conv(input)

print(torch.mean(torch.abs(out - out_ref)))

mingfeima added a commit that referenced this issue Jul 28, 2022
…channels last"


To fix #82060

When `input` is not explicitly converted to channels last while `conv` has, the output should also be in channels last. The root cause is that when input has IC of 1, `compute_columns2d` from `\aten\src\ATen\native\ConvolutionMM2d.cpp` would consider it as channels first: since for a N1HW tensor, `.contiguous(MemoryFormat::ChannelsLast)` would not change its stride , but its `suggest_memory_format()` still returns `MemoryFormat::Contiguous`.

Also updated the corresponding test cases, without this patch, the new test case would fail on forward path and runtime error on backward path.


[ghstack-poisoned]
@mingfeima
Copy link
Collaborator

mingfeima commented Jul 28, 2022

Fixed with #82392

When input is not explicitly converted to channels last while conv has, the output should also be in channels last. The root cause is that when input has IC of 1, compute_columns2d from \aten\src\ATen\native\ConvolutionMM2d.cpp would consider it as channels first:

We do have logic to make sure both input and weight have the same memory format even if they are given differently, like:

const Tensor input = self.contiguous(memory_format);

But for a N1HW input, .contiguous(MemoryFormat::ChannelsLast) would not change its stride , and its suggest_memory_format() still returns MemoryFormat::Contiguous. That's how it went wrong.

facebook-github-bot pushed a commit that referenced this issue Jul 29, 2022
…st (#82392) (#82392)

Summary:
To fix #82060

When `input` is not explicitly converted to channels last while `conv` has, the output should also be in channels last. The root cause is that when input has IC of 1, `compute_columns2d` from `\aten\src\ATen\native\ConvolutionMM2d.cpp` would consider it as channels first:

We do have logic to make sure both input and weight have the same memory format even if they are given differently, like:
```
auto input = self.contiguous(memory_format);
auto weight = weight_.contiguous(memory_format);
```

But  for a N1HW input, `.contiguous(MemoryFormat::ChannelsLast)` would not change its stride , and its `suggest_memory_format()` still returns `MemoryFormat::Contiguous`. That's how it went wrong.

Also updated the corresponding test cases, without this patch, the new test case would fail on forward path and runtime error on backward path.

attach old fail log on forward path:
```
FAIL: test_conv_thnn_nhwc_cpu_float32 (__main__.TestNNDeviceTypeCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/mingfeim/anaconda3/envs/pytorch-test-cpu/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 377, in instantiated_test
    result = test(self, **param_kwargs)
  File "/home/mingfeim/anaconda3/envs/pytorch-test-cpu/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 974, in only_fn
    return fn(slf, *args, **kwargs)
  File "test/test_nn.py", line 19487, in test_conv_thnn_nhwc
    input_format=torch.contiguous_format, weight_format=torch.channels_last)
  File "test/test_nn.py", line 19469, in helper
    self.assertEqual(out, ref_out, exact_dtype=False)
  File "/home/mingfeim/anaconda3/envs/pytorch-test-cpu/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 2376, in assertEqual
    msg=(lambda generated_msg: f"{generated_msg} : {msg}") if isinstance(msg, str) and self.longMessage else msg,
  File "/home/mingfeim/anaconda3/envs/pytorch-test-cpu/lib/python3.7/site-packages/torch/testing/_comparison.py", line 1093, in assert_equal
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 988 / 1024 (96.5%)
Greatest absolute difference: 42.0 at index (1, 2, 6, 6) (up to 1e-05 allowed)
Greatest relative difference: inf at index (0, 0, 2, 1) (up to 1.3e-06 allowed)
```

Pull Request resolved: #82392
Approved by: https://github.com/jbschlosser

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b019a416741c1a78ebed89516c1f9fdbfb276898

Reviewed By: osalpekar

Differential Revision: D38252506

Pulled By: osalpekar

fbshipit-source-id: 178fc0b64877cfb371993671d746137955ce6c5e
pytorchmergebot pushed a commit that referenced this issue Aug 25, 2022
Fix #82060(N>1 will call in OneDNN path) and #80837, those two issues are introduced by the definition of channels last is different between PyTorch FW side with ideep side, this PR will fix this gap which ideep will use the format flag given by FW side.

Pull Request resolved: #83653
Approved by: https://github.com/mingfeima, https://github.com/malfet
@ILoveSE
Copy link

ILoveSE commented Feb 23, 2023

Hello, I have a little question about this fix. @mingfeima
In PyTorch 1.12.0, if the input channel is 1, LazyConv2d with channels_last memory format will also return the wrong results.

import torch

torch.manual_seed(0)
input = torch.randn(1, 1, 100, 100)
def test():
	tmp_result= torch.nn.LazyConv2d(out_channels=64, kernel_size=(3, 3), padding=(1, 1), bias=False)
	return tmp_result
conv = test()

with torch.no_grad():
    out_ref = conv(input)

conv.to(memory_format=torch.channels_last)
with torch.no_grad():
    out = conv(input)

result=torch.mean(torch.abs(out - out_ref))

print(result)
# tensor(0.6507)

In 1.13.0 version, it returns tensor(0. ) which seems to be an expected result.
I want to confirm whether this bug in LazyConv2d is also related to \aten\src\ATen\native\ConvolutionMM2d.cpp and has been solved in this #82392.
Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: correctness (silent) issue that returns an incorrect result silently module: cpu CPU specific problem (e.g., perf, algorithm) module: memory format Memory format/layout related issues/changes (channels_last, nhwc) module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants