-
Notifications
You must be signed in to change notification settings - Fork 21.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
stft + abs is non-deterministic in backward path #54093
Comments
cc. @peterbell10 |
Is stft backward expeted to be non-deterministic? |
I do not know. I tried to isolate the problem and see if |
I was able to minimize the reproducer a bit more: def pad_complex_abs(tensor):
tensor = torch.nn.functional.pad(tensor, [128, 128], 'reflect')
tensor = tensor.transpose(0, -1).contiguous()
tensor = torch.view_as_complex(tensor)
tensor = torch.abs(tensor)
return tensor
for _ in range(100):
tensor = torch.randn(2, 1, 250, dtype=torch.float64, device='cuda', requires_grad=True)
gradgradcheck(pad_complex_abs, [tensor]) |
@ngimel and I took a look. This seems to be expected. Analysis is below. What do you think, @mthrok? The complex reproduction (see below for a reproduction in double) can be simplified to:
The difference is ~1.1102e-16 for sin() and ~2.7756e-17 for abs(). I cannot reproduce this on CPU. It triggers for 'replicate' and 'reflect' padding but not for 'constant' or 'circular.' The operation performed after the pad doesn't appear to be that special. Multiplying the tensor with itself will do it, too, but multiplying by a scalar will not. Running with use_deterministic_algorithms(True) reveals:
so I guess this is expected. I can also replicate this in double precision with:
|
Follow-up thought: if gradcheck is checking determinism, there may be an opportunity to hook into determinism metadata in the OpInfo generated tests or use the use_deterministic_algorithms() flag. |
Thanks for looking into this and thanks for giving the detailed explanation. It makes sense. |
馃悰 Bug
It seems that when
torch.stft(return_complex=True)
is followed bytorch.abs
,gradgradcheck
fails, but individually, they do not fail.To Reproduce
script
Steps to reproduce the behavior:
The
test_stft_with_abs()
fails with the following message;I also tried with
return_complex=False
butgradgradcheck
did not fail.Expected behavior
gradgradcheck
should pass forstft+abs
caseEnvironment
Additional context
In pytorch/audio#1340, I was adding test to run
gradgradcheck
ontorchaudio.transforms.Spectrogram
. The CI reported un-deterministic error.cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @ngimel @mruberry @kurtamohler
The text was updated successfully, but these errors were encountered: