diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index 52a4e6f305..3b35cc2cd5 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -1331,7 +1331,11 @@ def printstuff(self): # the external requirements of the .linker attribute of a mode # 1) it's a class instance # 2) it a has a .clone() method +# 3) it has required_rewrites and incompatible_rewrites class attributes class _DummyLinker: + required_rewrites = () + incompatible_rewrites = () + # This is not a real linker anyway def clone(self, allow_gc=None): return self diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 7d8afb9136..8cf612bff9 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -352,7 +352,14 @@ def __setstate__(self, state): if isinstance(optimizer, str) or optimizer is None: optimizer = predefined_optimizers[optimizer] if isinstance(optimizer, RewriteDatabaseQuery): + # TODO: From the __init__ signature this should always be the case + # But some tests and internal logic allow passing a GraphRewriter directly as optimizer + # Cleanup! self.provided_optimizer = optimizer + if r := linker.required_rewrites: + optimizer = optimizer.including(*r) + if r := linker.incompatible_rewrites: + optimizer = optimizer.excluding(*r) self._optimizer = optimizer self.call_time = 0 self.fn_time = 0 @@ -365,14 +372,13 @@ def __str__(self): f"optdb={self.optdb})" ) - def __get_optimizer(self): + @property + def optimizer(self): if isinstance(self._optimizer, RewriteDatabaseQuery): return self.optdb.query(self._optimizer) else: return self._optimizer - optimizer = property(__get_optimizer) - def get_linker_optimizer(self, linker, optimizer): if isinstance(linker, str) or linker is None: linker = predefined_linkers[linker] @@ -466,61 +472,21 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): NUMBA = Mode( NumbaLinker(), - RewriteDatabaseQuery( - include=["fast_run", "numba"], - exclude=[ - "cxx_only", - "BlasOpt", - "local_careduce_fusion", - "scan_save_mem_prealloc", - ], - ), + RewriteDatabaseQuery(include=["fast_run", "numba"]), ) JAX = Mode( JAXLinker(), - RewriteDatabaseQuery( - include=["fast_run", "jax"], - exclude=[ - "cxx_only", - "BlasOpt", - "fusion", - "inplace", - "scan_save_mem_prealloc", - # There are specific variants for the LU decompositions supported by JAX - "reuse_lu_decomposition_multiple_solves", - "scan_split_non_sequence_lu_decomposition_solve", - ], - ), + RewriteDatabaseQuery(include=["fast_run", "jax"]), ) PYTORCH = Mode( PytorchLinker(), - RewriteDatabaseQuery( - include=["fast_run"], - exclude=[ - "cxx_only", - "BlasOpt", - "fusion", - "inplace", - "scan_save_mem_prealloc", - "reuse_lu_decomposition_multiple_solves", - "scan_split_non_sequence_lu_decomposition_solve", - ], - ), + RewriteDatabaseQuery(include=["fast_run"]), ) MLX = Mode( MLXLinker(), - RewriteDatabaseQuery( - include=["fast_run"], - exclude=[ - "cxx_only", - "BlasOpt", - "fusion", - "inplace", - "scan_save_mem_prealloc", - ], - ), + RewriteDatabaseQuery(include=["fast_run"]), ) diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index 1c79d28289..f628da59b7 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -157,6 +157,9 @@ class Linker(ABC): the FunctionGraph. """ + required_rewrites: tuple[str, ...] = ("minimum_compile",) + incompatible_rewrites: tuple[str, ...] = () + def __init__( self, *, @@ -656,21 +659,37 @@ def create_jitable_thunk( thunk_outputs = [storage_map[n] for n in self.fgraph.outputs] fgraph_jit = self.jit_compile(converted_fgraph) - def thunk( - fgraph_jit=fgraph_jit, - thunk_inputs=thunk_inputs, - thunk_outputs=thunk_outputs, - ): - try: - outputs = fgraph_jit(*(x[0] for x in thunk_inputs)) - except Exception: - # TODO: Should we add a fake node that combines all outputs, - # since the error may come from any of them? - raise_with_op(self.fgraph, output_nodes[0], thunk) + if thunk_outputs: - # zip strict not specified because we are in a hot loop - for o_storage, o_val in zip(thunk_outputs, outputs): - o_storage[0] = o_val + def thunk( + fgraph_jit=fgraph_jit, + thunk_inputs=thunk_inputs, + thunk_outputs=thunk_outputs, + ): + try: + outputs = fgraph_jit(*(x[0] for x in thunk_inputs)) + except Exception: + # TODO: Should we add a fake node that combines all outputs, + # since the error may come from any of them? + raise_with_op(self.fgraph, output_nodes[0], thunk) + + # zip strict not specified because we are in a hot loop + for o_storage, o_val in zip(thunk_outputs, outputs): + o_storage[0] = o_val + + else: + # Edge case - functions without outputs + def thunk( + fgraph_jit=fgraph_jit, + thunk_inputs=thunk_inputs, + thunk_outputs=thunk_outputs, + ): + try: + res = fgraph_jit(*(x[0] for x in thunk_inputs)) + except Exception: + raise_with_op(self.fgraph, output_nodes[0], thunk) + assert res is None + return thunk_outputs thunk.inputs = thunk_inputs thunk.outputs = thunk_outputs @@ -714,3 +733,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None): thunks, nodes, ) + + def __repr__(self): + # Assumes no subclass needs init arguments + return f"{self.__class__.__name__}()" diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index dd634e630c..6abf467824 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -9,6 +9,22 @@ class JAXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using JAX.""" + required_rewrites = ( + "minimum_compile", + "jax", + ) # TODO: Distinguish between optional "jax" and "minimum_compile_jax" + incompatible_rewrites = ( + "cxx", + "BlasOpt", + "local_careduce_fusion", + "scan_save_mem_prealloc", + # JAX does it his own inplace optimization + "inplace", + # There are specific variants for the LU decompositions supported by JAX + "reuse_lu_decomposition_multiple_solves", + "scan_split_non_sequence_lu_decomposition_solve", + ) + scalar_shape_inputs: tuple[int, ...] def __init__(self, *args, **kwargs): diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index b2f8674ea5..bccb38de7d 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -4,6 +4,14 @@ class MLXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" + incompatible_rewrites = ( + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "scan_save_mem_prealloc", + ) + def __init__(self, use_compile=True, *args, **kwargs): super().__init__(*args, **kwargs) self.gen_functors = [] diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 7cc6f34454..068df4a95b 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -2,6 +2,17 @@ class NumbaLinker(JITLinker): + required_rewrites = ( + "minimum_compile", + "numba", + ) # TODO: Distinguish between optional "numba" and "minimum_compile_numba" + incompatible_rewrites = ( + "cxx", + "BlasOpt", + "local_careduce_fusion", + "scan_save_mem_prealloc", + ) + """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" def fgraph_convert(self, fgraph, **kwargs): diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 18824a5b71..8c62e3577f 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -5,6 +5,16 @@ class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" + incompatible_rewrites = ( + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "scan_save_mem_prealloc", + "reuse_lu_decomposition_multiple_solves", + "scan_split_non_sequence_lu_decomposition_solve", + ) + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gen_functors = [] diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index 291eac0782..48c2a6c4a0 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -55,11 +55,15 @@ def test_NoOutputFromInplace(): def test_including(): - mode = Mode(optimizer="merge") - assert set(mode._optimizer.include) == {"merge"} + mode = Mode(linker="py", optimizer="merge") + assert set(mode._optimizer.include) == {"minimum_compile", "merge"} new_mode = mode.including("fast_compile") - assert set(new_mode._optimizer.include) == {"merge", "fast_compile"} + assert set(new_mode._optimizer.include) == { + "minimum_compile", + "merge", + "fast_compile", + } class TestBunchOfModes: