Skip to content

Commit

Permalink
Rerun new version of black on the repo (#605)
Browse files Browse the repository at this point in the history
* rerun new black

* Use Python  3.8 for jax tests
  • Loading branch information
ordabayevy committed Feb 6, 2023
1 parent 77500f3 commit ff5e410
Show file tree
Hide file tree
Showing 28 changed files with 15 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7,3.8]
python-version: [3.8]
env:
CI: 1
FUNSOR_BACKEND: jax
Expand Down
1 change: 0 additions & 1 deletion examples/eeg_slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(
fine_observation_noise=False, # controls whether the observation noise depends on s_t
moment_matching_lag=1,
): # controls the expense of the moment matching approximation

self.num_components = num_components
self.hidden_dim = hidden_dim
self.obs_dim = obs_dim
Expand Down
1 change: 0 additions & 1 deletion examples/mixed_hmm/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def closure():


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", default="seal", type=str)
parser.add_argument("-g", "--group", default="none", type=str)
Expand Down
10 changes: 0 additions & 10 deletions examples/mixed_hmm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ def __init__(self, config):
self.params = self.initialize_params()

def initialize_params(self):

# dictionary of guide random effect parameters
params = {"eps_g": {}, "eps_i": {}}

N_state = self.config["sizes"]["state"]

# initialize group-level parameters
if self.config["group"]["random"] == "continuous":

params["eps_g"]["loc"] = Tensor(
pyro.param("loc_group", lambda: torch.zeros((N_state, N_state))),
OrderedDict([("y_prev", Bint[N_state])]),
Expand All @@ -48,7 +46,6 @@ def initialize_params(self):
# initialize individual-level random effect parameters
N_c = self.config["sizes"]["group"]
if self.config["individual"]["random"] == "continuous":

params["eps_i"]["loc"] = Tensor(
pyro.param(
"loc_individual", lambda: torch.zeros((N_c, N_state, N_state))
Expand All @@ -69,7 +66,6 @@ def initialize_params(self):
return self.params

def __call__(self):

# calls pyro.param so that params are exposed and constraints applied
# should not create any new torch.Tensors after __init__
self.initialize_params()
Expand Down Expand Up @@ -104,7 +100,6 @@ def __init__(self, config):
self.observations = self.initialize_observations()

def initialize_params(self):

# return a dict of per-site params as funsor.tensor.Tensors
params = {
"e_g": {},
Expand All @@ -126,7 +121,6 @@ def initialize_params(self):

# initialize group-level random effect parameters
if self.config["group"]["random"] == "discrete":

params["e_g"]["probs"] = Tensor(
pyro.param(
"probs_e_g",
Expand All @@ -142,7 +136,6 @@ def initialize_params(self):
)

elif self.config["group"]["random"] == "continuous":

# note these are prior values, trainable versions live in guide
params["eps_g"]["loc"] = Tensor(
torch.zeros((N_state, N_state)),
Expand All @@ -156,7 +149,6 @@ def initialize_params(self):
# initialize individual-level random effect parameters
N_c = self.config["sizes"]["group"]
if self.config["individual"]["random"] == "discrete":

params["e_i"]["probs"] = Tensor(
pyro.param(
"probs_e_i",
Expand All @@ -176,7 +168,6 @@ def initialize_params(self):
)

elif self.config["individual"]["random"] == "continuous":

params["eps_i"]["loc"] = Tensor(
torch.zeros((N_c, N_state, N_state)),
OrderedDict([("g", Bint[N_c]), ("y_prev", Bint[N_state])]),
Expand Down Expand Up @@ -311,7 +302,6 @@ def initialize_raggedness_masks(self):
return self.raggedness_masks

def __call__(self):

# calls pyro.param so that params are exposed and constraints applied
# should not create any new torch.Tensors after __init__
self.initialize_params()
Expand Down
1 change: 0 additions & 1 deletion examples/mixed_hmm/seal_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def download_seal_data(filename):


def prepare_seal(filename, random_effects):

if not os.path.exists(filename):
download_seal_data(filename)

Expand Down
3 changes: 0 additions & 3 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=set()):

# reverse the effects of alpha-renaming
with reflect:

lazy_output = self._eager_to_lazy[output]
lazy_fn = type(lazy_output)
lazy_inputs = lazy_output._ast_values
Expand Down Expand Up @@ -242,7 +241,6 @@ def adjoint_contract(
adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs
):
if prod_op is adj_prod_op and sum_op in (ops.null, adj_sum_op):

# the only change is here:
out_adj = Approximate(
adj_sum_op,
Expand Down Expand Up @@ -280,7 +278,6 @@ def adjoint_cat(adj_sum_op, adj_prod_op, out_adj, name, parts, part_name):

@adjoint_ops.register(Subs, AssociativeOp, AssociativeOp, Funsor, Funsor, tuple)
def adjoint_subs(adj_sum_op, adj_prod_op, out_adj, arg, subs):

# detect fresh variable collisions that should be relabeled and reduced
relabel = {k: interpreter.gensym(k) for k, v in subs}
relabeled_subs = tuple((relabel[k], v) for k, v in subs)
Expand Down
3 changes: 0 additions & 3 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def eager_contraction_to_reduce(red_op, bin_op, reduced_vars, term):

@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor)
def eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs):

if not reduced_vars.issubset(lhs.input_vars & rhs.input_vars):
args = red_op, bin_op, reduced_vars, (lhs, rhs)
result = eager.dispatch(Contraction, *args)(*args)
Expand Down Expand Up @@ -463,7 +462,6 @@ def normalize_trivial(red_op, bin_op, reduced_vars, term):

@normalize.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):

if not reduced_vars and red_op is not ops.null:
return Contraction(ops.null, bin_op, reduced_vars, *terms)

Expand All @@ -490,7 +488,6 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
return Contraction(red_op, bin_op, reduced_vars, *new_terms)

for i, v in enumerate(terms):

if not isinstance(v, Contraction):
continue

Expand Down
2 changes: 0 additions & 2 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def eager_log_prob(cls, *params):
return log_prob.align(tuple(inputs))

def _sample(self, sampled_vars, sample_inputs, rng_key):

# note this should handle transforms correctly via distribution_to_data
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()):
Expand Down Expand Up @@ -498,7 +497,6 @@ def indepdist_to_funsor(backend_dist, output=None, dim_to_name=None):


def expandeddist_to_funsor(backend_dist, output=None, dim_to_name=None):

funsor_base_dist = to_funsor(
backend_dist.base_dist, output=output, dim_to_name=dim_to_name
)
Expand Down
2 changes: 0 additions & 2 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def bint(size):


class ProductDomain(Domain):

_type_cache = WeakValueDictionary()

def __getitem__(cls, arg_domains):
Expand Down Expand Up @@ -453,7 +452,6 @@ def _find_domain_matmul(op, lhs, rhs):

@find_domain.register(ops.AssociativeOp)
def _find_domain_associative_generic(op, *domains):

assert 1 <= len(domains) <= 2

if len(domains) == 1:
Expand Down
1 change: 0 additions & 1 deletion funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def eager_integrate_gaussian_gaussian(log_measure, integrand, reduced_vars):
reduced_names = frozenset(v.name for v in reduced_vars)
real_vars = frozenset(v.name for v in reduced_vars if v.dtype == "real")
if real_vars:

lhs_reals = frozenset(
k for k, d in log_measure.inputs.items() if d.dtype == "real"
)
Expand Down
1 change: 1 addition & 0 deletions funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _infer_param_domain(cls, name, raw_shape):
# Converting distribution funsors to NumPyro distributions
###########################################################


# Convert Delta **distribution** to raw data
@to_data.register(Delta) # noqa: F821
def deltadist_to_data(funsor_dist, name_to_dim=None):
Expand Down
2 changes: 1 addition & 1 deletion funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _triangular_solve(x, y, upper=False, transpose=False):
prepend_ndim = dx - y.ndim # ndim of ... part
# Reshape x with the shape (..., 1, i, j, 1, n, m)
x_new_shape = batch_shape[:prepend_ndim]
for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]):
for sy, sx in zip(y.shape[:-2], batch_shape[prepend_ndim:]):
x_new_shape += (sx // sy, sy)
x_new_shape += (n, m)
x = np.reshape(x, x_new_shape)
Expand Down
4 changes: 1 addition & 3 deletions funsor/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@

@unfold.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):

for i, v in enumerate(terms):

if not isinstance(v, Contraction):
continue

Expand Down Expand Up @@ -121,7 +119,7 @@ def optimize_contract_finitary_funsor(red_op, bin_op, reduced_vars, terms):
reduce_dim_counter.update({d: 1 for d in input})

operands = list(terms)
for (a, b) in path:
for a, b in path:
b, a = tuple(sorted((a, b), reverse=True))
tb = operands.pop(b)
ta = operands.pop(a)
Expand Down
4 changes: 2 additions & 2 deletions funsor/pyro/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def log_prob(self, value):
seq_sum_prod = (
naive_sequential_sum_product if self.exact else sequential_sum_product
)
with (eager if self.exact else moment_matching):
with eager if self.exact else moment_matching:
result = self._trans + self._obs(value=value)
result = seq_sum_prod(
ops.logaddexp,
Expand Down Expand Up @@ -589,7 +589,7 @@ def filter(self, value):
seq_sum_prod = (
naive_sequential_sum_product if self.exact else sequential_sum_product
)
with (eager if self.exact else moment_matching):
with eager if self.exact else moment_matching:
logp = self._trans + self._obs(value=value)
logp = seq_sum_prod(
ops.logaddexp,
Expand Down
1 change: 0 additions & 1 deletion funsor/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self, default=None, name="PartialDispatcher"):
self.add(([object],), self.default)

def add(self, signature, func):

# Handle annotations
if not signature:
annotations = get_type_hints(func)
Expand Down
8 changes: 3 additions & 5 deletions funsor/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def partial_sum_product(
leaf = max(ordinal_to_factors, key=len) # CHOICE
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for (group_factors, group_vars) in _partition(
for group_factors, group_vars in _partition(
leaf_factors, leaf_reduce_vars
): # CHOICE
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate)
Expand Down Expand Up @@ -400,7 +400,7 @@ def dynamic_partial_sum_product(
leaf = max(ordinal_to_factors, key=len)
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for (group_factors, group_vars) in _partition(
for group_factors, group_vars in _partition(
leaf_factors, leaf_reduce_vars | markov_prod_vars
):
# eliminate non markov vars
Expand Down Expand Up @@ -529,7 +529,7 @@ def modified_partial_sum_product(
leaf = max(ordinal_to_factors, key=len)
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for (group_factors, group_vars) in _partition(
for group_factors, group_vars in _partition(
leaf_factors, leaf_reduce_vars | markov_prod_vars
):
# eliminate non markov vars
Expand Down Expand Up @@ -780,7 +780,6 @@ def _shift_funsor(f, t, global_vars):
def naive_sarkka_bilmes_product(
sum_op, prod_op, trans, time_var, global_vars=frozenset()
):

assert isinstance(global_vars, frozenset)

time = time_var.name
Expand Down Expand Up @@ -818,7 +817,6 @@ def naive_sarkka_bilmes_product(
def sarkka_bilmes_product(
sum_op, prod_op, trans, time_var, global_vars=frozenset(), num_periods=1
):

assert isinstance(global_vars, frozenset)

time = time_var.name
Expand Down
1 change: 0 additions & 1 deletion funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def random_mvn(batch_shape, dim, diag=False):


def make_plated_hmm_einsum(num_steps, num_obs_plates=1, num_hidden_plates=0):

assert num_obs_plates >= num_hidden_plates
t0 = num_obs_plates + 1

Expand Down
2 changes: 0 additions & 2 deletions test/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

@pytest.mark.parametrize("expr,expected_type", SMOKE_TESTS)
def test_smoke(expr, expected_type):

t = Tensor(randn(2, 3), OrderedDict([("i", Bint[2]), ("j", Bint[3])]))
assert isinstance(t, Tensor)

Expand Down Expand Up @@ -70,7 +69,6 @@ def test_smoke(expr, expected_type):

@pytest.mark.parametrize("expr,expected_type,expected_inputs", SUBS_TESTS)
def test_affine_subs(expr, expected_type, expected_inputs):

expected_output = Real

t = Tensor(randn(2, 3), OrderedDict([("i", Bint[2]), ("j", Bint[3])]))
Expand Down
1 change: 0 additions & 1 deletion test/test_alpha_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_subs_reduce():
@pytest.mark.parametrize("lhs_vars", [(), ("i",), ("j",), ("i", "j")])
@pytest.mark.parametrize("rhs_vars", [(), ("i",), ("j",), ("i", "j")])
def test_distribute_reduce(lhs_vars, rhs_vars):

lhs_vars, rhs_vars = frozenset(lhs_vars), frozenset(rhs_vars)
lhs = random_tensor(OrderedDict([("i", Bint[3]), ("j", Bint[2])]), Real)
rhs = random_tensor(OrderedDict([("i", Bint[3]), ("j", Bint[2])]), Real)
Expand Down
3 changes: 0 additions & 3 deletions test/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_gaussian_linear(approximate):


def test_backward_argmax_simple_reduce():

x = random_tensor(OrderedDict(i=Bint[2], j=Bint[3]))

with reflect:
Expand All @@ -165,7 +164,6 @@ def test_backward_argmax_simple_reduce():


def test_backward_argmax_simple_binary():

x1 = random_tensor(OrderedDict(i=Bint[2], j=Bint[3]))
x2 = random_tensor(OrderedDict(j=Bint[3], k=Bint[4]))
approx_vars = x1.input_vars | x2.input_vars
Expand All @@ -191,7 +189,6 @@ def test_backward_argmax_simple_binary():


def test_backward_argmax_simple_contraction():

x1 = random_tensor(OrderedDict(i=Bint[2], j=Bint[3]))
x2 = random_tensor(OrderedDict(j=Bint[3], k=Bint[4]))
approx_vars = x1.input_vars | x2.input_vars
Expand Down

0 comments on commit ff5e410

Please sign in to comment.