From 146721f1dfe29786f5d25a9e2a2c2227dfa85f76 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 12 Oct 2020 10:23:42 -0700 Subject: [PATCH] Fix typing errors in the torch.distributions module (#45689) Summary: Fixes https://github.com/pytorch/pytorch/issues/42979. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45689 Reviewed By: agolynski Differential Revision: D24229870 Pulled By: xuzhao9 fbshipit-source-id: 5fc87cc428170139962ab65b71cacba494d46130 --- mypy.ini | 3 --- torch/distributions/__init__.py | 1 + torch/distributions/beta.py | 4 ++-- torch/distributions/distribution.py | 9 ++++----- torch/distributions/independent.py | 4 ++-- torch/distributions/kl.py | 3 ++- torch/distributions/mixture_same_family.py | 3 ++- torch/distributions/multinomial.py | 4 ++-- torch/distributions/normal.py | 3 ++- .../distributions/transformed_distribution.py | 3 ++- torch/distributions/transforms.py | 20 +++++++++++++++---- torch/distributions/utils.py | 10 ++++++---- torch/functional.py | 2 +- 13 files changed, 42 insertions(+), 27 deletions(-) diff --git a/mypy.ini b/mypy.ini index af39fd619732..33f876fae4f7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -81,9 +81,6 @@ ignore_errors = True [mypy-torch.quantization.fx.*] ignore_errors = True -[mypy-torch.distributions.*] -ignore_errors = True - [mypy-torch._tensor_str] ignore_errors = True diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py index 4d7a4bff96af..ffcf75695d2f 100644 --- a/torch/distributions/__init__.py +++ b/torch/distributions/__init__.py @@ -111,6 +111,7 @@ from .uniform import Uniform from .von_mises import VonMises from .weibull import Weibull +from . import transforms __all__ = [ 'Bernoulli', diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index 7c017c133b32..76cb6ae7029a 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -1,4 +1,4 @@ -from numbers import Number +from numbers import Real, Number import torch from torch.distributions import constraints @@ -28,7 +28,7 @@ class Beta(ExponentialFamily): has_rsample = True def __init__(self, concentration1, concentration0, validate_args=None): - if isinstance(concentration1, Number) and isinstance(concentration0, Number): + if isinstance(concentration1, Real) and isinstance(concentration0, Real): concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)]) else: concentration1, concentration0 = broadcast_all(concentration1, concentration0) diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index d1e3a3924712..f16eb154e2dd 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -2,6 +2,7 @@ import warnings from torch.distributions import constraints from torch.distributions.utils import lazy_property +from typing import Dict, Optional, Any class Distribution(object): @@ -12,8 +13,6 @@ class Distribution(object): has_rsample = False has_enumerate_support = False _validate_args = False - support = None - arg_constraints = {} @staticmethod def set_default_validate_args(value): @@ -72,7 +71,7 @@ def event_shape(self): return self._event_shape @property - def arg_constraints(self): + def arg_constraints(self) -> Dict[str, constraints.Constraint]: """ Returns a dictionary from argument names to :class:`~torch.distributions.constraints.Constraint` objects that @@ -82,7 +81,7 @@ def arg_constraints(self): raise NotImplementedError @property - def support(self): + def support(self) -> Optional[Any]: """ Returns a :class:`~torch.distributions.constraints.Constraint` object representing this distribution's support. @@ -248,7 +247,7 @@ def _validate_sample(self, value): if i != 1 and j != 1 and i != j: raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'. format(actual_shape, expected_shape)) - + assert self.support is not None if not self.support.check(value).all(): raise ValueError('The value argument must be within the support') diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index cbec92dfd9c6..de34bb604774 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -2,7 +2,7 @@ from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import _sum_rightmost - +from typing import Dict class Independent(Distribution): r""" @@ -31,7 +31,7 @@ class Independent(Distribution): reinterpreted_batch_ndims (int): the number of batch dims to reinterpret as event dims """ - arg_constraints = {} + arg_constraints: Dict[str, constraints.Constraint] = {} def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None): if reinterpreted_batch_ndims > len(base_distribution.batch_shape): diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index a30054d370cc..c7e079b1f57a 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -1,6 +1,7 @@ import math import warnings from functools import total_ordering +from typing import Type, Dict, Callable, Tuple import torch from torch._six import inf @@ -33,7 +34,7 @@ from .utils import _sum_rightmost _KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions. -_KL_MEMOIZE = {} # Memoized version mapping many specific (type, type) pairs to functions. +_KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {} # Memoized version mapping many specific (type, type) pairs to functions. def register_kl(type_p, type_q): diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index c0f778c5ed53..716bfbd8c7a3 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -2,6 +2,7 @@ from torch.distributions.distribution import Distribution from torch.distributions import Categorical from torch.distributions import constraints +from typing import Dict class MixtureSameFamily(Distribution): @@ -45,7 +46,7 @@ class MixtureSameFamily(Distribution): component_distribution: `torch.distributions.Distribution`-like instance. Right-most batch dimension indexes component. """ - arg_constraints = {} + arg_constraints: Dict[str, constraints.Constraint] = {} has_rsample = False def __init__(self, diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 6d61578237cd..9162dd4713d4 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -2,7 +2,6 @@ from torch._six import inf from torch.distributions.distribution import Distribution from torch.distributions import Categorical -from numbers import Number from torch.distributions import constraints from torch.distributions.utils import broadcast_all @@ -40,6 +39,7 @@ class Multinomial(Distribution): """ arg_constraints = {'probs': constraints.simplex, 'logits': constraints.real} + total_count: int @property def mean(self): @@ -50,7 +50,7 @@ def variance(self): return self.total_count * self.probs * (1 - self.probs) def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): - if not isinstance(total_count, Number): + if not isinstance(total_count, int): raise NotImplementedError('inhomogeneous total_count is not supported') self.total_count = total_count self._categorical = Categorical(probs=probs, logits=logits) diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index a125806108e8..2468e2f225dc 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -1,4 +1,5 @@ import math +from numbers import Real from numbers import Number import torch @@ -72,7 +73,7 @@ def log_prob(self, value): self._validate_sample(value) # compute the variance var = (self.scale ** 2) - log_scale = math.log(self.scale) if isinstance(self.scale, Number) else self.scale.log() + log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) def cdf(self, value): diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 46c4fbccb43f..d6bb4de75c6b 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -3,6 +3,7 @@ from torch.distributions.distribution import Distribution from torch.distributions.transforms import Transform from torch.distributions.utils import _sum_rightmost +from typing import Dict class TransformedDistribution(Distribution): @@ -38,7 +39,7 @@ class TransformedDistribution(Distribution): :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` """ - arg_constraints = {} + arg_constraints: Dict[str, constraints.Constraint] = {} def __init__(self, base_distribution, transforms, validate_args=None): self.base_dist = base_distribution diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 09e00d55e8d9..f4de4b15b0bb 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -9,6 +9,7 @@ lazy_property) from torch.nn.functional import pad from torch.nn.functional import softplus +from typing import List __all__ = [ 'AbsTransform', @@ -77,6 +78,7 @@ class Transform(object): transforms that act jointly on matrices, etc. """ bijective = False + codomain: constraints.Constraint event_dim = 0 def __init__(self, cache_size=0): @@ -185,22 +187,27 @@ def __init__(self, transform): @constraints.dependent_property def domain(self): + assert self._inv is not None return self._inv.codomain @constraints.dependent_property def codomain(self): + assert self._inv is not None return self._inv.domain @property def bijective(self): + assert self._inv is not None return self._inv.bijective @property def sign(self): + assert self._inv is not None return self._inv.sign @property def event_dim(self): + assert self._inv is not None return self._inv.event_dim @property @@ -208,17 +215,21 @@ def inv(self): return self._inv def with_cache(self, cache_size=1): + assert self._inv is not None return self.inv.with_cache(cache_size).inv def __eq__(self, other): if not isinstance(other, _InverseTransform): return False + assert self._inv is not None return self._inv == other._inv def __call__(self, x): + assert self._inv is not None return self._inv._inv_call(x) def log_abs_det_jacobian(self, x, y): + assert self._inv is not None return -self._inv.log_abs_det_jacobian(y, x) @@ -500,8 +511,8 @@ def __eq__(self, other): @property def sign(self): - if isinstance(self.scale, numbers.Number): - return 1 if self.scale > 0 else -1 if self.scale < 0 else 0 + if isinstance(self.scale, numbers.Real): + return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 return self.scale.sign() def _call(self, x): @@ -513,7 +524,7 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y): shape = x.shape scale = self.scale - if isinstance(scale, numbers.Number): + if isinstance(scale, numbers.Real): result = torch.full_like(x, math.log(abs(scale))) else: result = torch.abs(scale).log() @@ -575,7 +586,7 @@ def _call(self, x): offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) z = _clipped_sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) - y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1) + y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1) return y def _inverse(self, y): @@ -619,6 +630,7 @@ def _inverse(self, y): class CatTransform(Transform): + tseq: List[numbers.Number] """ Transform functor that applies a sequence of transforms `tseq` component-wise to each submatrix at `dim`, of length `lengths[dim]`, diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 65636ab3f30a..0fd623086562 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -2,6 +2,7 @@ from numbers import Number import torch import torch.nn.functional as F +from typing import Dict, Any def broadcast_all(*values): @@ -23,13 +24,14 @@ def broadcast_all(*values): if not all(isinstance(v, torch.Tensor) or isinstance(v, Number) for v in values): raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.') if not all([isinstance(v, torch.Tensor) for v in values]): - options = dict(dtype=torch.get_default_dtype()) + options: Dict[str, Any] = dict(dtype=torch.get_default_dtype()) for value in values: if isinstance(value, torch.Tensor): options = dict(dtype=value.dtype, device=value.device) break - values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options) - for v in values] + new_values = [v if isinstance(v, torch.Tensor) else torch.tensor(v, **options) + for v in values] + return torch.broadcast_tensors(*new_values) return torch.broadcast_tensors(*values) @@ -94,7 +96,7 @@ class lazy_property(object): """ def __init__(self, wrapped): self.wrapped = wrapped - update_wrapper(self, wrapped) + update_wrapper(self, wrapped) # type: ignore[arg-type] def __get__(self, instance, obj_type=None): if instance is None: diff --git a/torch/functional.py b/torch/functional.py index 43882ca57c0e..526e3656bd5e 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -510,7 +510,7 @@ def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None, signal_dim = input.dim() extended_shape = [1] * (3 - signal_dim) + list(input.size()) pad = int(n_fft // 2) - input = F.pad(input.view(extended_shape), (pad, pad), pad_mode) + input = F.pad(input.view(extended_shape), [pad, pad], pad_mode) input = input.view(input.shape[-signal_dim:]) return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore normalized, onesided, return_complex)