Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reparameterization support to OneHotCategorical #46610

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 21 additions & 3 deletions test/distributions/test_distributions.py
Expand Up @@ -47,8 +47,9 @@
Independent, Laplace, LogisticNormal,
LogNormal, LowRankMultivariateNormal,
MixtureSameFamily, Multinomial, MultivariateNormal,
NegativeBinomial, Normal, OneHotCategorical, Pareto,
Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
NegativeBinomial, Normal,
OneHotCategorical, OneHotCategoricalStraightThrough,
Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
StudentT, TransformedDistribution, Uniform,
VonMises, Weibull, constraints, kl_divergence)
from torch.distributions.constraint_registry import transform_to
Expand Down Expand Up @@ -335,6 +336,11 @@ def is_all_nan(tensor):
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(OneHotCategoricalStraightThrough, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
{'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Pareto, [
{
'scale': 1.0,
Expand Down Expand Up @@ -604,6 +610,10 @@ def is_all_nan(tensor):
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(OneHotCategoricalStraightThrough, [
{'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)},
{'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
]),
Example(Pareto, [
{
'scale': 0.0,
Expand Down Expand Up @@ -3755,13 +3765,21 @@ def test_entropy_exponential_family(self):

class TestConstraints(TestCase):
def test_params_constraints(self):
normalize_probs_dists = (
Categorical,
Multinomial,
OneHotCategorical,
OneHotCategoricalStraightThrough,
RelaxedOneHotCategorical
)

for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
for name, value in param.items():
if isinstance(value, numbers.Number):
value = torch.tensor([value])
if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs':
if Dist in normalize_probs_dists and name == 'probs':
# These distributions accept positive probs, but elsewhere we
# use a stricter constraint to the simplex.
value = value / value.sum(-1, True)
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/__init__.py
Expand Up @@ -100,7 +100,7 @@
from .multivariate_normal import MultivariateNormal
from .negative_binomial import NegativeBinomial
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
from .pareto import Pareto
from .poisson import Poisson
from .relaxed_bernoulli import RelaxedBernoulli
Expand Down Expand Up @@ -142,6 +142,7 @@
'NegativeBinomial',
'Normal',
'OneHotCategorical',
'OneHotCategoricalStraightThrough',
'Pareto',
'RelaxedBernoulli',
'RelaxedOneHotCategorical',
Expand Down
15 changes: 15 additions & 0 deletions torch/distributions/one_hot_categorical.py
Expand Up @@ -96,3 +96,18 @@ def enumerate_support(self, expand=True):
if expand:
values = values.expand((n,) + self.batch_shape + (n,))
return values

class OneHotCategoricalStraightThrough(OneHotCategorical):
r"""
Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight-
through gradient estimator from [1].

[1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
(Bengio et al, 2013)
"""
has_rsample = True

def rsample(self, sample_shape=torch.Size()):
samples = self.sample(sample_shape)
probs = self._categorical.probs # cached via @lazy_property
return samples + (probs - probs.detach())