From 9e8a0f009e7e79608f48bcc1a659d247b641f35f Mon Sep 17 00:00:00 2001 From: atheendre130505 Date: Tue, 28 Oct 2025 14:36:07 +0530 Subject: [PATCH 1/2] Remove VarName from codebase - Remove VarName NewType definition from util.py - Replace all VarName type hints with str - Simplify get_var_name function to use var.name directly - Update imports in model_graph.py and model/core.py - Fix all type annotations and function signatures Resolves #7843 --- pymc/model/core.py | 5 ++--- pymc/model_graph.py | 36 ++++++++++++++++++------------------ pymc/util.py | 8 ++++---- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 18f4190fa4..69e1fbed72 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -67,7 +67,6 @@ ) from pymc.util import ( UNSET, - VarName, WithMemoization, _UnsetType, get_transformed_name, @@ -1968,7 +1967,7 @@ def debug_parameters(rv): def to_graphviz( self, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", save: str | None = None, figsize: tuple[int, int] | None = None, @@ -2172,7 +2171,7 @@ def compile_fn( ) -def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]: +def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]: """Build a point. Uses same args as dict() does. diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 188e40df28..b021baabfe 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -27,7 +27,7 @@ from pymc.model.core import modelcontext from pymc.pytensorf import _cheap_eval_mode -from pymc.util import VarName, get_default_varnames, get_var_name +from pymc.util import get_default_varnames, get_var_name __all__ = ( "ModelGraph", @@ -173,7 +173,7 @@ def default_data(var: Variable) -> GraphvizNodeKwargs: } -def get_node_type(var_name: VarName, model) -> NodeType: +def get_node_type(var_name: str, model) -> NodeType: """Return the node type of the variable in the model.""" v = model[var_name] @@ -242,7 +242,7 @@ def __init__(self, model): self._all_vars = {model[var_name] for var_name in self._all_var_names} self.var_list = self.model.named_vars.values() - def get_parent_names(self, var: Variable) -> set[VarName]: + def get_parent_names(self, var: Variable) -> set[str]: if var.owner is None: return set() @@ -261,12 +261,12 @@ def _expand(x): return x.owner.inputs return { - cast(VarName, ancestor.name) # type: ignore[union-attr] + cast(str, ancestor.name) # type: ignore[union-attr] for ancestor in walk(nodes=var.owner.inputs, expand=_expand) if ancestor in named_vars } - def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]: + def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]: if var_names is None: return self._all_var_names @@ -297,12 +297,12 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa return [get_var_name(var) for var in selected_ancestors] def make_compute_graph( - self, var_names: Iterable[VarName] | None = None - ) -> dict[VarName, set[VarName]]: + self, var_names: Iterable[str] | None = None + ) -> dict[str, set[str]]: """Get map of var_name -> set(input var names) for the model.""" model = self.model named_vars = self._all_vars - input_map: dict[VarName, set[VarName]] = defaultdict(set) + input_map: dict[str, set[str]] = defaultdict(set) var_names_to_plot = self.vars_to_plot(var_names) for var_name in var_names_to_plot: @@ -319,7 +319,7 @@ def make_compute_graph( for ancestor in ancestors([obs_var]): if ancestor not in named_vars: continue - obs_name = cast(VarName, ancestor.name) + obs_name = cast(str, ancestor.name) input_map[var_name].discard(obs_name) input_map[obs_name].add(var_name) @@ -327,7 +327,7 @@ def make_compute_graph( def get_plates( self, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, ) -> list[Plate]: """Rough but surprisingly accurate plate detection. @@ -337,7 +337,7 @@ def get_plates( Returns ------- dict - Maps plate labels to the set of ``VarName``s inside the plate. + Maps plate labels to the set of ``str``s inside the plate. """ plates = defaultdict(set) @@ -389,8 +389,8 @@ def get_plates( def edges( self, - var_names: Iterable[VarName] | None = None, - ) -> list[tuple[VarName, VarName]]: + var_names: Iterable[str] | None = None, + ) -> list[tuple[str, str]]: """Get edges between the variables in the model. Parameters @@ -405,7 +405,7 @@ def edges( """ return [ - (VarName(child.replace(":", "&")), VarName(parent.replace(":", "&"))) + (str(child.replace(":", "&")), str(parent.replace(":", "&"))) for child, parents in self.make_compute_graph(var_names=var_names).items() for parent in parents ] @@ -422,7 +422,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]: def make_graph( name: str, plates: list[Plate], - edges: list[tuple[VarName, VarName]], + edges: list[tuple[str, str]], formatting: str = "plain", save=None, figsize=None, @@ -496,7 +496,7 @@ def make_graph( def make_networkx( name: str, plates: list[Plate], - edges: list[tuple[VarName, VarName]], + edges: list[tuple[str, str]], formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length, @@ -566,7 +566,7 @@ def make_networkx( def model_to_networkx( model=None, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", node_formatters: NodeTypeFormatterMapping | None = None, include_dim_lengths: bool = True, @@ -660,7 +660,7 @@ def model_to_networkx( def model_to_graphviz( model=None, *, - var_names: Iterable[VarName] | None = None, + var_names: Iterable[str] | None = None, formatting: str = "plain", save: str | None = None, figsize: tuple[int, int] | None = None, diff --git a/pymc/util.py b/pymc/util.py index 3f108b8b03..a17c773dd4 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -18,7 +18,7 @@ from collections import namedtuple from collections.abc import Sequence from copy import deepcopy -from typing import NewType, cast +from typing import cast import arviz import cloudpickle @@ -31,7 +31,7 @@ from pymc.exceptions import BlockModelAccessError -VarName = NewType("VarName", str) + class _UnsetType: @@ -214,9 +214,9 @@ def get_default_varnames(var_iterator, include_transformed): return [var for var in var_iterator if not is_transformed_name(get_var_name(var))] -def get_var_name(var) -> VarName: +def get_var_name(var) -> str: """Get an appropriate, plain variable name for a variable.""" - return VarName(str(getattr(var, "name", var))) + return var.name if var.name is not None else str(var) def get_transformed(z): From cdf9e6804ca48baf864d9d9c0e1da895f26b0501 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:46:03 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc/model_graph.py | 4 +--- pymc/util.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/pymc/model_graph.py b/pymc/model_graph.py index b021baabfe..221fb715eb 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -296,9 +296,7 @@ def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]: # ordering of self._all_var_names is important return [get_var_name(var) for var in selected_ancestors] - def make_compute_graph( - self, var_names: Iterable[str] | None = None - ) -> dict[str, set[str]]: + def make_compute_graph(self, var_names: Iterable[str] | None = None) -> dict[str, set[str]]: """Get map of var_name -> set(input var names) for the model.""" model = self.model named_vars = self._all_vars diff --git a/pymc/util.py b/pymc/util.py index a17c773dd4..648e478f11 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -32,8 +32,6 @@ from pymc.exceptions import BlockModelAccessError - - class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs."""