From 432a9da3cfd5c0ea869eafeced392888ca693cda Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Mon, 31 Aug 2020 11:10:28 -0700 Subject: [PATCH] Add reduction string test for ctc_loss. [ghstack-poisoned] --- test/test_nn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index 2a350b19a3ab..ea09b6f0ca83 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -9769,6 +9769,12 @@ def v(fn): v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction)) v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction)) + log_probs = torch.randn(50, 16, 20, requires_grad=True, device=device).log_softmax(2) + targets = torch.randint(1, 20, (16, 30), dtype=torch.long, device=device) + input_lengths = torch.full((16,), 50, dtype=torch.long, device=device) + target_lengths = torch.randint(10, 30, (16,), dtype=torch.long, device=device) + v(lambda: F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction=reduction)) + # FIXME: should we allow derivatives on these? v(lambda: F.binary_cross_entropy(torch.sigmoid(input), input.detach(), reduction=reduction)) v(lambda: F.soft_margin_loss(input, input.sign().detach(), reduction=reduction))