Skip to content

Commit

Permalink
Add reparameterization support to OneHotCategorical (#46610)
Browse files Browse the repository at this point in the history
Summary:
Add reparameterization support to the `OneHotCategorical` distribution. Samples are reparameterized based on the straight-through gradient estimator, which is proposed in the paper [Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation](https://arxiv.org/abs/1308.3432).

Pull Request resolved: #46610

Reviewed By: neerajprad

Differential Revision: D25272883

Pulled By: ezyang

fbshipit-source-id: 8364408fe108a29620694caeac377a06f0dcdd84
  • Loading branch information
lqf96 authored and facebook-github-bot committed Dec 2, 2020
1 parent de46369 commit b006c7a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
24 changes: 21 additions & 3 deletions test/distributions/test_distributions.py
Expand Up @@ -47,8 +47,9 @@
Independent, Kumaraswamy, Laplace, LogisticNormal,
LogNormal, LowRankMultivariateNormal,
MixtureSameFamily, Multinomial, MultivariateNormal,
NegativeBinomial, Normal, OneHotCategorical, Pareto,
Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
NegativeBinomial, Normal,
OneHotCategorical, OneHotCategoricalStraightThrough,
Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
StudentT, TransformedDistribution, Uniform,
VonMises, Weibull, constraints, kl_divergence)
from torch.distributions.constraint_registry import transform_to
Expand Down Expand Up @@ -345,6 +346,11 @@ def is_all_nan(tensor):
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(OneHotCategoricalStraightThrough, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Pareto, [
{
'scale': 1.0,
Expand Down Expand Up @@ -614,6 +620,10 @@ def is_all_nan(tensor):
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(OneHotCategoricalStraightThrough, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Pareto, [
{
'scale': 0.0,
Expand Down Expand Up @@ -3830,13 +3840,21 @@ def test_entropy_exponential_family(self):

class TestConstraints(TestCase):
def test_params_constraints(self):
normalize_probs_dists = (
Categorical,
Multinomial,
OneHotCategorical,
OneHotCategoricalStraightThrough,
RelaxedOneHotCategorical
)

for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
for name, value in param.items():
if isinstance(value, numbers.Number):
value = torch.tensor([value])
if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs':
if Dist in normalize_probs_dists and name == 'probs':
# These distributions accept positive probs, but elsewhere we
# use a stricter constraint to the simplex.
value = value / value.sum(-1, True)
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/__init__.py
Expand Up @@ -101,7 +101,7 @@
from .multivariate_normal import MultivariateNormal
from .negative_binomial import NegativeBinomial
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
from .pareto import Pareto
from .poisson import Poisson
from .relaxed_bernoulli import RelaxedBernoulli
Expand Down Expand Up @@ -144,6 +144,7 @@
'NegativeBinomial',
'Normal',
'OneHotCategorical',
'OneHotCategoricalStraightThrough',
'Pareto',
'RelaxedBernoulli',
'RelaxedOneHotCategorical',
Expand Down
15 changes: 15 additions & 0 deletions torch/distributions/one_hot_categorical.py
Expand Up @@ -96,3 +96,18 @@ def enumerate_support(self, expand=True):
if expand:
values = values.expand((n,) + self.batch_shape + (n,))
return values

class OneHotCategoricalStraightThrough(OneHotCategorical):
r"""
Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight-
through gradient estimator from [1].
[1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
(Bengio et al, 2013)
"""
has_rsample = True

def rsample(self, sample_shape=torch.Size()):
samples = self.sample(sample_shape)
probs = self._categorical.probs # cached via @lazy_property
return samples + (probs - probs.detach())

0 comments on commit b006c7a

Please sign in to comment.