Skip to content

Commit

Permalink
Fix ruff not format the changes (#1761)
Browse files Browse the repository at this point in the history
* fix ruff did not format previously

* further ruff format
  • Loading branch information
fehiepsi committed Mar 15, 2024
1 parent 5da6fa5 commit c988ad0
Show file tree
Hide file tree
Showing 38 changed files with 346 additions and 198 deletions.
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
all: test

lint: FORCE
ruff .
ruff check .
ruff format . --check
python scripts/update_headers.py --check

license: FORCE
python scripts/update_headers.py

format: license FORCE
ruff . --fix
ruff format .

install: FORCE
pip install -e .[dev,doc,test,examples]
Expand Down
9 changes: 7 additions & 2 deletions examples/hmm_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,13 @@ def transition_fn(carry, y):
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
probs_x_t = Vindex(probs_x)[x_prev, x_curr]
x_prev, x_curr = x_curr, numpyro.sample(
"x", dist.Categorical(probs_x_t), infer={"enumerate": "parallel"}
x_prev, x_curr = (
x_curr,
numpyro.sample(
"x",
dist.Categorical(probs_x_t),
infer={"enumerate": "parallel"},
),
)
with numpyro.plate("tones", data_dim, dim=-1):
probs_y_t = probs_y[x_curr.squeeze(-1)]
Expand Down
1 change: 1 addition & 0 deletions examples/hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"""

import argparse
import os

Expand Down
1 change: 1 addition & 0 deletions examples/prodlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
.. image:: ../_static/img/examples/prodlda.png
:align: center
"""

import argparse

import matplotlib.pyplot as plt
Expand Down
1 change: 0 additions & 1 deletion examples/proportion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
density interval for the effect of making a call.
"""


import argparse
import os

Expand Down
1 change: 1 addition & 0 deletions examples/stein_dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
.. image:: ../_static/img/examples/stein_dmm.png
:align: center
"""

import argparse

import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion numpyro/compat/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
loss_and_grads=None,
num_samples=10,
num_steps=0,
**kwargs
**kwargs,
):
super(SVI, self).__init__(model=model, guide=guide, optim=optim, loss=loss)
self.svi_state = None
Expand Down
1 change: 0 additions & 1 deletion numpyro/contrib/ecs_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def log_likelihood_sum(params_flat, subsample_indices=None):
ref_sum_log_lik_hessians = hessian(log_likelihood_sum)(ref_params_flat)

def gibbs_init(rng_key, gibbs_sites):

ref_subsamples_taylor = [
log_likelihood(ref_params_flat, gibbs_sites),
jacobian(log_likelihood)(ref_params_flat, gibbs_sites),
Expand Down
12 changes: 5 additions & 7 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,11 @@ def local_trace(key):

def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
non_mixture_uparams = (
{ # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}
)
non_mixture_uparams = { # Includes any marked guide parameters and all model parameters
p: v
for p, v in unconstr_params.items()
if p not in self.guide_sites or self.non_mixture_params_fn(p)
}
stein_uparams = {
p: v for p, v in unconstr_params.items() if p not in non_mixture_uparams
}
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def random_flax_module(
input_shape=None,
apply_rng=None,
mutable=None,
**kwargs
**kwargs,
):
"""
A primitive to place a prior over the parameters of the Flax module `nn_module`.
Expand Down Expand Up @@ -372,7 +372,7 @@ def __call__(self, x):
input_shape=input_shape,
apply_rng=apply_rng,
mutable=mutable,
**kwargs
**kwargs,
)
params = nn.args[0]
new_params = deepcopy(params)
Expand Down
8 changes: 2 additions & 6 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,7 @@ def kl_divergence(p, q): # noqa: F811
_PyroDist.__doc__ = """
Wraps `{}.{} <https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/distributions/{}>`_
with :class:`~numpyro.contrib.tfp.distributions.TFPDistribution`.
""".format(
_Dist.__module__, _Dist.__name__, _Dist.__name__
)
""".format(_Dist.__module__, _Dist.__name__, _Dist.__name__)

__all__.append(_name)

Expand All @@ -328,9 +326,7 @@ def kl_divergence(p, q): # noqa: F811
{0}
----------------------------------------------------------------
.. autoclass:: numpyro.contrib.tfp.distributions.{0}
""".format(
_name
)
""".format(_name)
for _name in __all__[:_len_all]
]
)
8 changes: 2 additions & 6 deletions numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def sample(self, state, model_args, model_kwargs):
Wraps `{}.{} <https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/mcmc/{}>`_
with :class:`~numpyro.contrib.tfp.mcmc.TFPKernel`. The first argument `target_log_prob_fn`
in TFP kernel construction is replaced by either `model` or `potential_fn`.
""".format(
_Kernel.__module__, _Kernel.__name__, _Kernel.__name__
)
""".format(_Kernel.__module__, _Kernel.__name__, _Kernel.__name__)

__all__.append(_name)

Expand All @@ -250,9 +248,7 @@ def sample(self, state, model_args, model_kwargs):
{0}
----------------------------------------------------------------
.. autoclass:: numpyro.contrib.tfp.mcmc.{0}
""".format(
_name
)
""".format(_name)
for _name in __all__[:1] + sorted(__all__[1:])
]
)
3 changes: 3 additions & 0 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BetaBinomial(Distribution):
Beta distribution.
:param numpy.ndarray total_count: number of Bernoulli trials.
"""

arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
Expand Down Expand Up @@ -107,6 +108,7 @@ class DirichletMultinomial(Distribution):
Dirichlet distribution.
:param numpy.ndarray total_count: number of Categorical trials.
"""

arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1),
"total_count": constraints.nonnegative_integer,
Expand Down Expand Up @@ -182,6 +184,7 @@ class GammaPoisson(Distribution):
:param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution.
:param numpy.ndarray rate: rate parameter (beta) for the Gamma distribution.
"""

arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
Expand Down
20 changes: 11 additions & 9 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,7 @@ def mean(self):
@property
def variance(self):
con0 = jnp.sum(self.concentration, axis=-1, keepdims=True)
return (
self.concentration * (con0 - self.concentration) / (con0**2 * (con0 + 1))
)
return self.concentration * (con0 - self.concentration) / (con0**2 * (con0 + 1))

@staticmethod
def infer_shapes(concentration):
Expand Down Expand Up @@ -909,6 +907,7 @@ def model(y): # y has dimension N x d
[1] `Generating random correlation matrices based on vines and extended onion method`,
Daniel Lewandowski, Dorota Kurowicka, Harry Joe
"""

arg_constraints = {"concentration": constraints.positive}
reparametrized_params = ["concentration"]
support = constraints.corr_matrix
Expand Down Expand Up @@ -985,6 +984,7 @@ def model(y): # y has dimension N x d
[1] `Generating random correlation matrices based on vines and extended onion method`,
Daniel Lewandowski, Dorota Kurowicka, Harry Joe
"""

arg_constraints = {"concentration": constraints.positive}
reparametrized_params = ["concentration"]
support = constraints.corr_cholesky
Expand Down Expand Up @@ -1961,9 +1961,10 @@ def scale_tril(self):

@lazy_property
def covariance_matrix(self):
covariance_matrix = add_diag(jnp.matmul(
self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)
), self.cov_diag)
covariance_matrix = add_diag(
jnp.matmul(self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2)),
self.cov_diag,
)
return covariance_matrix

@lazy_property
Expand All @@ -1976,7 +1977,7 @@ def precision_matrix(self):
)
A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
inverse_cov_diag = jnp.reciprocal(self.cov_diag)
return add_diag(- jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag)
return add_diag(-jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
Expand Down Expand Up @@ -2068,8 +2069,9 @@ class Pareto(TransformedDistribution):
def __init__(self, scale, alpha, *, validate_args=None):
self.scale, self.alpha = promote_shapes(scale, alpha)
batch_shape = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(alpha))
scale, alpha = jnp.broadcast_to(scale, batch_shape), jnp.broadcast_to(
alpha, batch_shape
scale, alpha = (
jnp.broadcast_to(scale, batch_shape),
jnp.broadcast_to(alpha, batch_shape),
)
base_dist = Exponential(alpha)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=scale)]
Expand Down
1 change: 1 addition & 0 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ class Poisson(Distribution):
:param bool is_sparse: Whether to assume value is mostly zero when computing
:meth:`log_prob`, which can speed up computation when data is sparse.
"""

arg_constraints = {"rate": constraints.positive}
support = constraints.nonnegative_integer
pytree_aux_fields = ("is_sparse",)
Expand Down
8 changes: 7 additions & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ class CholeskyTransform(ParameterFreeTransform):
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
positive definite matrix.
"""

domain = constraints.positive_definite
codomain = constraints.lower_cholesky

Expand Down Expand Up @@ -444,6 +445,7 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a
c. Applies :math:`s_i = StickBreakingTransform(z_i)`.
d. Transforms back into signed domain: :math:`y_i = (sign(r_i), 1) * \sqrt{s_i}`.
"""

