Skip to content

Autograd through avg_pool2d gives inconsistent results when using CUDA + float64 [EDIT: minimal code] #98927

@p-wol

Description

@p-wol

🐛 Describe the bug

EDIT

Gradients backpropagated through avg_pool2d and avg_pool1d are wrong when using CUDA + float64. Here is a minimal code:

import numpy as np
import torch

for device in [torch.device('cpu'), torch.device('cuda:0')]:
    for dtype in [torch.float32, torch.float64]:
        x = torch.tensor([[[-.3, 1.]]], device = device, dtype = dtype)
        x = torch.nn.Parameter(x)

        y = torch.nn.functional.avg_pool1d(x, 2).reshape(x.size(0), -1)
        y.backward()
        print('x.grad = {}'.format(x.grad))

Output:

x.grad = tensor([[[0.5000, 0.5000]]])
x.grad = tensor([[[0.5000, 0.5000]]], dtype=torch.float64)
x.grad = tensor([[[0.5000, 0.5000]]], device='cuda:0')
x.grad = tensor([[[0.0078, 0.0000]]], device='cuda:0', dtype=torch.float64)

INITIAL ISSUE

I have a toy model with a avg_pool2d in it, with a toy dataset and the NLLLoss. I am getting inconsistent results when computing the differentials of the loss w.r.t. the parameters of the model, when using CUDA + float64 tensors.

Here is the procedure: I am computing the Hessian $H$ of the loss according to the parameters, and I apply it to two vectors $u$ and $v$. So, I compute: $u^T H v$, which should be equal to $v^T H u$, by symmetry of the Hessian.

I have tested whether $u^T H v = v^T H u$ in several cases: on cpu or gpu, with float32 or float64 tensors, and avg_pool2d or max_pool2d. The results are:

cpu f64 avg => True
gpu f64 avg => False
gpu f32 avg => True
gpu f64 max => True

Obviously, the behavior of avg_pool2d is inconsistent when using CUDA and float64 together.

The code and the results are here: https://gist.github.com/p-wol/aacfec2f29e6c4ba4e5cbfcb978ada1f

The results show the relative error $|u^T H v - v^T H u|/|u^T H v|$ for 100 samples $(u, v)$. The relative error in the case gpu f64 avg is around $10^{-2}$, while it is below $10^{-5}$ in the float32 case and around $10^{-15}$ in the float64 case with CPU. In my use case, I cannot deal with a relative error of $10^{-2}$...

Versions

Collecting environment information...
PyTorch version: 1.13.0
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Red Hat Enterprise Linux release 8.6 (Ootpa) (x86_64)
GCC version: (GCC) 8.5.0 20210514 (Red Hat 8.5.0-10)
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.28

Python version: 3.10.8 (main, Nov  4 2022, 13:48:29) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.18.0-372.41.1.el8_6.x86_64-x86_64-with-glibc2.28
Is CUDA available: True
CUDA runtime version: 11.2.67
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 525.60.13
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              80
On-line CPU(s) list: 0-79
Thread(s) per core:  2
Core(s) per socket:  20
Socket(s):           2
NUMA node(s):        2
Vendor ID:           GenuineIntel
CPU family:          6
Model:               85
Model name:          Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
Stepping:            7
CPU MHz:             2500.000
CPU max MHz:         3900.0000
CPU min MHz:         1000.0000
BogoMIPS:            5000.00
Virtualization:      VT-x
L1d cache:           32K
L1i cache:           32K
L2 cache:            1024K
L3 cache:            28160K
NUMA node0 CPU(s):   0-19,40-59
NUMA node1 CPU(s):   20-39,60-79
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req pku ospke avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] efficientnet-pytorch==0.7.1
[pip3] flake8==5.0.4
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.3
[pip3] numpydoc==1.5.0
[pip3] nvidia-dlprof-pytorch-nvtx==1.8.0
[pip3] pytorch-fast-transformers==0.4.0
[pip3] pytorch-ignite==0.4.10
[pip3] pytorch-lightning==1.8.3
[pip3] pytorch-msssim==0.2.1
[pip3] pytorch-pfn-extras==0.6.2
[pip3] pytorch3d==0.7.2
[pip3] pytorchvideo==0.1.5
[pip3] segmentation-models-pytorch==0.3.0
[pip3] torch==1.13.0
[pip3] torch-cluster==1.6.0
[pip3] torch-geometric==2.1.0.post1
[pip3] torch-scatter==2.0.9
[pip3] torch-sparse==0.6.15
[pip3] torch-spline-conv==1.2.1
[pip3] torch-tb-profiler==0.4.0
[pip3] torchaudio==0.13.0+bc8640b
[pip3] torchio==0.18.86
[pip3] torchmetrics==0.10.3
[pip3] torchsparse==1.4.0
[pip3] torchtext==0.14.0a0+e2b27f9
[pip3] torchvision==0.14.0a0+5ce4506
[pip3] torchviz==0.0.2
[conda] blas                      1.0                         mkl
[conda] efficientnet-pytorch      0.7.1                    pypi_0    pypi
[conda] mkl                       2021.4.0           h06a4308_640
[conda] mkl-service               2.4.0           py310h7f8727e_0
[conda] mkl_fft                   1.3.1           py310hd6ae3a3_0
[conda] mkl_random                1.2.2           py310h00e6091_0
[conda] numpy                     1.23.3          py310hd5efca6_1
[conda] numpy-base                1.23.3          py310h8e6c178_1
[conda] numpydoc                  1.5.0                    pypi_0    pypi
[conda] nvidia-dlprof-pytorch-nvtx 1.8.0                    pypi_0    pypi
[conda] pytorch-fast-transformers 0.4.0                    pypi_0    pypi
[conda] pytorch-ignite            0.4.10                   pypi_0    pypi
[conda] pytorch-lightning         1.8.3                    pypi_0    pypi
[conda] pytorch-msssim            0.2.1                    pypi_0    pypi
[conda] pytorch-pfn-extras        0.6.2                    pypi_0    pypi
[conda] pytorch3d                 0.7.2                    pypi_0    pypi
[conda] pytorchvideo              0.1.5                    pypi_0    pypi
[conda] segmentation-models-pytorch 0.3.0                    pypi_0    pypi
[conda] torch                     1.13.0                   pypi_0    pypi
[conda] torch-cluster             1.6.0                    pypi_0    pypi
[conda] torch-geometric           2.1.0.post1              pypi_0    pypi
[conda] torch-scatter             2.0.9                    pypi_0    pypi
[conda] torch-sparse              0.6.15                   pypi_0    pypi
[conda] torch-spline-conv         1.2.1                    pypi_0    pypi
[conda] torch-tb-profiler         0.4.0                    pypi_0    pypi
[conda] torchaudio                0.13.0+bc8640b           pypi_0    pypi
[conda] torchio                   0.18.86                  pypi_0    pypi
[conda] torchmetrics              0.10.3                   pypi_0    pypi
[conda] torchsparse               1.4.0                    pypi_0    pypi
[conda] torchtext                 0.14.0a0+e2b27f9          pypi_0    pypi
[conda] torchvision               0.14.0a0+5ce4506          pypi_0    pypi
[conda] torchviz                  0.0.2                    pypi_0    pypi

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: poolingneeds reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions