Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix docstring to clarify logits usage for multiclass case #51053

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 11 additions & 6 deletions torch/distributions/categorical.py
Expand Up @@ -16,14 +16,19 @@ class Categorical(Distribution):

Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.

If :attr:`probs` is 1-dimensional with length-`K`, each element is the relative
probability of sampling the class at that index.
If `probs` is 1-dimensional with length-`K`, each element is the relative probability
of sampling the class at that index.

If :attr:`probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of
relative probability vectors.

.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension.
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
will return this normalized value.
The `logits` argument will be interpreted as unnormalized log probabilities
and can therefore be any real number. It will likewise be normalized so that
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
will return this normalized value.

See also: :func:`torch.multinomial`

Expand All @@ -35,7 +40,7 @@ class Categorical(Distribution):

Args:
probs (Tensor): event probabilities
logits (Tensor): event log-odds
logits (Tensor): event log probabilities (unnormalized)
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/half_cauchy.py
Expand Up @@ -43,7 +43,7 @@ def scale(self):

@property
def mean(self):
return self.base_dist.mean
return torch.full(self._extended_shape(), math.inf, dtype=self.scale.dtype, device=self.scale.device)
neerajprad marked this conversation as resolved.
Show resolved Hide resolved

@property
def variance(self):
Expand Down
11 changes: 8 additions & 3 deletions torch/distributions/multinomial.py
Expand Up @@ -15,8 +15,13 @@ class Multinomial(Distribution):
Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
called (see example below)

.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1.
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
will return this normalized value.
The `logits` argument will be interpreted as unnormalized log probabilities
and can therefore be any real number. It will likewise be normalized so that
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
will return this normalized value.

- :meth:`sample` requires a single shared `total_count` for all
parameters and samples.
Expand All @@ -35,7 +40,7 @@ class Multinomial(Distribution):
Args:
total_count (int): number of trials
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
logits (Tensor): event log probabilities (unnormalized)
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
Expand Down
11 changes: 8 additions & 3 deletions torch/distributions/one_hot_categorical.py
Expand Up @@ -11,8 +11,13 @@ class OneHotCategorical(Distribution):

Samples are one-hot coded vectors of size ``probs.size(-1)``.

.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1.
.. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
and it will be normalized to sum to 1 along the last dimension. attr:`probs`
will return this normalized value.
The `logits` argument will be interpreted as unnormalized log probabilities
and can therefore be any real number. It will likewise be normalized so that
the resulting probabilities sum to 1 along the last dimension. attr:`logits`
will return this normalized value.

See also: :func:`torch.distributions.Categorical` for specifications of
:attr:`probs` and :attr:`logits`.
Expand All @@ -25,7 +30,7 @@ class OneHotCategorical(Distribution):

Args:
probs (Tensor): event probabilities
logits (Tensor): event log probabilities
logits (Tensor): event log probabilities (unnormalized)
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
Expand Down
10 changes: 5 additions & 5 deletions torch/distributions/relaxed_categorical.py
Expand Up @@ -21,7 +21,7 @@ class ExpRelaxedCategorical(Distribution):
Args:
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): the log probability of each event.
logits (Tensor): unnormalized log probability for each event

[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
(Maddison et al, 2017)
Expand All @@ -30,8 +30,8 @@ class ExpRelaxedCategorical(Distribution):
(Jang et al, 2017)
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
support = constraints.real
'logits': constraints.real_vector}
support = constraints.real_vector
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc. @fritzo.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'm not sure whether real_vector is the optimal constraint, but I believe it is the tightest valid constraint that we currently implement. I think that's fine since most users will use the wrapper class with the exact simplex constraint.

has_rsample = True

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
Expand Down Expand Up @@ -101,10 +101,10 @@ class RelaxedOneHotCategorical(TransformedDistribution):
Args:
temperature (Tensor): relaxation temperature
probs (Tensor): event probabilities
logits (Tensor): the log probability of each event.
logits (Tensor): unnormalized log probability for each event
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
'logits': constraints.real_vector}
support = constraints.simplex
has_rsample = True

Expand Down