Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNPACK Conv2D operation gives wrong result for non-contiguous weights #55781

Closed
octavianmm opened this issue Apr 12, 2021 · 5 comments
Closed
Assignees
Labels
high priority module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: correctness (silent) issue that returns an incorrect result silently triage review

Comments

@octavianmm
Copy link

octavianmm commented Apr 12, 2021

馃悰 Bug

Hello,
I鈥檓 getting wrong results for a Conv2D operation on ARM CPU compared to the correct result I get for the same code on x86_64 architectures. Basically, the output tensors are identical in some parts, but have major blocks of data different in other parts. A quick view of the differences: https://github.com/octavianmm/torch_nn_functional_conv2d_problem/blob/main/results/difference.png

To Reproduce

Steps to reproduce the behavior:
A minimal working example that can reproduce this issue can be found in this Git repository:
https://github.com/octavianmm/torch_nn_functional_conv2d_problem

Expected behavior

The expected (and correct result) can be found in the above Git repo, in the folder results/output_tensor_x86_64.pt (and .txt), whereas the incorrect result I got on the ARM CPU is in results/output_tensor_arm.pt (and .txt)

Environment

PyTorch version: 1.8.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (aarch64)
GCC version: (Ubuntu/Linaro 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/aarch64-linux-gnu/libcudnn.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_adv_train.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_cnn_train.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.0.0
/usr/lib/aarch64-linux-gnu/libcudnn_ops_train.so.8.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.8.0
[pip3] torchvision==0.9.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @malfet

@ngimel ngimel added module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: correctness (silent) issue that returns an incorrect result silently high priority labels Apr 13, 2021
@ngimel
Copy link
Collaborator

ngimel commented Apr 13, 2021

High pri as it's a silent correctness issue

@malfet
Copy link
Contributor

malfet commented Apr 13, 2021

I think it was fixed earlier today by #55794

@ngimel
Copy link
Collaborator

ngimel commented Apr 13, 2021

I don't think it was, input size in this example is (64,16,16,16), and groups=1, so it would return false in use_cpu_depthwise3x3_winograd even before the fix (input.size(1)==groups evaluates to false)

@malfet malfet self-assigned this Apr 13, 2021
@malfet malfet removed the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Apr 21, 2021
@malfet
Copy link
Contributor

malfet commented Apr 21, 2021

I can more-or-less reproduce the problem on x86 by comparing convolution result with MKL enabled or disabled:

% cat repro.py
import torch

input = torch.ones(64, 16, 16, 16)
weight = torch.load("weight.pt")
output = torch.nn.functional.conv2d(input, weight, None)
with torch.backends.mkldnn.flags(enabled=False):
  output_nomkl = torch.nn.functional.conv2d(input, weight, None)
print((output_nomkl-output).abs().max())
% python3 repro.py 
tensor(9.7794, grad_fn=<MaxBackward1>)

@malfet
Copy link
Contributor

malfet commented Apr 21, 2021

But the problem goes away if weight is contiguous...

@malfet malfet changed the title Conv2D operation on ARM architecture gives wrong result NNPACK Conv2D operation gives wrong result for non-contiguous weights Apr 21, 2021
malfet added a commit that referenced this issue Apr 21, 2021
krshrimali pushed a commit to krshrimali/pytorch that referenced this issue May 19, 2021
Summary:
Added TestNN.test_conv2d_discontiguous_weight to prevent further regressions

Fixes pytorch#55781

Pull Request resolved: pytorch#56569

Reviewed By: ngimel

Differential Revision: D27926509

Pulled By: malfet

fbshipit-source-id: fa5ce943c3e4db4aa4de1b1cba35bd399fb3c54d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: correctness (silent) issue that returns an incorrect result silently triage review
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants