Skip to content

Commit

Permalink
Add reduction string test for ctc_loss.
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
gchanan committed Aug 31, 2020
1 parent 5937951 commit 432a9da
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/test_nn.py
Expand Up @@ -9769,6 +9769,12 @@ def v(fn):
v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction))
v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction))

log_probs = torch.randn(50, 16, 20, requires_grad=True, device=device).log_softmax(2)
targets = torch.randint(1, 20, (16, 30), dtype=torch.long, device=device)
input_lengths = torch.full((16,), 50, dtype=torch.long, device=device)
target_lengths = torch.randint(10, 30, (16,), dtype=torch.long, device=device)
v(lambda: F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction=reduction))

# FIXME: should we allow derivatives on these?
v(lambda: F.binary_cross_entropy(torch.sigmoid(input), input.detach(), reduction=reduction))
v(lambda: F.soft_margin_loss(input, input.sign().detach(), reduction=reduction))
Expand Down

0 comments on commit 432a9da

Please sign in to comment.