Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/numba/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytensor.link.numba.dispatch.sparse
import pytensor.link.numba.dispatch.subtensor
import pytensor.link.numba.dispatch.tensor_basic
import pytensor.link.numba.dispatch.typed_list


# isort: on
8 changes: 7 additions & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytensor.tensor.random.type import RandomGeneratorType
from pytensor.tensor.type import TensorType
from pytensor.tensor.utils import hash_from_ndarray
from pytensor.typed_list import TypedListType


def _filter_numba_warnings():
Expand Down Expand Up @@ -132,6 +133,8 @@ def get_numba_type(
return CSCMatrixType(numba_dtype)
elif isinstance(pytensor_type, RandomGeneratorType):
return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType")
elif isinstance(pytensor_type, TypedListType):
return numba.types.List(get_numba_type(pytensor_type.ttype))
else:
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")

Expand Down Expand Up @@ -260,7 +263,10 @@ def numba_typify(data, dtype=None, **kwargs):


def generate_fallback_impl(op, node, storage_map=None, **kwargs):
"""Create a Numba compatible function from a Pytensor `Op`."""
"""Create a Numba compatible function from a Pytensor `Op`.

Note limitations: https://numba.pydata.org/numba-doc/dev/user/withobjmode.html#the-objmode-context-manager
"""

warnings.warn(
f"Numba will use object mode to run {op}'s perform method. "
Expand Down
47 changes: 34 additions & 13 deletions pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from copy import deepcopy
from hashlib import sha256

import numba
import numpy as np

from pytensor.compile.builders import OpFromGraph
Expand All @@ -15,7 +17,34 @@
register_funcify_default_op_cache_key,
)
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.type import TensorType


def numba_deepcopy(x):
return deepcopy(x)


@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_tensor(x):
if isinstance(x, numba.types.Number):

def number_deepcopy(x):
return x

return number_deepcopy

if isinstance(x, numba.types.Array):

def array_deepcopy(x):
return np.copy(x)

return array_deepcopy

if isinstance(x, numba.types.UnicodeType):

def string_deepcopy(x):
return x

return string_deepcopy


@register_funcify_and_cache_key(OpFromGraph)
Expand Down Expand Up @@ -64,19 +93,11 @@ def identity(x):

@register_funcify_default_op_cache_key(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):

@numba_basic.numba_njit
def deepcopy(x):
return np.copy(x)

else:

@numba_basic.numba_njit
def deepcopy(x):
return x
@numba_basic.numba_njit
def deepcopy(x):
return numba_deepcopy(x)

return deepcopy
return deepcopy, 1


@register_funcify_default_op_cache_key(IfElse)
Expand Down
21 changes: 11 additions & 10 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from copy import copy, deepcopy
from copy import deepcopy
from functools import singledispatch
from hashlib import sha256
from textwrap import dedent
Expand All @@ -20,6 +20,7 @@
numba_funcify,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
Expand All @@ -35,16 +36,16 @@
from pytensor.tensor.utils import _parse_gufunc_signature


@overload(copy)
def copy_NumPyRandomGenerator(rng):
def impl(rng):
# TODO: Open issue on Numba?
with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(rng)
@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_random_generator(x):
if isinstance(x, numba.types.NumPyRandomGeneratorType):

return new_rng
def random_generator_deepcopy(x):
with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(x)
return new_rng

return impl
return random_generator_deepcopy


@singledispatch
Expand Down Expand Up @@ -449,7 +450,7 @@ def random(core_shape, rng, size, *dist_params):
def ov_random(core_shape, rng, size, *dist_params):
def impl(core_shape, rng, size, *dist_params):
if not inplace:
rng = copy(rng)
rng = numba_deepcopy(rng)

draws = _vectorized(
core_op_fn,
Expand Down
13 changes: 13 additions & 0 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
Expand Down Expand Up @@ -104,6 +105,18 @@ def in_seq_empty_tuple(x, y):
enable_slice_boxing()


@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_slice(x):
if isinstance(x, types.SliceType):

def deepcopy_slice(x):
return slice(
numba_deepcopy(x.start), numba_deepcopy(x.stop), numba_deepcopy(x.step)
)

return deepcopy_slice


@register_funcify_default_op_cache_key(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba_basic.numba_njit
Expand Down
Loading
Loading