Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds n_jobs and state_names arguments to DBN.fit method #1620

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions pgmpy/models/DynamicBayesianNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def get_constant_bn(self, t_slice=0):
bn.add_cpds(*new_cpds)
return bn

def fit(self, data, estimator="MLE"):
def fit(self, data, estimator="MLE", state_names={}, n_jobs=-1):
"""
Learns the CPD of the model from data.

Expand All @@ -831,6 +831,16 @@ def fit(self, data, estimator="MLE"):
estimator: str
Currently only Maximum Likelihood Estimator is supported.

state_names: dict (optional)
A dict indicating, for each variable, the discrete set of states
that the variable can take. If unspecified, the observed values
in the data set are taken to be the only possible states.

n_jobs: int (default: -1)
Number of threads/processes to use for estimation. It improves speed only
for large networks (>100 nodes). For smaller networks might reduce
performance.

Returns
-------
None: The CPDs are added to the model instance.
Expand Down Expand Up @@ -895,21 +905,37 @@ def fit(self, data, estimator="MLE"):

# Fit or fit_update with df_slice depending on the time slice
if t_slice == 0:
const_bn.fit(df_slice)
if state_names != {}:
state_names = {
**{
str(var) + "_" + str(0): s for var, s in state_names.items()
},
**{
str(var) + "_" + str(1): s for var, s in state_names.items()
},
}
const_bn.fit(df_slice, state_names=state_names, n_jobs=n_jobs)
else:
const_bn.fit_update(df_slice, n_prev_samples=t_slice * n_samples)
const_bn.fit_update(
df_slice, n_prev_samples=t_slice * n_samples, n_jobs=n_jobs
)

cpds = []
for cpd in const_bn.cpds:
var_tuples = [var.rsplit("_", 1) for var in cpd.variables]
new_vars = [DynamicNode(var, int(t)) for var, t in var_tuples]
state_names = {
var: cpd.state_names[str(var.node) + "_" + str(var.time_slice)]
for var in new_vars
}
cpds.append(
TabularCPD(
variable=new_vars[0],
variable_card=cpd.variable_card,
values=cpd.get_values(),
evidence=new_vars[1:],
evidence_card=cpd.cardinality[1:],
state_names=state_names,
)
)

Expand Down
20 changes: 20 additions & 0 deletions pgmpy/tests/test_models/test_DynamicBayesianNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,26 @@ def test_fit(self):
df = pd.DataFrame(data, columns=colnames)
model.fit(df)

self.assertTrue(model.check_model())
self.assertEqual(len(model.cpds), 8)
for cpd in model.cpds:
np_test.assert_almost_equal(cpd.values, 0.5, decimal=1)

model = DBN(
[
(("A", 0), ("B", 0)),
(("A", 0), ("C", 0)),
(("B", 0), ("D", 0)),
(("C", 0), ("D", 0)),
(("A", 0), ("A", 1)),
(("B", 0), ("B", 1)),
(("C", 0), ("C", 1)),
(("D", 0), ("D", 1)),
]
)
model.fit(
df, state_names={"A": [0, 1, 2], "B": [0, 1, 2], "C": [0, 1], "D": [0, 1]}
)
self.assertTrue(model.check_model())
self.assertEqual(len(model.cpds), 8)
for cpd in model.cpds:
Expand Down