diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 68f5922753a3..fabb53ed0094 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -47,8 +47,9 @@ Independent, Laplace, LogisticNormal, LogNormal, LowRankMultivariateNormal, MixtureSameFamily, Multinomial, MultivariateNormal, - NegativeBinomial, Normal, OneHotCategorical, Pareto, - Poisson, RelaxedBernoulli, RelaxedOneHotCategorical, + NegativeBinomial, Normal, + OneHotCategorical, OneHotCategoricalStraightThrough, + Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical, StudentT, TransformedDistribution, Uniform, VonMises, Weibull, constraints, kl_divergence) from torch.distributions.constraint_registry import transform_to @@ -335,6 +336,11 @@ def is_all_nan(tensor): {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, ]), + Example(OneHotCategoricalStraightThrough, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)}, + {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), Example(Pareto, [ { 'scale': 1.0, @@ -604,6 +610,10 @@ def is_all_nan(tensor): {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)}, {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, ]), + Example(OneHotCategoricalStraightThrough, [ + {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)}, + {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)}, + ]), Example(Pareto, [ { 'scale': 0.0, @@ -3755,13 +3765,21 @@ def test_entropy_exponential_family(self): class TestConstraints(TestCase): def test_params_constraints(self): + normalize_probs_dists = ( + Categorical, + Multinomial, + OneHotCategorical, + OneHotCategoricalStraightThrough, + RelaxedOneHotCategorical + ) + for Dist, params in EXAMPLES: for i, param in enumerate(params): dist = Dist(**param) for name, value in param.items(): if isinstance(value, numbers.Number): value = torch.tensor([value]) - if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs': + if Dist in normalize_probs_dists and name == 'probs': # These distributions accept positive probs, but elsewhere we # use a stricter constraint to the simplex. value = value / value.sum(-1, True) diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py index ffcf75695d2f..adbfa89d82e3 100644 --- a/torch/distributions/__init__.py +++ b/torch/distributions/__init__.py @@ -100,7 +100,7 @@ from .multivariate_normal import MultivariateNormal from .negative_binomial import NegativeBinomial from .normal import Normal -from .one_hot_categorical import OneHotCategorical +from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough from .pareto import Pareto from .poisson import Poisson from .relaxed_bernoulli import RelaxedBernoulli @@ -142,6 +142,7 @@ 'NegativeBinomial', 'Normal', 'OneHotCategorical', + 'OneHotCategoricalStraightThrough', 'Pareto', 'RelaxedBernoulli', 'RelaxedOneHotCategorical', diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index bd23f2344df5..c661a245f716 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -96,3 +96,18 @@ def enumerate_support(self, expand=True): if expand: values = values.expand((n,) + self.batch_shape + (n,)) return values + +class OneHotCategoricalStraightThrough(OneHotCategorical): + r""" + Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight- + through gradient estimator from [1]. + + [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation + (Bengio et al, 2013) + """ + has_rsample = True + + def rsample(self, sample_shape=torch.Size()): + samples = self.sample(sample_shape) + probs = self._categorical.probs # cached via @lazy_property + return samples + (probs - probs.detach())