diff --git a/pymc/model/core.py b/pymc/model/core.py index 87b4df23f..b85cc802f 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -34,7 +34,7 @@ from pytensor.compile import DeepCopyOp, Function, get_mode from pytensor.compile.sharedvalue import SharedVariable -from pytensor.graph.basic import Constant, Variable, graph_inputs +from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -1241,15 +1241,13 @@ def register_rv( self.add_named_variable(rv_var, dims) self.set_initval(rv_var, initval) else: - if ( - isinstance(observed, TensorVariable) - and observed.owner is not None - and isinstance(observed.owner.op, MinibatchOp) - and total_size is None - ): - warnings.warn( - f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" - ) + if total_size is None and isinstance(observed, TensorVariable): + for node in ancestors([observed]): + if node.owner is not None and isinstance(node.owner.op, MinibatchOp): + warnings.warn( + f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" + ) + break if not is_valid_observed(observed): raise TypeError( "Variables that depend on other nodes cannot be used for observed data."