Skip to content

Commit

Permalink
Extend partial_sum_product() to cases tractable for Gaussians (#584)
Browse files Browse the repository at this point in the history
* Handle intractable case for Gaussians in sum_product

* Strengthen test

* Add NotImplementedError at point in code

* Fix tests

* NotImplementedError -> assert

* Simplify

* lint

* Add pedantic kwarg

* Relax test tolerance
  • Loading branch information
fritzo committed Mar 2, 2022
1 parent 56fa967 commit ecbe439
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 12 deletions.
83 changes: 72 additions & 11 deletions funsor/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import funsor
import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.domains import Bint
from funsor.domains import Bint, Reals
from funsor.interpreter import gensym
from funsor.ops import UNITS, AssociativeOp
from funsor.terms import (
Cat,
Expand Down Expand Up @@ -116,7 +117,7 @@ def _unroll_plate(factors, var_to_ordinal, sum_vars, plate, step):
**{
prev: "{}_{}".format(var, i)
for prev, var in prev_to_var.items()
}
},
)
for i in range(size)
]
Expand Down Expand Up @@ -202,7 +203,7 @@ def partial_unroll(factors, eliminate=frozenset(), plate_to_step=dict()):


def partial_sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False
):
"""
Performs partial sum-product contraction of a collection of factors.
Expand All @@ -216,8 +217,21 @@ def partial_sum_product(
assert all(isinstance(f, Funsor) for f in factors)
assert isinstance(eliminate, frozenset)
assert isinstance(plates, frozenset)
sum_vars = eliminate - plates

if pedantic:
var_to_errors = defaultdict(lambda: eliminate)
for f in factors:
ordinal = plates.intersection(f.inputs)
for var in set(f.inputs) - plates - eliminate:
var_to_errors[var] &= ordinal
for var, errors in var_to_errors.items():
for plate in errors:
raise ValueError(
f"Cannot eliminate plate {plate} containing preserved var {var}"
)

plates &= eliminate
sum_vars = eliminate - plates
var_to_ordinal = {}
ordinal_to_factors = defaultdict(list)
for f in factors:
Expand All @@ -231,12 +245,15 @@ def partial_sum_product(
ordinal_to_vars[ordinal].add(var)

results = []

while ordinal_to_factors:
leaf = max(ordinal_to_factors, key=len)
leaf = max(ordinal_to_factors, key=len) # CHOICE
leaf_factors = ordinal_to_factors.pop(leaf)
leaf_reduce_vars = ordinal_to_vars[leaf]
for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars):
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars)
for (group_factors, group_vars) in _partition(
leaf_factors, leaf_reduce_vars
): # CHOICE
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & eliminate))
Expand All @@ -245,8 +262,50 @@ def partial_sum_product(
*(var_to_ordinal[v] for v in remaining_sum_vars)
)
if new_plates == leaf:
raise ValueError("intractable!")
f = f.reduce(prod_op, leaf - new_plates)
# Choose the smallest plate to eliminate.
plate = min(
(f.inputs[plate].size, plate) for plate in leaf & eliminate
)[-1]
new_plates = leaf - {plate}
plate_shape = (f.inputs[plate].size,)
subs = {}
for v in remaining_sum_vars:
if plate in var_to_ordinal[v]:
if f.inputs[v].dtype != "real":
raise ValueError("intractable!")
v_ = Variable(
gensym(v), Reals[plate_shape + f.inputs[v].shape]
)
v_ordinal = var_to_ordinal[v] - {plate}
var_to_ordinal[v_.name] = v_ordinal
ordinal_to_vars[v_ordinal].add(v_.name)
sum_vars = sum_vars - {v} | {v_.name}
eliminate = eliminate - {v} | {v_.name}
subs[v] = v_[plate]
# This will only work for terms implementing substituting
# {var1: ops.getitem(var2, var3)}, e.g. Gaussian but not Tensor.
f = f(**subs)
for o, gs in list(ordinal_to_factors.items()):
if plate not in o:
assert all(set(g.inputs).isdisjoint(subs) for g in gs)
continue # nothing to do below
remaining = []
for g in gs:
if set(subs).intersection(g.inputs):
g = g(**subs)
assert all(
plate not in var_to_ordinal[u]
for u in g.inputs
if u in sum_vars
)
g = g.reduce(prod_op, plate)
ordinal_to_factors[o - {plate}].append(g)
else:
remaining.append(g)
ordinal_to_factors[o] = remaining
reduced_plates = leaf - new_plates
assert reduced_plates.issubset(eliminate)
f = f.reduce(prod_op, reduced_plates)
ordinal_to_factors[new_plates].append(f)

return results
Expand Down Expand Up @@ -511,14 +570,16 @@ def modified_partial_sum_product(
return results


def sum_product(sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset()):
def sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False
):
"""
Performs sum-product contraction of a collection of factors.
:return: a single contracted Funsor.
:rtype: :class:`~funsor.terms.Funsor`
"""
factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates)
factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates, pedantic)
return reduce(prod_op, factors, Number(UNITS[prod_op]))


Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ filterwarnings = error
ignore::DeprecationWarning
ignore:CUDA initialization:UserWarning
ignore:floor_divide is deprecated:UserWarning
ignore:__floordiv__ is deprecated:UserWarning
ignore:torch.cholesky is deprecated:UserWarning
ignore:torch.symeig is deprecated:UserWarning
once::DeprecationWarning
Expand Down
2 changes: 1 addition & 1 deletion test/test_minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def hand_guide(data):
elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
elbo = elbo.differentiable_loss if backend == "pyro" else elbo
hand_loss = elbo(hand_model, hand_guide, data)
_check_loss_and_grads(hand_loss, auto_loss)
_check_loss_and_grads(hand_loss, auto_loss, rtol=1e-3, atol=1e-3)


@pytest.mark.xfail(reason="missing patterns")
Expand Down
208 changes: 208 additions & 0 deletions test/test_sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,214 @@ def test_partial_sum_product(impl, sum_op, prod_op, inputs, plates, vars1, vars2
assert_close(actual, unrolled_expected)


def test_partial_sum_product_batch_1():
factor = random_gaussian(OrderedDict(i=Bint[2], j=Bint[2], x=Real))
actual = sum_product(
ops.logaddexp,
ops.add,
[factor],
eliminate=frozenset("ix"),
plates=frozenset("ij"),
)
assert actual.inputs == OrderedDict(j=Bint[2])


def test_partial_sum_product_batch_2():
factor_i = random_gaussian(OrderedDict(i=Bint[2], x=Real))
factor_ik = random_gaussian(OrderedDict(i=Bint[2], k=Bint[2], x=Real))
actual = sum_product(
ops.logaddexp,
ops.add,
[factor_i, factor_ik],
eliminate=frozenset("ix"),
plates=frozenset("ik"),
)
assert actual.inputs == OrderedDict(k=Bint[2])


def test_intractable_1():
factor_i = random_gaussian(OrderedDict(i=Bint[2], x=Real))
factor_j = random_gaussian(OrderedDict(j=Bint[2], y=Real))
factor_ij = random_gaussian(OrderedDict(i=Bint[2], j=Bint[2], x=Real, y=Real))
actual = sum_product(
ops.logaddexp,
ops.add,
[factor_i, factor_j, factor_ij],
eliminate=frozenset("ijxy"),
plates=frozenset("ij"),
)
assert isinstance(actual, Tensor)
assert not actual.inputs

# Manually unroll plate j.
factor_j0 = factor_j(j=0, y="y0")
factor_j1 = factor_j(j=1, y="y1")
factor_ij0 = factor_ij(j=0, y="y0")
factor_ij1 = factor_ij(j=1, y="y1")
expected = sum_product(
ops.logaddexp,
ops.add,
[factor_i, factor_j0, factor_j1, factor_ij0, factor_ij1],
eliminate=frozenset(["i", "x", "y0", "y1"]),
plates=frozenset("i"),
)

assert_close(actual, expected, atol=5e-4, rtol=5e-4)


def test_intractable_2():
factor_i = random_gaussian(OrderedDict(i=Bint[2], x=Real, zx=Real))
factor_j = random_gaussian(OrderedDict(j=Bint[2], y=Real, zy=Real))
factor_ij = random_gaussian(
OrderedDict(i=Bint[2], j=Bint[2], x=Real, y=Real, zx=Real, zy=Real)
)
actual = sum_product(
ops.logaddexp,
ops.add,
[factor_i, factor_j, factor_ij],
eliminate=frozenset(["i", "j", "x", "y", "zx", "zy"]),
plates=frozenset("ij"),
)
assert isinstance(actual, Tensor)
assert not actual.inputs

# Manually unroll plate j.
factor_j0 = factor_j(j=0, y="y0", zy="zy0")
factor_j1 = factor_j(j=1, y="y1", zy="zy1")
factor_ij0 = factor_ij(j=0, y="y0", zy="zy0")
factor_ij1 = factor_ij(j=1, y="y1", zy="zy1")
expected = sum_product(
ops.logaddexp,
ops.add,
[factor_i, factor_j0, factor_j1, factor_ij0, factor_ij1],
eliminate=frozenset(["i", "x", "y0", "y1", "zx", "zy0", "zy1"]),
plates=frozenset("i"),
)

assert_close(actual, expected, atol=5e-4, rtol=5e-4)


def test_intractable_3():
factor_ = random_gaussian(OrderedDict(w=Real))
factor_i = random_gaussian(OrderedDict(i=Bint[2], w=Real, x=Real))
factor_j = random_gaussian(OrderedDict(j=Bint[2], y=Real))
factor_ij = random_gaussian(OrderedDict(i=Bint[2], j=Bint[2], x=Real, y=Real))
actual = sum_product(
ops.logaddexp,
ops.add,
[factor_, factor_i, factor_j, factor_ij],
eliminate=frozenset("ijxy"),
plates=frozenset("ij"),
)
assert set(actual.inputs) == {"w"}

# Manually unroll plate j.
factor_j0 = factor_j(j=0, y="y0")
factor_j1 = factor_j(j=1, y="y1")
factor_ij0 = factor_ij(j=0, y="y0")
factor_ij1 = factor_ij(j=1, y="y1")
expected = sum_product(
ops.logaddexp,
ops.add,
[factor_, factor_i, factor_j0, factor_j1, factor_ij0, factor_ij1],
eliminate=frozenset(["i", "x", "y0", "y1"]),
plates=frozenset("i"),
)
assert set(expected.inputs) == {"w"}

assert_close(actual, expected, atol=5e-4, rtol=5e-4)


def test_intractable_4():
factor_i = random_gaussian(OrderedDict(i=Bint[2], x=Real))
factor_jk = random_gaussian(OrderedDict(j=Bint[2], k=Bint[2], y=Real))
factor_ij = random_gaussian(OrderedDict(i=Bint[2], j=Bint[2], x=Real, y=Real))
actual = sum_product(
ops.logaddexp,
ops.add,
[factor_i, factor_jk, factor_ij],
eliminate=frozenset("ijxy"),
plates=frozenset("ijk"),
)
assert set(actual.inputs) == {"k"}

# Manually unroll plate j.
factor_jk0 = factor_jk(j=0, y="y0")
factor_jk1 = factor_jk(j=1, y="y1")
factor_ij0 = factor_ij(j=0, y="y0")
factor_ij1 = factor_ij(j=1, y="y1")
expected = sum_product(
ops.logaddexp,
ops.add,
[factor_i, factor_jk0, factor_jk1, factor_ij0, factor_ij1],
eliminate=frozenset(["i", "x", "y0", "y1"]),
plates=frozenset("ik"),
)
assert set(expected.inputs) == {"k"}

assert_close(actual, expected, atol=5e-4, rtol=5e-4)


def test_intractable_5():
factor_ik = random_gaussian(OrderedDict(i=Bint[2], k=Bint[2], x=Real))
factor_j = random_gaussian(OrderedDict(j=Bint[2], y=Real))
factor_ij = random_gaussian(OrderedDict(i=Bint[2], j=Bint[2], x=Real, y=Real))
actual = sum_product(
ops.logaddexp,
ops.add,
[factor_ik, factor_j, factor_ij],
eliminate=frozenset("ijxy"),
plates=frozenset("ijk"),
)
assert set(actual.inputs) == {"k"}

# Manually unroll plate j.
factor_j0 = factor_j(j=0, y="y0")
factor_j1 = factor_j(j=1, y="y1")
factor_ij0 = factor_ij(j=0, y="y0")
factor_ij1 = factor_ij(j=1, y="y1")
expected = sum_product(
ops.logaddexp,
ops.add,
[factor_ik, factor_j0, factor_j1, factor_ij0, factor_ij1],
eliminate=frozenset(["i", "x", "y0", "y1"]),
plates=frozenset("ik"),
)
assert set(expected.inputs) == {"k"}

assert_close(actual, expected, atol=5e-4, rtol=5e-4)


def test_var_in_plate_error():
factor_i = random_gaussian(OrderedDict(i=Bint[2], x=Real, z=Real))
with pytest.raises(ValueError):
sum_product(
ops.logaddexp,
ops.add,
[factor_i],
eliminate=frozenset("ix"),
plates=frozenset("i"),
pedantic=True,
)


@pytest.mark.xfail(reason="unclear semantics; incorrect computation of var_to_ordinal?")
def test_var_in_plate_ok():
zs = Variable("zs", Reals[2])
factor_i = random_gaussian(OrderedDict(i=Bint[2], x=Real, z=Real))
factor_i = factor_i(z=zs["i"])
with reflect:
actual = sum_product(
ops.logaddexp,
ops.add,
[factor_i],
eliminate=frozenset("ix"),
plates=frozenset("i"),
)
assert set(actual.inputs) == {"zs"}
assert actual.inputs["zs"] == factor_i.inputs["zs"]


@pytest.mark.parametrize(
"vars1,vars2",
[
Expand Down

0 comments on commit ecbe439

Please sign in to comment.