From a3be71f743a98d27542d1a1c4e0fc29f7d942c28 Mon Sep 17 00:00:00 2001 From: lohani2280 Date: Fri, 27 Oct 2017 02:23:02 +0530 Subject: [PATCH] DynamicBayesianNetwork.py: Added get_cardinality method closes https://github.com/pgmpy/pgmpy/issues/921 --- pgmpy/models/DynamicBayesianNetwork.py | 47 +++++++++++++++++++ .../test_DynamicBayesianNetwork.py | 18 +++++++ 2 files changed, 65 insertions(+) diff --git a/pgmpy/models/DynamicBayesianNetwork.py b/pgmpy/models/DynamicBayesianNetwork.py index a6ee95134..54ab99fbf 100644 --- a/pgmpy/models/DynamicBayesianNetwork.py +++ b/pgmpy/models/DynamicBayesianNetwork.py @@ -440,6 +440,53 @@ def remove_cpds(self, *cpds): cpd = self.get_cpds(cpd) self.cpds.remove(cpd) + def get_cardinality(self, node=None): + """ + Returns the cardinality of the node.If node=None returns all the cardinalities. + + Parameters + --------- + node: tuple (node_name, time_slice) + The node should be in the following form (node_name, time_slice). + Here, node_name is the node that is inserted while the time_slice is + an integer value, which denotes the index of the time_slice that the + node belongs to. + + Returns + ------- + int or dict : If node is specified returns the cardinality of the node. + If node is not specified returns a dictionary with the given + variable as keys and their respective cardinality as values. + + Examples: + ------- + >>> from pgmpy.models import DynamicBayesianNetwork as DBN + >>> from pgmpy.factors.discrete import TabularCPD + >>> dbn = DBN() + >>> dbn.add_edges_from([(('D',0),('G',0)),(('I',0),('G',0)),(('D',0),('D',1)),(('I',0),('I',1))]) + >>> grade_cpd = TabularCPD(('G',0), 3, [[0.3,0.05,0.9,0.5], + ... [0.4,0.25,0.8,0.03], + ... [0.3,0.7,0.02,0.2]], + evidence=[('I', 0),('D', 0)], + evidence_card=[2,2]) + >>> dbn.add_cpds(grade_cpd) + >>> dbn.get_cardinality(node=('D',0)) + 2 + >>> dbn.get_cardinality() + defaultdict(int, {('D', 0): 2, ('D', 1): 2, ('G', 0): 3, ('I', 0): 2, ('I', 1): 2}) + """ + if node: + if node not in super(DynamicBayesianNetwork, self).nodes(): + raise ValueError('Node not present in the model.') + else: + temp_node = (node[0], 0) if node[1] else node + return self.get_cpds(temp_node).cardinality[0] + else: + cardinalities = defaultdict(int) + for cpd in self.cpds: + cardinalities[cpd.variable] = cpd.cardinality[0] + return cardinalities + def check_model(self): """ Check the model for various errors. This method checks for the following diff --git a/pgmpy/tests/test_models/test_DynamicBayesianNetwork.py b/pgmpy/tests/test_models/test_DynamicBayesianNetwork.py index 2795e2c9e..d103c2632 100644 --- a/pgmpy/tests/test_models/test_DynamicBayesianNetwork.py +++ b/pgmpy/tests/test_models/test_DynamicBayesianNetwork.py @@ -125,6 +125,24 @@ def test_add_multiple_cpds(self): self.assertEqual(self.network.get_cpds(('I', 0)).variable, ('I', 0)) self.assertEqual(self.network.get_cpds(('I', 1)).variable, ('I', 1)) + def test_get_cardinality(self): + self.network.add_edges_from( + [(('D', 0), ('G', 0)), (('I', 0), ('G', 0)), (('D', 0), ('D', 1)), (('I', 0), ('I', 1))]) + self.network.add_cpds(self.grade_cpd, self.d_i_cpd, self.diff_cpd, self.intel_cpd, self.i_i_cpd) + self.assertDictEqual(self.network.get_cardinality(), + {('D', 0): 2, ('D', 1): 2, ('G', 0): 3, ('I', 0): 2, ('I', 1): 2}) + + def test_get_cardinality_with_node(self): + self.network.add_edges_from( + [(('D', 0), ('G', 0)), (('I', 0), ('G', 0)), (('D', 0), ('D', 1)), (('I', 0), ('I', 1))]) + self.network.add_cpds(self.grade_cpd, self.d_i_cpd, self.diff_cpd, self.intel_cpd, self.i_i_cpd) + self.assertRaises(ValueError, self.network.get_cardinality, ('D',2)) + self.assertEqual(self.network.get_cardinality(('D',0)), 2) + self.assertEqual(self.network.get_cardinality(('D',1)), 2) + self.assertEqual(self.network.get_cardinality(('G',0)), 3) + self.assertEqual(self.network.get_cardinality(('I',0)), 2) + self.assertEqual(self.network.get_cardinality(('I',1)), 2) + def test_initialize_initial_state(self): self.network.add_nodes_from(['D', 'G', 'I', 'S', 'L'])