Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autograd test for T.Spectrogram/T.MelSpectrogram #1340

Merged
merged 4 commits into from Mar 31, 2021

Conversation

mthrok
Copy link
Collaborator

@mthrok mthrok commented Mar 2, 2021

Add test for checking auto grad. Part of #1337



@skipIfNoCuda
class AutogradCUDATest(AutogradTestCase, PytorchTestCase):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to have a device generic class similar to what we have in PyTorch for autograd tests https://github.com/pytorch/pytorch/blob/a3a2150409472fe6fa66f3abe9e795303786252c/test/test_autograd.py#L7931-L7937

Copy link
Collaborator Author

@mthrok mthrok Mar 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but let's defer on changing the test infrastructure. This is the common pattern used in torchaudio, so if we are going to change, I would like to change them in one go.

@mthrok mthrok force-pushed the autograd-test-spectrogram branch from 7454333 to b3fbe41 Compare March 3, 2021 18:23
@mthrok mthrok changed the title Add autograd test for T.Spectrogram Add autograd test for T.Spectrogram/T.MelSpectrogram Mar 3, 2021
@mthrok
Copy link
Collaborator Author

mthrok commented Mar 3, 2021

@anjali411 I have also added T.MelSpectrogram. let me know what you think.

for i in inputs:
i.requires_grad = True
inputs_.append(i.to(dtype=self.dtype, device=self.device))
assert gradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.assertTrue(gradcheck(transform, inputs_))



class AutogradTestCase(TestBaseMixin):
def assert_grad(self, transform, *inputs, eps=1e-06, atol=1e-05, rtol=0.001):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to define these values eps, atol, rtol here

i.requires_grad = True
inputs_.append(i.to(dtype=self.dtype, device=self.device))
assert gradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)
assert gradgradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

`self.assertTrue(gradgradcheck(transform, inputs_))

test/torchaudio_unittest/transforms/autograd_test_impl.py Outdated Show resolved Hide resolved
for i in inputs:
i.requires_grad = True
inputs_.append(i.to(dtype=self.dtype, device=self.device))
assert gradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use self.assertTrue()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason that assertTrue is preferred?
I always think that the message returned by assertTrue, AssertionError: False is not true non sense, so I always use assert.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mostly suggesting to use the self. variant of the assert instead of the default one.
The reason is that it makes the test framework aware of it in a better way than catching the error from the assert. And it allows it to give better messages when running the whole test suite.

test/torchaudio_unittest/transforms/autograd_test_impl.py Outdated Show resolved Hide resolved
test/torchaudio_unittest/transforms/autograd_test_impl.py Outdated Show resolved Hide resolved
@mthrok
Copy link
Collaborator Author

mthrok commented Mar 4, 2021

@albanD @anjali411

What does this mean? How much should we be worried?

RuntimeError: Backward is not reentrant, i.e., running backward with same input and grad_output multiple times gives different values, although analytical gradient matches numerical gradient. The tolerance for nondeterminism was 0.0.

https://app.circleci.com/pipelines/github/pytorch/audio/5301/workflows/782e2d49-6f99-4e98-a240-b6597557d855/jobs/176159

@albanD
Copy link

albanD commented Mar 4, 2021

That means that your backward function is not deterministic.
If you expect it to be deterministic, then you might want to double check.
If this is expected, then you can increase the threshold by passing nondet_tol: float to gradcheck with the difference that you expect to see there (this defaults to 0.0). And you should document that the backward is not deterministic (at least that's what we do in core here https://pytorch.org/docs/stable/generated/torch.set_deterministic.html#torch.set_deterministic)

@mthrok
Copy link
Collaborator Author

mthrok commented Mar 4, 2021

That means that your backward function is not deterministic.
If you expect it to be deterministic, then you might want to double check.
If this is expected, then you can increase the threshold by passing nondet_tol: float to gradcheck with the difference that you expect to see there (this defaults to 0.0). And you should document that the backward is not deterministic (at least that's what we do in core here https://pytorch.org/docs/stable/generated/torch.set_deterministic.html#torch.set_deterministic)

Thanks for the info. This is an interesting finding. Let me dig into that.

@mthrok
Copy link
Collaborator Author

mthrok commented Mar 17, 2021

@albanD Following the finding in pytorch/pytorch#54093, I added nondet_tol=1e-10. Do you think this is tight enough?

@albanD
Copy link

albanD commented Mar 18, 2021

Interesting investigation!
In that case yes 1e-10 should be good enough!

self.assert_grad(transform, [waveform], nondet_tol=1e-10)

def test_melspectrogram(self):
# replication_pad1d_backward_cuda is not deteministic and

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - is nondeterministic

Copy link

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mthrok mthrok merged commit e4a0bd2 into pytorch:master Mar 31, 2021
@mthrok mthrok deleted the autograd-test-spectrogram branch March 31, 2021 01:10
@mthrok
Copy link
Collaborator Author

mthrok commented Mar 31, 2021

Thanks!

@mthrok mthrok mentioned this pull request Apr 2, 2021
15 tasks
@mthrok mthrok modified the milestones: Complex Tensor Migration, v0.9 Apr 5, 2021
mthrok pushed a commit to mthrok/audio that referenced this pull request Dec 13, 2022
Co-authored-by: Brian Johnson <brianjo@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants