Skip to content

Commit

Permalink
Add check for 0 to 1 inclusive for elements of target tensor in BCE l…
Browse files Browse the repository at this point in the history
…oss (#97814)

TODO for @mikaylagawarecki : add BC breaking description

Fixes #87373

Pull Request resolved: #97814
Approved by: https://github.com/mikaylagawarecki
  • Loading branch information
kiersten-stokes authored and ZainRizvi committed Apr 19, 2023
1 parent be8c563 commit 6b10504
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 3 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/native/Loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target,
(input_val >= 0) && (input_val <= 1),
"all elements of input should be between 0 and 1"
);
TORCH_CHECK(
(target_val >= 0) && (target_val <= 1),
"all elements of target should be between 0 and 1"
);

// Binary cross entropy tensor is defined by the equation:
// L = -w (y ln(x) + (1-y) ln(1-x))
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/Loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Tensor& binary_cross_entropy_out_cuda(const Tensor& input, const Tensor& target,
const scalar_t neg_100 = -100;

CUDA_KERNEL_ASSERT(input_val >= zero && input_val <= one);
CUDA_KERNEL_ASSERT(target_val >= zero && target_val <= one);

scalar_t log_input_val = std::log(input_val);
scalar_t log_1_minus_input_val = std::log1p(-input_val);
Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9108,7 +9108,7 @@ def v(fn):
v(lambda: F.hinge_embedding_loss(input, input, reduction=reduction))
v(lambda: F.poisson_nll_loss(input, input, reduction=reduction))
v(lambda: F.gaussian_nll_loss(input, input, var, reduction=reduction))
v(lambda: F.binary_cross_entropy(torch.sigmoid(input), input, reduction=reduction))
v(lambda: F.binary_cross_entropy(torch.sigmoid(input), input.gt(0).double(), reduction=reduction))
v(lambda: F.binary_cross_entropy_with_logits(input, input, reduction=reduction))

zeros = torch.zeros_like(input).to(torch.int64)
Expand Down
15 changes: 15 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13369,6 +13369,21 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
"TestJit",
"test_variant_consistency_jit",
),
# https://github.com/pytorch/pytorch/issues/98431
# ROCM fails with Device-side assertion `target_val >= zero && target_val <= one' failed
# even though sample inputs for target are generated with low=0 and high=1
DecorateInfo(
unittest.skip("Skipped!"),
"TestFwdGradients",
"test_fn_fwgrad_bwgrad",
active_if=TEST_WITH_ROCM,
),
DecorateInfo(
unittest.skip("Skipped!"),
"TestBwdGradients",
"test_fn_grad",
active_if=TEST_WITH_ROCM,
)
),
skips=(
# RuntimeError: expected int at position 0, but got: Tensor
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def bceloss_weights_no_reduce_test():


def bceloss_weights_no_reduce_scalar_test():
t = torch.randn(()).double()
t = torch.randn(()).gt(0).double()
weights = torch.rand(())
return dict(
fullname='BCELoss_weights_no_reduce_scalar',
Expand Down Expand Up @@ -3930,7 +3930,7 @@ def flatten(xs):
# Check that classification criterion work with no batch dimensions
# List of tuples of (name, input_fn, target_fn)
classification_criterion_no_batch = [
('BCELoss', lambda: torch.sigmoid(torch.randn(9)), lambda: torch.randn(9)),
('BCELoss', lambda: torch.sigmoid(torch.randn(9)), lambda: torch.randn(9).gt(0).double()),
('BCEWithLogitsLoss', lambda: torch.randn(9), lambda: torch.randn(9)),
('HingeEmbeddingLoss', lambda: torch.randn(9), lambda: torch.tensor([-1, 1, 1] * 3)),
('MultiLabelMarginLoss', lambda: torch.randn(4), lambda: torch.tensor([3, 0, -1, 1])),
Expand Down

0 comments on commit 6b10504

Please sign in to comment.