domain = constraints.real_vector
codomain = constraints.corr_cholesky

Expand Down Expand Up @@ -493,6 +495,7 @@ class CorrMatrixCholeskyTransform(CholeskyTransform):
Transform via the mapping :math:`y = cholesky(x)`, where `x` is a
correlation matrix.
"""

domain = constraints.corr_matrix
codomain = constraints.corr_cholesky

Expand Down Expand Up @@ -624,6 +627,7 @@ class L1BallTransform(ParameterFreeTransform):
r"""
Transforms a uncontrained real vector :math:`x` into the unit L1 ball.
"""

domain = constraints.real_vector
codomain = constraints.l1_ball

Expand Down Expand Up @@ -687,6 +691,7 @@ class LowerCholeskyAffine(Transform):
>>> affine(base)
Array([0.3, 1.5], dtype=float32)
"""

domain = constraints.real_vector
codomain = constraints.real_vector

Expand Down Expand Up @@ -786,6 +791,7 @@ class ScaledUnitLowerCholeskyTransform(LowerCholeskyTransform):
and :math:`scale\_diag` is a diagonal matrix with all positive
entries that is parameterized with a softplus transform.
"""

domain = constraints.real_vector
codomain = constraints.scaled_unit_lower_cholesky

Expand Down Expand Up @@ -995,6 +1001,7 @@ class SoftplusTransform(ParameterFreeTransform):
Transform from unconstrained space to positive domain via softplus :math:`y = \log(1 + \exp(x))`.
The inverse is computed as :math:`x = \log(\exp(y) - 1)`.
"""

domain = constraints.real
codomain = constraints.softplus_positive

Expand Down Expand Up @@ -1200,7 +1207,6 @@ def __eq__(self, other):
)



##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down
4 changes: 1 addition & 3 deletions numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,7 @@ def mean(self):
if isinstance(self.base_dist, Normal):
low_prob = jnp.exp(self.log_prob(self.low))
high_prob = jnp.exp(self.log_prob(self.high))
return (
self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2
)
return self.base_dist.loc + (low_prob - high_prob) * self.base_dist.scale**2
elif isinstance(self.base_dist, Cauchy):
return jnp.full(self.batch_shape, jnp.nan)
else:
Expand Down
22 changes: 16 additions & 6 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ def _setup_prototype(self, *args, **kwargs):
f"Expected {self.batch_ndim} batch dimensions, but site "
f"`{site['name']}` only has shape {shape}."
)
shape = shape[:self.batch_ndim]
shape = shape[: self.batch_ndim]
if batch_shape is None:
batch_shape = shape
elif shape != batch_shape:
Expand All @@ -1884,7 +1884,9 @@ def _get_posterior(self):
)

def _get_reshape_transform(self) -> ReshapeTransform:
return ReshapeTransform((self.latent_dim,), self._batch_shape + self._event_shape)
return ReshapeTransform(
(self.latent_dim,), self._batch_shape + self._event_shape
)


class AutoBatchedMultivariateNormal(AutoBatchedMixin, AutoContinuous):
Expand Down Expand Up @@ -1914,7 +1916,10 @@ def __init__(
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
self._init_scale = init_scale
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
model,
prefix=prefix,
init_loc_fn=init_loc_fn,
batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
Expand Down Expand Up @@ -2044,16 +2049,21 @@ def __init__(
self._init_scale = init_scale
self.rank = rank
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
model,
prefix=prefix,
init_loc_fn=init_loc_fn,
batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
rank = int(round(self._event_shape[0]**0.5)) if self.rank is None else self.rank
rank = (
int(round(self._event_shape[0] ** 0.5)) if self.rank is None else self.rank
)
init_latent = self._init_latent.reshape(self._batch_shape + self._event_shape)
loc = numpyro.param("{}_loc".format(self.prefix), init_latent)
cov_factor = numpyro.param(
"{}_cov_factor".format(self.prefix),
jnp.zeros(self._batch_shape + self._event_shape + (rank,))
jnp.zeros(self._batch_shape + self._event_shape + (rank,)),
)
scale = numpyro.param(
"{}_scale".format(self.prefix),
Expand Down

0 comments on commit c988ad0

Please sign in to comment.