Skip to content

Commit

Permalink
fix bug with node id precedence that is required for maping back to p…
Browse files Browse the repository at this point in the history
…robabilities
  • Loading branch information
vangj committed May 22, 2020
1 parent de72595 commit 6f5947f
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pybbn/sampling/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def __init__(self, node, parents=[]):
if len(parents) == 0:
self.probs = np.array([node.probs]).cumsum()
else:
cartesian = itertools.product(*[node.variable.values for node in parents])
cartesian = list(itertools.product(*[node.variable.values for node in self.parents]))
get_kv = lambda i, v: f'{i}={v}'
keys = [','.join([get_kv(node.id, val) for node, val in zip(parents, values)]) for values in cartesian]
n = len(keys)
keys = [','.join([get_kv(node.id, val) for node, val in zip(self.parents, values)]) for values in cartesian]
n = len(node.variable.values)

probs = [node.probs[i:i + n] for i in range(0, len(node.probs), n)]
probs = [np.array(p).cumsum() for p in probs]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='pybbn',
version='1.1.0',
version='1.1.1',
author='Jee Vang',
author_email='vangjee@gmail.com',
packages=find_packages(),
Expand Down
34 changes: 33 additions & 1 deletion tests/sampling/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_table():
@with_setup(setup, teardown)
def test_pa_ch_table():
"""
Tests create table with parent.
Tests create table with a single parent.
:return: None.
"""
a = BbnNode(Variable(0, 'a', ['on', 'off']), [0.5, 0.5])
Expand All @@ -61,6 +61,38 @@ def test_pa_ch_table():
assert 'off' == table.get_value(0.6, sample={0: 'off'})


@with_setup(setup, teardown)
def test_multiple_pa_ch_table():
"""
Tests create table with multiple parent.
:return: None.
"""
d_probs = [0.23323615160349853, 0.7667638483965015,
0.7563025210084033, 0.24369747899159663]
r_probs = [0.31000000000000005, 0.69,
0.27, 0.73,
0.13, 0.87,
0.06999999999999995, 0.93]
g_probs = [0.49, 0.51]

g = BbnNode(Variable(0, 'gender', ['female', 'male']), g_probs)
d = BbnNode(Variable(1, 'drug', ['false', 'true']), d_probs)
r = BbnNode(Variable(2, 'recovery', ['false', 'true']), r_probs)

table = Table(r, parents=[d, g])

assert table.has_parents()
lhs = np.array(list(table.probs.values()))
rhs = np.array([[0.31, 1.0], [0.27, 1.0], [0.13, 1.0], [0.07, 1.0]])
assert_almost_equal(lhs, rhs)

lhs = list(table.probs.keys())
rhs = ['0=female,1=false', '0=female,1=true', '0=male,1=false', '0=male,1=true']
assert len(lhs) == len(rhs)
for l, r in zip(lhs, rhs):
assert l == r


@with_setup(setup, teardown)
def test_toplogical_sort_huang():
"""
Expand Down

0 comments on commit 6f5947f

Please sign in to comment.