Skip to content

Commit

Permalink
Merge pull request #13 from phinate/feat/add-probs
Browse files Browse the repository at this point in the history
Add validation to check if conditional prob tables have correct dimensions and properties
  • Loading branch information
phinate committed Jul 13, 2021
2 parents 1f1fba7 + b8ba90b commit f4f9201
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/clarinet/nets/bayesnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,32 @@ def __getitem__(self, item: str) -> Node:
def dict_to_map(cls, dct: Dict[str, Node]) -> Map[str, Node]:
return Map(dct)

@staticmethod
def _validate_prob_table(
nodes: Dict[str, Dict[str, Any]],
name: str,
prob_table: np.ndarray,
has_parents: bool,
) -> None:
target = nodes[name]
table = np.array(prob_table)
num_states = len(target["states"])
assert (
table.shape[-1] == num_states
), f"{name} should have a probability table with last dimension of size {num_states} ({table.shape[-1]} given)"

if has_parents:
parent_states_sizes = [len(nodes[p]["states"]) for p in target["parents"]]
assert set(parent_states_sizes) == set(
table.shape[:-1]
), f"{name} has incorrect shape, needs to be (*len(parent states), ..., len(states))"

# by fixing the values of the parent nodes, we define a distribution, so we need to check
# it sums to unit probability in all cases
assert np.isclose(
table.sum(axis=-1).prod(), 1
), f"Probability over states of node'{name}' doesn't sum to 1!" # to account for possible truncation -- needed?

# this doesn't pick up cycles that occur when searching for node-centric cycles
# not to worry -- I think this is done easier through the link matrix impl
@staticmethod
Expand Down Expand Up @@ -127,6 +153,11 @@ def _validate_node(
)
# check recursively to see if any child links back to this node
cycle_check(name, children)
if "prob_table" in node_dict.keys() and "states" in node_dict.keys():
if len(node_dict["prob_table"]):
BayesNet._validate_prob_table(
network_dict, name, node_dict["prob_table"], has_parents
)

@staticmethod
def _nodes_to_dict(nodes: Map[str, Node]) -> Dict[str, Dict[str, Any]]:
Expand Down Expand Up @@ -214,6 +245,7 @@ def convert_nodes(
BayesNet._validate_node(name, net_dct[name], net_dct)
return self.copy(update={"nodes": nodes})

# TODO: add test
@no_type_check
@singledispatchmethod
def add_prob_tables(
Expand Down

0 comments on commit f4f9201

Please sign in to comment.