Skip to content

Commit

Permalink
Fix typing errors in the torch.distributions module (#45689)
Browse files Browse the repository at this point in the history
Summary:
Fixes #42979.

Pull Request resolved: #45689

Reviewed By: agolynski

Differential Revision: D24229870

Pulled By: xuzhao9

fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 12, 2020
1 parent 6a001de commit 146721f
Show file tree
Hide file tree
Showing 13 changed files with 42 additions and 27 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Expand Up @@ -81,9 +81,6 @@ ignore_errors = True
[mypy-torch.quantization.fx.*]
ignore_errors = True

[mypy-torch.distributions.*]
ignore_errors = True

[mypy-torch._tensor_str]
ignore_errors = True

Expand Down
1 change: 1 addition & 0 deletions torch/distributions/__init__.py
Expand Up @@ -111,6 +111,7 @@
from .uniform import Uniform
from .von_mises import VonMises
from .weibull import Weibull
from . import transforms

__all__ = [
'Bernoulli',
Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/beta.py
@@ -1,4 +1,4 @@
from numbers import Number
from numbers import Real, Number

import torch
from torch.distributions import constraints
Expand Down Expand Up @@ -28,7 +28,7 @@ class Beta(ExponentialFamily):
has_rsample = True

def __init__(self, concentration1, concentration0, validate_args=None):
if isinstance(concentration1, Number) and isinstance(concentration0, Number):
if isinstance(concentration1, Real) and isinstance(concentration0, Real):
concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
Expand Down
9 changes: 4 additions & 5 deletions torch/distributions/distribution.py
Expand Up @@ -2,6 +2,7 @@
import warnings
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from typing import Dict, Optional, Any


class Distribution(object):
Expand All @@ -12,8 +13,6 @@ class Distribution(object):
has_rsample = False
has_enumerate_support = False
_validate_args = False
support = None
arg_constraints = {}

@staticmethod
def set_default_validate_args(value):
Expand Down Expand Up @@ -72,7 +71,7 @@ def event_shape(self):
return self._event_shape

@property
def arg_constraints(self):
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
"""
Returns a dictionary from argument names to
:class:`~torch.distributions.constraints.Constraint` objects that
Expand All @@ -82,7 +81,7 @@ def arg_constraints(self):
raise NotImplementedError

@property
def support(self):
def support(self) -> Optional[Any]:
"""
Returns a :class:`~torch.distributions.constraints.Constraint` object
representing this distribution's support.
Expand Down Expand Up @@ -248,7 +247,7 @@ def _validate_sample(self, value):
if i != 1 and j != 1 and i != j:
raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
format(actual_shape, expected_shape))

assert self.support is not None
if not self.support.check(value).all():
raise ValueError('The value argument must be within the support')

Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/independent.py
Expand Up @@ -2,7 +2,7 @@
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _sum_rightmost

from typing import Dict

class Independent(Distribution):
r"""
Expand Down Expand Up @@ -31,7 +31,7 @@ class Independent(Distribution):
reinterpreted_batch_ndims (int): the number of batch dims to
reinterpret as event dims
"""
arg_constraints = {}
arg_constraints: Dict[str, constraints.Constraint] = {}

def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/kl.py
@@ -1,6 +1,7 @@
import math
import warnings
from functools import total_ordering
from typing import Type, Dict, Callable, Tuple

import torch
from torch._six import inf
Expand Down Expand Up @@ -33,7 +34,7 @@
from .utils import _sum_rightmost

_KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions.
_KL_MEMOIZE = {} # Memoized version mapping many specific (type, type) pairs to functions.
_KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {} # Memoized version mapping many specific (type, type) pairs to functions.


def register_kl(type_p, type_q):
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/mixture_same_family.py
Expand Up @@ -2,6 +2,7 @@
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical
from torch.distributions import constraints
from typing import Dict


class MixtureSameFamily(Distribution):
Expand Down Expand Up @@ -45,7 +46,7 @@ class MixtureSameFamily(Distribution):
component_distribution: `torch.distributions.Distribution`-like
instance. Right-most batch dimension indexes component.
"""
arg_constraints = {}
arg_constraints: Dict[str, constraints.Constraint] = {}
has_rsample = False

def __init__(self,
Expand Down
4 changes: 2 additions & 2 deletions torch/distributions/multinomial.py
Expand Up @@ -2,7 +2,6 @@
from torch._six import inf
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical
from numbers import Number
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

Expand Down Expand Up @@ -40,6 +39,7 @@ class Multinomial(Distribution):
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
total_count: int

@property
def mean(self):
Expand All @@ -50,7 +50,7 @@ def variance(self):
return self.total_count * self.probs * (1 - self.probs)

def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if not isinstance(total_count, Number):
if not isinstance(total_count, int):
raise NotImplementedError('inhomogeneous total_count is not supported')
self.total_count = total_count
self._categorical = Categorical(probs=probs, logits=logits)
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/normal.py
@@ -1,4 +1,5 @@
import math
from numbers import Real
from numbers import Number

import torch
Expand Down Expand Up @@ -72,7 +73,7 @@ def log_prob(self, value):
self._validate_sample(value)
# compute the variance
var = (self.scale ** 2)
log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log()
log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))

def cdf(self, value):
Expand Down
3 changes: 2 additions & 1 deletion torch/distributions/transformed_distribution.py
Expand Up @@ -3,6 +3,7 @@
from torch.distributions.distribution import Distribution
from torch.distributions.transforms import Transform
from torch.distributions.utils import _sum_rightmost
from typing import Dict


class TransformedDistribution(Distribution):
Expand Down Expand Up @@ -38,7 +39,7 @@ class TransformedDistribution(Distribution):
:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
"""
arg_constraints = {}
arg_constraints: Dict[str, constraints.Constraint] = {}

def __init__(self, base_distribution, transforms, validate_args=None):
self.base_dist = base_distribution
Expand Down
20 changes: 16 additions & 4 deletions torch/distributions/transforms.py
Expand Up @@ -9,6 +9,7 @@
lazy_property)
from torch.nn.functional import pad
from torch.nn.functional import softplus
from typing import List

__all__ = [
'AbsTransform',
Expand Down Expand Up @@ -77,6 +78,7 @@ class Transform(object):
transforms that act jointly on matrices, etc.
"""
bijective = False
codomain: constraints.Constraint
event_dim = 0

def __init__(self, cache_size=0):
Expand Down Expand Up @@ -185,40 +187,49 @@ def __init__(self, transform):

@constraints.dependent_property
def domain(self):
assert self._inv is not None
return self._inv.codomain

@constraints.dependent_property
def codomain(self):
assert self._inv is not None
return self._inv.domain

@property
def bijective(self):
assert self._inv is not None
return self._inv.bijective

@property
def sign(self):
assert self._inv is not None
return self._inv.sign

@property
def event_dim(self):
assert self._inv is not None
return self._inv.event_dim

@property
def inv(self):
return self._inv

def with_cache(self, cache_size=1):
assert self._inv is not None
return self.inv.with_cache(cache_size).inv

def __eq__(self, other):
if not isinstance(other, _InverseTransform):
return False
assert self._inv is not None
return self._inv == other._inv

def __call__(self, x):
assert self._inv is not None
return self._inv._inv_call(x)

def log_abs_det_jacobian(self, x, y):
assert self._inv is not None
return -self._inv.log_abs_det_jacobian(y, x)


Expand Down Expand Up @@ -500,8 +511,8 @@ def __eq__(self, other):

@property
def sign(self):
if isinstance(self.scale, numbers.Number):
return 1 if self.scale > 0 else -1 if self.scale < 0 else 0
if isinstance(self.scale, numbers.Real):
return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
return self.scale.sign()

def _call(self, x):
Expand All @@ -513,7 +524,7 @@ def _inverse(self, y):
def log_abs_det_jacobian(self, x, y):
shape = x.shape
scale = self.scale
if isinstance(scale, numbers.Number):
if isinstance(scale, numbers.Real):
result = torch.full_like(x, math.log(abs(scale)))
else:
result = torch.abs(scale).log()
Expand Down Expand Up @@ -575,7 +586,7 @@ def _call(self, x):
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
z = _clipped_sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
return y

def _inverse(self, y):
Expand Down Expand Up @@ -619,6 +630,7 @@ def _inverse(self, y):


class CatTransform(Transform):
tseq: List[numbers.Number]

This comment has been minimized.

Copy link
@fritzo

fritzo Mar 3, 2022

Collaborator

note this annotation added above the docstring hides the docstring. Fixing in #73747

"""
Transform functor that applies a sequence of transforms `tseq`
component-wise to each submatrix at `dim`, of length `lengths[dim]`,
Expand Down
10 changes: 6 additions & 4 deletions torch/distributions/utils.py
Expand Up @@ -2,6 +2,7 @@
from numbers import Number
import torch
import torch.nn.functional as F
from typing import Dict, Any


def broadcast_all(*values):
Expand All @@ -23,13 +24,14 @@ def broadcast_all(*values):
if not all(isinstance(v, torch.Tensor) or isinstance(v, Number) for v in values):
raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
if not all([isinstance(v, torch.Tensor) for v in values]):
options = dict(dtype=torch.get_default_dtype())
options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
for value in values:
if isinstance(value, torch.Tensor):
options = dict(dtype=value.dtype, device=value.device)
break
values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options)
for v in values]
new_values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options)
for v in values]
return torch.broadcast_tensors(*new_values)
return torch.broadcast_tensors(*values)


Expand Down Expand Up @@ -94,7 +96,7 @@ class lazy_property(object):
"""
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped)
update_wrapper(self, wrapped) # type: ignore[arg-type]

def __get__(self, instance, obj_type=None):
if instance is None:
Expand Down
2 changes: 1 addition & 1 deletion torch/functional.py
Expand Up @@ -510,7 +510,7 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
signal_dim = input.dim()
extended_shape = [1] * (3 - signal_dim) + list(input.size())
pad = int(n_fft // 2)
input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
input = input.view(input.shape[-signal_dim:])
return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore
normalized, onesided, return_complex)
Expand Down

0 comments on commit 146721f

Please sign in to comment.