Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

stft + abs is non-deterministic in backward path #54093

Closed
mthrok opened this issue Mar 16, 2021 · 8 comments
Closed

stft + abs is non-deterministic in backward path #54093

mthrok opened this issue Mar 16, 2021 · 8 comments
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: cuda Related to torch.cuda, and CUDA support in general module: determinism module: padding triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mthrok
Copy link
Contributor

mthrok commented Mar 16, 2021

馃悰 Bug

It seems that when torch.stft(return_complex=True) is followed by torch.abs, gradgradcheck fails, but individually, they do not fail.

To Reproduce

script

Steps to reproduce the behavior:

import torch

from torch.autograd import gradgradcheck


def stft_with_abs(tensor):
    tensor = torch.stft(input=tensor, n_fft=256, return_complex=True)
    tensor = tensor.abs()
    return tensor


def abs_(tensor):
    return tensor.abs()


def stft(tensor):
    return torch.stft(tensor, n_fft=256, return_complex=True)


def test_stft_with_abs():
    for i in range(100):
        print(i, '\r', end='')
        tensor = torch.randn([2, 250])
        tensor.requires_grad = True

        tensor = tensor.to(dtype=torch.float64, device='cuda')
        assert gradgradcheck(stft_with_abs, [tensor])


def test_stft_only():
    for i in range(100):
        print(i, '\r', end='')
        tensor = torch.randn([2, 250])
        tensor.requires_grad = True

        tensor = tensor.to(dtype=torch.float64, device='cuda')
        assert gradgradcheck(stft, [tensor])


def test_abs_only():
    for i in range(100):
        print(i, '\r', end='')
        tensor = torch.randn([2, 250])
        tensor = tensor.to(dtype=torch.float64, device='cuda')
        tensor = torch.stft(input=tensor, n_fft=256, return_complex=True)

        tensor.requires_grad = True
        assert gradgradcheck(abs_, [tensor])


# test_stft_only()  # does not fail
# test_abs_only()  # does not fail
test_stft_with_abs()

The test_stft_with_abs() fails with the following message;

Traceback (most recent call last):
  File "foo.py", line 63, in <module>
    test_stft_with_abs()
  File "foo.py", line 39, in test_stft_with_abs
    assert gradgradcheck(stft_with_abs, [tensor])
  File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 674, in gradgradcheck
    return gradcheck(
  File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 479, in gradcheck
    return not_reentrant_error()
  File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 476, in not_reentrant_error
    return fail_test(error_msg)
  File "/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 367, in fail_test
    raise RuntimeError(msg)
RuntimeError: Backward is not reentrant, i.e., running backward with same                         input and grad_output multiple times gives different values,                         although analytical gradient matches numerical gradient.                         The tolerance for nondeterminism was 0.0.

I also tried with return_complex=False but gradgradcheck did not fail.

Expected behavior

gradgradcheck should pass for stft+abs case

Environment

Collecting environment information...
PyTorch version: 1.9.0.dev20210316
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.18.4

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 450.80.02
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.9.0.dev20210316
[pip3] torchaudio==0.9.0a0+ba61c9b
[pip3] torchtext==0.9.0a0+c072ba6
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] magma-cuda101             2.5.2                         1    pytorch
[conda] mkl                       2020.2                      256
[conda] mkl-include               2020.4             h726a3e6_304    conda-forge
[conda] mkl-service               2.3.0            py38he904b0f_0
[conda] mkl_fft                   1.3.0            py38h54f3939_0
[conda] mkl_random                1.1.1            py38h0573a6f_0
[conda] numpy                     1.19.2           py38h54aff64_0
[conda] numpy-base                1.19.2           py38hfa32c7d_0
[conda] pytorch                   1.9.0.dev20210316 py3.8_cuda10.1_cudnn7.6.3_0    pytorch-nightly
[conda] pytorch-sphinx-theme      0.0.24                    dev_0    <develop>
[conda] torch                     1.7.1                    pypi_0    pypi
[conda] torchaudio                0.9.0a0+ba61c9b           dev_0    <develop>
[conda] torchtext                 0.9.0a0+c072ba6           dev_0    <develop>

Additional context

In pytorch/audio#1340, I was adding test to run gradgradcheck on torchaudio.transforms.Spectrogram. The CI reported un-deterministic error.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @ngimel @mruberry @kurtamohler

@mthrok
Copy link
Contributor Author

mthrok commented Mar 16, 2021

cc @albanD @anjali411

