Skip to content

Commit

Permalink
Scale factors across plate dims in partial_sum_product (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Aug 31, 2023
1 parent ff5e410 commit 349c038
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: [3.9]
env:
CI: 1
FUNSOR_BACKEND: jax
Expand Down
4 changes: 4 additions & 0 deletions funsor/ops/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .op import (
BINARY_INVERSES,
DISTRIBUTIVE_OPS,
PRODUCT_TO_POWER,
SAFE_BINARY_INVERSES,
UNARY_INVERSES,
UNITS,
Expand Down Expand Up @@ -287,6 +288,9 @@ def sigmoid_log_abs_det_jacobian(x, y):
UNARY_INVERSES[mul] = reciprocal
UNARY_INVERSES[add] = neg

PRODUCT_TO_POWER[add] = mul
PRODUCT_TO_POWER[mul] = pow

__all__ = [
"AssociativeOp",
"ComparisonOp",
Expand Down
2 changes: 2 additions & 0 deletions funsor/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def log_abs_det_jacobian(x, y, fn):
BINARY_INVERSES = {} # binary op -> inverse binary op
SAFE_BINARY_INVERSES = {} # binary op -> numerically safe inverse binary op
UNARY_INVERSES = {} # binary op -> inverse unary op
PRODUCT_TO_POWER = {} # product op -> power op

__all__ = [
"BINARY_INVERSES",
Expand All @@ -430,6 +431,7 @@ def log_abs_det_jacobian(x, y, fn):
"LogAbsDetJacobianOp",
"NullaryOp",
"Op",
"PRODUCT_TO_POWER",
"SAFE_BINARY_INVERSES",
"TernaryOp",
"TransformOp",
Expand Down
49 changes: 44 additions & 5 deletions funsor/sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from funsor.cnf import Contraction
from funsor.domains import Bint, Reals
from funsor.interpreter import gensym
from funsor.ops import UNITS, AssociativeOp
from funsor.ops import PRODUCT_TO_POWER, UNITS, AssociativeOp
from funsor.terms import (
Cat,
Funsor,
Expand Down Expand Up @@ -203,7 +203,14 @@ def partial_unroll(factors, eliminate=frozenset(), plate_to_step=dict()):


def partial_sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False
sum_op,
prod_op,
factors,
eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
pow_op=None,
plate_to_scale=None, # dict
):
"""
Performs partial sum-product contraction of a collection of factors.
Expand All @@ -218,6 +225,10 @@ def partial_sum_product(
assert isinstance(eliminate, frozenset)
assert isinstance(plates, frozenset)

if plate_to_scale:
if pow_op is None:
pow_op = PRODUCT_TO_POWER[prod_op]

if pedantic:
var_to_errors = defaultdict(lambda: eliminate)
for f in factors:
Expand Down Expand Up @@ -256,7 +267,17 @@ def partial_sum_product(
f = reduce(prod_op, group_factors).reduce(sum_op, group_vars & eliminate)
remaining_sum_vars = sum_vars.intersection(f.inputs)
if not remaining_sum_vars:
results.append(f.reduce(prod_op, leaf & eliminate))
f = f.reduce(prod_op, leaf & eliminate)
if plate_to_scale:
f_scales = [
plate_to_scale[plate]
for plate in leaf & eliminate
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
results.append(f)
else:
new_plates = frozenset().union(
*(var_to_ordinal[v] for v in remaining_sum_vars)
Expand Down Expand Up @@ -306,6 +327,15 @@ def partial_sum_product(
reduced_plates = leaf - new_plates
assert reduced_plates.issubset(eliminate)
f = f.reduce(prod_op, reduced_plates)
if plate_to_scale:
f_scales = [
plate_to_scale[plate]
for plate in reduced_plates
if plate in plate_to_scale
]
if f_scales:
scale = reduce(ops.mul, f_scales)
f = pow_op(f, scale)
ordinal_to_factors[new_plates].append(f)

return results
Expand Down Expand Up @@ -571,15 +601,24 @@ def modified_partial_sum_product(


def sum_product(
sum_op, prod_op, factors, eliminate=frozenset(), plates=frozenset(), pedantic=False
sum_op,
prod_op,
factors,
eliminate=frozenset(),
plates=frozenset(),
pedantic=False,
pow_op=None,
plate_to_scale=None, # dict
):
"""
Performs sum-product contraction of a collection of factors.
:return: a single contracted Funsor.
:rtype: :class:`~funsor.terms.Funsor`
"""
factors = partial_sum_product(sum_op, prod_op, factors, eliminate, plates, pedantic)
factors = partial_sum_product(
sum_op, prod_op, factors, eliminate, plates, pedantic, pow_op, plate_to_scale
)
return reduce(prod_op, factors, Number(UNITS[prod_op]))


Expand Down
4 changes: 2 additions & 2 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def id_from_inputs(inputs):

@dispatch(object, object, Variadic[float])
def allclose(a, b, rtol=1e-05, atol=1e-08):
if type(a) != type(b):
if type(a) is not type(b):
return False
return ops.abs(a - b) < rtol + atol * ops.abs(b)

Expand Down Expand Up @@ -125,7 +125,7 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6):
elif isinstance(actual, Gaussian):
assert isinstance(expected, Gaussian)
else:
assert type(actual) == type(expected), msg
assert type(actual) is type(expected), msg

if isinstance(actual, Funsor):
assert isinstance(expected, Funsor), msg
Expand Down
96 changes: 95 additions & 1 deletion test/test_sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
sum_product,
)
from funsor.tensor import Tensor, get_default_prototype
from funsor.terms import Variable
from funsor.terms import Cat, Variable
from funsor.testing import assert_close, random_gaussian, random_tensor
from funsor.util import get_backend

Expand Down Expand Up @@ -2899,3 +2899,97 @@ def test_mixed_sequential_sum_product(duration, num_segments):
)

assert_close(actual, expected)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale", [1, 2])
def test_partial_sum_product_scale_1(sum_op, prod_op, scale):
f1 = random_tensor(OrderedDict(a=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3]))

eliminate = frozenset("ai")
plates = frozenset("i")

# Actual result based on applying scaling
factors = [f1, f2]
scales = {"i": scale}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f3 = Cat("i", (f2,) * scale)
factors = [f1, f3]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale_i", [1, 2])
@pytest.mark.parametrize("scale_j", [1, 3])
def test_partial_sum_product_scale_2(sum_op, prod_op, scale_i, scale_j):
f1 = random_tensor(OrderedDict(a=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[3]))
f3 = random_tensor(OrderedDict(a=Bint[2], j=Bint[4]))

eliminate = frozenset("aij")
plates = frozenset("ij")

# Actual result based on applying scaling
factors = [f1, f2, f3]
scales = {"i": scale_i, "j": scale_j}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f4 = Cat("i", (f2,) * scale_i)
f5 = Cat("j", (f3,) * scale_j)
factors = [f1, f4, f5]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize(
"sum_op,prod_op",
[(ops.logaddexp, ops.add), (ops.add, ops.mul)],
)
@pytest.mark.parametrize("scale_i", [1, 2])
@pytest.mark.parametrize("scale_j", [1, 3])
@pytest.mark.parametrize("scale_k", [1, 4])
def test_partial_sum_product_scale_3(sum_op, prod_op, scale_i, scale_j, scale_k):
f1 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2]))
f2 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3]))
f3 = random_tensor(OrderedDict(a=Bint[2], i=Bint[2], j=Bint[3], k=Bint[3]))

eliminate = frozenset("aijk")
plates = frozenset("ijk")

# Actual result based on applying scaling
factors = [f1, f2, f3]
scales = {"i": scale_i, "j": scale_j, "k": scale_k}
actual = sum_product(
sum_op, prod_op, factors, eliminate, plates, plate_to_scale=scales
)

# Expected result based on concatenating factors
f4 = Cat("i", (f1,) * scale_i)
# concatenate across multiple dims
f5 = Cat("i", (f2,) * scale_i)
f5 = Cat("j", (f5,) * scale_j)
# concatenate across multiple dims
f6 = Cat("i", (f3,) * scale_i)
f6 = Cat("j", (f6,) * scale_j)
f6 = Cat("k", (f6,) * scale_k)
factors = [f4, f5, f6]
expected = sum_product(sum_op, prod_op, factors, eliminate, plates)

assert_close(actual, expected, atol=1e-4, rtol=1e-4)
4 changes: 2 additions & 2 deletions test/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_to_funsor_error(x):
def test_to_data():
actual = to_data(Number(0.0))
expected = 0.0
assert type(actual) == type(expected)
assert type(actual) is type(expected)
assert actual == expected


Expand Down Expand Up @@ -569,7 +569,7 @@ def test_stack_slice(start, stop, step):
xs = tuple(map(Number, range(10)))
actual = Stack("i", xs)(i=Slice("j", start, stop, step, dtype=10))
expected = Stack("j", xs[start:stop:step])
assert type(actual) == type(expected)
assert type(actual) is type(expected)
assert actual.name == expected.name
assert actual.parts == expected.parts

Expand Down

0 comments on commit 349c038

Please sign in to comment.