Skip to content

Commit

Permalink
add slope == 0 case into standard leaky relu nn test
Browse files Browse the repository at this point in the history
ghstack-source-id: d7b20a65968c320da2d4de1e76d1578bf06d4eea
Pull Request resolved: #37559
  • Loading branch information
lixinyu committed Apr 30, 2020
1 parent f09eb39 commit 9f36f81
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
7 changes: 5 additions & 2 deletions test/test_autograd.py
Expand Up @@ -5624,7 +5624,7 @@ def test_ctc_loss_cudnn(self, device):
self.assertEqual(grad_cudnn, grad_native, atol=1e-4)

@skipCUDAIfRocm
def test_leaky_relu_inplace_with_zero_or_neg_slope(self, device):
def test_leaky_relu_inplace_with_neg_slope(self, device):
a = torch.tensor([-1., 1.], device=device, requires_grad=True)
b = torch.nn.functional.leaky_relu_(a.clone(), -2)
with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
Expand All @@ -5635,10 +5635,13 @@ def test_leaky_relu_inplace_with_zero_or_neg_slope(self, device):
with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
b.backward(torch.ones(2, device=device))

@skipCUDAIfRocm
def test_leaky_relu_inplace_with_zero_slope(self, device):
a = torch.tensor([-2., 0., 2.], device=device, requires_grad=True)
b = torch.nn.functional.leaky_relu_(a.clone(), 0.0)
b.backward(torch.ones(3, device=device))
self.assertEqual(a.grad, torch.tensor([0., 0., 1.], device=device))
expected = torch.tensor([0., 0., 1.], device=device)
self.assertEqual(a.grad, expected)

@onlyCUDA
def test_free_unneeded_tensor(self, device):
Expand Down
8 changes: 8 additions & 0 deletions torch/testing/_internal/common_nn.py
Expand Up @@ -225,6 +225,14 @@ def get_weight(m):
check_inplace=True,
desc='with_negval'
),
dict(
module_name='LeakyReLU',
constructor_args=(0.0,),
cpp_constructor_args='torch::nn::LeakyReLUOptions().negative_slope(0.0)',
input_fn=lambda: torch.randn(10, 10),
check_inplace=True,
desc='with_zero_negval'
),
dict(
module_name='LogSigmoid',
input_size=(2, 3, 4),
Expand Down

0 comments on commit 9f36f81

Please sign in to comment.