From 556a6e20e4751a4309f3674882fc7f1863d48d91 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Thu, 3 Apr 2025 23:09:20 +0200 Subject: [PATCH 1/2] Check all ancestors for MinibatchOp --- pymc/model/core.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 87b4df23f..c6e53fa36 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -34,7 +34,7 @@ from pytensor.compile import DeepCopyOp, Function, get_mode from pytensor.compile.sharedvalue import SharedVariable -from pytensor.graph.basic import Constant, Variable, graph_inputs +from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -1241,15 +1241,13 @@ def register_rv( self.add_named_variable(rv_var, dims) self.set_initval(rv_var, initval) else: - if ( - isinstance(observed, TensorVariable) - and observed.owner is not None - and isinstance(observed.owner.op, MinibatchOp) - and total_size is None - ): - warnings.warn( - f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" - ) + if total_size is None and isinstance(node, TensorVariable): + for node in ancestors([observed]): + if node.owner is not None and isinstance(node.owner.op, MinibatchOp): + warnings.warn( + f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" + ) + break if not is_valid_observed(observed): raise TypeError( "Variables that depend on other nodes cannot be used for observed data." From e955f4473d6a440cc7d5d1e1f65d876d983b9c93 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Thu, 3 Apr 2025 23:45:24 +0200 Subject: [PATCH 2/2] Fix --- pymc/model/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index c6e53fa36..b85cc802f 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1241,7 +1241,7 @@ def register_rv( self.add_named_variable(rv_var, dims) self.set_initval(rv_var, initval) else: - if total_size is None and isinstance(node, TensorVariable): + if total_size is None and isinstance(observed, TensorVariable): for node in ancestors([observed]): if node.owner is not None and isinstance(node.owner.op, MinibatchOp): warnings.warn(