Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reparameterization support to OneHotCategorical #46610

Closed

Conversation

lqf96
Copy link
Contributor

@lqf96 lqf96 commented Oct 20, 2020

Add reparameterization support to the OneHotCategorical distribution. Samples are reparameterized based on the straight-through gradient estimator, which is proposed in the paper Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation.

@dr-ci
Copy link

dr-ci bot commented Oct 20, 2020

💊 CI failures summary and remediations

As of commit 5e9278b (more details on the Dr. CI page):


  • 3/3 failures introduced in this PR

🕵️ 3 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/3)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Nov 11 04:28:23 FAIL [0.169s]: test_all_any_vs_numpy_xla_uint8 (__main__.TestTorchDeviceTypeXLA)
Nov 11 04:28:23     return DeviceTypeTestBase.assertEqual(self, x, y, *args, **kwargs) 
Nov 11 04:28:23   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1037, in assertEqual 
Nov 11 04:28:23     exact_dtype=exact_dtype, exact_device=exact_device) 
Nov 11 04:28:23   File "/var/lib/jenkins/workspace/xla/test/pytorch_test_base.py", line 553, in assertEqual 
Nov 11 04:28:23     return DeviceTypeTestBase.assertEqual(self, x, y, *args, **kwargs) 
Nov 11 04:28:23   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1151, in assertEqual 
Nov 11 04:28:23     super().assertEqual(x, y, msg=msg) 
Nov 11 04:28:23 AssertionError: True != 46 
Nov 11 04:28:23  
Nov 11 04:28:23 ====================================================================== 
Nov 11 04:28:23 FAIL [0.169s]: test_all_any_vs_numpy_xla_uint8 (__main__.TestTorchDeviceTypeXLA) 
Nov 11 04:28:23 ---------------------------------------------------------------------- 
Nov 11 04:28:23 Traceback (most recent call last): 
Nov 11 04:28:23   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 272, in instantiated_test 
Nov 11 04:28:23     result = test_fn(self, *args) 
Nov 11 04:28:23   File "/var/lib/jenkins/workspace/xla/test/../../test/test_torch.py", line 19631, in test_all_any_vs_numpy 
Nov 11 04:28:23     _test_all_any(x) 
Nov 11 04:28:23   File "/var/lib/jenkins/workspace/xla/test/../../test/test_torch.py", line 19616, in _test_all_any 
Nov 11 04:28:23     self.compare_with_numpy(torch.all, np.all, x) 
Nov 11 04:28:23   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 913, in compare_with_numpy 
Nov 11 04:28:23     self.assertEqual(np_result, torch_result, **kwargs) 

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (2/3)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Nov 11 03:29:12 sccache: error: couldn't connect to server
Nov 11 03:29:12 ++++ trap -p EXIT 
Nov 11 03:29:12 +++ eval 'extract_trap_cmd ' 
Nov 11 03:29:12 ++++ extract_trap_cmd 
Nov 11 03:29:12 ++++ printf '%s\n' '' 
Nov 11 03:29:12 +++ printf '%s\n' cleanup 
Nov 11 03:29:12 ++ trap -- ' 
Nov 11 03:29:12 cleanup' EXIT 
Nov 11 03:29:12 ++ [[ pytorch-linux-xenial-py3.6-gcc5.4-test != *pytorch-win-* ]] 
Nov 11 03:29:12 ++ which sccache 
Nov 11 03:29:12 ++ sccache --stop-server 
Nov 11 03:29:12 sccache: error: couldn't connect to server 
Nov 11 03:29:12 sccache: caused by: Connection refused (os error 111) 
Nov 11 03:29:12 Stopping sccache server... 
Nov 11 03:29:12 ++ true 
Nov 11 03:29:12 ++ rm /var/lib/jenkins/sccache_error.log 
Nov 11 03:29:12 ++ [[ pytorch-linux-xenial-py3.6-gcc5.4-test == *rocm* ]] 
Nov 11 03:29:12 ++ SCCACHE_ERROR_LOG=/var/lib/jenkins/sccache_error.log 
Nov 11 03:29:12 ++ SCCACHE_IDLE_TIMEOUT=1200 
Nov 11 03:29:12 ++ RUST_LOG=sccache::server=error 
Nov 11 03:29:12 ++ sccache --start-server 
Nov 11 03:29:12 sccache: Starting the server... 

See CircleCI build pytorch_linux_backward_compatibility_check_test (3/3)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Nov 11 03:29:46 The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not.
Nov 11 03:29:46 processing existing schema:  __getstate__(__torch__.torch.classes.xnnpack.LinearOpContext _0) -> ((Tensor, Tensor?, Scalar?, Scalar?) _0) 
Nov 11 03:29:46 processing existing schema:  __setstate__(__torch__.torch.classes.xnnpack.LinearOpContext _0, (Tensor, Tensor?, Scalar?, Scalar?) _1) -> (None _0) 
Nov 11 03:29:46 processing existing schema:  __getstate__(__torch__.torch.classes.xnnpack.Conv2dOpContext _0) -> ((Tensor, Tensor?, int[], int[], int[], int, Scalar?, Scalar?) _0) 
Nov 11 03:29:46 processing existing schema:  __setstate__(__torch__.torch.classes.xnnpack.Conv2dOpContext _0, (Tensor, Tensor?, int[], int[], int[], int, Scalar?, Scalar?) _1) -> (None _0) 
Nov 11 03:29:46 processing existing schema:  __getstate__(__torch__.torch.classes.xnnpack.TransposeConv2dOpContext _0) -> ((Tensor, Tensor?, int[], int[], int[], int[], int, Scalar?, Scalar?) _0) 
Nov 11 03:29:46 processing existing schema:  __setstate__(__torch__.torch.classes.xnnpack.TransposeConv2dOpContext _0, (Tensor, Tensor?, int[], int[], int[], int[], int, Scalar?, Scalar?) _1) -> (None _0) 
Nov 11 03:29:46 processing existing schema:  __init__(__torch__.torch.classes._nnapi.Compilation _0) -> (None _0) 
Nov 11 03:29:46 processing existing schema:  init(__torch__.torch.classes._nnapi.Compilation _0, Tensor _1, Tensor[] _2) -> (None _0) 
Nov 11 03:29:46 processing existing schema:  run(__torch__.torch.classes._nnapi.Compilation _0, Tensor[] _1, Tensor[] _2) -> (None _0) 
Nov 11 03:29:46 processing existing schema:  __init__(__torch__.torch.classes.dist_rpc.WorkerInfo _0, str _1, int _2) -> (None _0) 
Nov 11 03:29:46 The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not.  
Nov 11 03:29:46  
Nov 11 03:29:46 Broken ops: [ 
Nov 11 03:29:46 	aten::_foreach_log(Tensor[] tensors) -> (Tensor[]) 
Nov 11 03:29:46 	aten::_foreach_round(Tensor[] tensors) -> (Tensor[]) 
Nov 11 03:29:46 	aten::_foreach_sinh(Tensor[] tensors) -> (Tensor[]) 
Nov 11 03:29:46 	aten::_foreach_lgamma_(Tensor[] self) -> () 
Nov 11 03:29:46 	aten::_foreach_lgamma(Tensor[] tensors) -> (Tensor[]) 
Nov 11 03:29:46 	aten::_foreach_log10(Tensor[] tensors) -> (Tensor[]) 
Nov 11 03:29:46 	aten::_foreach_round_(Tensor[] self) -> () 
Nov 11 03:29:46 	aten::_foreach_sin(Tensor[] tensors) -> (Tensor[]) 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 42 times.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 21, 2020
@zou3519
Copy link
Contributor

zou3519 commented Oct 21, 2020

@neerajprad, @fritzo, could you take a look at this please?

@fritzo
Copy link
Collaborator

fritzo commented Oct 21, 2020

Hmm the idea is sound, but I believe this may break a lot of software if we make OneHotCategorical reparametrized by default. I think it would be safe to either (1) add a has_rsample kwarg to the constructor and default to False, or (2) create a separate distribution say OneHotCagetoricalStraightThrough that inherits from OneHotCategorical. I believe @karalets has done something similar in Pyro . Does anyone else have an interface opinion?

@lqf96
Copy link
Contributor Author

lqf96 commented Oct 23, 2020

@fritzo Personally I think that depends on if there are more than one way to do reparameterization for OneHotCateogorical. If there is only one strightforward way to do it (i.e. straight-through gradient) then it probably makes sense to simply add an enable_rsample parameter. If there are more than one common way to do it, we should go with the OneHotCagetoricalStraightThrough API instead.

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is approximate and since there are multiple ways to approximate the gradients [1], I believe it would be best to create a subclass

class OneHotCategoricalStraightThrough(OneHotCategorical):
    has_rsample = True
    def rsample(self, sample_shape=torch.size()):
        samples = self.sample(sample_shape)
        probs = self._categorical.probs  # note this is cached via @lazy_property
        return samples + (probs - probs.detach())

@martinjankowiak @eb8680 @karalets does this seem reasonable to you?

[1] Bengio et al (2013) https://arxiv.org/abs/1308.3432

Another estimator of the expected gradient through stochastic neurons was proposed by Hinton (2012) in his lecture 15b. The idea is simply to back-propagate through the hard threshold function (1 if the argument is positive, 0 otherwise) as if it had been the identity function. It is clearly a biased estimator, but when considering a single layer of neurons, it has the right sign (this is not guaranteed anymore when back-propagating through more hidden layers). We call it the straight-through (ST) estimator. A possible variant investigated here multiplies the gradient on hi by the derivative of the sigmoid. Better results were actually obtained without multiplying by the derivative of the sigmoid.

torch/distributions/one_hot_categorical.py Outdated Show resolved Hide resolved
@fritzo fritzo added the module: distributions Related to torch.distributions label Oct 27, 2020
@martinjankowiak
Copy link

yes, i think subclassing is definitely the way to go here 👍

@facebook-github-bot
Copy link
Contributor

Hi @lqf96!

Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@lqf96 lqf96 force-pushed the one-hot-categorical-reparameterization branch from 6ec518e to 5e9278b Compare November 11, 2020 02:54
@lqf96
Copy link
Contributor Author

lqf96 commented Nov 11, 2020

@fritzo I applied your suggestions and I think this is ready for review again. There are some spurious test failures and hopefully they're not related to this PR.

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Test failures appear unrelated. Thanks for your patience.

@fritzo
Copy link
Collaborator

fritzo commented Nov 30, 2020

@pytorchbot merge this please

@pytorchbot pytorchbot added the merge-this-please Was marked for merge with @pytorchbot merge this please label Nov 30, 2020
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in b006c7a.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in b006c7a.

@lqf96 lqf96 deleted the one-hot-categorical-reparameterization branch December 3, 2020 01:48
shaibagon pushed a commit to shaibagon/pytorch that referenced this pull request Dec 3, 2020
Summary:
Add reparameterization support to the `OneHotCategorical` distribution. Samples are reparameterized based on the straight-through gradient estimator, which is proposed in the paper [Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation](https://arxiv.org/abs/1308.3432).

Pull Request resolved: pytorch#46610

Reviewed By: neerajprad

Differential Revision: D25272883

Pulled By: ezyang

fbshipit-source-id: 8364408fe108a29620694caeac377a06f0dcdd84
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed merge-this-please Was marked for merge with @pytorchbot merge this please Merged module: distributions Related to torch.distributions open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants