Skip to content

Wrong conv2d output after using F.pad + mps #80856

@Zhangyanbo

Description

@Zhangyanbo

🐛 Describe the bug

When using F.pad with mode='circular' on mps, the nn.Conv2d will return a wrong result.
This bug is hard to see unless you are using pyplot:

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

If you use CPU, the result looks correct:

torch.manual_seed(1)
dev = 'cpu'

k = torch.ones(3, 3, 9, 9).to(dev)

x = torch.rand(1, 3, 32, 32).to(dev)
x = F.pad(x, (2, 2, 2, 2), mode='circular')

y = F.conv2d(x, k)
plt.imshow(y[0, 0].detach().cpu())
plt.show()

However, if you use MPS, the result is totally different to CPU, and doesn't make sense.

torch.manual_seed(1)
dev = 'mps'

k = torch.ones(3, 3, 9, 9).to(dev)

x = torch.rand(1, 3, 32, 32).to(dev)
x = F.pad(x, (2, 2, 2, 2), mode='circular')

y = F.conv2d(x, k)
plt.imshow(y[0, 0].detach().cpu())
plt.show()

Results:

Further experiments

My experiments shown that this is come from x = F.pad(x, (2, 2, 2, 2), mode='circular').

  • If I remove this, the results are consistent;
  • And, more interestingly, if I change circular to reflect, run the code. Then, change it back to circular again. Everything looks normal. But, even now, the MPS and CPU have different output. See following figure:

Full code see here: https://github.com/Zhangyanbo/torch_bugs

Versions

Collecting environment information...
PyTorch version: 1.12.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.4 (arm64)
GCC version: Could not collect
Clang version: 13.1.6 (clang-1316.0.21.2.5)
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:13:24) [Clang 11.1.0 ] (64-bit runtime)
Python platform: macOS-12.4-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.6
[pip3] pytorch-lightning==1.6.3
[pip3] torch==1.12.0
[pip3] torch-cluster==1.6.0
[pip3] torch-geometric==2.0.4
[pip3] torch-scatter==2.0.9
[pip3] torch-sparse==0.6.13
[pip3] torch-spline-conv==1.2.1
[pip3] torchmetrics==0.8.2
[pip3] torchvision==0.13.0
[conda] numpy 1.21.6 py38hf29d37f_0 conda-forge
[conda] pytorch 1.12.0 py3.8_0 pytorch
[conda] pytorch-lightning 1.6.3 pypi_0 pypi
[conda] torch-cluster 1.6.0 pypi_0 pypi
[conda] torch-geometric 2.0.4 pypi_0 pypi
[conda] torch-scatter 2.0.9 pypi_0 pypi
[conda] torch-sparse 0.6.13 pypi_0 pypi
[conda] torch-spline-conv 1.2.1 pypi_0 pypi
[conda] torchmetrics 0.8.2 pypi_0 pypi
[conda] torchvision 0.13.0 py38_cpu pytorch

cc @ezyang @gchanan @zou3519 @kulinseth @albanD

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions