Skip to content

Commit

Permalink
DynamicBayesianNetwork.py: Added get_cardinality method
Browse files Browse the repository at this point in the history
closes #921
  • Loading branch information
lohani2280 committed Oct 27, 2017
1 parent 0f62927 commit 33fe43a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
45 changes: 45 additions & 0 deletions pgmpy/models/DynamicBayesianNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,51 @@ def remove_cpds(self, *cpds):
cpd = self.get_cpds(cpd)
self.cpds.remove(cpd)

def get_cardinality(self, node=None):
"""
Returns the cardinality of the node.
Parameter
---------
node: tuple (node_name, time_slice)
The node should be in the following form (node_name, time_slice).
Here, node_name is the node that is inserted while the time_slice is
an integer value, which denotes the index of the time_slice that the
node belongs to.
Returns
-------
int or dict : If node is specified returns the cardinality of the node.
If node is not specified returns a dictionary with the given
variable as keys and their respective cardinality as values.
Examples:
-------
>>> from pgmpy.models import DynamicBayesianNetwork as DBN
>>> from pgmpy.factors.discrete import TabularCPD
>>> dbn = DBN()
>>> dbn.add_edges_from([(('D',0),('G',0)),(('I',0),('G',0)),(('D',0),('D',1)),(('I',0),('I',1))])
>>> grade_cpd = TabularCPD(('G',0), 3, [[0.3,0.05,0.9,0.5],
... [0.4,0.25,0.8,0.03],
... [0.3,0.7,0.02,0.2]], [('I', 0),('D', 0)],[2,2])
>>> dbn.add_cpds(grade_cpd)
>>> dbn.get_cardinality(('D',0))
2
>>> dbn.get_cardinality()
defaultdict(int, {('D', 0): 2, ('D', 1): 2, ('G', 0): 3, ('I', 0): 2, ('I', 1): 2})
"""
if node:
if node not in super(DynamicBayesianNetwork, self).nodes():
raise ValueError('Node not present in the model.')
else:
temp_node = (node[0], 1 - node[1]) if node[1] else node
return self.get_cpds(temp_node).cardinality[0]
else:
cardinalities = defaultdict(int)
for cpd in self.cpds:
cardinalities[cpd.variable] = cpd.cardinality[0]
return cardinalities

def check_model(self):
"""
Check the model for various errors. This method checks for the following
Expand Down
17 changes: 17 additions & 0 deletions pgmpy/tests/test_models/test_DynamicBayesianNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ def test_add_multiple_cpds(self):
self.assertEqual(self.network.get_cpds(('I', 0)).variable, ('I', 0))
self.assertEqual(self.network.get_cpds(('I', 1)).variable, ('I', 1))

def test_get_cardinality(self):
self.network.add_edges_from(
[(('D', 0), ('G', 0)), (('I', 0), ('G', 0)), (('D', 0), ('D', 1)), (('I', 0), ('I', 1))])
self.network.add_cpds(self.grade_cpd, self.d_i_cpd, self.diff_cpd, self.intel_cpd, self.i_i_cpd)
self.assertDictEqual(self.network.get_cardinality(),
{('D', 0): 2, ('D', 1): 2, ('G', 0): 3, ('I', 0): 2, ('I', 1): 2})

def test_get_cardinality_with_node(self):
self.network.add_edges_from(
[(('D', 0), ('G', 0)), (('I', 0), ('G', 0)), (('D', 0), ('D', 1)), (('I', 0), ('I', 1))])
self.network.add_cpds(self.grade_cpd, self.d_i_cpd, self.diff_cpd, self.intel_cpd, self.i_i_cpd)
self.assertEqual(self.network.get_cardinality(('D',0)), 2)
self.assertEqual(self.network.get_cardinality(('D',1)), 2)
self.assertEqual(self.network.get_cardinality(('G',0)), 3)
self.assertEqual(self.network.get_cardinality(('I',0)), 2)
self.assertEqual(self.network.get_cardinality(('I',1)), 2)

def test_initialize_initial_state(self):

self.network.add_nodes_from(['D', 'G', 'I', 'S', 'L'])
Expand Down

0 comments on commit 33fe43a

Please sign in to comment.