Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.nn.functional.binary_cross_entropy and torch.nn.functional.binary_cross_entropy_with_logits documentation wrong on description about target #99151

Closed
cheyennee opened this issue Apr 14, 2023 · 4 comments
Labels
module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@cheyennee
Copy link

cheyennee commented Apr 14, 2023

πŸ“š The doc issue

The binary_cross_entropy documentation shows that target – Tensor of the same shape as input with values between 0 and 1. However, the value of target does not necessarily have to be between 0-1, but the value of input must be between 0-1.

import torch
arg_1_tensor = torch.randint(2, 3, [5, 5], dtype=torch.float32)
arg_2_tensor = torch.randint(2, 3, [5, 5], dtype=torch.float32)
arg_3 = None
arg_4 = "mean"
res = torch.nn.functional.binary_cross_entropy(input=arg_1_tensor,target=arg_2_tensor,weight=arg_3,reduction=arg_4,)
print(res)
# res: RuntimeError: all elements of input should be between 0 and 1

The binary_cross_entropy_with_logits documentation shows that target – Tensor of the same shape as input with values between 0 and 1. However, the value of target does not necessarily have to be between 0-1.

import torch
arg_1_tensor = torch.randint(2, 4, [10, 64], dtype=torch.float32)
arg_2_tensor = torch.randint(2, 4, [10, 64], dtype=torch.float32)
arg_3 = None
arg_4_tensor = torch.rand([64], dtype=torch.float32)
arg_4 = arg_4_tensor.clone()
arg_5 = "mean"
res = torch.nn.functional.binary_cross_entropy_with_logits(input=arg_1_tensor,target=arg_2_tensor,weight=arg_3,pos_weight=arg_4,reduction=arg_5,)
print(res)
# res: tensor(-3.8518)

Suggest a potential alternative/fix

No response

cc @svekars @carljparker @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@MylesJP
Copy link

MylesJP commented Apr 14, 2023

Mind if I take a shot at this one for a school project?

@cpuhrsch cpuhrsch added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 14, 2023
@jbschlosser
Copy link
Contributor

IIRC we've tried to introduce a [0, 1] restriction for target in the past but ran into BC issues from people using values outside the range. I'm not sure we want to encourage this by changing the docs, however.

@jbschlosser jbschlosser added module: docs Related to our documentation, both in docs/ and docblocks needs research We need to decide whether or not this merits inclusion, based on research world labels Apr 14, 2023
@jbschlosser
Copy link
Contributor

I take it back; it looks like this was addressed and actually landed in #97814. Closing as addressed but feel free to reopen if there's more I missed.

@cheyennee
Copy link
Author

Mind if I take a shot at this one for a school project?

Sure!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

4 participants