-
Notifications
You must be signed in to change notification settings - Fork 24.6k
AdaptiveLogSoftmaxWithLoss no_batch_dim support #69054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
You can add a comment to the PR and tag @pytorchbot with the following commands: # ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun
# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow For more information, please take a look at the CI Flow Wiki. |
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit e54a712 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
# test no_batch_dim support | ||
asfm = nn.AdaptiveLogSoftmaxWithLoss(16, 20, [5, 10, 15], div_value=2.) | ||
x = torch.randn(1, 16) | ||
y = torch.tensor([17]) | ||
x2 = x.squeeze(0) | ||
y2 = y.squeeze(0) | ||
self.assertEqual(asfm(x, y).output.squeeze(0), asfm(x2, y2).output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the tests! Note that there is some precedent for using a reference function alongside the old style dictionary-based tests (see here for an example).
It won't work directly here because this module returns multiple outputs, but perhaps this can be made a bit more consistent with what's been done before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this! I left a few nitpicky comments below but it's looking pretty good :)
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll also need some tests to verify the logic on the C++ side (both functional and module forms) since it's duplicated
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking pretty good, just need a test or two on the C++ side for the C++ module changes :)
@george-qi has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Differential Revision: [D33200166](https://our.internmc.facebook.com/intern/diff/D33200166) [ghstack-poisoned]
@george-qi has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks :)
Differential Revision: [D33200166](https://our.internmc.facebook.com/intern/diff/D33200166) [ghstack-poisoned]
Differential Revision: [D33200166](https://our.internmc.facebook.com/intern/diff/D33200166) [ghstack-poisoned]
Differential Revision: [D33200166](https://our.internmc.facebook.com/intern/diff/D33200166) [ghstack-poisoned]
@george-qi has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Pull Request resolved: #69054 Test Plan: Imported from OSS Reviewed By: jbschlosser Differential Revision: D33200166 Pulled By: george-qi fbshipit-source-id: 9d953744351a25f372418d2a64e8402356d1e9b7
Stack from ghstack:
Differential Revision: D33200166