In [18]:
from mbi import Dataset, FactoredInference, Factor, GraphicalModel, Domain, CliqueVector, RegionGraph, FactorGraph
from mbi.region_graph import estimate_kikuchi_marginal
import numpy as np
import itertools
from scipy import sparse
import pandas as pd
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def get_random_cliques(attrs, size, number, prng=np.random):
    allcl = list(itertools.combinations(attrs, size))
    idx = prng.choice(len(allcl), number, replace=False)
    return [allcl[i] for i in idx]

In [3]:
domain = Domain(['A', 'B', 'C'], [2,3,4])
cliques = [('A','B'), ('B','C')]
model = RegionGraph(domain, cliques, iters=1000, convex=True, minimal=True)
potentials = CliqueVector.random(domain, model.cliques)*3
marginals = model.belief_propagation(potentials)
marginals1 = model.optimize_kikuchi(potentials)

CHECKPT (7, 21) 7
{'status': 'optimal', 'x': <21x1 matrix, tc='d'>, 'y': <7x1 matrix, tc='d'>, 'znl': <0x1 matrix, tc='d'>, 'zl': <21x1 matrix, tc='d'>, 'snl': <0x1 matrix, tc='d'>, 'sl': <21x1 matrix, tc='d'>, 'gap': 5.4417377103023725e-08, 'relative gap': 5.827579472300036e-09, 'primal objective': -9.337903903616128, 'dual objective': -9.337903944407225, 'primal slack': 8.38590482748788e-11, 'dual slack': 7.475669814810973e-11, 'primal infeasibility': 1.5579665647026412e-09, 'dual infeasibility': 7.490353314969441e-09}


In [26]:
model.potentials = potentials
model.marginals = marginals
model.project(['A','B','C'], iters=1000).project(['A','C']).datavector()

array([0.04153231, 0.01857051, 0.04704931, 0.05632712, 0.24184426,
       0.0795644 , 0.23660759, 0.2785045 ])

In [5]:
marginals[('A','B')].datavector(), marginals1[('A','B')].datavector()

(array([0.06362768, 0.06229905, 0.03755252, 0.31009945, 0.22310793,
        0.30331337]),
 array([0.06362769, 0.06229905, 0.03755253, 0.31009944, 0.22310793,
        0.30331336]))

In [25]:
model.project(['A','C'], iters=5).datavector()

array([0.04632619, 0.01604302, 0.04637202, 0.05473802, 0.23705038,
       0.08209189, 0.23728488, 0.2800936 ])

In [17]:
(model.project(['A'])*model.project(['C'])).datavector()

array([0.04632619, 0.01604302, 0.04637202, 0.05473802, 0.23705038,
       0.08209189, 0.23728488, 0.2800936 ])

In [21]:
tmp = {}
tmp[('A',)] = model.project(['A'])
tmp[('C',)] = model.project(['C'])
ans = estimate_kikuchi_marginal(domain.project(['A','C']), 1.0, tmp)
ans.datavector()

array([0.04632619, 0.01604302, 0.04637202, 0.05473802, 0.23705038,
       0.08209189, 0.23728488, 0.2800936 ])

In [9]:
cp = 1e-6

new = {}
p = ('A','B','C')
for r in [('A','B'), ('B','C')]:
    new[p,r] = Factor.zeros(domain.project(r))
    new[r,p] = Factor.zeros(domain.project(r))

for r in [('A','B'), ('B','C')]:
    comp = [c for c in cliques if c != r]
    diff = tuple(set(p) - set(r))
    tmp = Factor.zeros(domain.project(p)) + sum(new[c,p] for c in comp)
    new[r,p] = potentials[r] + model.messages[(('B',),r)]

logP = Factor.zeros(domain.project(p)) + sum(new[r,p] for r in cliques)
P = (logP - logP.logsumexp()).exp()

P.datavector()

array([0.05689502, 0.00997681, 0.07154035, 0.0096694 , 0.03251866,
       0.00829868, 0.03287898, 0.01929192, 0.08391789, 0.02180309,
       0.03153647, 0.0987724 , 0.05534567, 0.00970512, 0.06959218,
       0.00940609, 0.04923967, 0.01256584, 0.04978528, 0.02921179,
       0.08463566, 0.02198958, 0.03180621, 0.09961723])

In [10]:
P.project(['A','B']).datavector(), marginals[('A','B')].datavector()

(array([0.14808158, 0.09298824, 0.23602986, 0.14404906, 0.14080258,
        0.23804869]),
 array([0.15998187, 0.11229814, 0.2001705 , 0.15562528, 0.17004159,
        0.20188261]))

In [218]:
potentials[('A','C')] = Factor.zeros(domain.project(['A','C']))
model2 = RegionGraph(domain, list(potentials.keys()), iters=10, convex=True, minimal=True)
model2.messages.update(model.messages)
model2.counting_numbers = {('B',): 1.0, 
                           ('C',): 0.01, 
                           ('A',): 0.01, 
                          ('A', 'B'): 1.0, 
                          ('A', 'C'): 0.01, 
                          ('B', 'C'): 1.0}

marginals2 = model2.hazan_peng_shashua(potentials)
marginals3 = model2.wiegerinck(potentials)
#marginals4 = model.loh_wibisono(potentials)
marginals5 = model2.optimize_kikuchi(potentials, backend='cvxopt')

CHECKPT (18, 35) 18
{'status': 'optimal', 'x': <35x1 matrix, tc='d'>, 'y': <18x1 matrix, tc='d'>, 'znl': <0x1 matrix, tc='d'>, 'zl': <35x1 matrix, tc='d'>, 'snl': <0x1 matrix, tc='d'>, 'sl': <35x1 matrix, tc='d'>, 'gap': 4.020451145108476e-08, 'relative gap': 4.367918945694654e-09, 'primal objective': -9.204500346947443, 'dual objective': -9.20450037680028, 'primal slack': 2.1873703004273418e-10, 'dual slack': 1.3983654319453492e-10, 'primal infeasibility': 1.1183458203270308e-09, 'dual infeasibility': 7.032728461791809e-09}


In [219]:
print(marginals2[('A','B')].datavector())
print(marginals3[('A','B')].datavector())
print(marginals5[('A','B')].datavector())

[0.26556084 0.08700729 0.25203784 0.10000487 0.25334595 0.0420432 ]
[0.26562855 0.08696464 0.25206253 0.10004171 0.25325048 0.04205209]
[0.26562854 0.08696464 0.25206253 0.10004171 0.25325048 0.04205209]


In [202]:
print(marginals2[('A','B','C')].project(['A','B']).datavector())
print(marginals2[('A','B')].datavector())
print(marginals3[('A','B','C')].project(['A','B']).datavector())
#print(marginals4[('A','B','C')].project(['A','B']).datavector())
print(marginals5[('A','B','C')].project(['A','B']).datavector())

KeyError: ('A', 'B', 'C')

In [92]:
marginals2[('B',)] = marginals2[('B','C')].project('B')
marginals3[('B',)] = marginals3[('B','C')].project('B')

model.energy_functional(potentials, marginals2)[0] < model.energy_functional(potentials, marginals3)[0]

True

In [169]:
model2.messages[(('B','C'),('B',))].datavector()

array([-1.04868313, -0.96022106, -1.32128267])

In [41]:
P = marginals[('A','B')]*marginals[('B','C')]/marginals[('A','B')].project('B')
print(P.project(['A','B']).datavector())

[0.17676168 0.02332185 0.19762599 0.10444025 0.2502634  0.24758683]