@anjali411 anjali411 added module: autograd Related to torch.autograd, and the autograd engine in general module: determinism triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 16, 2021
@anjali411
Copy link
Contributor

cc. @peterbell10

@albanD
Copy link
Collaborator

albanD commented Mar 16, 2021

Is stft backward expeted to be non-deterministic?

@mthrok
Copy link
Contributor Author

mthrok commented Mar 16, 2021

Is stft backward expeted to be non-deterministic?

I do not know. I tried to isolate the problem and see if stft alone would fail on gradgradcheck, but I could not make gradgradcheck fail.

@peterbell10
Copy link
Collaborator

peterbell10 commented Mar 17, 2021

I was able to minimize the reproducer a bit more:

def pad_complex_abs(tensor):
    tensor = torch.nn.functional.pad(tensor, [128, 128], 'reflect')
    tensor = tensor.transpose(0, -1).contiguous()
    tensor = torch.view_as_complex(tensor)
    tensor = torch.abs(tensor)
    return tensor

for _ in range(100):
    tensor = torch.randn(2, 1, 250, dtype=torch.float64, device='cuda', requires_grad=True)
    gradgradcheck(pad_complex_abs, [tensor])

@mruberry
Copy link
Collaborator

mruberry commented Mar 17, 2021

@ngimel and I took a look. This seems to be expected. Analysis is below. What do you think, @mthrok?

The complex reproduction (see below for a reproduction in double) can be simplified to:

def foo(tensor):
    tensor = torch.nn.functional.pad(tensor, [128, 128], 'reflect')
    tensor = torch.sin(tensor)
    return tensor


tensor = torch.randn(2, 1, 250, dtype=torch.complex128, device='cuda', requires_grad=True)
gradgradcheck(foo, [tensor])

The difference is ~1.1102e-16 for sin() and ~2.7756e-17 for abs(). I cannot reproduce this on CPU. It triggers for 'replicate' and 'reflect' padding but not for 'constant' or 'circular.'

The operation performed after the pad doesn't appear to be that special. Multiplying the tensor with itself will do it, too, but multiplying by a scalar will not.

Running with use_deterministic_algorithms(True) reveals:

RuntimeError: replication_pad1d_backward_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

so I guess this is expected. I can also replicate this in double precision with:

def foo(tensor):
    tensor = torch.nn.functional.pad(tensor, [128, 128], 'replicate')
    tensor = tensor * tensor
    return tensor


tensor = torch.randn(4, 32, 32, dtype=torch.double, device='cuda', requires_grad=True)
gradgradcheck(foo, [tensor])

@mruberry mruberry added complex_autograd module: cuda Related to torch.cuda, and CUDA support in general module: padding and removed complex_autograd labels Mar 17, 2021
@mruberry
Copy link
Collaborator

Follow-up thought: if gradcheck is checking determinism, there may be an opportunity to hook into determinism metadata in the OpInfo generated tests or use the use_deterministic_algorithms() flag.

@mthrok
Copy link
Contributor Author

mthrok commented Mar 17, 2021

@ngimel and I took a look. This seems to be expected. Analysis is below. What do you think, @mthrok?

The complex reproduction (see below for a reproduction in double) can be simplified to:

def foo(tensor):
    tensor = torch.nn.functional.pad(tensor, [128, 128], 'reflect')
    tensor = torch.sin(tensor)
    return tensor


tensor = torch.randn(2, 1, 250, dtype=torch.complex128, device='cuda', requires_grad=True)
gradgradcheck(foo, [tensor])

The difference is ~1.1102e-16 for sin() and ~2.7756e-17 for abs(). I cannot reproduce this on CPU. It triggers for 'replicate' and 'reflect' padding but not for 'constant' or 'circular.'

The operation performed after the pad doesn't appear to be that special. Multiplying the tensor with itself will do it, too, but multiplying by a scalar will not.

Running with use_deterministic_algorithms(True) reveals:

RuntimeError: replication_pad1d_backward_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

so I guess this is expected. I can also replicate this in double precision with:

def foo(tensor):
    tensor = torch.nn.functional.pad(tensor, [128, 128], 'replicate')
    tensor = tensor * tensor
    return tensor


tensor = torch.randn(4, 32, 32, dtype=torch.double, device='cuda', requires_grad=True)
gradgradcheck(foo, [tensor])

@mruberry

Thanks for looking into this and thanks for giving the detailed explanation. It makes sense.
It is good enough for me as the source of non-deteministic behavior is identified. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: cuda Related to torch.cuda, and CUDA support in general module: determinism module: padding triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants