Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added ExhaustiveSearch class for score-based structure learning
- Loading branch information
1 parent
398953d
commit e910c23
Showing
3 changed files
with
220 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#!/usr/bin/env python | ||
import numpy as np | ||
import pandas as pd | ||
import networkx as nx | ||
from warnings import warn | ||
from itertools import combinations | ||
from pgmpy.estimators import StructureEstimator, BayesianScore | ||
from pgmpy.utils.mathext import powerset | ||
from pgmpy.models import BayesianModel | ||
|
||
|
||
class ExhaustiveSearch(StructureEstimator): | ||
def __init__(self, data, scoring_method=None, **kwargs): | ||
""" | ||
Search class for exhaustive searches over all BayesianModels with a given set of variables. | ||
Takes a `StructureScore`-Instance as parameter; `estimate` finds the model with maximal score. | ||
Parameters | ||
---------- | ||
model: pgmpy.models.BayesianModel or pgmpy.models.MarkovModel or pgmpy.models.NoisyOrModel | ||
model for which parameter estimation is to be done | ||
data: pandas DataFrame object | ||
datafame object where each column represents one variable. | ||
(If some values in the data are missing the data cells should be set to `numpy.NaN`. | ||
Note that pandas converts each column containing `numpy.NaN`s to dtype `float`.) | ||
scoring_method: Instance of a `StructureScore`-subclass (`BayesianScore` is used if not set) | ||
An instance of either `BayesianScore` or `BicScore`. | ||
This score is optimized during structure estimation by the `estimate`-method. | ||
state_names: dict (optional) | ||
A dict indicating, for each variable, the discrete set of states (or values) | ||
that the variable can take. If unspecified, the observed values in the data set | ||
are taken to be the only possible states. | ||
complete_samples_only: bool (optional, default `True`) | ||
Specifies how to deal with missing data, if present. If set to `True` all rows | ||
that contain `np.Nan` somewhere are ignored. If `False` then, for each variable, | ||
every row where neither the variable nor its parents are `np.NaN` is used. | ||
This sets the behavior of the `state_count`-method. | ||
Returns | ||
------- | ||
state_counts: pandas.DataFrame | ||
Table with state counts for 'variable' | ||
Examples | ||
-------- | ||
>>> import pandas as pd | ||
>>> from pgmpy.models import BayesianModel | ||
>>> from pgmpy.estimators import ParameterEstimator | ||
>>> model = BayesianModel([('A', 'C'), ('B', 'C')]) | ||
>>> data = pd.DataFrame(data={'A': ['a1', 'a1', 'a2'], | ||
'B': ['b1', 'b2', 'b1'], | ||
'C': ['c1', 'c1', 'c2']}) | ||
>>> estimator = ParameterEstimator(model, data) | ||
>>> estimator.state_counts('A') | ||
A | ||
a1 2 | ||
a2 1 | ||
>>> estimator.state_counts('C') | ||
A a1 a2 | ||
B b1 b2 b1 b2 | ||
C | ||
c1 1 1 0 0 | ||
c2 0 0 1 0 | ||
""" | ||
if scoring_method is not None: | ||
self.scoring_method = scoring_method | ||
else: | ||
from pgmpy.estimators import BayesianScore | ||
self.scoring_method = BayesianScore(data, **kwargs) | ||
|
||
super(ExhaustiveSearch, self).__init__(data, **kwargs) | ||
|
||
def all_dags(self, nodes=None): | ||
"Generates all possible DAGs with a given set of nodes; sparse ones first" | ||
if nodes is None: | ||
nodes = sorted(self.state_names.keys()) | ||
if len(nodes) > 6: | ||
warn("Generating all DAGs of n nodes likely not feasible for n>6") | ||
|
||
edges = list(combinations(nodes, 2)) # n*(n-1) possible directed edges | ||
edges.extend([(y, x) for x, y in edges]) | ||
all_graphs = powerset(edges) # 2^(n*(n-1)) graphs | ||
|
||
for graph_edges in all_graphs: | ||
graph = nx.DiGraph() | ||
graph.add_nodes_from(nodes) | ||
graph.add_edges_from(graph_edges) | ||
if nx.is_directed_acyclic_graph(graph): | ||
yield graph | ||
|
||
def all_scores(self): | ||
"Computes an list of DAGs and their structure scores, ordered by score" | ||
|
||
scored_dags = sorted([(self.scoring_method.score(dag), dag) for dag in self.all_dags()], | ||
key=lambda x: x[0]) | ||
return scored_dags | ||
|
||
def estimate(self): | ||
""" | ||
Estimates the `BayesianModel` structure that fits best to the given data set, | ||
according to the scoring method supplied in the constructor. | ||
Exhaustively searches through all models. Only estimates network structure, no parametrization. | ||
Returns | ||
------- | ||
model: `BayesianModel` instance | ||
A `BayesianModel` with maximal score. | ||
""" | ||
|
||
best_dag = max(self.all_dags(), key=self.scoring_method.score) | ||
|
||
best_model = BayesianModel() | ||
best_model.add_nodes_from(sorted(best_dag.nodes())) | ||
best_model.add_edges_from(sorted(best_dag.edges())) | ||
return best_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import unittest | ||
|
||
import pandas as pd | ||
import numpy as np | ||
from pgmpy.estimators import ExhaustiveSearch | ||
from pgmpy.factors import TabularCPD | ||
from pgmpy.extern import six | ||
from pgmpy.models import BayesianModel | ||
|
||
|
||
class TestBaseEstimator(unittest.TestCase): | ||
def setUp(self): | ||
self.rand_data = pd.DataFrame(np.random.randint(0, 5, size=(5000, 2)), columns=list('AB')) | ||
self.rand_data['C'] = self.rand_data['B'] | ||
self.est_rand = ExhaustiveSearch(self.rand_data) | ||
|
||
# link to dataset: "https://www.kaggle.com/c/titanic/download/train.csv" | ||
self.titanic_data = pd.read_csv('pgmpy/tests/test_estimators/testdata/titanic_train.csv') | ||
self.titanic_data2 = self.titanic_data[["Survived", "Sex", "Pclass"]] | ||
self.est_titanic = ExhaustiveSearch(self.titanic_data2) | ||
|
||
def test_all_dags(self): | ||
self.assertEqual(len(list(self.est_rand.all_dags(['A', 'B', 'C', 'D']))), 543) | ||
# self.assertEqual(len(list(self.est_rand.all_dags(nodes=range(5)))), 29281) # takes ~30s | ||
|
||
abc_dags = set(six.moves.map(tuple, [sorted(dag.edges()) for dag in self.est_rand.all_dags()])) | ||
abc_dags_ref = set([(('A', 'B'), ('C', 'A'), ('C', 'B')), (('A', 'C'), ('B', 'C')), | ||
(('B', 'A'), ('B', 'C')), (('C', 'B'),), (('A', 'C'), ('B', 'A')), | ||
(('B', 'C'), ('C', 'A')), (('A', 'B'), ('B', 'C')), (('A', 'C'), | ||
('B', 'A'), ('B', 'C')), (('A', 'B'),), (('A', 'B'), ('C', 'A')), | ||
(('B', 'A'), ('C', 'A'), ('C', 'B')), (('A', 'C'), ('C', 'B')), | ||
(('A', 'B'), ('A', 'C'), ('C', 'B')), (('B', 'A'), ('C', 'B')), | ||
(('A', 'B'), ('A', 'C')), (('C', 'A'), ('C', 'B')), (('A', 'B'), | ||
('A', 'C'), ('B', 'C')), (('C', 'A'),), (('B', 'A'), ('B', 'C'), ('C', 'A')), | ||
(('B', 'A'),), (('A', 'B'), ('C', 'B')), (), (('B', 'A'), ('C', 'A')), | ||
(('A', 'C'),), (('B', 'C'),)]) | ||
self.assertSetEqual(abc_dags, abc_dags_ref) | ||
|
||
def test_estimate_rand(self): | ||
est = self.est_rand.estimate() | ||
self.assertSetEqual(set(est.nodes()), set(['A', 'B', 'C'])) | ||
self.assertTrue(est.edges() == [('B', 'C')] or est.edges() == [('C', 'B')]) | ||
|
||
def test_estimate_titanic(self): | ||
e1 = self.est_titanic.estimate() | ||
self.assertSetEqual(set(e1.edges()), set([('Survived', 'Pclass'), ('Sex', 'Pclass'), ('Sex', 'Survived')])) | ||
|
||
def test_all_scores(self): | ||
scores = self.est_titanic.all_scores() | ||
scores_ref = [(-2072.9132364404695, []), | ||
(-2069.071694164769, [('Pclass', 'Sex')]), | ||
(-2069.0144197068785, [('Sex', 'Pclass')]), | ||
(-2025.869489762676, [('Survived', 'Pclass')]), | ||
(-2025.8559302273054, [('Pclass', 'Survived')]), | ||
(-2022.0279474869753, [('Pclass', 'Sex'), ('Survived', 'Pclass')]), | ||
(-2022.0143879516047, [('Pclass', 'Sex'), ('Pclass', 'Survived')]), | ||
(-2021.9571134937144, [('Pclass', 'Survived'), ('Sex', 'Pclass')]), | ||
(-2017.5258065853768, [('Sex', 'Pclass'), ('Survived', 'Pclass')]), | ||
(-1941.3075053892837, [('Survived', 'Sex')]), | ||
(-1941.2720031713893, [('Sex', 'Survived')]), | ||
(-1937.4304608956886, [('Pclass', 'Sex'), ('Sex', 'Survived')]), | ||
(-1937.4086886556927, [('Sex', 'Pclass'), ('Survived', 'Sex')]), | ||
(-1937.3731864377983, [('Sex', 'Pclass'), ('Sex', 'Survived')]), | ||
(-1934.1344850608882, [('Pclass', 'Sex'), ('Survived', 'Sex')]), | ||
(-1894.2637587114903, [('Survived', 'Pclass'), ('Survived', 'Sex')]), | ||
(-1894.2501991761198, [('Pclass', 'Survived'), ('Survived', 'Sex')]), | ||
(-1894.2282564935958, [('Sex', 'Survived'), ('Survived', 'Pclass')]), | ||
(-1891.0630673606006, [('Pclass', 'Survived'), ('Sex', 'Survived')]), | ||
(-1887.2215250849, [('Pclass', 'Sex'), ('Pclass', 'Survived'), ('Sex', 'Survived')]), | ||
(-1887.1642506270096, [('Pclass', 'Survived'), ('Sex', 'Pclass'), ('Sex', 'Survived')]), | ||
(-1887.0907383830947, [('Pclass', 'Sex'), ('Survived', 'Pclass'), ('Survived', 'Sex')]), | ||
(-1887.0771788477243, [('Pclass', 'Sex'), ('Pclass', 'Survived'), ('Survived', 'Sex')]), | ||
(-1885.9200755341915, [('Sex', 'Pclass'), ('Survived', 'Pclass'), ('Survived', 'Sex')]), | ||
(-1885.884573316297, [('Sex', 'Pclass'), ('Sex', 'Survived'), ('Survived', 'Pclass')])] | ||
|
||
self.assertEqual([sorted(model.edges()) for score, model in scores], | ||
[edges for score, edges in scores_ref]) | ||
# use assertAlmostEqual pointwise to avoid rounding issues | ||
six.moves.map(lambda x, y: self.assertAlmostEqual(x, y), | ||
[score for score, model in scores], | ||
[score for score, edges in scores_ref]) | ||
|
||
def tearDown(self): | ||
del self.rand_data | ||
del self.est_rand | ||
del self.titanic_data | ||
del self.est_titanic |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters