From dab15cc78b86968401bf0c082b870f75527a7fad Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Nov 2025 14:28:31 +0100 Subject: [PATCH 1/4] Change TypedList Count and Index output to int64 --- pytensor/typed_list/basic.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index 08ac20e5a2..49e83cbebe 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -2,11 +2,10 @@ import pytensor.tensor as pt from pytensor.compile.debugmode import _lessbroken_deepcopy -from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.op import Op from pytensor.link.c.op import COp -from pytensor.tensor.type import scalar +from pytensor.tensor.type import lscalar from pytensor.tensor.type_other import SliceType from pytensor.tensor.variable import TensorVariable from pytensor.typed_list.type import TypedListType @@ -508,7 +507,7 @@ class Index(Op): def make_node(self, x, elem): assert isinstance(x.type, TypedListType) assert x.ttype == elem.type - return Apply(self, [x, elem], [scalar()]) + return Apply(self, [x, elem], [lscalar()]) def perform(self, node, inputs, outputs): """ @@ -520,7 +519,7 @@ def perform(self, node, inputs, outputs): (out,) = outputs for y in range(len(x)): if node.inputs[0].ttype.values_eq(x[y], elem): - out[0] = np.asarray(y, dtype=config.floatX) + out[0] = np.asarray(y, dtype="int64") break def __str__(self): @@ -537,7 +536,7 @@ class Count(Op): def make_node(self, x, elem): assert isinstance(x.type, TypedListType) assert x.ttype == elem.type - return Apply(self, [x, elem], [scalar()]) + return Apply(self, [x, elem], [lscalar()]) def perform(self, node, inputs, outputs): """ @@ -551,7 +550,7 @@ def perform(self, node, inputs, outputs): for y in range(len(x)): if node.inputs[0].ttype.values_eq(x[y], elem): out[0] += 1 - out[0] = np.asarray(out[0], dtype=config.floatX) + out[0] = np.asarray(out[0], "int64") def __str__(self): return self.__class__.__name__ @@ -583,7 +582,7 @@ class Length(COp): def make_node(self, x): assert isinstance(x.type, TypedListType) - return Apply(self, [x], [scalar(dtype="int64")]) + return Apply(self, [x], [lscalar()]) def perform(self, node, x, outputs): (out,) = outputs From 3a65fab2c271d76b0ec2f66bcfaa5cb503a0ef28 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Nov 2025 15:06:33 +0100 Subject: [PATCH 2/4] len(TypedList) cannot be symbolic --- pytensor/typed_list/basic.py | 13 +++++++++---- tests/typed_list/test_basic.py | 9 +++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pytensor/typed_list/basic.py b/pytensor/typed_list/basic.py index 49e83cbebe..188581b2c9 100644 --- a/pytensor/typed_list/basic.py +++ b/pytensor/typed_list/basic.py @@ -12,12 +12,14 @@ class _typed_list_py_operators: + def __len__(self): + raise TypeError( + "Cannot call len(TypedList). Use `.length()` method instead for the symbolic equivalent." + ) + def __getitem__(self, index): return getitem(self, index) - def __len__(self): - return length(self) - def append(self, toAppend): return append(self, toAppend) @@ -36,10 +38,13 @@ def reverse(self): def count(self, elem): return count(self, elem) - # name "index" is already used by an attribute + # name "index" is claimed as an attribute of PyTensor Variable(s) def ind(self, elem): return index_(self, elem) + def length(self): + return length(self) + ttype = property(lambda self: self.type.ttype) dtype = property(lambda self: self.type.ttype.dtype) ndim = property(lambda self: self.type.ttype.ndim + 1) diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 8a9db1ae4b..feae2361a9 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest import scipy @@ -551,8 +553,11 @@ def test_interface(self): mySymbolicMatricesList = TypedListType( TensorType(pytensor.config.floatX, shape=(None, None)) )() - z = mySymbolicMatricesList.__len__() - + with pytest.raises( + TypeError, match=re.escape("Use `.length()` method instead") + ): + len(mySymbolicMatricesList) + z = mySymbolicMatricesList.length() f = pytensor.function([mySymbolicMatricesList], z) x = rand_ranged_matrix(-1000, 1000, [100, 101]) From c54bef1e76700b8cf87f9291422d4842615468f0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 25 Nov 2025 13:32:52 +0100 Subject: [PATCH 3/4] Implement extensible deepcopy in numba --- pytensor/link/numba/dispatch/compile_ops.py | 47 +++++++++++++++------ pytensor/link/numba/dispatch/random.py | 21 ++++----- pytensor/link/numba/dispatch/subtensor.py | 13 ++++++ 3 files changed, 58 insertions(+), 23 deletions(-) diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 266fa07d74..8eb73d0111 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -1,5 +1,7 @@ +from copy import deepcopy from hashlib import sha256 +import numba import numpy as np from pytensor.compile.builders import OpFromGraph @@ -15,7 +17,34 @@ register_funcify_default_op_cache_key, ) from pytensor.raise_op import CheckAndRaise -from pytensor.tensor.type import TensorType + + +def numba_deepcopy(x): + return deepcopy(x) + + +@numba.extending.overload(numba_deepcopy) +def numba_deepcopy_tensor(x): + if isinstance(x, numba.types.Number): + + def number_deepcopy(x): + return x + + return number_deepcopy + + if isinstance(x, numba.types.Array): + + def array_deepcopy(x): + return np.copy(x) + + return array_deepcopy + + if isinstance(x, numba.types.UnicodeType): + + def string_deepcopy(x): + return x + + return string_deepcopy @register_funcify_and_cache_key(OpFromGraph) @@ -64,19 +93,11 @@ def identity(x): @register_funcify_default_op_cache_key(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): - if isinstance(node.inputs[0].type, TensorType): - - @numba_basic.numba_njit - def deepcopy(x): - return np.copy(x) - - else: - - @numba_basic.numba_njit - def deepcopy(x): - return x + @numba_basic.numba_njit + def deepcopy(x): + return numba_deepcopy(x) - return deepcopy + return deepcopy, 1 @register_funcify_default_op_cache_key(IfElse) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index a20881db7a..bb3ffdff5c 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from copy import copy, deepcopy +from copy import deepcopy from functools import singledispatch from hashlib import sha256 from textwrap import dedent @@ -20,6 +20,7 @@ numba_funcify, register_funcify_and_cache_key, ) +from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.link.numba.dispatch.vectorize_codegen import ( _jit_options, _vectorized, @@ -35,16 +36,16 @@ from pytensor.tensor.utils import _parse_gufunc_signature -@overload(copy) -def copy_NumPyRandomGenerator(rng): - def impl(rng): - # TODO: Open issue on Numba? - with numba.objmode(new_rng=types.npy_rng): - new_rng = deepcopy(rng) +@numba.extending.overload(numba_deepcopy) +def numba_deepcopy_random_generator(x): + if isinstance(x, numba.types.NumPyRandomGeneratorType): - return new_rng + def random_generator_deepcopy(x): + with numba.objmode(new_rng=types.npy_rng): + new_rng = deepcopy(x) + return new_rng - return impl + return random_generator_deepcopy @singledispatch @@ -449,7 +450,7 @@ def random(core_shape, rng, size, *dist_params): def ov_random(core_shape, rng, size, *dist_params): def impl(core_shape, rng, size, *dist_params): if not inplace: - rng = copy(rng) + rng = numba_deepcopy(rng) draws = _vectorized( core_op_fn, diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index c7cc4cfd8e..51787daf41 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -18,6 +18,7 @@ register_funcify_and_cache_key, register_funcify_default_op_cache_key, ) +from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.tensor import TensorType from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( @@ -104,6 +105,18 @@ def in_seq_empty_tuple(x, y): enable_slice_boxing() +@numba.extending.overload(numba_deepcopy) +def numba_deepcopy_slice(x): + if isinstance(x, types.SliceType): + + def deepcopy_slice(x): + return slice( + numba_deepcopy(x.start), numba_deepcopy(x.stop), numba_deepcopy(x.step) + ) + + return deepcopy_slice + + @register_funcify_default_op_cache_key(MakeSlice) def numba_funcify_MakeSlice(op, **kwargs): @numba_basic.numba_njit From 88857674993910222b9ae0013da8fcca8fa62093 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 25 Nov 2025 13:34:05 +0100 Subject: [PATCH 4/4] Add support for TypedList in numba backend Note: Numba object mode fallback is not safe with lists --- pytensor/link/numba/dispatch/__init__.py | 1 + pytensor/link/numba/dispatch/basic.py | 8 +- pytensor/link/numba/dispatch/typed_list.py | 219 +++++++++++++++++++++ tests/link/numba/test_typed_list.py | 46 +++++ 4 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 pytensor/link/numba/dispatch/typed_list.py create mode 100644 tests/link/numba/test_typed_list.py diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 76751deb31..17d630bd2f 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -17,6 +17,7 @@ import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.subtensor import pytensor.link.numba.dispatch.tensor_basic +import pytensor.link.numba.dispatch.typed_list # isort: on diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 07fa376699..5f18b9561f 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -23,6 +23,7 @@ from pytensor.tensor.random.type import RandomGeneratorType from pytensor.tensor.type import TensorType from pytensor.tensor.utils import hash_from_ndarray +from pytensor.typed_list import TypedListType def _filter_numba_warnings(): @@ -132,6 +133,8 @@ def get_numba_type( return CSCMatrixType(numba_dtype) elif isinstance(pytensor_type, RandomGeneratorType): return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType") + elif isinstance(pytensor_type, TypedListType): + return numba.types.List(get_numba_type(pytensor_type.ttype)) else: raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") @@ -260,7 +263,10 @@ def numba_typify(data, dtype=None, **kwargs): def generate_fallback_impl(op, node, storage_map=None, **kwargs): - """Create a Numba compatible function from a Pytensor `Op`.""" + """Create a Numba compatible function from a Pytensor `Op`. + + Note limitations: https://numba.pydata.org/numba-doc/dev/user/withobjmode.html#the-objmode-context-manager + """ warnings.warn( f"Numba will use object mode to run {op}'s perform method. " diff --git a/pytensor/link/numba/dispatch/typed_list.py b/pytensor/link/numba/dispatch/typed_list.py new file mode 100644 index 0000000000..7ab45c1306 --- /dev/null +++ b/pytensor/link/numba/dispatch/typed_list.py @@ -0,0 +1,219 @@ +import numba +import numpy as np + +import pytensor.link.numba.dispatch.basic as numba_basic +from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key +from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy +from pytensor.tensor.type_other import SliceType +from pytensor.typed_list import ( + Append, + Count, + Extend, + GetItem, + Index, + Insert, + Length, + MakeList, + Remove, + Reverse, +) + + +def numba_all_equal(x, y): + if isinstance(x, np.ndarray) or isinstance(y, np.ndarray): + if not (isinstance(x, np.ndarray) and isinstance(y, np.ndarray)): + return False + return (x == y).all() + if isinstance(x, list) or isinstance(y, list): + if not (isinstance(x, list) and isinstance(y, list)): + return False + if len(x) != len(y): + return False + return all(numba_all_equal(xi, yi) for xi, yi in zip(x, y)) + return x == y + + +@numba.extending.overload(numba_all_equal) +def list_all_equal(x, y): + all_equal = None + + if isinstance(x, numba.types.List) and isinstance(y, numba.types.List): + + def all_equal(x, y): + if len(x) != len(y): + return False + for xi, yi in zip(x, y): + if not numba_all_equal(xi, yi): + return False + return True + + if isinstance(x, numba.types.Array) and isinstance(y, numba.types.Array): + + def all_equal(x, y): + return (x == y).all() + + if isinstance(x, numba.types.Number) and isinstance(y.numba.types.Number): + + def all_equal(x, y): + return x == y + + return all_equal + + +@numba.extending.overload(numba_deepcopy) +def numba_deepcopy_list(x): + if isinstance(x, numba.types.List): + + def deepcopy_list(x): + return [numba_deepcopy(xi) for xi in x] + + return deepcopy_list + + +@register_funcify_default_op_cache_key(MakeList) +def numba_funcify_make_list(op, node, **kwargs): + @numba_basic.numba_njit + def make_list(*args): + return [numba_deepcopy(arg) for arg in args] + + return make_list + + +@register_funcify_default_op_cache_key(Length) +def numba_funcify_list_length(op, node, **kwargs): + @numba_basic.numba_njit + def list_length(x): + return np.array(len(x), dtype=np.int64) + + return list_length + + +@register_funcify_default_op_cache_key(GetItem) +def numba_funcify_list_get_item(op, node, **kwargs): + if isinstance(node.inputs[1].type, SliceType): + + @numba_basic.numba_njit + def list_get_item_slice(x, index): + return x[index] + + return list_get_item_slice + + else: + + @numba_basic.numba_njit + def list_get_item_index(x, index): + return x[index.item()] + + return list_get_item_index + + +@register_funcify_default_op_cache_key(Reverse) +def numba_funcify_list_reverse(op, node, **kwargs): + inplace = op.inplace + + @numba_basic.numba_njit + def list_reverse(x): + if inplace: + z = x + else: + z = numba_deepcopy(x) + z.reverse() + return z + + return list_reverse + + +@register_funcify_default_op_cache_key(Append) +def numba_funcify_list_append(op, node, **kwargs): + inplace = op.inplace + + @numba_basic.numba_njit + def list_append(x, to_append): + if inplace: + z = x + else: + z = numba_deepcopy(x) + z.append(numba_deepcopy(to_append)) + return z + + return list_append + + +@register_funcify_default_op_cache_key(Extend) +def numba_funcify_list_extend(op, node, **kwargs): + inplace = op.inplace + + @numba_basic.numba_njit + def list_extend(x, to_append): + if inplace: + z = x + else: + z = numba_deepcopy(x) + z.extend(numba_deepcopy(to_append)) + return z + + return list_extend + + +@register_funcify_default_op_cache_key(Insert) +def numba_funcify_list_insert(op, node, **kwargs): + inplace = op.inplace + + @numba_basic.numba_njit + def list_insert(x, index, to_insert): + if inplace: + z = x + else: + z = numba_deepcopy(x) + z.insert(index.item(), numba_deepcopy(to_insert)) + return z + + return list_insert + + +@register_funcify_default_op_cache_key(Index) +def numba_funcify_list_index(op, node, **kwargs): + @numba_basic.numba_njit + def list_index(x, elem): + for idx, xi in enumerate(x): + if numba_all_equal(xi, elem): + break + return np.array(idx, dtype=np.int64) + + return list_index + + +@register_funcify_default_op_cache_key(Count) +def numba_funcify_list_count(op, node, **kwargs): + @numba_basic.numba_njit + def list_count(x, elem): + c = 0 + for xi in x: + if numba_all_equal(xi, elem): + c += 1 + return np.array(c, dtype=np.int64) + + return list_count + + +@register_funcify_default_op_cache_key(Remove) +def numba_funcify_list_remove(op, node, **kwargs): + inplace = op.inplace + + @numba_basic.numba_njit + def list_remove(x, to_remove): + if inplace: + z = x + else: + z = numba_deepcopy(x) + index_to_remove = -1 + for i, zi in enumerate(z): + if numba_all_equal(zi, to_remove): + index_to_remove = i + break + if index_to_remove == -1: + raise ValueError("list.remove(x): x not in list") + z.pop(index_to_remove) + return z + + return list_remove diff --git a/tests/link/numba/test_typed_list.py b/tests/link/numba/test_typed_list.py new file mode 100644 index 0000000000..78fa09f30c --- /dev/null +++ b/tests/link/numba/test_typed_list.py @@ -0,0 +1,46 @@ +import numpy as np + +from pytensor.tensor import matrix +from pytensor.typed_list import make_list +from tests.link.numba.test_basic import compare_numba_and_py + + +def test_list_basic_ops(): + x = matrix("x", shape=(3, None), dtype="int64") + l = make_list([x[0], x[2]]) + + x_test = np.arange(12).reshape(3, 4) + compare_numba_and_py([x], [l, l.length()], [x_test]) + + # Test nested list + ll = make_list([l, l, l]) + compare_numba_and_py([x], [ll, ll.length()], [x_test]) + + +def test_make_list_index_ops(): + x = matrix("x", shape=(3, None), dtype="int64") + l = make_list([x[0], x[2]]) + + x_test = np.arange(12).reshape(3, 4) + compare_numba_and_py([x], [l[-1], l[:-1], l.reverse()], [x_test]) + + +def test_make_list_extend_ops(): + x = matrix("x", shape=(3, None), dtype="int64") + l = make_list([x[0], x[2]]) + + x_test = np.arange(12).reshape(3, 4) + compare_numba_and_py( + [x], [l.append(x[1]), l.extend(l), l.insert(0, x[1])], [x_test] + ) + + +def test_make_list_find_ops(): + # Remove requires to first find it + x = matrix("x", shape=(3, None), dtype="int64") + y = x[0].type("y") + l = make_list([x[0], x[2], x[0], x[2]]) + + x_test = np.arange(12).reshape(3, 4) + test_y = x_test[2] + compare_numba_and_py([x, y], [l.ind(y), l.count(y), l.remove(y)], [x_test, test_y])