# Sum Product Network

In [1]:
from collections import OrderedDict

import jax
import numpy as np

import funsor
import funsor.jax.distributions as dist
import funsor.ops as ops

funsor.set_backend("jax")

### network

In [2]:
# sum_op = +, prod_op = *
# alternatively, we can use rewrite_ops as in
# https://github.com/pyro-ppl/funsor/pull/456
# and switch to sum_op = logsumexp, prod_op = +
# FIXME: what is the best way to set constraints to the weights
spn = 0.4 * (dist.Categorical(np.array([0.2, 0.8]), value="v0").exp() *
             (0.3 * (dist.Categorical(np.array([0.3, 0.7]), value="v1").exp() *
                     dist.Categorical(np.array([0.4, 0.6]), value="v2").exp())
            + 0.7 * (dist.Categorical(np.array([0.5, 0.5]), value="v1").exp() *
                     dist.Categorical(np.array([0.6, 0.4]), value="v2").exp()))) \
    + 0.6 * (dist.Categorical(np.array([0.2, 0.8]), value="v0").exp() *
             dist.Categorical(np.array([0.3, 0.7]), value="v1").exp() *
             dist.Categorical(np.array([0.4, 0.6]), value="v2").exp())
spn

Tensor([[[0.03408    0.03712   ]
  [0.05712    0.07167999]]

 [[0.13632001 0.14848001]
  [0.22848003 0.28672004]]], OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')

### marginalize

In [3]:
spn_marg = spn.reduce(ops.add, "v0")
print(spn_marg)

Tensor([[0.17040001 0.18560001]
 [0.28560004 0.35840005]], OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))


### likelihood

In [4]:
test_data = {"v0": 1, "v1": 0, "v2": 1}

In [5]:
ll_exp = spn(**test_data)
print(ll_exp.log(), ll_exp)

-1.9073049 0.14848001


In [6]:
llm_exp = spn_marg(**test_data)
print(llm_exp.log(), llm_exp)

-1.6841614 0.18560001


In [7]:
test_data2 = {"v1": 0, "v2": 1}
llom_exp = spn(**test_data2).reduce(ops.add)
print(llom_exp.log(), llom_exp)

-1.6841614 0.18560001


### sample

In [8]:
sample_inputs = OrderedDict(particle=funsor.Bint[5])
spn(v1=0, v2=0).sample(frozenset({"v0"}), sample_inputs, jax.random.PRNGKey(0))

Delta((('v0', (Tensor([1 0 1 1 0], OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor([-0.8297847 -0.8297847 -0.8297847 -0.8297847 -0.8297847], OrderedDict([('particle', Bint[5, ])]), 'real').reduce(nullop, set())

what is `-0.8297847`? a normalization factor? why the latter term is a constant in torch but it is an array in jax

### train parameters

### parameter optimization

### most probable explanation

### multivariate leaf

### cutset networks

### expectations and moments

In [9]:
# Integrate(q, x, q_vars)

### pareto

In [11]:
spn = 0.3 * dist.Pareto(1., 2., value="v0").exp() + 0.7 * dist.Pareto(1., 3., value="v0").exp()
print(spn(v0=1.5).log())

-0.523248
