In [0]:
import time
import itertools
import numpy as np

In [0]:
def get_ve_graph():    
    # factor graph information. (for the given graph)
    factor_dict = {1: [[1],     np.array([[0, 0.999], 
                                          [1, 0.001]])],
                   2: [[2],     np.array([[0, 0.998], 
                                          [1, 0.002]])],
                   3: [[1,2,3], np.array([[0,0,0,0.999],
                                          [0,0,1,0.001],
                                          [0,1,0,0.71],
                                          [0,1,1,0.29],
                                          [1,0,0,0.06],
                                          [1,0,1,0.94],
                                          [1,1,0,0.05],
                                          [1,1,1,0.95]])],
                   4: [[3,4],   np.array([[0,0,0.95],
                                          [0,1,0.05],
                                          [1,0,0.1],
                                          [1,1,0.9]])],
                   5: [[3,5],   np.array([[0,0,0.99],
                                          [0,1,0.01],
                                          [1,0,0.3],
                                          [1,1,0.7]])]
                    }
    return factor_dict

PT1 = np.array([0.999, 0.001])
PT2 = np.array([0.998, 0.002])
PT3 = np.array([[[0.999, 0.001], [0.71, 0.29]], [[0.06, 0.94], [0.05, 0.95]]])
PT4 = np.array([[0.95, 0.05], [0.1, 0.9]])
PT5 = np.array([[0.99, 0.01], [0.3, 0.7]])
def get_naive_graph(x1, x2, x3, x4, x5):
    return PT1[x1]*PT2[x2]*PT3[x1,x2,x3]*PT4[x3,x4]*PT5[x3,x5]


In [0]:
def topological_sort(factor_dict, visited_list, return_list, current_point):
    # function for determining an elimination order
    visited_list.append(current_point)
    for each in factor_dict.values():
        if current_point in each[0]:
            for candidate in each[0]:
                if candidate not in visited_list:
                    topological_sort(factor_dict, visited_list, return_list, candidate)
                    return_list.append(candidate)

In [0]:
def variable_elimination(post_id, post_val, prev_id_list, prev_val_list):
    start = time.time()
    factor_dict = get_ve_graph()
    conds = []
    for i in zip(prev_id_list, prev_val_list):
        conds.append(i)
     # dictionary that stores factor ids that each node are involved in
    node_to_factor = {}
    for factor_id, factor_info in factor_dict.items():
        for each in factor_info[0]:
            node_to_factor.setdefault(each, []).append(factor_id)
    
    # list of node indexes for elimination
    eliminate_list = []
    topological_sort(factor_dict, [], eliminate_list, post_id)
    for node_id in eliminate_list:
        factor_ids = node_to_factor[node_id]
        factors_to_multiply = [idx for idx in factor_ids if len(factor_dict[idx][0]) == 1]
        # do marginalize
        for factor_idx in factor_ids:
            if factor_idx in factors_to_multiply:
                continue
            current_factor = factor_dict[factor_idx]
            
            # swap and sort the array for marginalize
            factor_pos = current_factor[0].index(node_id)
            current_factor[0][0], current_factor[0][factor_pos] = current_factor[0][factor_pos], current_factor[0][0]
            current_factor[1][:,[0, factor_pos]] = current_factor[1][:,[factor_pos, 0]]
            current_factor[1] = current_factor[1][np.lexsort( current_factor[1].T[::-1] )]  
            
            num_rows = int( len(current_factor[1]) / 2 )

            if node_id not in prev_id_list:
                # multiply variables that comes from other factor's
                for idx in factors_to_multiply:
                    current_factor[1][:num_rows, len(current_factor[0])] *= factor_dict[idx][1][0][1]
                    current_factor[1][num_rows:, len(current_factor[0])] *= factor_dict[idx][1][1][1]
                # marginalize
                current_factor[1][:num_rows, len(current_factor[0])] += current_factor[1][num_rows:, len(current_factor[0])]
                current_factor[1] = current_factor[1][:num_rows, 1:]
            else:
                valid_val = prev_val_list[prev_id_list.index(node_id)]
                if valid_val == 0:
                    current_factor[1] = current_factor[1][:num_rows, 1:]
                else:
                    current_factor[1] = current_factor[1][num_rows:, 1:]
            current_factor[0] = current_factor[0][1:]
            factor_dict[factor_idx] = current_factor
            
    result_value = [1.0, 1.0]
    for factor_id in node_to_factor[post_id]:
        result_value *= factor_dict[factor_id][1][:,1]
        
    end = time.time()
    output = result_value[post_val] / sum(result_value)
    cond_str = ','.join(['X_%d=%d' % (var, val) for var, val in conds])
    query_str = "P(X_%d=%s%s%s)" % (post_id, post_val, '|' if cond_str else '', cond_str)
    print('== Variable Elimination ==')
    print("Result: %s=%.6f" % (query_str, output))
    print('Execution time: {:.8f}\n'.format(end-start))
    return 

In [0]:
def naive_inference(post_id, post_val, prev_id_list, prev_val_list):
    start = time.time()
    query_var = post_id
    conds = []
    for i in zip(prev_id_list, prev_val_list):
        conds.append(i)

    # make all cases
    possible_vars = [[0,1] for _ in range(5)]
    for var_idx, value in conds:
        possible_vars[var_idx-1] = [value]

    if len(possible_vars[query_var-1]) == 1:
        value = possible_vars[query_var-1]
        result = [1.0 - value, 0.0 + value]
    else:
        result = [0.0, 0.0]
        for i in range(2):
            temp_iter = possible_vars
            temp_iter[query_var-1] = [i]
            for perm in itertools.product(*temp_iter):
                result[i] += get_naive_graph(*perm)
        result = result / sum(result)
    if post_val == 1:
        result = result[1]
    elif post_val == 0:
        result = 1 - result[1]
    else:
        raise ValueError('Second value must be either 0 or 1!')
    end = time.time()
    cond_str = ','.join(['X_%d=%d' % (var, val) for var, val in conds])
    query_str = "P(X_%d=%s%s%s)" % (query_var, post_val, '|' if cond_str else '', cond_str)
    print('=== Naive Inference ===')
    print("Result: %s=%.6f" % (query_str, result))
    print('Execution time: {:.8f}'.format(end-start))
    return 

In [0]:
def experiment(post_id, post_val, prev_id_list, prev_val_list):
    naive_inference(post_id, post_val, prev_id_list, prev_val_list)
    variable_elimination(post_id, post_val, prev_id_list, prev_val_list)

In [0]:
experiment(4, 0, [], [])

experiment(4, 1, [3], [1])

experiment(3, 1, [1, 2], [1, 0])


=== Naive Inference ===
Result: P(X_1=0|X_2=1)=0.999000
Execution time: 0.00010276
== Variable Elimination ==
Result: P(X_1=0|X_2=1)=0.999000
Execution time: 0.00137854

=== Naive Inference ===
Result: P(X_4=1|X_3=1)=0.900000
Execution time: 0.00007153
== Variable Elimination ==
Result: P(X_4=1|X_3=1)=0.900000
Execution time: 0.00032520

=== Naive Inference ===
Result: P(X_3=1|X_1=1,X_2=0)=0.940000
Execution time: 0.00009656
== Variable Elimination ==
Result: P(X_3=1|X_1=1,X_2=0)=0.940000
Execution time: 0.00028515

