-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🐛 Describe the bug
When using F.pad
with mode='circular'
on mps, the nn.Conv2d
will return a wrong result.
This bug is hard to see unless you are using pyplot
:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
If you use CPU, the result looks correct:
torch.manual_seed(1)
dev = 'cpu'
k = torch.ones(3, 3, 9, 9).to(dev)
x = torch.rand(1, 3, 32, 32).to(dev)
x = F.pad(x, (2, 2, 2, 2), mode='circular')
y = F.conv2d(x, k)
plt.imshow(y[0, 0].detach().cpu())
plt.show()
However, if you use MPS, the result is totally different to CPU, and doesn't make sense.
torch.manual_seed(1)
dev = 'mps'
k = torch.ones(3, 3, 9, 9).to(dev)
x = torch.rand(1, 3, 32, 32).to(dev)
x = F.pad(x, (2, 2, 2, 2), mode='circular')
y = F.conv2d(x, k)
plt.imshow(y[0, 0].detach().cpu())
plt.show()
Further experiments
My experiments shown that this is come from x = F.pad(x, (2, 2, 2, 2), mode='circular')
.
- If I remove this, the results are consistent;
- And, more interestingly, if I change
circular
toreflect
, run the code. Then, change it back tocircular
again. Everything looks normal. But, even now, the MPS and CPU have different output. See following figure:
Full code see here: https://github.com/Zhangyanbo/torch_bugs
Versions
Collecting environment information...
PyTorch version: 1.12.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 12.4 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:13:24) [Clang 11.1.0 ] (64-bit runtime)
Python platform: macOS-12.4-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.6
[pip3] pytorch-lightning==1.6.3
[pip3] torch==1.12.0
[pip3] torch-cluster==1.6.0
[pip3] torch-geometric==2.0.4
[pip3] torch-scatter==2.0.9
[pip3] torch-sparse==0.6.13
[pip3] torch-spline-conv==1.2.1
[pip3] torchmetrics==0.8.2
[pip3] torchvision==0.13.0
[conda] numpy 1.21.6 py38hf29d37f_0 conda-forge
[conda] pytorch 1.12.0 py3.8_0 pytorch
[conda] pytorch-lightning 1.6.3 pypi_0 pypi
[conda] torch-cluster 1.6.0 pypi_0 pypi
[conda] torch-geometric 2.0.4 pypi_0 pypi
[conda] torch-scatter 2.0.9 pypi_0 pypi
[conda] torch-sparse 0.6.13 pypi_0 pypi
[conda] torch-spline-conv 1.2.1 pypi_0 pypi
[conda] torchmetrics 0.8.2 pypi_0 pypi
[conda] torchvision 0.13.0 py38_cpu pytorch