Skip to content

Commit

Permalink
sketch for sum product network
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Feb 21, 2021
1 parent 80283a2 commit 274841d
Show file tree
Hide file tree
Showing 40 changed files with 502 additions and 316 deletions.
10 changes: 3 additions & 7 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
exclude_patterns = [
".ipynb_checkpoints",
"examples/*ipynb",
"examples/*py",
]
exclude_patterns = [".ipynb_checkpoints", "examples/*ipynb", "examples/*py"]

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
Expand Down Expand Up @@ -226,7 +222,7 @@
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, "Funsor.tex", u"Funsor Documentation", u"Uber AI Labs", "manual"),
(master_doc, "Funsor.tex", u"Funsor Documentation", u"Uber AI Labs", "manual")
]

# -- Options for manual page output ------------------------------------------
Expand All @@ -249,7 +245,7 @@
"Funsor",
"Functional analysis + tensors + symbolic algebra.",
"Miscellaneous",
),
)
]


Expand Down
2 changes: 1 addition & 1 deletion examples/eeg_slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_tensors_and_dists(self):
self.observation_matrix, obs_mvn, event_dims, "x", "y"
)

return trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist
return (trans_logits, trans_probs, trans_mvn, obs_mvn, x_trans_dist, y_dist)

# compute the marginal log probability of the observed data using a moment-matching approximation
@funsor.interpretation(funsor.terms.moment_matching)
Expand Down
12 changes: 4 additions & 8 deletions examples/mixed_hmm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def __init__(self, config):
def initialize_params(self):

# dictionary of guide random effect parameters
params = {
"eps_g": {},
"eps_i": {},
}
params = {"eps_g": {}, "eps_i": {}}

N_state = self.config["sizes"]["state"]

Expand Down Expand Up @@ -153,8 +150,7 @@ def initialize_params(self):
)

params["eps_g"]["scale"] = Tensor(
torch.ones((N_state, N_state)),
OrderedDict([("y_prev", Bint[N_state])]),
torch.ones((N_state, N_state)), OrderedDict([("y_prev", Bint[N_state])])
)

# initialize individual-level random effect parameters
Expand All @@ -164,7 +160,7 @@ def initialize_params(self):
params["e_i"]["probs"] = Tensor(
pyro.param(
"probs_e_i",
lambda: torch.randn((N_c, N_v,)).abs(),
lambda: torch.randn((N_c, N_v)).abs(),
constraint=constraints.simplex,
),
OrderedDict([("g", Bint[N_c])]), # different value per group
Expand Down Expand Up @@ -324,7 +320,7 @@ def __call__(self):

# initialize gamma to uniform
gamma = Tensor(
torch.zeros((N_state, N_state)), OrderedDict([("y_prev", Bint[N_state])]),
torch.zeros((N_state, N_state)), OrderedDict([("y_prev", Bint[N_state])])
)

N_v = self.config["sizes"]["random"]
Expand Down
2 changes: 1 addition & 1 deletion examples/slds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main(args):
)
trans_noise = funsor.Tensor(
torch.tensor(
[0.1, 1.0,], # low noise component # high noisy component
[0.1, 1.0], # low noise component # high noisy component
requires_grad=True,
)
)
Expand Down
6 changes: 1 addition & 5 deletions funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,4 @@ def extract_affine(fn):
return const, coeffs


__all__ = [
"affine_inputs",
"extract_affine",
"is_affine",
]
__all__ = ["affine_inputs", "extract_affine", "is_affine"]
6 changes: 1 addition & 5 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,7 @@ def unary_contract(op, arg):
)


