Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce both input tensor shapes of CosineEmbeddingLoss to be equal. #112782

Closed

Conversation

tringwald
Copy link
Collaborator

…Added a test to prevent regressions.

Fixes #112732.

@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Nov 2, 2023
Copy link

pytorch-bot bot commented Nov 2, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112782

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit e4de008 with merge base 68dead4 (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code did allow broadcasting. Perhaps we want to allow broadcasting as well?

@tringwald
Copy link
Collaborator Author

The previous code did allow broadcasting. Perhaps we want to allow broadcasting as well?

I've changed the checking logic a bit. Looks like we just need to make sure the second tensor has the same number of dimensions as the first one. Things like this never worked anyway:

>>> torch.__version__
'2.1.0'
>>> torch.nn.functional.cosine_embedding_loss(torch.ones(2, 10), torch.ones(10), target=torch.ones(1))
[...]
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
>>> torch.nn.functional.cosine_embedding_loss(torch.ones(10), torch.ones(2, 10), target=torch.ones(1))
RuntimeError: 1D target tensor expects 2D input tensors, but found inputs with sizes [10] and [2, 10].

Should we mention broadcasting semantics in the docs too? Right now it says Input2: (N,D) or (D), same shape as Input1.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Updating the docs to shortly mention broadcasting would also be good.

@tringwald
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: nn release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Segmentation fault in torch.nn.functional.cosine_embedding_loss with empty tensors and mixed dtypes
5 participants