New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reparameterization support to OneHotCategorical
#46610
Conversation
💊 CI failures summary and remediationsAs of commit 5e9278b (more details on the Dr. CI page):
🕵️ 3 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_xla_linux_bionic_py3_6_clang9_test (1/3)Step: "Run tests" (full log | diagnosis details | 🔁 rerun)
|
@neerajprad, @fritzo, could you take a look at this please? |
Hmm the idea is sound, but I believe this may break a lot of software if we make |
@fritzo Personally I think that depends on if there are more than one way to do reparameterization for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is approximate and since there are multiple ways to approximate the gradients [1], I believe it would be best to create a subclass
class OneHotCategoricalStraightThrough(OneHotCategorical):
has_rsample = True
def rsample(self, sample_shape=torch.size()):
samples = self.sample(sample_shape)
probs = self._categorical.probs # note this is cached via @lazy_property
return samples + (probs - probs.detach())
@martinjankowiak @eb8680 @karalets does this seem reasonable to you?
[1] Bengio et al (2013) https://arxiv.org/abs/1308.3432
Another estimator of the expected gradient through stochastic neurons was proposed by Hinton (2012) in his lecture 15b. The idea is simply to back-propagate through the hard threshold function (1 if the argument is positive, 0 otherwise) as if it had been the identity function. It is clearly a biased estimator, but when considering a single layer of neurons, it has the right sign (this is not guaranteed anymore when back-propagating through more hidden layers). We call it the straight-through (ST) estimator. A possible variant investigated here multiplies the gradient on hi by the derivative of the sigmoid. Better results were actually obtained without multiplying by the derivative of the sigmoid.
yes, i think subclassing is definitely the way to go here 👍 |
Hi @lqf96! Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
7084968
to
1bf4edf
Compare
6ec518e
to
5e9278b
Compare
@fritzo I applied your suggestions and I think this is ready for review again. There are some spurious test failures and hopefully they're not related to this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Test failures appear unrelated. Thanks for your patience.
@pytorchbot merge this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Add reparameterization support to the `OneHotCategorical` distribution. Samples are reparameterized based on the straight-through gradient estimator, which is proposed in the paper [Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation](https://arxiv.org/abs/1308.3432). Pull Request resolved: pytorch#46610 Reviewed By: neerajprad Differential Revision: D25272883 Pulled By: ezyang fbshipit-source-id: 8364408fe108a29620694caeac377a06f0dcdd84
Add reparameterization support to the
OneHotCategorical
distribution. Samples are reparameterized based on the straight-through gradient estimator, which is proposed in the paper Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation.