-
Notifications
You must be signed in to change notification settings - Fork 698
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Restructured pgmpy/estimators/ for structure learning
- Loading branch information
1 parent
258b5bd
commit d052b42
Showing
6 changed files
with
188 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
from pgmpy.estimators.base import BaseEstimator | ||
from pgmpy.estimators.base import BaseEstimator, ParameterEstimator, StructureEstimator | ||
from pgmpy.estimators.MLE import MaximumLikelihoodEstimator | ||
from pgmpy.estimators.BayesianEstimator import BayesianEstimator | ||
from pgmpy.estimators.StructureScore import StructureScore | ||
from pgmpy.estimators.BayesianScore import BayesianScore | ||
from pgmpy.estimators.ExhaustiveSearch import ExhaustiveSearch | ||
|
||
__all__ = ['BaseEstimator', | ||
'MaximumLikelihoodEstimator', | ||
'BayesianEstimator'] | ||
'ParameterEstimator', 'MaximumLikelihoodEstimator', 'BayesianEstimator', | ||
'StructureEstimator', 'ExhaustiveSearch', | ||
'StructureScore', 'BayesianScore'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import unittest | ||
|
||
import pandas as pd | ||
from numpy import NaN | ||
from pgmpy.models import BayesianModel | ||
from pgmpy.estimators import ParameterEstimator | ||
from pgmpy.factors import TabularCPD | ||
|
||
|
||
class TestParameterEstimator(unittest.TestCase): | ||
def setUp(self): | ||
self.m1 = BayesianModel([('A', 'C'), ('B', 'C'), ('D', 'B')]) | ||
self.d1 = pd.DataFrame(data={'A': [0, 0, 1], 'B': [0, 1, 0], 'C': [1, 1, 0], 'D': ['X', 'Y', 'Z']}) | ||
self.d2 = pd.DataFrame(data={'A': [0, NaN, 1], 'B': [0, 1, 0], 'C': [1, 1, NaN], 'D': [NaN, 'Y', NaN]}) | ||
|
||
def test_state_count(self): | ||
e = ParameterEstimator(self.m1, self.d1) | ||
self.assertEqual(e.state_counts('A').values.tolist(), [[2], [1]]) | ||
self.assertEqual(e.state_counts('C').values.tolist(), | ||
[[0., 0., 1., 0.], [1., 1., 0., 0.]]) | ||
|
||
def test_missing_data(self): | ||
e = ParameterEstimator(self.m1, self.d2, state_names={'C': [0, 1]}, complete_samples_only=False) | ||
self.assertEqual(e.state_counts('A', complete_samples_only=True).values.tolist(), [[0], [0]]) | ||
self.assertEqual(e.state_counts('A').values.tolist(), [[1], [1]]) | ||
self.assertEqual(e.state_counts('C', complete_samples_only=True).values.tolist(), | ||
[[0, 0, 0, 0], [0, 0, 0, 0]]) | ||
self.assertEqual(e.state_counts('C').values.tolist(), | ||
[[0, 0, 0, 0], [1, 0, 0, 0]]) | ||
|
||
def tearDown(self): | ||
del self.m1 | ||
del self.d1 |