BACKEND_TO_EINSUM_BACKEND = {
"numpy": "numpy",
"torch": "torch",
"jax": "jax.numpy",
}
BACKEND_TO_EINSUM_BACKEND = {"numpy": "numpy", "torch": "torch", "jax": "jax.numpy"}
# NB: numpy_log, numpy_map is backend-agnostic so they also work for torch backend;
# however, we might need to profile to make a switch
BACKEND_TO_LOGSUMEXP_BACKEND = {
Expand Down
5 changes: 1 addition & 4 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,4 @@ def eager_independent_delta(delta, reals_var, bint_var, diag_var):
return None


__all__ = [
"Delta",
"solve",
]
__all__ = ["Delta", "solve"]
4 changes: 2 additions & 2 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def eager_log_prob(cls, *params):
params, value = params[:-1], params[-1]
params = params + (Variable("value", value.output),)
instance = reflect.interpret(cls, *params)
raw_dist, value_name, value_output, dim_to_name = instance._get_raw_dist()
(raw_dist, value_name, value_output, dim_to_name) = instance._get_raw_dist()
assert value.output == value_output
name_to_dim = {v: k for k, v in dim_to_name.items()}
dim_to_name.update(
Expand Down Expand Up @@ -379,7 +379,7 @@ def dist_init(self, **kwargs):
dist_class = DistributionMeta(
backend_dist_class.__name__.split("Wrapper_")[-1],
(Distribution,),
{"dist_class": backend_dist_class, "__init__": dist_init,},
{"dist_class": backend_dist_class, "__init__": dist_init},
)

if generate_eager:
Expand Down
7 changes: 1 addition & 6 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,9 +779,4 @@ def eager_neg(op, arg):
return Gaussian(info_vec, precision, arg.inputs)


__all__ = [
"BlockMatrix",
"BlockVector",
"Gaussian",
"align_gaussian",
]
__all__ = ["BlockMatrix", "BlockVector", "Gaussian", "align_gaussian"]
9 changes: 1 addition & 8 deletions funsor/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,4 @@ def print_counters():
print("-" * 80)


__all__ = [
"DEBUG",
"PROFILE",
"STACK_SIZE",
"debug_logged",
"get_indent",
"profile",
]
__all__ = ["DEBUG", "PROFILE", "STACK_SIZE", "debug_logged", "get_indent", "profile"]
4 changes: 1 addition & 3 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,4 @@ def eager_integrate(log_measure, integrand, reduced_vars):
return None # defer to default implementation


__all__ = [
"Integrate",
]
__all__ = ["Integrate"]
2 changes: 1 addition & 1 deletion funsor/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def interpret(cls, *args):

def interpretation(new):
warnings.warn(
"'with interpretation(x)' should be replaced by 'with x'", DeprecationWarning,
"'with interpretation(x)' should be replaced by 'with x'", DeprecationWarning
)
return new

Expand Down
2 changes: 1 addition & 1 deletion funsor/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@adjoint_ops.register(
Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object,
Tensor, AssociativeOp, AssociativeOp, Funsor, (DeviceArray, Tracer), tuple, object
)
def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
return {}
Expand Down
5 changes: 1 addition & 4 deletions funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,7 @@ def _triangular_solve(x, y, upper=False, transpose=False):
x_new_shape = batch_shape[:prepend_ndim]
for (sy, sx) in zip(y.shape[:-2], batch_shape[prepend_ndim:]):
x_new_shape += (sx // sy, sy)
x_new_shape += (
n,
m,
)
x_new_shape += (n, m)
x = np.reshape(x, x_new_shape)
# Permute y to make it have shape (..., 1, j, m, i, 1, n)
batch_ndim = x.ndim - 2
Expand Down
2 changes: 1 addition & 1 deletion funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss
discrete += gaussian.log_normalizer
new_discrete = discrete.reduce(ops.logaddexp, approx_vars & discrete.input_vars)
num_elements = reduce(
ops.mul, [v.output.num_elements for v in approx_vars - discrete.input_vars], 1,
ops.mul, [v.output.num_elements for v in approx_vars - discrete.input_vars], 1
)
if num_elements != 1:
new_discrete -= math.log(num_elements)
Expand Down
4 changes: 1 addition & 3 deletions funsor/memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,4 @@ def interpret(self, cls, *args):
return value


__all__ = [
"memoize",
]
__all__ = ["memoize"]
4 changes: 1 addition & 3 deletions funsor/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,4 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars):
return Integrate(sample, integrand, reduced_vars)


__all__ = [
"MonteCarlo",
]
__all__ = ["MonteCarlo"]
4 changes: 1 addition & 3 deletions funsor/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,4 @@ def dispatch(self, key, *args):
return self[key].partial_call(*args)


__all__ = [
"KeyedRegistry",
]
__all__ = ["KeyedRegistry"]
14 changes: 4 additions & 10 deletions funsor/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def visit_UnaryOp(self, node):
var = self.prefix.get(type(node.op))
if var is not None:
node = ast.Call(
func=ast.Name(id=var, ctx=ast.Load(),),
args=[node.operand],
keywords=[],
func=ast.Name(id=var, ctx=ast.Load()), args=[node.operand], keywords=[]
)
return node

Expand All @@ -70,7 +68,7 @@ def visit_BinOp(self, node):
var = self.infix.get(type(node.op))
if var is not None:
node = ast.Call(
func=ast.Name(id=var, ctx=ast.Load(),),
func=ast.Name(id=var, ctx=ast.Load()),
args=[node.left, node.right],
keywords=[],
)
Expand All @@ -92,7 +90,7 @@ def visit_Compare(self, node):
var = self.infix.get(type(node_op))
if var is not None:
node = ast.Call(
func=ast.Name(id=var, ctx=ast.Load(),),
func=ast.Name(id=var, ctx=ast.Load()),
args=[node.left, node_right],
keywords=[],
)
Expand Down Expand Up @@ -163,8 +161,4 @@ def decorator(fn):
return decorator


__all__ = [
"INFIX_OPERATORS",
"PREFIX_OPERATORS",
"rewrite_ops",
]
__all__ = ["INFIX_OPERATORS", "PREFIX_OPERATORS", "rewrite_ops"]
2 changes: 1 addition & 1 deletion funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,7 @@ def eager_subs(self, subs):
n -= size
assert False
elif isinstance(value, Slice):
start, stop, step = value.slice.start, value.slice.stop, value.slice.step
start, stop, step = (value.slice.start, value.slice.stop, value.slice.step)
new_parts = []
pos = 0
for part in self.parts:
Expand Down
2 changes: 1 addition & 1 deletion funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6):
actual = actual.align(tuple(n for n, p in expected.terms))
for (
(actual_name, (actual_point, actual_log_density)),
(expected_name, (expected_point, expected_log_density),),
(expected_name, (expected_point, expected_log_density)),
) in zip(actual.terms, expected.terms):
assert actual_name == expected_name
assert_close(actual_point, expected_point, atol=atol, rtol=rtol)
Expand Down
5 changes: 1 addition & 4 deletions scripts/update_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@

root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
blacklist = ["/build/", "/dist/"]
file_types = [
("*.py", "# {}"),
("*.cpp", "// {}"),
]
file_types = [("*.py", "# {}"), ("*.cpp", "// {}")]

parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
Expand Down
13 changes: 4 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,13 @@
description="A tensor-like library for functions and distributions",
packages=find_packages(include=["funsor", "funsor.*"]),
url="https://github.com/pyro-ppl/funsor",
project_urls={"Documentation": "https://funsor.pyro.ai",},
project_urls={"Documentation": "https://funsor.pyro.ai"},
author="Uber AI Labs",
python_requires=">=3.6",
install_requires=[
"makefun",
"multipledispatch",
"numpy>=1.7",
"opt_einsum>=2.3.2",
],
install_requires=["makefun", "multipledispatch", "numpy>=1.7", "opt_einsum>=2.3.2"],
extras_require={
"torch": ["pyro-ppl>=1.5.2", "torch>=1.7.0",],
"jax": ["numpyro>=0.2.4", "jax>=0.1.57", "jaxlib>=0.1.37",],
"torch": ["pyro-ppl>=1.5.2", "torch>=1.7.0"],
"jax": ["numpyro>=0.2.4", "jax>=0.1.57", "jaxlib>=0.1.37"],
"test": [
"black",
"flake8",
Expand Down

0 comments on commit 274841d

Please sign in to comment.