diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index f628da59b7..ae28fbf15c 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -283,6 +283,9 @@ class PerformLinker(LocalLinker): """ + required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only") + incompatible_rewrites: tuple[str, ...] = ("cxx",) + def __init__( self, allow_gc: bool | None = None, schedule: Callable | None = None ) -> None: @@ -584,6 +587,9 @@ class JITLinker(PerformLinker): """ + required_rewrites: tuple[str, ...] = ("minimum_compile",) + incompatible_rewrites: tuple[str, ...] = () + @abstractmethod def fgraph_convert( self, fgraph, order, input_storage, output_storage, storage_map, **kwargs diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index bb5f519b01..7e1a779c2e 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -1787,8 +1787,6 @@ class OpWiseCLinker(LocalLinker): """ - __cache__: dict = {} - def __init__( self, fallback_on_perform=True, allow_gc=None, nice_errors=True, schedule=None ): diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index a2126855e4..8c9a2dc90f 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -812,6 +812,10 @@ class VMLinker(LocalLinker): """ + # We can only set these correctly after `__init__`, as it depends on `c_thunks` + required_rewrites: tuple[str, ...] = ("minimum_compile",) + incompatible_rewrites: tuple[str, ...] = () + def __init__( self, allow_gc=None, @@ -834,6 +838,9 @@ def __init__( self.lazy = lazy if c_thunks is None: c_thunks = bool(config.cxx) + if not c_thunks: + self.required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only") + self.incompatible_rewrites: tuple[str, ...] = ("cxx",) self.c_thunks = c_thunks self.allow_partial_eval = allow_partial_eval self.updated_vars = {} diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 7e2d8186fd..7521fd3828 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -76,6 +76,7 @@ from pytensor.graph.type import HasShape from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.link.c.basic import CLinker +from pytensor.link.vm import VMLinker from pytensor.printing import op_debug_information from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new from pytensor.tensor.basic import as_tensor_variable @@ -884,16 +885,24 @@ def tensorConstructor(shape, dtype): self.nit_sot_arg_offset = ( self.untraced_sit_sot_arg_offset + info.n_untraced_sit_sot_outs ) - # XXX: This doesn't include `info.n_nit_sot`s, so it's really a count + # Note: This doesn't include `info.n_nit_sot`s, so it's really a count # of the number of outputs generated by taps with inputs self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot self.n_tap_outs = info.n_mit_mot + info.n_mit_sot - # TODO: These can be moved to thunk/function compilation - ( - _, - self.mitmots_preallocated, - ) = self._mitmot_preallocations() + # Python and Cython perform methods provide the array location where a mitmot output should be + # stored to the VM as a symbolic update. This helper variable is used in the perform method for validation + mitmots_preallocated = [False] * info.n_mit_mot_outs + if config.scan__allow_output_prealloc: + for mitmot_idx in range(info.n_mit_mot): + for inp_tap in info.mit_mot_in_slices[mitmot_idx]: + if inp_tap in info.mit_mot_out_slices[mitmot_idx]: + # Figure out the index of the corresponding output + output_idx = sum( + len(m) for m in info.mit_mot_out_slices[:mitmot_idx] + ) + info.mit_mot_out_slices[mitmot_idx].index(inp_tap) + mitmots_preallocated[output_idx] = True + self.mitmots_preallocated = tuple(mitmots_preallocated) self.n_outer_inputs = info.n_outer_inputs self.n_outer_outputs = info.n_outer_outputs @@ -908,39 +917,6 @@ def tensorConstructor(shape, dtype): ) self._hash_inner_graph = hash(self._cmodule_key) - def _mitmot_preallocations(self): - if config.scan__allow_output_prealloc: - preallocated_mitmot_outs = [] - - info = self.info - input_idx = info.n_seqs - for mitmot_idx in range(info.n_mit_mot): - for inp_tap in info.mit_mot_in_slices[mitmot_idx]: - if inp_tap in info.mit_mot_out_slices[mitmot_idx]: - # Figure out the index of the corresponding output - output_idx = sum( - len(m) for m in info.mit_mot_out_slices[:mitmot_idx] - ) - output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap) - preallocated_mitmot_outs.append(output_idx) - - input_idx += 1 - - preallocated_mitmot_outs.sort() - - else: - # Output preallocation is not activated. Mark every mitmot output - # tap as not being preallocated - preallocated_mitmot_outs = [] - - # Store the list of mitmot output taps that have been altered so they - # can be preallocated - mitmots_preallocated = [ - i in preallocated_mitmot_outs for i in range(info.n_mit_mot_outs) - ] - - return preallocated_mitmot_outs, mitmots_preallocated - def __setstate__(self, d): self.__dict__.update(d) # Ensure that the graph associated with the inner function is valid. @@ -999,8 +975,8 @@ def make_node(self, *inputs): if n_outer_ins != n_inner_ins: raise ValueError( - "The number of inputs given to the inner function of scan" - " does not match the number of inputs given to scan." + f"The number of inputs given to the inner function of scan {n_inner_ins} " + f"does not match the number of inputs given to scan {n_outer_ins}." ) # Force the inputs to be on the CPU @@ -1483,11 +1459,26 @@ def fn(self): # Clone mode_instance, altering "allow_gc" for the linker, # and adding a message if we profile - mode_instance = get_mode(self.mode).clone( - link_kwargs=dict(allow_gc=self.allow_gc), - message=f"{self.name or 'Scan'} sub profile", - ) - + mode = self.mode + if mode in (None, "FAST_RUN"): + mode_instance = Mode("cvm", "fast_run") + elif mode == "FAST_COMPILE": + mode_instance = Mode( + VMLinker(use_cloop=False, c_thunks=False), "fast_compile" + ) + else: + mode_instance = get_mode(mode).clone( + link_kwargs=dict(allow_gc=self.allow_gc), + message=f"{self.name or 'Scan'} sub profile", + ) + # Scan python and cython perform relies on the VM being able to set updates for preallocated MIT-MOT, + # which only the VMs produced by VMLinker do + if any(self.mitmots_preallocated) and not isinstance( + mode_instance.linker, VMLinker + ): + raise NotImplementedError( + f"Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker, got {mode_instance.linker}" + ) self._fn = pfunc( wrapped_inputs, wrapped_outputs, @@ -2007,6 +1998,9 @@ def perform(self, node, inputs, output_storage): new_var = inner_input_storage[inner_inp_idx].storage[0] if old_var is new_var: old_data = old_mitmot_input_data[mitmot_inp_idx] + # This check is only valid if the VM performs updates + # Otherwise the output value may remain the same as the input, + # but doesn't mean that it has been setup correctly same_data = new_var.data == old_data else: same_data = False @@ -2051,14 +2045,8 @@ def perform(self, node, inputs, output_storage): old_data = old_inner_output_data[offset_out + j] if old_data is None: output_reused = False - elif isinstance( - self.fn.maker.fgraph.outputs[offset_out + j], TensorVariable - ): - output_reused = new_var.data == old_data else: - raise RuntimeError( - "FIXME: output_reused = new_var.gpudata == old_data" - ) + output_reused = new_var.data == old_data else: output_reused = False diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index 48c2a6c4a0..e5d8bf35e7 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -56,11 +56,12 @@ def test_NoOutputFromInplace(): def test_including(): mode = Mode(linker="py", optimizer="merge") - assert set(mode._optimizer.include) == {"minimum_compile", "merge"} + assert set(mode._optimizer.include) == {"minimum_compile", "py_only", "merge"} new_mode = mode.including("fast_compile") assert set(new_mode._optimizer.include) == { "minimum_compile", + "py_only", "merge", "fast_compile", } diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index b34e6ced28..d4f3e1bde1 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -34,10 +34,12 @@ from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.graph.traversal import ancestors from pytensor.graph.utils import MissingInputError +from pytensor.link.vm import VMLinker from pytensor.raise_op import assert_op from pytensor.scan.basic import scan -from pytensor.scan.op import Scan +from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.utils import until +from pytensor.tensor import as_tensor from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh from pytensor.tensor.math import sum as pt_sum @@ -4308,3 +4310,91 @@ def test_return_updates_api_change(): with pytest.raises(ValueError, match=err_msg): scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False) + + +@pytest.mark.parametrize( + "scan_mode", + [ + None, + "FAST_RUN", + "FAST_COMPILE", + Mode("cvm", optimizer=None), + Mode("vm", optimizer=None), + Mode("c", optimizer=None), + Mode("py", optimizer=None), + ], +) +def test_scan_mode_compatibility(scan_mode): + # Regression test for case where using Scan with a non-updating VM failed + + # Build a scan with one sequence and two MIT-MOTs + info = ScanInfo( + n_seqs=1, + mit_mot_in_slices=((0, 1), (0, 1)), + mit_mot_out_slices=((1,), (1,)), + mit_sot_in_slices=(), + sit_sot_in_slices=(), + n_nit_sot=0, + n_untraced_sit_sot_outs=0, + n_non_seqs=0, + as_while=False, + ) + bool_seq = pt.scalar(dtype="bool") + mitmot_A0, mitmot_A1, mitmot_B0, mitmot_B1 = [ + pt.matrix(shape=(2, 2)) for i in range(4) + ] + inputs = [ + bool_seq, + mitmot_A0, + mitmot_A1, + mitmot_B0, + mitmot_B1, + ] + outputs = [ + pt.add(bool_seq + mitmot_A0, mitmot_A1), + pt.add(bool_seq * mitmot_B0, mitmot_B1), + ] + + scan_op = Scan( + inputs, + outputs, + info=info, + mode=scan_mode, + ) + + n_steps = 5 + numerical_inputs = [ + np.array(n_steps, dtype="int64"), + np.array([1, 1, 0, 1, 0], dtype="bool"), + np.zeros(n_steps + 1)[:, None, None] * np.eye(2), + np.arange(n_steps + 1)[:, None, None] * np.eye(2), + ] + tensor_inputs = [as_tensor(inp, dtype=inp.dtype).type() for inp in numerical_inputs] + tensor_outputs = [o.sum() for o in scan_op(*tensor_inputs)] + + no_opt_mode = Mode(linker="py", optimizer=None) + # NotImplementedError should only be triggered when we try to compile the function + if ( + # Abstract modes should never fail + scan_mode not in (None, "FAST_RUN", "FAST_COMPILE") + # Only if the user tries something specific and incompatible + and not isinstance(get_mode(scan_mode).linker, VMLinker) + ): + with pytest.raises( + NotImplementedError, + match="Python/Cython implementation of Scan with preallocated MIT-MOT outputs requires a VMLinker", + ): + function(tensor_inputs, tensor_outputs, mode=no_opt_mode) + return + + fn = function(tensor_inputs, tensor_outputs, mode=no_opt_mode) + + # Check we have the expected Scan in the compiled function + [fn_scan_op] = [ + node.op for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + assert fn_scan_op.info == info + assert fn_scan_op.mitmots_preallocated == (True, True) + + # Expected value computed by running correct Scan once + np.testing.assert_allclose(fn(*numerical_inputs), [44, 38])