Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
)
from pymc.util import (
UNSET,
VarName,
WithMemoization,
_UnsetType,
get_transformed_name,
Expand Down Expand Up @@ -1968,7 +1967,7 @@ def debug_parameters(rv):
def to_graphviz(
self,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
save: str | None = None,
figsize: tuple[int, int] | None = None,
Expand Down Expand Up @@ -2172,7 +2171,7 @@ def compile_fn(
)


def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]:
def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]:
"""Build a point.

Uses same args as dict() does.
Expand Down
36 changes: 17 additions & 19 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pymc.model.core import modelcontext
from pymc.pytensorf import _cheap_eval_mode
from pymc.util import VarName, get_default_varnames, get_var_name
from pymc.util import get_default_varnames, get_var_name

__all__ = (
"ModelGraph",
Expand Down Expand Up @@ -173,7 +173,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs:
}


def get_node_type(var_name: VarName, model) -> NodeType:
def get_node_type(var_name: str, model) -> NodeType:
"""Return the node type of the variable in the model."""
v = model[var_name]

Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(self, model):
self._all_vars = {model[var_name] for var_name in self._all_var_names}
self.var_list = self.model.named_vars.values()

def get_parent_names(self, var: Variable) -> set[VarName]:
def get_parent_names(self, var: Variable) -> set[str]:
if var.owner is None:
return set()

Expand All @@ -261,12 +261,12 @@ def _expand(x):
return x.owner.inputs

return {
cast(VarName, ancestor.name) # type: ignore[union-attr]
cast(str, ancestor.name) # type: ignore[union-attr]
for ancestor in walk(nodes=var.owner.inputs, expand=_expand)
if ancestor in named_vars
}

def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]:
if var_names is None:
return self._all_var_names

Expand Down Expand Up @@ -296,13 +296,11 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
# ordering of self._all_var_names is important
return [get_var_name(var) for var in selected_ancestors]

def make_compute_graph(
self, var_names: Iterable[VarName] | None = None
) -> dict[VarName, set[VarName]]:
def make_compute_graph(self, var_names: Iterable[str] | None = None) -> dict[str, set[str]]:
"""Get map of var_name -> set(input var names) for the model."""
model = self.model
named_vars = self._all_vars
input_map: dict[VarName, set[VarName]] = defaultdict(set)
input_map: dict[str, set[str]] = defaultdict(set)

var_names_to_plot = self.vars_to_plot(var_names)
for var_name in var_names_to_plot:
Expand All @@ -319,15 +317,15 @@ def make_compute_graph(
for ancestor in ancestors([obs_var]):
if ancestor not in named_vars:
continue
obs_name = cast(VarName, ancestor.name)
obs_name = cast(str, ancestor.name)
input_map[var_name].discard(obs_name)
input_map[obs_name].add(var_name)

return input_map

def get_plates(
self,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
) -> list[Plate]:
"""Rough but surprisingly accurate plate detection.

Expand All @@ -337,7 +335,7 @@ def get_plates(
Returns
-------
dict
Maps plate labels to the set of ``VarName``s inside the plate.
Maps plate labels to the set of ``str``s inside the plate.
"""
plates = defaultdict(set)

Expand Down Expand Up @@ -389,8 +387,8 @@ def get_plates(

def edges(
self,
var_names: Iterable[VarName] | None = None,
) -> list[tuple[VarName, VarName]]:
var_names: Iterable[str] | None = None,
) -> list[tuple[str, str]]:
"""Get edges between the variables in the model.

Parameters
Expand All @@ -405,7 +403,7 @@ def edges(

"""
return [
(VarName(child.replace(":", "&")), VarName(parent.replace(":", "&")))
(str(child.replace(":", "&")), str(parent.replace(":", "&")))
for child, parents in self.make_compute_graph(var_names=var_names).items()
for parent in parents
]
Expand All @@ -422,7 +420,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
def make_graph(
name: str,
plates: list[Plate],
edges: list[tuple[VarName, VarName]],
edges: list[tuple[str, str]],
formatting: str = "plain",
save=None,
figsize=None,
Expand Down Expand Up @@ -496,7 +494,7 @@ def make_graph(
def make_networkx(
name: str,
plates: list[Plate],
edges: list[tuple[VarName, VarName]],
edges: list[tuple[str, str]],
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
Expand Down Expand Up @@ -566,7 +564,7 @@ def make_networkx(
def model_to_networkx(
model=None,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
node_formatters: NodeTypeFormatterMapping | None = None,
include_dim_lengths: bool = True,
Expand Down Expand Up @@ -660,7 +658,7 @@ def model_to_networkx(
def model_to_graphviz(
model=None,
*,
var_names: Iterable[VarName] | None = None,
var_names: Iterable[str] | None = None,
formatting: str = "plain",
save: str | None = None,
figsize: tuple[int, int] | None = None,
Expand Down
8 changes: 3 additions & 5 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import namedtuple
from collections.abc import Sequence
from copy import deepcopy
from typing import NewType, cast
from typing import cast

import arviz
import cloudpickle
Expand All @@ -31,8 +31,6 @@

from pymc.exceptions import BlockModelAccessError

VarName = NewType("VarName", str)


class _UnsetType:
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
Expand Down Expand Up @@ -214,9 +212,9 @@ def get_default_varnames(var_iterator, include_transformed):
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]


def get_var_name(var) -> VarName:
def get_var_name(var) -> str:
"""Get an appropriate, plain variable name for a variable."""
return VarName(str(getattr(var, "name", var)))
return var.name if var.name is not None else str(var)


def get_transformed(z):
Expand Down