From e7dc9d1db7637ed29c9ea8341a987360be607ccd Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 21 Feb 2021 10:17:42 -0600 Subject: [PATCH] revert black changes --- funsor/affine.py | 6 +- funsor/cnf.py | 6 +- funsor/delta.py | 5 +- funsor/distribution.py | 7 +- funsor/gaussian.py | 7 +- funsor/instrument.py | 9 +- funsor/integrate.py | 4 +- funsor/interpreter.py | 3 +- funsor/jax/__init__.py | 8 +- funsor/jax/ops.py | 5 +- funsor/joint.py | 4 +- funsor/memoize.py | 4 +- funsor/montecarlo.py | 4 +- funsor/registry.py | 4 +- funsor/syntax.py | 23 +- funsor/terms.py | 2 +- funsor/testing.py | 6 +- test/examples/test_bart.py | 92 ++++- test/examples/test_sensor_fusion.py | 11 +- test/pyro/test_hmm.py | 145 ++++++- test/test_adjoint.py | 7 +- test/test_distribution.py | 5 +- test/test_distribution_generic.py | 8 +- test/test_domains.py | 10 +- test/test_factory.py | 22 +- test/test_gaussian.py | 15 +- test/test_memoize.py | 4 +- test/test_minipyro.py | 17 +- test/test_optimizer.py | 28 +- test/test_samplers.py | 111 +++++- test/test_sum_product.py | 573 ++++++++++++++++++++++++---- test/test_tensor.py | 182 ++++++++- test/test_terms.py | 29 +- 33 files changed, 1172 insertions(+), 194 deletions(-) diff --git a/funsor/affine.py b/funsor/affine.py index 1d6c9a2d1..c9e181911 100644 --- a/funsor/affine.py +++ b/funsor/affine.py @@ -159,4 +159,8 @@ def extract_affine(fn): return const, coeffs -__all__ = ["affine_inputs", "extract_affine", "is_affine"] +__all__ = [ + "affine_inputs", + "extract_affine", + "is_affine", +] diff --git a/funsor/cnf.py b/funsor/cnf.py index 0e570df50..2200a6e74 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -594,7 +594,11 @@ def unary_contract(op, arg): ) -BACKEND_TO_EINSUM_BACKEND = {"numpy": "numpy", "torch": "torch", "jax": "jax.numpy"} +BACKEND_TO_EINSUM_BACKEND = { + "numpy": "numpy", + "torch": "torch", + "jax": "jax.numpy", +} # NB: numpy_log, numpy_map is backend-agnostic so they also work for torch backend; # however, we might need to profile to make a switch BACKEND_TO_LOGSUMEXP_BACKEND = { diff --git a/funsor/delta.py b/funsor/delta.py index 341280050..f87359b8f 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -248,4 +248,7 @@ def eager_independent_delta(delta, reals_var, bint_var, diag_var): return None -__all__ = ["Delta", "solve"] +__all__ = [ + "Delta", + "solve", +] diff --git a/funsor/distribution.py b/funsor/distribution.py index f67ceb5ef..b179905a5 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -184,7 +184,7 @@ def eager_log_prob(cls, *params): params, value = params[:-1], params[-1] params = params + (Variable("value", value.output),) instance = reflect.interpret(cls, *params) - (raw_dist, value_name, value_output, dim_to_name) = instance._get_raw_dist() + raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist() assert value.output == value_output name_to_dim = {v: k for k, v in dim_to_name.items()} dim_to_name.update( @@ -379,7 +379,10 @@ def dist_init(self, **kwargs): dist_class = DistributionMeta( backend_dist_class.__name__.split("Wrapper_")[-1], (Distribution,), - {"dist_class": backend_dist_class, "__init__": dist_init}, + { + "dist_class": backend_dist_class, + "__init__": dist_init, + }, ) if generate_eager: diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 13edad9d1..caa278c9c 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -779,4 +779,9 @@ def eager_neg(op, arg): return Gaussian(info_vec, precision, arg.inputs) -__all__ = ["BlockMatrix", "BlockVector", "Gaussian", "align_gaussian"] +__all__ = [ + "BlockMatrix", + "BlockVector", + "Gaussian", + "align_gaussian", +] diff --git a/funsor/instrument.py b/funsor/instrument.py index 569c771a4..71f797cf9 100644 --- a/funsor/instrument.py +++ b/funsor/instrument.py @@ -108,4 +108,11 @@ def print_counters(): print("-" * 80) -__all__ = ["DEBUG", "PROFILE", "STACK_SIZE", "debug_logged", "get_indent", "profile"] +__all__ = [ + "DEBUG", + "PROFILE", + "STACK_SIZE", + "debug_logged", + "get_indent", + "profile", +] diff --git a/funsor/integrate.py b/funsor/integrate.py index 7212628af..b75b6f50e 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -230,4 +230,6 @@ def eager_integrate(log_measure, integrand, reduced_vars): return None # defer to default implementation -__all__ = ["Integrate"] +__all__ = [ + "Integrate", +] diff --git a/funsor/interpreter.py b/funsor/interpreter.py index 5059672b4..759d75587 100644 --- a/funsor/interpreter.py +++ b/funsor/interpreter.py @@ -80,7 +80,8 @@ def interpret(cls, *args): def interpretation(new): warnings.warn( - "'with interpretation(x)' should be replaced by 'with x'", DeprecationWarning + "'with interpretation(x)' should be replaced by 'with x'", + DeprecationWarning, ) return new diff --git a/funsor/jax/__init__.py b/funsor/jax/__init__.py index dae7b7435..ffefedd58 100644 --- a/funsor/jax/__init__.py +++ b/funsor/jax/__init__.py @@ -18,7 +18,13 @@ @adjoint_ops.register( - Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object + Tensor, + AssociativeOp, + AssociativeOp, + Funsor, + (DeviceArray, Tracer), + tuple, + object, ) def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype): return {} diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index 0570b92b6..07e38c9d6 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -257,7 +257,10 @@ def _triangular_solve(x, y, upper=False, transpose=False): x_new_shape = batch_shape[:prepend_ndim] for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]): x_new_shape += (sx // sy, sy) - x_new_shape += (n, m) + x_new_shape += ( + n, + m, + ) x = np.reshape(x, x_new_shape) # Permute y to make it have shape (..., 1, j, m, i, 1, n) batch_ndim = x.ndim - 2 diff --git a/funsor/joint.py b/funsor/joint.py index 8658248db..cdc2a092c 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -104,7 +104,9 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss discrete += gaussian.log_normalizer new_discrete = discrete.reduce(ops.logaddexp, approx_vars & discrete.input_vars) num_elements = reduce( - ops.mul, [v.output.num_elements for v in approx_vars - discrete.input_vars], 1 + ops.mul, + [v.output.num_elements for v in approx_vars - discrete.input_vars], + 1, ) if num_elements != 1: new_discrete -= math.log(num_elements) diff --git a/funsor/memoize.py b/funsor/memoize.py index baf45a471..90fea683f 100644 --- a/funsor/memoize.py +++ b/funsor/memoize.py @@ -40,4 +40,6 @@ def interpret(self, cls, *args): return value -__all__ = ["memoize"] +__all__ = [ + "memoize", +] diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index 3ef205b4f..b533c0f1a 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -40,4 +40,6 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): return Integrate(sample, integrand, reduced_vars) -__all__ = ["MonteCarlo"] +__all__ = [ + "MonteCarlo", +] diff --git a/funsor/registry.py b/funsor/registry.py index 353693861..07f9f5542 100644 --- a/funsor/registry.py +++ b/funsor/registry.py @@ -84,4 +84,6 @@ def dispatch(self, key, *args): return self[key].partial_call(*args) -__all__ = ["KeyedRegistry"] +__all__ = [ + "KeyedRegistry", +] diff --git a/funsor/syntax.py b/funsor/syntax.py index 7b4291548..ed2b7d6eb 100644 --- a/funsor/syntax.py +++ b/funsor/syntax.py @@ -59,7 +59,12 @@ def visit_UnaryOp(self, node): var = self.prefix.get(type(node.op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load()), args=[node.operand], keywords=[] + func=ast.Name( + id=var, + ctx=ast.Load(), + ), + args=[node.operand], + keywords=[], ) return node @@ -68,7 +73,10 @@ def visit_BinOp(self, node): var = self.infix.get(type(node.op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load()), + func=ast.Name( + id=var, + ctx=ast.Load(), + ), args=[node.left, node.right], keywords=[], ) @@ -90,7 +98,10 @@ def visit_Compare(self, node): var = self.infix.get(type(node_op)) if var is not None: node = ast.Call( - func=ast.Name(id=var, ctx=ast.Load()), + func=ast.Name( + id=var, + ctx=ast.Load(), + ), args=[node.left, node_right], keywords=[], ) @@ -161,4 +172,8 @@ def decorator(fn): return decorator -__all__ = ["INFIX_OPERATORS", "PREFIX_OPERATORS", "rewrite_ops"] +__all__ = [ + "INFIX_OPERATORS", + "PREFIX_OPERATORS", + "rewrite_ops", +] diff --git a/funsor/terms.py b/funsor/terms.py index efc1cc33c..3b13ced73 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1520,7 +1520,7 @@ def eager_subs(self, subs): n -= size assert False elif isinstance(value, Slice): - start, stop, step = (value.slice.start, value.slice.stop, value.slice.step) + start, stop, step = value.slice.start, value.slice.stop, value.slice.step new_parts = [] pos = 0 for part in self.parts: diff --git a/funsor/testing.py b/funsor/testing.py index aefcb52d8..ef5dc7ef9 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -114,9 +114,9 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): n for n, p in expected.terms ) actual = actual.align(tuple(n for n, p in expected.terms)) - for ( - (actual_name, (actual_point, actual_log_density)), - (expected_name, (expected_point, expected_log_density)), + for (actual_name, (actual_point, actual_log_density)), ( + expected_name, + (expected_point, expected_log_density), ) in zip(actual.terms, expected.terms): assert actual_name == expected_name assert_close(actual_point, expected_point, atol=atol, rtol=rtol) diff --git a/test/examples/test_bart.py b/test/examples/test_bart.py index c6f1762f1..afcdb7a27 100644 --- a/test/examples/test_bart.py +++ b/test/examples/test_bart.py @@ -52,7 +52,10 @@ def unpack_gate_rate(gate_rate): @pytest.mark.parametrize( "analytic_kl", - [False, xfail_param(True, reason="missing pattern")], + [ + False, + xfail_param(True, reason="missing pattern"), + ], ids=["monte-carlo-kl", "analytic-kl"], ) def test_bart(analytic_kl): @@ -93,7 +96,16 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b4", Bint[2]), ("_event_1_b2", Bint[8])), + ( + ( + "time_b4", + Bint[2], + ), + ( + "_event_1_b2", + Bint[8], + ), + ), "real", ), Gaussian( @@ -148,9 +160,18 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b4", Bint[2]), - ("_event_1_b2", Bint[8]), - ("value_b1", Real), + ( + "time_b4", + Bint[2], + ), + ( + "_event_1_b2", + Bint[8], + ), + ( + "value_b1", + Real, + ), ), ), ), @@ -220,8 +241,14 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("state_b7", Reals[2]), - ("state(time=1)_b8", Reals[2]), + ( + "state_b7", + Reals[2], + ), + ( + "state(time=1)_b8", + Reals[2], + ), ), ), Subs( @@ -281,7 +308,12 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b9", Bint[2]),), + ( + ( + "time_b9", + Bint[2], + ), + ), "real", ), Tensor( @@ -310,7 +342,12 @@ def test_bart(analytic_kl): ], dtype=torch.float32, ), # noqa - (("time_b9", Bint[2]),), + ( + ( + "time_b9", + Bint[2], + ), + ), "real", ), Variable("state(time=1)_b8", Reals[2]), @@ -352,7 +389,12 @@ def test_bart(analytic_kl): ), Variable("value_b5", Reals[2]), ), - (("value_b5", Variable("state_b10", Reals[2])),), + ( + ( + "value_b5", + Variable("state_b10", Reals[2]), + ), + ), ), ), ) @@ -449,9 +491,18 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b17", Bint[2]), - ("origin_b15", Bint[2]), - ("destin_b16", Bint[2]), + ( + "time_b17", + Bint[2], + ), + ( + "origin_b15", + Bint[2], + ), + ( + "destin_b16", + Bint[2], + ), ), "real", ), @@ -476,9 +527,18 @@ def test_bart(analytic_kl): dtype=torch.float32, ), # noqa ( - ("time_b17", Bint[2]), - ("origin_b15", Bint[2]), - ("destin_b16", Bint[2]), + ( + "time_b17", + Bint[2], + ), + ( + "origin_b15", + Bint[2], + ), + ( + "destin_b16", + Bint[2], + ), ), "real", ), diff --git a/test/examples/test_sensor_fusion.py b/test/examples/test_sensor_fusion.py index f52e483a0..f8fc8b77f 100644 --- a/test/examples/test_sensor_fusion.py +++ b/test/examples/test_sensor_fusion.py @@ -142,7 +142,16 @@ def test_affine_subs(): ], dtype=torch.float32, ), # noqa - (("state_1_b6", Reals[3]), ("obs_b2", Reals[2])), + ( + ( + "state_1_b6", + Reals[3], + ), + ( + "obs_b2", + Reals[2], + ), + ), ), ( ( diff --git a/test/pyro/test_hmm.py b/test/pyro/test_hmm.py index 9db9dcad7..69d645ca5 100644 --- a/test/pyro/test_hmm.py +++ b/test/pyro/test_hmm.py @@ -245,19 +245,134 @@ def test_gaussian_mrf_log_prob(init_shape, trans_shape, obs_shape, hidden_dim, o ] ) SLHMM_SHAPES = [ - ((2,), (), (1, 2), (1, 3, 3), (1,), (1, 3, 4), (1,)), - ((2,), (), (5, 1, 2), (1, 3, 3), (1,), (1, 3, 4), (1,)), - ((2,), (), (1, 2), (5, 1, 3, 3), (1,), (1, 3, 4), (1,)), - ((2,), (), (1, 2), (1, 3, 3), (5, 1), (1, 3, 4), (1,)), - ((2,), (), (1, 2), (1, 3, 3), (1,), (5, 1, 3, 4), (1,)), - ((2,), (), (1, 2), (1, 3, 3), (1,), (1, 3, 4), (5, 1)), - ((2,), (), (5, 1, 2), (5, 1, 3, 3), (5, 1), (5, 1, 3, 4), (5, 1)), - ((2,), (2,), (5, 2, 2), (5, 2, 3, 3), (5, 2), (5, 2, 3, 4), (5, 2)), - ((7, 2), (), (7, 5, 1, 2), (7, 5, 1, 3, 3), (7, 5, 1), (7, 5, 1, 3, 4), (7, 5, 1)), ( + (2,), + (), + ( + 1, + 2, + ), + (1, 3, 3), + (1,), + (1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 5, + 1, + 2, + ), + (1, 3, 3), + (1,), + (1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 1, + 2, + ), + (5, 1, 3, 3), + (1,), + (1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 1, + 2, + ), + (1, 3, 3), + (5, 1), + (1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 1, + 2, + ), + (1, 3, 3), + (1,), + (5, 1, 3, 4), + (1,), + ), + ( + (2,), + (), + ( + 1, + 2, + ), + (1, 3, 3), + (1,), + (1, 3, 4), + (5, 1), + ), + ( + (2,), + (), + ( + 5, + 1, + 2, + ), + (5, 1, 3, 3), + (5, 1), + (5, 1, 3, 4), + (5, 1), + ), + ( + (2,), + (2,), + ( + 5, + 2, + 2, + ), + (5, 2, 3, 3), + (5, 2), + (5, 2, 3, 4), + (5, 2), + ), + ( + ( + 7, + 2, + ), + (), + ( + 7, + 5, + 1, + 2, + ), + (7, 5, 1, 3, 3), + (7, 5, 1), + (7, 5, 1, 3, 4), + (7, 5, 1), + ), + ( + ( + 7, + 2, + ), (7, 2), - (7, 2), - (7, 5, 2, 2), + ( + 7, + 5, + 2, + 2, + ), (7, 5, 2, 3, 3), (7, 5, 2), (7, 5, 2, 3, 4), @@ -403,7 +518,13 @@ def test_switching_linear_hmm_log_prob_alternating(exact, num_steps, num_compone -1, num_components, -1, -1 ) - trans_mvn = random_mvn((num_steps, num_components), hidden_dim) + trans_mvn = random_mvn( + ( + num_steps, + num_components, + ), + hidden_dim, + ) hmm_obs_matrix = torch.randn(num_steps, hidden_dim, obs_dim) switching_obs_matrix = hmm_obs_matrix.unsqueeze(-3).expand( -1, num_components, -1, -1 diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 4756479eb..b098c1c95 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -201,7 +201,12 @@ def test_optimized_plated_einsum_adjoint(equation, plates, backend): ids=lambda d: ",".join(d.keys()), ) @pytest.mark.parametrize( - "impl", [sequential_sum_product, naive_sequential_sum_product, MarkovProduct] + "impl", + [ + sequential_sum_product, + naive_sequential_sum_product, + MarkovProduct, + ], ) def test_sequential_sum_product_adjoint( impl, sum_op, prod_op, batch_inputs, state_domain, num_steps diff --git a/test/test_distribution.py b/test/test_distribution.py index c73da8567..d0794a1eb 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1459,7 +1459,10 @@ def test_power_transform(shape): @pytest.mark.parametrize("shape", [(10,), (4, 3)], ids=str) @pytest.mark.parametrize( "to_event", - [True, xfail_param(False, reason="bug in to_funsor(TransformedDistribution)")], + [ + True, + xfail_param(False, reason="bug in to_funsor(TransformedDistribution)"), + ], ) def test_haar_transform(shape, to_event): try: diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index db1c4c263..5ffa99bf6 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -186,7 +186,9 @@ def __hash__(self): # Chi2 DistTestCase( - "dist.Chi2(df=case.df)", (("df", f"rand({batch_shape})"),), funsor.Real + "dist.Chi2(df=case.df)", + (("df", f"rand({batch_shape})"),), + funsor.Real, ) # ContinuousBernoulli @@ -368,7 +370,9 @@ def __hash__(self): # Poisson DistTestCase( - "dist.Poisson(rate=case.rate)", (("rate", f"rand({batch_shape})"),), funsor.Real + "dist.Poisson(rate=case.rate)", + (("rate", f"rand({batch_shape})"),), + funsor.Real, ) # RelaxedBernoulli diff --git a/test/test_domains.py b/test/test_domains.py index 29bfc5cdd..d721ee03e 100644 --- a/test/test_domains.py +++ b/test/test_domains.py @@ -9,7 +9,15 @@ from funsor.domains import Bint, Real, Reals # noqa F401 -@pytest.mark.parametrize("expr", ["Bint[2]", "Real", "Reals[4]", "Reals[3, 2]"]) +@pytest.mark.parametrize( + "expr", + [ + "Bint[2]", + "Real", + "Reals[4]", + "Reals[3, 2]", + ], +) def test_pickle(expr): x = eval(expr) f = io.BytesIO() diff --git a/test/test_factory.py b/test/test_factory.py index e23f50e20..a9fe8b78c 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -19,7 +19,9 @@ def test_lambda_lambda(): @make_funsor def LambdaLambda( - i: Bound, j: Bound, x: Funsor + i: Bound, + j: Bound, + x: Funsor, ) -> Fresh[lambda i, j, x: Array[x.dtype, (i.size, j.size) + x.shape]]: assert i in x.inputs assert j in x.inputs @@ -49,7 +51,10 @@ def GetitemGetitem( def test_flatten(): @make_funsor def Flatten21( - x: Funsor, i: Bound, j: Bound, ij: Fresh[lambda i, j: Bint[i.size * j.size]] + x: Funsor, + i: Bound, + j: Bound, + ij: Fresh[lambda i, j: Bint[i.size * j.size]], ) -> Fresh[lambda x: x.dtype]: m = to_funsor(i, x.inputs.get(i, None)).output.size n = to_funsor(j, x.inputs.get(j, None)).output.size @@ -115,7 +120,9 @@ def Cat2( def test_normal(): @make_funsor def Normal( - loc: Funsor, scale: Funsor, value: Fresh[lambda loc: loc] + loc: Funsor, + scale: Funsor, + value: Fresh[lambda loc: loc], ) -> Fresh[Real]: return None @@ -140,7 +147,11 @@ def _(loc, scale, value): def test_matmul(): @make_funsor - def MatMul(x: Funsor, y: Funsor, i: Bound) -> Fresh[lambda x: x]: + def MatMul( + x: Funsor, + y: Funsor, + i: Bound, + ) -> Fresh[lambda x: x]: return (x * y).reduce(ops.add, i) x = random_tensor(OrderedDict(a=Bint[3], b=Bint[4])) @@ -171,7 +182,8 @@ def Scatter1( def test_value_dependence(): @make_funsor def Sum( - x: Funsor, dim: Value[int] + x: Funsor, + dim: Value[int], ) -> Fresh[lambda x, dim: Array[x.dtype, x.shape[:dim] + x.shape[dim + 1 :]]]: return None diff --git a/test/test_gaussian.py b/test/test_gaussian.py index f3c5ac636..d9e66af02 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -576,10 +576,21 @@ def test_reduce_logsumexp(int_inputs, real_inputs): ) -@pytest.mark.parametrize("int_inputs", [{}, {"i": Bint[2]}], ids=id_from_inputs) +@pytest.mark.parametrize( + "int_inputs", + [ + {}, + {"i": Bint[2]}, + ], + ids=id_from_inputs, +) @pytest.mark.parametrize( "real_inputs", - [{"x": Real}, {"x": Reals[4]}, {"x": Reals[2, 3]}], + [ + {"x": Real}, + {"x": Reals[4]}, + {"x": Reals[2, 3]}, + ], ids=id_from_inputs, ) def test_integrate_variable(int_inputs, real_inputs): diff --git a/test/test_memoize.py b/test/test_memoize.py index e54b18cb2..14b11b3aa 100644 --- a/test/test_memoize.py +++ b/test/test_memoize.py @@ -169,10 +169,10 @@ def test_nested_einsum_complete_sharing( eqn1, eqn2, einsum_impl1, einsum_impl2, backend1, backend2 ): - (inputs1, outputs1, sizes1, operands1, funsor_operands1) = make_einsum_example( + inputs1, outputs1, sizes1, operands1, funsor_operands1 = make_einsum_example( eqn1, sizes=(3,) ) - (inputs2, outputs2, sizes2, operands2, funsor_operands2) = make_einsum_example( + inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example( eqn2, sizes=(3,) ) diff --git a/test/test_minipyro.py b/test/test_minipyro.py index b9c1eb937..5224ab25c 100644 --- a/test/test_minipyro.py +++ b/test/test_minipyro.py @@ -36,9 +36,8 @@ def Vindex(x): def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4): # copied from pyro - expected_loss, actual_loss = ( - funsor.to_data(expected_loss), - funsor.to_data(actual_loss), + expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data( + actual_loss ) assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol) names = pyro.get_param_store().keys() @@ -302,7 +301,11 @@ def guide(): @pytest.mark.parametrize( - "backend", ["pyro", xfail_param("funsor", reason="missing patterns")] + "backend", + [ + "pyro", + xfail_param("funsor", reason="missing patterns"), + ], ) def test_mean_field_ok(backend): def model(): @@ -320,7 +323,11 @@ def guide(): @pytest.mark.parametrize( - "backend", ["pyro", xfail_param("funsor", reason="missing patterns")] + "backend", + [ + "pyro", + xfail_param("funsor", reason="missing patterns"), + ], ) def test_mean_field_warn(backend): def model(): diff --git a/test/test_optimizer.py b/test/test_optimizer.py index 7c5399622..ee16c9d75 100644 --- a/test/test_optimizer.py +++ b/test/test_optimizer.py @@ -45,9 +45,19 @@ @pytest.mark.parametrize("equation", OPTIMIZED_EINSUM_EXAMPLES) @pytest.mark.parametrize( - "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map"] + "backend", + [ + "pyro.ops.einsum.torch_log", + "pyro.ops.einsum.torch_map", + ], +) +@pytest.mark.parametrize( + "einsum_impl", + [ + naive_einsum, + naive_contract_einsum, + ], ) -@pytest.mark.parametrize("einsum_impl", [naive_einsum, naive_contract_einsum]) def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, backend=backend)[0] @@ -69,7 +79,11 @@ def test_optimized_einsum(equation, backend, einsum_impl): @pytest.mark.parametrize( - "eqn1,eqn2", [("a,ab->b", "bc->"), ("ab,bc,cd->d", "de,ef,fg->")] + "eqn1,eqn2", + [ + ("a,ab->b", "bc->"), + ("ab,bc,cd->d", "de,ef,fg->"), + ], ) @pytest.mark.parametrize("optimize1", [False, True]) @pytest.mark.parametrize("optimize2", [False, True]) @@ -84,7 +98,7 @@ def test_nested_einsum( eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl ): inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,)) - (inputs2, outputs2, sizes2, operands2, funsor_operands2) = make_einsum_example( + inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example( eqn2, sizes=(3,) ) @@ -137,7 +151,11 @@ def test_nested_einsum( @pytest.mark.parametrize("equation,plates", PLATED_EINSUM_EXAMPLES) @pytest.mark.parametrize( - "backend", ["pyro.ops.einsum.torch_log", "pyro.ops.einsum.torch_map"] + "backend", + [ + "pyro.ops.einsum.torch_log", + "pyro.ops.einsum.torch_map", + ], ) def test_optimized_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) diff --git a/test/test_samplers.py b/test/test_samplers.py index d30aca38e..aafc29a21 100644 --- a/test/test_samplers.py +++ b/test/test_samplers.py @@ -37,17 +37,28 @@ @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7]))], + [ + (), + (("s", Bint[6]),), + (("s", Bint[6]), ("t", Bint[7])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[4]),), (("b", Bint[4]), ("c", Bint[5]))], + [ + (), + (("b", Bint[4]),), + (("b", Bint[4]), ("c", Bint[5])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Bint[2]),), (("e", Bint[2]), ("f", Bint[3]))], + [ + (("e", Bint[2]),), + (("e", Bint[2]), ("f", Bint[3])), + ], ids=id_from_inputs, ) def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): @@ -81,16 +92,30 @@ def test_tensor_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4]))], + [ + (), + (("s", Bint[3]),), + (("s", Bint[3]), ("t", Bint[4])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real))], + [ + (), + (("b", Bint[2]),), + (("c", Real),), + (("b", Bint[2]), ("c", Real)), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs + "event_inputs", + [ + (("e", Real),), + (("e", Real), ("f", Reals[2])), + ], + ids=id_from_inputs, ) def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) @@ -130,16 +155,30 @@ def test_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[3]),), (("s", Bint[3]), ("t", Bint[4]))], + [ + (), + (("s", Bint[3]),), + (("s", Bint[3]), ("t", Bint[4])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[2]),), (("c", Real),), (("b", Bint[2]), ("c", Real))], + [ + (), + (("b", Bint[2]),), + (("c", Real),), + (("b", Bint[2]), ("c", Real)), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs + "event_inputs", + [ + (("e", Real),), + (("e", Real), ("f", Reals[2])), + ], + ids=id_from_inputs, ) def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): be_inputs = OrderedDict(batch_inputs + event_inputs) @@ -187,17 +226,28 @@ def test_transformed_gaussian_shape(sample_inputs, batch_inputs, event_inputs): @pytest.mark.parametrize( "sample_inputs", - [(), (("s", Bint[6]),), (("s", Bint[6]), ("t", Bint[7]))], + [ + (), + (("s", Bint[6]),), + (("s", Bint[6]), ("t", Bint[7])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "int_event_inputs", - [(), (("d", Bint[2]),), (("d", Bint[2]), ("e", Bint[3]))], + [ + (), + (("d", Bint[2]),), + (("d", Bint[2]), ("e", Bint[3])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "real_event_inputs", - [(("g", Real),), (("g", Real), ("h", Reals[4]))], + [ + (("g", Real),), + (("g", Real), ("h", Reals[4])), + ], ids=id_from_inputs, ) def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @@ -239,12 +289,19 @@ def test_joint_shape(sample_inputs, int_event_inputs, real_event_inputs): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[4]),), (("b", Bint[2]), ("c", Bint[2]))], + [ + (), + (("b", Bint[4]),), + (("b", Bint[2]), ("c", Bint[2])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Bint[3]),), (("e", Bint[2]), ("f", Bint[2]))], + [ + (("e", Bint[3]),), + (("e", Bint[2]), ("f", Bint[2])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize("test_grad", [False, True], ids=["value", "grad"]) @@ -267,7 +324,7 @@ def diff_fn(p_data): _, (p_data, mq_data) = align_tensors(p, mq) assert p_data.shape == mq_data.shape - return ((ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq) + return (ops.exp(mq_data) * probe).sum() - (ops.exp(p_data) * probe).sum(), mq if test_grad: if get_backend() == "jax": @@ -290,11 +347,20 @@ def diff_fn(p_data): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[4]))], + [ + (), + (("b", Bint[3]),), + (("b", Bint[3]), ("c", Bint[4])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( - "event_inputs", [(("e", Real),), (("e", Real), ("f", Reals[2]))], ids=id_from_inputs + "event_inputs", + [ + (("e", Real),), + (("e", Real), ("f", Reals[2])), + ], + ids=id_from_inputs, ) def test_gaussian_distribution(event_inputs, batch_inputs): num_samples = 100000 @@ -330,12 +396,19 @@ def test_gaussian_distribution(event_inputs, batch_inputs): @pytest.mark.parametrize( "batch_inputs", - [(), (("b", Bint[3]),), (("b", Bint[3]), ("c", Bint[2]))], + [ + (), + (("b", Bint[3]),), + (("b", Bint[3]), ("c", Bint[2])), + ], ids=id_from_inputs, ) @pytest.mark.parametrize( "event_inputs", - [(("e", Real), ("f", Bint[3])), (("e", Reals[2]), ("f", Bint[2]))], + [ + (("e", Real), ("f", Bint[3])), + (("e", Reals[2]), ("f", Bint[2])), + ], ids=id_from_inputs, ) def test_gaussian_mixture_distribution(batch_inputs, event_inputs): diff --git a/test/test_sum_product.py b/test/test_sum_product.py index ff2fb5ee8..8977cc22c 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -100,7 +100,13 @@ def test_partition(inputs, dims, expected_num_components): ("abcij", ""), ], ) -@pytest.mark.parametrize("impl", [partial_sum_product, modified_partial_sum_product]) +@pytest.mark.parametrize( + "impl", + [ + partial_sum_product, + modified_partial_sum_product, + ], +) def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2): inputs = inputs.split(",") factors = [random_tensor(OrderedDict((d, Bint[2]) for d in ds)) for ds in inputs] @@ -138,7 +144,14 @@ def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2 (frozenset({"time", "x_0", "x_prev", "x_curr"}), frozenset()), ], ) -@pytest.mark.parametrize("x_dim,time", [(3, 1), (1, 5), (3, 5)]) +@pytest.mark.parametrize( + "x_dim,time", + [ + (3, 1), + (1, 5), + (3, 5), + ], +) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] ) @@ -146,10 +159,22 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2, x_dim, ti f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) factors = [f1, f2, f3] @@ -182,7 +207,13 @@ def test_modified_partial_sum_product_0(sum_op, prod_op, vars1, vars2, x_dim, ti ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] + "x_dim,y_dim,time", + [ + (2, 3, 5), + (1, 3, 5), + (2, 1, 5), + (2, 3, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -193,16 +224,41 @@ def test_modified_partial_sum_product_1( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) - f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( - OrderedDict({"time": Bint[time], "x_curr": Bint[x_dim], "y_curr": Bint[y_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_curr": Bint[x_dim], + "y_curr": Bint[y_dim], + } + ) ) factors = [f1, f2, f3, f4, f5] @@ -240,7 +296,13 @@ def test_modified_partial_sum_product_1( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] + "x_dim,y_dim,time", + [ + (2, 3, 5), + (1, 3, 5), + (2, 1, 5), + (2, 3, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -251,16 +313,40 @@ def test_modified_partial_sum_product_2( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) - f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( - OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "y_curr": Bint[y_dim]}) + OrderedDict( + { + "time": Bint[time], + "y_prev": Bint[y_dim], + "y_curr": Bint[y_dim], + } + ) ) factors = [f1, f2, f3, f4, f5] @@ -300,7 +386,13 @@ def test_modified_partial_sum_product_2( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] + "x_dim,y_dim,time", + [ + (2, 3, 5), + (1, 3, 5), + (2, 1, 5), + (2, 3, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -311,13 +403,32 @@ def test_modified_partial_sum_product_3( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) - f4 = random_tensor(OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( OrderedDict( @@ -398,7 +509,12 @@ def test_modified_partial_sum_product_3( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], + [ + (2, 3, 2, 5, 4), + (1, 3, 2, 5, 4), + (2, 1, 2, 5, 4), + (2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -409,7 +525,14 @@ def test_modified_partial_sum_product_4( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -424,7 +547,11 @@ def test_modified_partial_sum_product_4( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "tones": Bint[tones], "y_0": Bint[y_dim]} + { + "sequences": Bint[sequences], + "tones": Bint[tones], + "y_0": Bint[y_dim], + } ) ) @@ -530,7 +657,12 @@ def test_modified_partial_sum_product_4( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,days,weeks,tones", - [(2, 3, 2, 5, 4, 3), (1, 3, 2, 5, 4, 3), (2, 1, 2, 5, 4, 3), (2, 3, 2, 1, 4, 3)], + [ + (2, 3, 2, 5, 4, 3), + (1, 3, 2, 5, 4, 3), + (2, 1, 2, 5, 4, 3), + (2, 3, 2, 1, 4, 3), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -543,7 +675,11 @@ def test_modified_partial_sum_product_5( f2 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "tones": Bint[tones], "x_0": Bint[x_dim]} + { + "sequences": Bint[sequences], + "tones": Bint[tones], + "x_0": Bint[x_dim], + } ) ) @@ -559,7 +695,14 @@ def test_modified_partial_sum_product_5( ) ) - f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( OrderedDict( @@ -643,7 +786,12 @@ def test_modified_partial_sum_product_5( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], + [ + (2, 3, 2, 5, 4), + (1, 3, 2, 5, 4), + (2, 1, 2, 5, 4), + (2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -654,7 +802,14 @@ def test_modified_partial_sum_product_6( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -760,7 +915,12 @@ def test_modified_partial_sum_product_6( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 5, 4), (1, 3, 2, 5, 4), (2, 1, 2, 5, 4), (2, 3, 2, 1, 4)], + [ + (2, 3, 2, 5, 4), + (1, 3, 2, 5, 4), + (2, 1, 2, 5, 4), + (2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -771,7 +931,14 @@ def test_modified_partial_sum_product_7( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -811,7 +978,12 @@ def test_modified_partial_sum_product_7( factors = [f1, f2, f3, f4, f5] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr")}), + "time": frozenset( + { + ("x_0", "x_prev", "x_curr"), + ("y_0", "y_prev", "y_curr"), + } + ), "tones": {}, } @@ -900,7 +1072,12 @@ def test_modified_partial_sum_product_7( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], + [ + (3, 2, 3, 2, 5, 4), + (3, 1, 3, 2, 5, 4), + (3, 2, 1, 2, 5, 4), + (3, 2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -911,7 +1088,14 @@ def test_modified_partial_sum_product_8( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -924,7 +1108,14 @@ def test_modified_partial_sum_product_8( ) ) - f4 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f4 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f5 = random_tensor( OrderedDict( @@ -965,7 +1156,12 @@ def test_modified_partial_sum_product_8( factors = [f1, f2, f3, f4, f5, f6, f7] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr")}), + "time": frozenset( + { + ("x_0", "x_prev", "x_curr"), + ("w_0", "w_prev", "w_curr"), + } + ), "tones": {}, } @@ -1063,7 +1259,12 @@ def test_modified_partial_sum_product_8( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], + [ + (3, 2, 3, 2, 5, 4), + (3, 1, 3, 2, 5, 4), + (3, 2, 1, 2, 5, 4), + (3, 2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1074,7 +1275,14 @@ def test_modified_partial_sum_product_9( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -1089,7 +1297,11 @@ def test_modified_partial_sum_product_9( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim]} + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + "x_0": Bint[x_dim], + } ) ) @@ -1133,7 +1345,12 @@ def test_modified_partial_sum_product_9( factors = [f1, f2, f3, f4, f5, f6, f7] plate_to_step = { "sequences": {}, - "time": frozenset({("x_0", "x_prev", "x_curr"), ("w_0", "w_prev", "w_curr")}), + "time": frozenset( + { + ("x_0", "x_prev", "x_curr"), + ("w_0", "w_prev", "w_curr"), + } + ), "tones": {}, } @@ -1220,7 +1437,12 @@ def test_modified_partial_sum_product_9( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], + [ + (3, 2, 3, 2, 5, 4), + (3, 1, 3, 2, 5, 4), + (3, 2, 1, 2, 5, 4), + (3, 2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1231,17 +1453,32 @@ def test_modified_partial_sum_product_10( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } + ) + ) f3 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim]} + { + "sequences": Bint[sequences], + "time": Bint[time], + "w_curr": Bint[w_dim], + } ) ) f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "w_0": Bint[w_dim], "x_0": Bint[x_dim]} + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + "x_0": Bint[x_dim], + } ) ) @@ -1417,13 +1654,30 @@ def test_modified_partial_sum_product_11( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"a": Bint[a_dim]})) + f2 = random_tensor( + OrderedDict( + { + "a": Bint[a_dim], + } + ) + ) - f3 = random_tensor(OrderedDict({"sequences": Bint[sequences], "b": Bint[b_dim]})) + f3 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "b": Bint[b_dim], + } + ) + ) f4 = random_tensor( OrderedDict( - {"a": Bint[a_dim], "sequences": Bint[sequences], "w_0": Bint[w_dim]} + { + "a": Bint[a_dim], + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } ) ) @@ -1575,7 +1829,12 @@ def test_modified_partial_sum_product_11( ) @pytest.mark.parametrize( "w_dim,x_dim,y_dim,sequences,time,tones", - [(3, 2, 3, 2, 5, 4), (3, 1, 3, 2, 5, 4), (3, 2, 1, 2, 5, 4), (3, 2, 3, 2, 1, 4)], + [ + (3, 2, 3, 2, 5, 4), + (3, 1, 3, 2, 5, 4), + (3, 2, 1, 2, 5, 4), + (3, 2, 3, 2, 1, 4), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1586,11 +1845,22 @@ def test_modified_partial_sum_product_12( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "w_0": Bint[w_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "w_0": Bint[w_dim], + } + ) + ) f3 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "time": Bint[time], "w_curr": Bint[w_dim]} + { + "sequences": Bint[sequences], + "time": Bint[time], + "w_curr": Bint[w_dim], + } ) ) @@ -1799,7 +2069,11 @@ def test_modified_partial_sum_product_13( f4 = random_tensor( OrderedDict( - {"w": Bint[w_dim], "sequences": Bint[sequences], "y_0": Bint[y_dim]} + { + "w": Bint[w_dim], + "sequences": Bint[sequences], + "y_0": Bint[y_dim], + } ) ) @@ -1920,7 +2194,12 @@ def test_modified_partial_sum_product_13( ) @pytest.mark.parametrize( "x_dim,y_dim,sequences,time,tones", - [(2, 3, 2, 3, 2), (1, 3, 2, 3, 2), (2, 1, 2, 3, 2), (2, 3, 2, 1, 2)], + [ + (2, 3, 2, 3, 2), + (1, 3, 2, 3, 2), + (2, 1, 2, 3, 2), + (2, 3, 2, 1, 2), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -1931,7 +2210,14 @@ def test_modified_partial_sum_product_14( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"sequences": Bint[sequences], "x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( OrderedDict( @@ -1946,7 +2232,11 @@ def test_modified_partial_sum_product_14( f4 = random_tensor( OrderedDict( - {"sequences": Bint[sequences], "x_0": Bint[x_dim], "y0_0": Bint[y_dim]} + { + "sequences": Bint[sequences], + "x_0": Bint[x_dim], + "y0_0": Bint[y_dim], + } ) ) @@ -1991,7 +2281,10 @@ def test_modified_partial_sum_product_14( "sequences": {}, "time": frozenset({("x_0", "x_prev", "x_curr")}), "tones": frozenset( - {("y0_0", "y0_prev", "y0_curr"), ("ycurr_0", "ycurr_prev", "ycurr_curr")} + { + ("y0_0", "y0_prev", "y0_curr"), + ("ycurr_0", "ycurr_prev", "ycurr_curr"), + } ), } @@ -2027,7 +2320,13 @@ def test_modified_partial_sum_product_14( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,time", [(2, 3, 5), (1, 3, 5), (2, 1, 5), (2, 3, 1)] + "x_dim,y_dim,time", + [ + (2, 3, 5), + (1, 3, 5), + (2, 1, 5), + (2, 3, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2038,21 +2337,50 @@ def test_modified_partial_sum_product_16( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "y_prev": Bint[y_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "y_prev": Bint[y_dim], + "x_curr": Bint[x_dim], + } + ) ) - f4 = random_tensor(OrderedDict({"y_0": Bint[y_dim]})) + f4 = random_tensor( + OrderedDict( + { + "y_0": Bint[y_dim], + } + ) + ) f5 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "y_curr": Bint[y_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "y_curr": Bint[y_dim], + } + ) ) factors = [f1, f2, f3, f4, f5] plate_to_step = { - "time": frozenset({("x_0", "x_prev", "x_curr"), ("y_0", "y_prev", "y_curr")}) + "time": frozenset( + { + ("x_0", "x_prev", "x_curr"), + ("y_0", "y_prev", "y_curr"), + } + ), } factors1 = modified_partial_sum_product( @@ -2122,7 +2450,13 @@ def test_modified_partial_sum_product_16( ], ) @pytest.mark.parametrize( - "x_dim,y_dim,z_dim,time", [(2, 3, 2, 5), (1, 3, 2, 5), (2, 1, 2, 5), (2, 3, 2, 1)] + "x_dim,y_dim,z_dim,time", + [ + (2, 3, 2, 5), + (1, 3, 2, 5), + (2, 1, 2, 5), + (2, 3, 2, 1), + ], ) @pytest.mark.parametrize( "sum_op,prod_op", [(ops.logaddexp, ops.add), (ops.add, ops.mul)] @@ -2133,10 +2467,22 @@ def test_modified_partial_sum_product_17( f1 = random_tensor(OrderedDict({})) - f2 = random_tensor(OrderedDict({"x_0": Bint[x_dim]})) + f2 = random_tensor( + OrderedDict( + { + "x_0": Bint[x_dim], + } + ) + ) f3 = random_tensor( - OrderedDict({"time": Bint[time], "x_prev": Bint[x_dim], "x_curr": Bint[x_dim]}) + OrderedDict( + { + "time": Bint[time], + "x_prev": Bint[x_dim], + "x_curr": Bint[x_dim], + } + ) ) f4 = random_tensor( @@ -2186,7 +2532,13 @@ def test_modified_partial_sum_product_17( ) f8 = random_tensor( - OrderedDict({"x_0": Bint[x_dim], "y_0": Bint[y_dim], "z2_0": Bint[z_dim]}) + OrderedDict( + { + "x_0": Bint[x_dim], + "y_0": Bint[y_dim], + "z2_0": Bint[z_dim], + } + ) ) f9 = random_tensor( @@ -2201,7 +2553,9 @@ def test_modified_partial_sum_product_17( ) factors = [f1, f2, f3, f4, f5, f6, f7, f8, f9] - plate_to_step = {"time": frozenset({("x_0", "x_prev", "x_curr")})} + plate_to_step = { + "time": frozenset({("x_0", "x_prev", "x_curr")}), + } with (lazy if use_lazy else eager): factors1 = modified_partial_sum_product( @@ -2301,7 +2655,11 @@ def test_sequential_sum_product( ) @pytest.mark.parametrize( "x_domain,y_domain", - [(Bint[2], Bint[3]), (Real, Reals[2, 2]), (Bint[2], Reals[2])], + [ + (Bint[2], Bint[3]), + (Real, Reals[2, 2]), + (Bint[2], Reals[2]), + ], ids=str, ) @pytest.mark.parametrize( @@ -2370,15 +2728,29 @@ def test_sequential_sum_product_multi( @pytest.mark.parametrize("dim", [1, 2, 3]) def test_sequential_sum_product_bias_1(num_steps, dim): time = Variable("time", Bint[num_steps]) - bias_dist = random_gaussian(OrderedDict([("bias", Reals[dim])])) + bias_dist = random_gaussian( + OrderedDict( + [ + ("bias", Reals[dim]), + ] + ) + ) trans = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim])] + [ + ("time", Bint[num_steps]), + ("x_prev", Reals[dim]), + ("x_curr", Reals[dim]), + ] ) ) obs = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim])] + [ + ("time", Bint[num_steps]), + ("x_curr", Reals[dim]), + ("bias", Reals[dim]), + ] ) ) factor = trans + obs + bias_dist @@ -2397,15 +2769,29 @@ def test_sequential_sum_product_bias_1(num_steps, dim): def test_sequential_sum_product_bias_2(num_steps, num_sensors, dim): time = Variable("time", Bint[num_steps]) bias = Variable("bias", Reals[num_sensors, dim]) - bias_dist = random_gaussian(OrderedDict([("bias", Reals[num_sensors, dim])])) + bias_dist = random_gaussian( + OrderedDict( + [ + ("bias", Reals[num_sensors, dim]), + ] + ) + ) trans = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_prev", Reals[dim]), ("x_curr", Reals[dim])] + [ + ("time", Bint[num_steps]), + ("x_prev", Reals[dim]), + ("x_curr", Reals[dim]), + ] ) ) obs = random_gaussian( OrderedDict( - [("time", Bint[num_steps]), ("x_curr", Reals[dim]), ("bias", Reals[dim])] + [ + ("time", Bint[num_steps]), + ("x_curr", Reals[dim]), + ("bias", Reals[dim]), + ] ) ) @@ -2451,9 +2837,18 @@ def _check_sarkka_bilmes(trans, expected_inputs, global_vars, num_periods=1): @pytest.mark.parametrize("duration", [2, 3, 4, 5, 6]) def test_sarkka_bilmes_example_0(duration): - trans = random_tensor(OrderedDict({"time": Bint[duration], "a": Bint[3]})) + trans = random_tensor( + OrderedDict( + { + "time": Bint[duration], + "a": Bint[3], + } + ) + ) - expected_inputs = {"a": Bint[3]} + expected_inputs = { + "a": Bint[3], + } _check_sarkka_bilmes(trans, expected_inputs, frozenset()) @@ -2463,11 +2858,20 @@ def test_sarkka_bilmes_example_1(duration): trans = random_tensor( OrderedDict( - {"time": Bint[duration], "a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2]} + { + "time": Bint[duration], + "a": Bint[3], + "b": Bint[2], + "_PREV_b": Bint[2], + } ) ) - expected_inputs = {"a": Bint[3], "b": Bint[2], "_PREV_b": Bint[2]} + expected_inputs = { + "a": Bint[3], + "b": Bint[2], + "_PREV_b": Bint[2], + } _check_sarkka_bilmes(trans, expected_inputs, frozenset()) @@ -2553,11 +2957,20 @@ def test_sarkka_bilmes_example_5(duration): trans = random_tensor( OrderedDict( - {"time": Bint[duration], "a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2]} + { + "time": Bint[duration], + "a": Bint[3], + "_PREV_a": Bint[3], + "x": Bint[2], + } ) ) - expected_inputs = {"a": Bint[3], "_PREV_a": Bint[3], "x": Bint[2]} + expected_inputs = { + "a": Bint[3], + "_PREV_a": Bint[3], + "x": Bint[2], + } global_vars = frozenset(["x"]) @@ -2593,7 +3006,13 @@ def test_sarkka_bilmes_example_6(duration): @pytest.mark.parametrize("time_input", [("time", Bint[t]) for t in range(6, 11)]) -@pytest.mark.parametrize("global_inputs", [(), (("x", Bint[2]),)]) +@pytest.mark.parametrize( + "global_inputs", + [ + (), + (("x", Bint[2]),), + ], +) @pytest.mark.parametrize( "local_inputs", [ diff --git a/test/test_tensor.py b/test/test_tensor.py index ff06ce771..2050fb105 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -126,7 +126,15 @@ def test_indexing(): def test_advanced_indexing_shape(): I, J, M, N = 4, 4, 2, 3 - x = Tensor(randn((I, J)), OrderedDict([("i", Bint[I]), ("j", Bint[J])])) + x = Tensor( + randn((I, J)), + OrderedDict( + [ + ("i", Bint[I]), + ("j", Bint[J]), + ] + ), + ) m = Tensor(numeric_array([2, 3]), OrderedDict([("m", Bint[M])]), I) n = Tensor(numeric_array([0, 1, 1]), OrderedDict([("n", Bint[N])]), J) assert x.data.shape == (I, J) @@ -223,17 +231,54 @@ def test_advanced_indexing_tensor(output_shape): # x output = Reals[output_shape] x = random_tensor( - OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]), output + OrderedDict( + [ + ("i", Bint[2]), + ("j", Bint[3]), + ("k", Bint[4]), + ] + ), + output, + ) + i = random_tensor( + OrderedDict( + [ + ("u", Bint[5]), + ] + ), + Bint[2], + ) + j = random_tensor( + OrderedDict( + [ + ("v", Bint[6]), + ("u", Bint[5]), + ] + ), + Bint[3], + ) + k = random_tensor( + OrderedDict( + [ + ("v", Bint[6]), + ] + ), + Bint[4], ) - i = random_tensor(OrderedDict([("u", Bint[5])]), Bint[2]) - j = random_tensor(OrderedDict([("v", Bint[6]), ("u", Bint[5])]), Bint[3]) - k = random_tensor(OrderedDict([("v", Bint[6])]), Bint[4]) expected_data = empty((5, 6) + output_shape) for u in range(5): for v in range(6): expected_data[u, v] = x.data[i.data[u], j.data[v, u], k.data[v]] - expected = Tensor(expected_data, OrderedDict([("u", Bint[5]), ("v", Bint[6])])) + expected = Tensor( + expected_data, + OrderedDict( + [ + ("u", Bint[5]), + ("v", Bint[6]), + ] + ), + ) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) @@ -258,7 +303,13 @@ def test_advanced_indexing_tensor(output_shape): def test_advanced_indexing_lazy(output_shape): x = Tensor( randn((2, 3, 4) + output_shape), - OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]), + OrderedDict( + [ + ("i", Bint[2]), + ("j", Bint[3]), + ("k", Bint[4]), + ] + ), ) u = Variable("u", Bint[2]) v = Variable("v", Bint[3]) @@ -274,7 +325,15 @@ def test_advanced_indexing_lazy(output_shape): for u in range(2): for v in range(3): expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] - expected = Tensor(expected_data, OrderedDict([("u", Bint[2]), ("v", Bint[3])])) + expected = Tensor( + expected_data, + OrderedDict( + [ + ("u", Bint[2]), + ("v", Bint[3]), + ] + ), + ) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) @@ -304,7 +363,18 @@ def unary_eval(symbol, x): @pytest.mark.parametrize("dims", [(), ("a",), ("a", "b")]) @pytest.mark.parametrize( "symbol", - ["~", "-", "abs", "atanh", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh"], + [ + "~", + "-", + "abs", + "atanh", + "sqrt", + "exp", + "log", + "log1p", + "sigmoid", + "tanh", + ], ) def test_unary(symbol, dims): sizes = {"a": 3, "b": 4} @@ -837,7 +907,14 @@ def test_function_of_numeric_array(): def test_align(): x = Tensor( - randn((2, 3, 4)), OrderedDict([("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])]) + randn((2, 3, 4)), + OrderedDict( + [ + ("i", Bint[2]), + ("j", Bint[3]), + ("k", Bint[4]), + ] + ), ) y = x.align(("j", "k", "i")) assert isinstance(y, Tensor) @@ -950,13 +1027,41 @@ def test_tensor_stack(n, shape, dim): @pytest.mark.parametrize("output", [Bint[2], Real, Reals[4], Reals[2, 3]], ids=str) def test_funsor_stack(output): - x = random_tensor(OrderedDict([("i", Bint[2])]), output) - y = random_tensor(OrderedDict([("j", Bint[3])]), output) - z = random_tensor(OrderedDict([("i", Bint[2]), ("k", Bint[4])]), output) + x = random_tensor( + OrderedDict( + [ + ("i", Bint[2]), + ] + ), + output, + ) + y = random_tensor( + OrderedDict( + [ + ("j", Bint[3]), + ] + ), + output, + ) + z = random_tensor( + OrderedDict( + [ + ("i", Bint[2]), + ("k", Bint[4]), + ] + ), + output, + ) xy = Stack("t", (x, y)) assert isinstance(xy, Tensor) - assert xy.inputs == OrderedDict([("t", Bint[2]), ("i", Bint[2]), ("j", Bint[3])]) + assert xy.inputs == OrderedDict( + [ + ("t", Bint[2]), + ("i", Bint[2]), + ("j", Bint[3]), + ] + ) assert xy.output == output for j in range(3): assert_close(xy(t=0, j=j), x) @@ -966,7 +1071,12 @@ def test_funsor_stack(output): xyz = Stack("t", (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict( - [("t", Bint[3]), ("i", Bint[2]), ("j", Bint[3]), ("k", Bint[4])] + [ + ("t", Bint[3]), + ("i", Bint[2]), + ("j", Bint[3]), + ("k", Bint[4]), + ] ) assert xy.output == output for j in range(3): @@ -981,9 +1091,32 @@ def test_funsor_stack(output): @pytest.mark.parametrize("output", [Bint[2], Real, Reals[4], Reals[2, 3]], ids=str) def test_cat_simple(output): - x = random_tensor(OrderedDict([("i", Bint[2])]), output) - y = random_tensor(OrderedDict([("i", Bint[3]), ("j", Bint[4])]), output) - z = random_tensor(OrderedDict([("i", Bint[5]), ("k", Bint[6])]), output) + x = random_tensor( + OrderedDict( + [ + ("i", Bint[2]), + ] + ), + output, + ) + y = random_tensor( + OrderedDict( + [ + ("i", Bint[3]), + ("j", Bint[4]), + ] + ), + output, + ) + z = random_tensor( + OrderedDict( + [ + ("i", Bint[5]), + ("k", Bint[6]), + ] + ), + output, + ) assert Cat("i", (x,)) is x assert Cat("i", (y,)) is y @@ -991,13 +1124,22 @@ def test_cat_simple(output): xy = Cat("i", (x, y)) assert isinstance(xy, Tensor) - assert xy.inputs == OrderedDict([("i", Bint[2 + 3]), ("j", Bint[4])]) + assert xy.inputs == OrderedDict( + [ + ("i", Bint[2 + 3]), + ("j", Bint[4]), + ] + ) assert xy.output == output xyz = Cat("i", (x, y, z)) assert isinstance(xyz, Tensor) assert xyz.inputs == OrderedDict( - [("i", Bint[2 + 3 + 5]), ("j", Bint[4]), ("k", Bint[6])] + [ + ("i", Bint[2 + 3 + 5]), + ("j", Bint[4]), + ("k", Bint[6]), + ] ) assert xy.output == output diff --git a/test/test_terms.py b/test/test_terms.py index ab33c2566..daa5f49a5 100644 --- a/test/test_terms.py +++ b/test/test_terms.py @@ -260,7 +260,18 @@ def unary_eval(symbol, x): @pytest.mark.parametrize("data", [0, 0.5, 1]) @pytest.mark.parametrize( "symbol", - ["~", "-", "atanh", "abs", "sqrt", "exp", "log", "log1p", "sigmoid", "tanh"], + [ + "~", + "-", + "atanh", + "abs", + "sqrt", + "exp", + "log", + "log1p", + "sigmoid", + "tanh", + ], ) def test_unary(symbol, data): dtype = "real" @@ -276,7 +287,21 @@ def test_unary(symbol, data): check_funsor(actual, {}, Array[dtype, ()], expected_data) -BINARY_OPS = ["+", "-", "*", "/", "**", "==", "!=", "<", "<=", ">", ">=", "min", "max"] +BINARY_OPS = [ + "+", + "-", + "*", + "/", + "**", + "==", + "!=", + "<", + "<=", + ">", + ">=", + "min", + "max", +] BOOLEAN_OPS = ["&", "|", "^"]