# Sum Product Network

Some text

In [1]:
from collections import OrderedDict

import torch

import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops

funsor.set_backend("torch")

### network

In [2]:
# sum_op = +, prod_op = *
# alternatively, we can use rewrite_ops as in
# https://github.com/pyro-ppl/funsor/pull/456
# and switch to sum_op = logsumexp, prod_op = +
spn = 0.4 * (dist.Categorical(torch.tensor([0.2, 0.8]), value="v0").exp() *
             (0.3 * (dist.Categorical(torch.tensor([0.3, 0.7]), value="v1").exp() *
                     dist.Categorical(torch.tensor([0.4, 0.6]), value="v2").exp())
            + 0.7 * (dist.Categorical(torch.tensor([0.5, 0.5]), value="v1").exp() *
                     dist.Categorical(torch.tensor([0.6, 0.4]), value="v2").exp()))) \
    + 0.6 * (dist.Categorical(torch.tensor([0.2, 0.8]), value="v0").exp() *
             dist.Categorical(torch.tensor([0.3, 0.7]), value="v1").exp() *
             dist.Categorical(torch.tensor([0.4, 0.6]), value="v2").exp())
spn

Tensor(tensor([[[0.0341, 0.0371],
         [0.0571, 0.0717]],

        [[0.1363, 0.1485],
         [0.2285, 0.2867]]]), OrderedDict([('v0', Bint[2, ]), ('v1', Bint[2, ]), ('v2', Bint[2, ])]), 'real')

### marginalize

In [3]:
spn_marg = spn.reduce(ops.add, "v0")
print(spn_marg)

Tensor(tensor([[0.1704, 0.1856],
        [0.2856, 0.3584]]), OrderedDict([('v1', Bint[2, ]), ('v2', Bint[2, ])]))


### likelihood

In [4]:
test_data = {"v0": 1, "v1": 0, "v2": 1}

In [5]:
ll_exp = spn(**test_data)
print(ll_exp.log(), ll_exp)

tensor(-1.9073) tensor(0.1485)


In [6]:
llm_exp = spn_marg(**test_data)
print(llm_exp.log(), llm_exp)

tensor(-1.6842) tensor(0.1856)


In [7]:
test_data2 = {"v1": 0, "v2": 1}
llom_exp = spn(**test_data2).reduce(ops.add)
print(llom_exp.log(), llom_exp)

tensor(-1.6842) tensor(0.1856)


### sample

In [8]:
sample_inputs = OrderedDict(particle=funsor.Bint[5])
spn(v1=0, v2=0).sample(frozenset({"v0"}), sample_inputs)

Delta((('v0', (Tensor(tensor([1, 1, 1, 0, 1]), OrderedDict([('particle', Bint[5, ])]), 2), Number(0.0))),)) + Tensor(-0.8297846913337708, OrderedDict(), 'real').reduce(nullop, set())

what is `-0.8297846913337708`? a normalization factor?

### train parameters

In [9]:
-torch.nn.functional.softplus(-torch.tensor(20.))

tensor(-2.0612e-09)

### parameter optimization

### most probable explanation

### multivariate leaf

### cutset networks

### expectations and moments

In [10]:
# Integrate(q, x, q_vars)

### pareto

In [11]:
spn = 0.3 * dist.Pareto(1., 2., value="v0").exp() + 0.7 * dist.Pareto(1., 3., value="v0").exp()
print(spn(v0=1.5).log())

tensor(-0.5232)
