In [1]:
import numpy as np
from copy import deepcopy

In [134]:
# number of variables in a factor graph
n_vars = 6

# all variables
variables = ['x' + str(i) for i in range(n_vars)]

# number of factors in a factor graph
n_funcs = 4

# variable space
var_space = [-0.1, 1.2]

# some function examples
f0 = lambda x : x['x0'] + x['x1'] * x['x2']
f1 = lambda x : x['x0'] + np.sin(x['x3']) + np.cos(x['x5'])
f2 = lambda x : np.exp(x['x3'])
f3 = lambda x : x['x3'] * x['x4']

# list of all functions
funcs = {'f0': f0, 'f1': f1, 'f2': f2, 'f3': f3}

# list of argument numbers required for each functions
deps = {'f0': ['x0','x1','x2'], 'f1': ['x0', 'x3', 'x5'], 'f2': ['x3'], 'f3': ['x3','x4']}

# sanity check
assert(len(deps) == n_funcs)
assert(len(funcs) == n_funcs)

# list of all nodes
nodes = list(funcs.keys()) + variables

# sanity check
assert(len(nodes) == n_vars + n_funcs)

In [135]:
# dictionaries merger
# https://stackoverflow.com/questions/38987/how-to-merge-two-dictionaries-in-a-single-expression
def merge_two_dicts(x, y):
    z = x.copy()   # start with x's keys and values
    z.update(y)    # modifies z with y's keys and values & returns None
    return z

# brute force marginalization
def brute_marginalize1(sum_over, assignments):
    if len(sum_over) == 0:
        res = lambda x : 1.
        for f in funcs.values():
            res = lambda x, f=f, g=res, ass=assignments : f(merge_two_dicts(x, ass)) * g(merge_two_dicts(x, ass))
        return res
    var = sum_over[0]
    result = lambda x : 0.
    for value in var_space:
        new_assignments = deepcopy(assignments)
        new_assignments[var] = value
        result = lambda x, f=brute_marginalize1(sum_over[1:], new_assignments),g=result : f(x) + g(x)
    return result

def brute_marginalize(variable):
    sum_over = deepcopy(variables)
    sum_over.remove(variable)
    return lambda x, f=brute_marginalize1(sum_over, {}), v=variable : f({v : x})

# get siblings of a node in the factor graph
def get_siblings(node):
    if node[0] == 'f':
        return deepcopy(deps[node])
    elif node[0] == 'x':
        return deepcopy([key for key, value in deps.items() if node in value])

In [136]:
# array with all of the messages
messages = {n: {x: (lambda z: None) for x in nodes} for n in nodes}

def get_descendants(node):
    return tree[node]

def message_v_a(v, a):
    # v -- variable node
    # a -- factor node
    neigh_v = get_descendants(v)
    if a in neigh_v:
        neigh_v.remove(a)
    res = lambda x: 1.
    for f in neigh_v:
        res = lambda x, f=messages[f][v], g=res: g(x) * f(x)
    messages[v][a] = res
    return res

# returns list of assignments for a sum over sum_over
def all_assignments(sum_over, assignments = {}):
    result = []
    if len(sum_over) == 0:
        return [assignments]
    else:
        current_var = sum_over[0]
        for value in var_space:
            assignments_new = deepcopy(assignments)
            assignments_new[current_var] = value
            [result.append(tmp) for tmp in all_assignments(sum_over[1:], assignments_new)]
        return result
    
def message_a_v(a, v):
    # v -- variable node
    # a -- factor node
    neigh_a = get_descendants(a)
    if v in neigh_a:
        neigh_a.remove(v)
    sum_assignments = all_assignments(neigh_a, {})
    res = lambda x : 0.
    for assignment in sum_assignments:
        product = lambda x : 1.
        for v1 in neigh_a:
            #print('{}->{}: {}->{}'.format(a, v, v1, a))
            product = lambda x, element = messages[v1][a], curr=product : curr(x) * element(x)
        res = lambda x, curr=res, element1=product, ass=assignment, cf=funcs[a] : curr(merge_two_dicts(x, ass)) + \
            (element1(merge_two_dicts(x, ass)) * cf(merge_two_dicts(x, ass)))
    messages[a][v] = res
    return res

def set_message(u, v):
    if u[0] == 'f' and v[0] == 'x':
        messages[u][v] = message_a_v(u, v)
    elif u[0] == 'x' and v[0] == 'f':
        messages[u][v] = message_v_a(u, v)
    else:
        print("Error")
                
def bp_marginalize(variable):
    res = lambda x : 1.
    for f in get_descendants(variable):
        res = lambda x, f=messages[f][variable], g=res: f(x) * g(x)
    return res
#set_messages()
#set_messages()
#set_messages()

def get_tree(node, visited = []):
    res = {}
    res[node] = []
    for v in get_siblings(node):
        if v not in visited:
            res[node].append(v)
            res.update(get_tree(v, visited + [node, v]))
    return res

tree = get_tree('x0')

def set_messages_tree(node):
    for v in get_descendants(node):
        set_messages_tree(v)
        set_message(v, node)
        
set_messages_tree('x0')

In [139]:
#print(bp_marginalize('x1'))
print(bp_marginalize('x0')({'x0': 0.2}))

31.5893957643


In [140]:
brute_marginalize('x0')(0.2)

31.589395764297642