Skip to content

Commit

Permalink
Merge pull request #14 from phinate/feat/add-adj
Browse files Browse the repository at this point in the history
Add link matrix, still need to use for checking cyclic behaviour
  • Loading branch information
phinate committed Jul 14, 2021
2 parents f4f9201 + fd9fd21 commit e73d906
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/clarinet/nets/bayesnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
import numpy as np
from immutables import Map
from pydantic import BaseModel, root_validator, validator
from scipy.sparse import csr_matrix

from ..modelstring import Modelstring
from ..nodes import BaseDiscrete, DiscreteNode, Node


class BayesNet(BaseModel):
nodes: Map[str, Node]
link_matrix: csr_matrix
link_ordering: Map[str, int]
modelstring: Modelstring = Modelstring("")

# for pydantic
Expand All @@ -24,7 +27,9 @@ class Config:
json_encoders = {
Map: lambda t: {name: node for name, node in t.items()},
np.ndarray: lambda t: t.tolist(),
csr_matrix: lambda t: None,
}
fields = {"link_matrix": {"exclude": True}}
keep_untouched = (singledispatchmethod,)

def __getitem__(self, item: str) -> Node:
Expand Down Expand Up @@ -187,6 +192,24 @@ def init_modelstring(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["modelstring"] += f"|{':'.join(parents)}]" if parents else "]"
return values

@root_validator(skip_on_failure=True, pre=True)
def init_link_matrix(cls, values: Dict[str, Any]) -> Dict[str, Any]:
nodes = values["nodes"].values()

ordering = {node.name: i for i, node in enumerate(nodes)}

m = csr_matrix((len(ordering), len(ordering)), dtype=int)

for i, node in enumerate(nodes):
for parent in node.parents:
assert parent in ordering.keys(), f"{parent} is not declared as a node!"
m[ordering[parent], i] = 1

values["link_matrix"] = m
values["link_ordering"] = Map(ordering)

return values

@classmethod
def from_dict(
cls,
Expand Down

0 comments on commit e73d906

Please sign in to comment.