Skip to content

Commit

Permalink
Fixes XMLBeliefNetwork reader and writer classes; adds write_xbn meth…
Browse files Browse the repository at this point in the history
…od (#1763)
  • Loading branch information
ankurankan committed May 17, 2024
1 parent a553d96 commit e6fa742
Showing 1 changed file with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions pgmpy/readwrite/XMLBeliefNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ def get_static_properties(self):
>>> reader.get_static_properties()
{'FORMAT': 'MSR DTAS XML', 'VERSION': '0.2', 'CREATOR': 'Microsoft Research DTAS'}
"""
return {
tags.tag: tags.get("VALUE")
for tags in self.bnmodel.find("STATICPROPERTIES")
}
if self.bnmodel.find("STATICPROPERTIES") is not None:
return {
tags.tag: tags.get("VALUE")
for tags in self.bnmodel.find("STATICPROPERTIES")
}
else:
return {}

def get_variables(self):
"""
Expand Down Expand Up @@ -368,17 +371,17 @@ def set_variables(self, data):
"VAR",
attrib={
"NAME": var,
"TYPE": data[var]["TYPE"],
"XPOS": data[var]["XPOS"],
"YPOS": data[var]["YPOS"],
"TYPE": data[var].get("TYPE", ""),
"XPOS": data[var].get("XPOS", ""),
"YPOS": data[var].get("YPOS", ""),
},
)
etree.SubElement(
variable,
"DESCRIPTION",
attrib={"DESCRIPTION": data[var]["DESCRIPTION"]},
attrib={"DESCRIPTION": data[var].get("DESCRIPTION", "")},
)
for state in data[var]["STATES"]:
for state in self.model.states[var]:
etree.SubElement(variable, "STATENAME").text = state

def set_edges(self, edge_list):
Expand Down Expand Up @@ -420,7 +423,9 @@ def set_distributions(self):
cpd_values = cpd.get_values().transpose()
var = cpd.variable
dist = etree.SubElement(
distributions, "DIST", attrib={"TYPE": self.model.nodes[var]["TYPE"]}
distributions,
"DIST",
attrib={"TYPE": self.model.nodes[var].get("TYPE", "")},
)
etree.SubElement(dist, "PRIVATE", attrib={"NAME": var})
dpis = etree.SubElement(dist, "DPIS")
Expand All @@ -442,3 +447,23 @@ def set_distributions(self):
etree.SubElement(dpis, "DPI").text = (
" " + " ".join(map(str, cpd_values[0])) + " "
)

def write_xbn(self, filename):
"""
Writes the BIF data into a file
Parameters
----------
filename : Name of the file
Example
-------
>>> from pgmpy.utils import get_example_model
>>> from pgmpy.readwrite import XBNReader, XBNWriter
>>> asia = get_example_model('asia')
>>> writer = XBNWriter(asia)
>>> writer.write_xbn(filename='asia.xbn')
"""
writer = self.__str__()
with open(filename, "wb") as fout:
fout.write(writer)

0 comments on commit e6fa742

Please sign in to comment.