In [1]:
import numpy as np
from copy import deepcopy

In [2]:
# number of variables in a factor graph
n_vars = 3

# all variables
variables = ['x' + str(i) for i in range(n_vars)]

# number of factors in a factor graph
n_funcs = 5

# variable space
var_space = [1., 0., 25., -25.]

# some function examples
f0 = lambda x : x['x0'] + x['x1']
f1 = lambda x : x['x1']
f2 = lambda x : x['x2']
f3 = lambda x : x['x1'] + x['x0']
f4 = lambda x : x['x0']

# list of all functions
funcs = {'f0': f0, 'f1': f1, 'f2': f2, 'f3': f3, 'f4': f4}

# list of argument numbers required for each functions
deps = {'f0': ['x0','x1'], 'f1': ['x1'], 'f2': ['x2'], 'f3': ['x0','x1'], 'f4': ['x0']}

# sanity check
assert(len(deps) == n_funcs)
assert(len(funcs) == n_funcs)

# list of all nodes
nodes = list(funcs.keys()) + variables

# sanity check
assert(len(nodes) == n_vars + n_funcs)

In [3]:
# dictionaries merger
# https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression
def merge_two_dicts(x, y):
    z = x.copy()   # start with x's keys and values
    z.update(y)    # modifies z with y's keys and values & returns None
    return z

# brute force marginalization
def brute_marginalize1(sum_over, assignments):
    if len(sum_over) == 0:
        res = lambda x : 1.
        for f in funcs.values():
            res = lambda x, f=f, g=res, ass=assignments : f(merge_two_dicts(x, ass)) * g(merge_two_dicts(x, ass))
        return res
    var = sum_over[0]
    result = lambda x : 0.
    for value in var_space:
        new_assignments = deepcopy(assignments)
        new_assignments[var] = value
        result = lambda x, f=brute_marginalize1(sum_over[1:], new_assignments),g=result : f(x) + g(x)
    return result

def brute_marginalize(variable):
    sum_over = deepcopy(variables)
    sum_over.remove(variable)
    return lambda x, f=brute_marginalize1(sum_over, {}), v=variable : f({v : x})

# get siblings of a node in the factor graph
def get_siblings(node):
    if node[0] == 'f':
        return deepcopy(deps[node])
    elif node[0] == 'x':
        return deepcopy([key for key, value in deps.items() if node in value])

In [59]:
# array with all of the messages
messages = {n: {x: (lambda z: 1.) for x in nodes} for n in nodes}

def message_v_a(v, a):
    # v -- variable node
    # a -- factor node
    neigh_v = get_siblings(v)
    neigh_v.remove(a)
    res = lambda x: 1.
    for f in neigh_v:
        res = lambda x, f=messages[f][v], g=res: g(x) * f(x)
    messages[v][a] = res
    return res

# returns list of assignments for a sum over sum_over
def all_assignments(sum_over, assignments = {}):
    result = []
    if len(sum_over) == 0:
        return [assignments]
    else:
        current_var = sum_over[0]
        for value in var_space:
            assignments_new = deepcopy(assignments)
            assignments_new[current_var] = value
            [result.append(tmp) for tmp in all_assignments(sum_over[1:], assignments_new)]
        return result
    
def message_a_v(v, a):
    # v -- variable node
    # a -- factor node
    neigh_a = get_siblings(a)
    neigh_a.remove(v)
    sum_assignments = all_assignments(neigh_a, {})
    res = lambda x : 0.
    for assignment in sum_assignments:
        product = lambda x : 1.
        for v1 in neigh_a:
            #print('{}->{}: {}->{}'.format(a, v, v1, a))
            product = lambda x, element = messages[v1][a], curr=product : curr(x) * element(x)
        res = lambda x, curr=res, element1=product, ass=assignment, cf=funcs[a] : curr(merge_two_dicts(x, ass)) + \
            (element1(merge_two_dicts(x, ass)) * cf(merge_two_dicts(x, ass)))
    messages[a][v] = res
    return res

def set_messages():
    for v in variables:
        for a in funcs.keys():
            if v in get_siblings(a):
                messages[v][a] = message_v_a(v, a)
                messages[a][v] = message_a_v(v, a)
                
def bp_marginalize(variable):
    res = lambda x : 1.
    for f in get_siblings(variable):
        res = lambda x, f=messages[f][variable], g=res: f(x) * g(x)
    return res
set_messages()
#set_messages()
#set_messages()

In [69]:
#print(bp_marginalize('x1'))
print(bp_marginalize('x0')({'x0': 3}))

507.0


In [70]:
brute_marginalize('x0')(3)

22548.0