Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def numba_typify(data, dtype=None, **kwargs):
return data


def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
def generate_fallback_impl(op, node, storage_map=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's not really optional, the code immediately assumes node is passed (and not None)

"""Create a Numba compatible function from a Pytensor `Op`."""

warnings.warn(
Expand Down
48 changes: 13 additions & 35 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
get_numba_type,
numba_funcify,
)
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import (
Expand Down Expand Up @@ -200,45 +204,19 @@ def ravelmultiindex(*inp):
@numba_funcify.register(Repeat)
def numba_funcify_Repeat(op, node, **kwargs):
axis = op.axis
a, _ = node.inputs

use_python = False

if axis is not None:
use_python = True

if use_python:
warnings.warn(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning,
)

ret_sig = get_numba_type(node.outputs[0].type)
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
if axis == 0 and a.type.ndim == 1:

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def repeatop(x, repeats):
with numba.objmode(ret=ret_sig):
ret = np.repeat(x, repeats, axis)
return ret
return np.repeat(x, repeats)

else:
repeats_ndim = node.inputs[1].ndim
return repeatop

if repeats_ndim == 0:

@numba_basic.numba_njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats.item())

else:

@numba_basic.numba_njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats)

return repeatop
else:
return generate_fallback_impl(op, node)


@numba_funcify.register(Unique)
Expand Down
173 changes: 76 additions & 97 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,53 +668,72 @@ class Repeat(Op):

__props__ = ("axis",)

def __init__(self, axis: int | None = None):
if axis is not None:
if not isinstance(axis, int) or axis < 0:
def __init__(self, axis: int):
if isinstance(axis, int):
if axis < 0:
raise ValueError(
f"Repeat only accepts positive integer axis or None, got {axis}"
f"Repeat Op only accepts positive integer axis, got {axis}. "
"Use the helper `pt.repeat` to handle negative axis."
)
elif axis is None:
raise ValueError(
"Repeat Op only accepts positive integer axis. "
"Use the helper `pt.repeat` to handle axis=None."
)
else:
raise TypeError(
f"Invalid type for axis {axis}, expected int got {type(axis)}"
)

self.axis = axis

def make_node(self, x, repeats):
x = ptb.as_tensor_variable(x)
repeats = ptb.as_tensor_variable(repeats, dtype="int64")

if repeats.dtype not in integer_dtypes:
raise TypeError("repeats.dtype must be an integer.")
if repeats.type.ndim != 1:
if repeats.type.ndim == 0:
raise ValueError(
f"repeats {repeats} must have 1 dimension, got 0. Use the helper `pt.repeat` to handle scalar repeats."
)
else:
raise ValueError(
f"repeats {repeats} must have 1 dimension, got {repeats.type.ndim}"
)

if repeats.type.dtype not in integer_dtypes:
raise TypeError(
f"repeats {repeats} dtype must be an integer, got {repeats.type.dtype}."
)

# Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction
# time, not wait for execution.
ptr_bitwidth = LOCAL_BITWIDTH
if ptr_bitwidth == 64:
numpy_unsupported_dtypes = ("uint64",)
if ptr_bitwidth == 32:
numpy_unsupported_dtypes = ("uint32", "int64", "uint64")

if repeats.dtype in numpy_unsupported_dtypes:
numpy_unsupported_dtypes = (
("uint64",) if LOCAL_BITWIDTH == 64 else ("uint64", "uint32", "int64")
)
if repeats.type.dtype in numpy_unsupported_dtypes:
raise TypeError(
(
f"dtypes {numpy_unsupported_dtypes!s} are not supported by numpy.repeat "
"for the 'repeats' parameter, "
),
repeats.dtype,
f"repeats {repeats} dtype {repeats.type.dtype} are not supported by numpy.repeat"
)

if self.axis is None:
out_shape = [None]
else:
shape = list(x.type.shape)
axis_input_dim_length = shape[self.axis]
axis_output_dim_length = None

if axis_input_dim_length is not None:
# If we have a static dim and constant repeats we can infer the length of the output dim
# Right now we only support homogenous constant repeats
try:
const_reps = ptb.get_scalar_constant_value(repeats)
const_reps = ptb.get_underlying_scalar_constant_value(repeats)
except NotScalarConstantError:
const_reps = None
if const_reps == 1:
out_shape = x.type.shape
pass
else:
out_shape = list(x.type.shape)
out_shape[self.axis] = None
axis_output_dim_length = int(const_reps * axis_input_dim_length)

shape[self.axis] = axis_output_dim_length

out_type = TensorType(x.dtype, shape=out_shape)
out_type = TensorType(x.dtype, shape=shape)
return Apply(self, [x, repeats], [out_type()])

def perform(self, node, inputs, output_storage):
Expand All @@ -728,36 +747,16 @@ def grad(self, inputs, gout):
(x, repeats) = inputs
(gz,) = gout
axis = self.axis
if repeats.ndim == 0:
# When axis is a scalar (same number of reps for all elements),
# We can split the repetitions into their own axis with reshape and sum them back
# to the original element location
sum_axis = x.ndim if axis is None else axis + 1
shape = list(x.shape)
shape.insert(sum_axis, repeats)
gx = gz.reshape(shape).sum(axis=sum_axis)

elif repeats.ndim == 1:
# To sum the gradients that belong to the same repeated x,
# We create a repeated eye and dot product it with the gradient.
axis_size = x.size if axis is None else x.shape[axis]
repeated_eye = repeat(
ptb.eye(axis_size), repeats, axis=0
) # A sparse repeat would be neat

if axis is None:
gx = gz @ repeated_eye
# Undo the ravelling when axis=None
gx = gx.reshape(x.shape)
else:
# Place gradient axis at end for dot product
gx = ptb.moveaxis(gz, axis, -1)
gx = gx @ repeated_eye
# Place gradient back into the correct axis
gx = ptb.moveaxis(gx, -1, axis)

else:
raise ValueError()
# Use IncSubtensor to sum the gradients that belong to the repeated entries of x
axis_size = x.shape[axis]
repeated_arange = repeat(ptb.arange(axis_size), repeats, axis=0)

# Move the axis to repeat to front for easier indexing
x_transpose = ptb.moveaxis(x, axis, 0)
gz_transpose = ptb.moveaxis(gz, axis, 0)
gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose)
gx = ptb.moveaxis(gx_transpose, 0, axis)

return [gx, DisconnectedType()()]

Expand All @@ -771,22 +770,8 @@ def infer_shape(self, fgraph, node, ins_shapes):
dtype = None
if repeats.dtype in ("uint8", "uint16", "uint32"):
dtype = "int64"
if axis is None:
if repeats.ndim == 0:
if len(i0_shapes) == 0:
out_shape = [repeats]
else:
res = 1
for d in i0_shapes:
res = res * d
out_shape = (res * repeats,)
else:
out_shape = [pt_sum(repeats, dtype=dtype)]
else:
if repeats.ndim == 0:
out_shape[axis] = out_shape[axis] * repeats
else:
out_shape[axis] = pt_sum(repeats, dtype=dtype)

out_shape[axis] = pt_sum(repeats, dtype=dtype)
return [out_shape]


Expand Down Expand Up @@ -851,48 +836,42 @@ def repeat(
"""
a = ptb.as_tensor_variable(a)

if axis is not None:
if axis is None:
axis = 0
a = a.flatten()
else:
axis = normalize_axis_index(axis, a.ndim)

repeats = ptb.as_tensor_variable(repeats, dtype=np.int64)

if repeats.ndim > 1:
raise ValueError("The dimension of repeats should not exceed 1.")

if repeats.ndim == 1 and not repeats.broadcastable[0]:
if repeats.type.broadcastable == (True,):
# This behaves the same as scalar repeat
repeats = repeats.squeeze()

if repeats.ndim == 1:
# We only use the Repeat Op for vector repeats
return Repeat(axis=axis)(a, repeats)
else:
if repeats.ndim == 1:
repeats = repeats[0]

if a.dtype == "uint64":
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
# Which is not valid for the `reshape` operation at the end
raise TypeError("repeat doesn't support dtype uint64")

if axis is None:
axis = 0
a = a.flatten()

repeat_shape = list(a.shape)
# Scalar repeat, we implement this with canonical Ops broadcast + reshape
a_shape = a.shape

# alloc_shape is the shape of the intermediate tensor which has
# an additional dimension comparing to x. We use alloc to
# allocate space for this intermediate tensor to replicate x
# along that additional dimension.
alloc_shape = repeat_shape[:]
alloc_shape.insert(axis + 1, repeats)
# Replicate a along a new axis (axis+1) repeats times
broadcast_shape = list(a_shape)
broadcast_shape.insert(axis + 1, repeats)
broadcast_a = broadcast_to(ptb.expand_dims(a, axis + 1), broadcast_shape)

# repeat_shape is now the shape of output, where shape[axis] becomes
# shape[axis]*repeats.
# Reshape broadcast_a to the final shape, merging axis and axis+1
repeat_shape = list(a_shape)
repeat_shape[axis] = repeat_shape[axis] * repeats

# After the original tensor is duplicated along the additional
# dimension, we reshape it to the expected output shape
return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape(
repeat_shape
)
return broadcast_a.reshape(repeat_shape)


class Bartlett(Op):
Expand Down
1 change: 0 additions & 1 deletion pytensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

Note that according to Python documentation, `platform.architecture()` is
not reliable on OS X with universal binaries.
Also, sys.maxsize does not exist in Python < 2.6.
'P' denotes a void*, and the size is expressed in bytes.
"""

Expand Down
18 changes: 3 additions & 15 deletions tests/link/numba/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,15 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
@pytest.mark.parametrize(
"x, repeats, axis, exc",
[
(
(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lscalar(), np.array(0, dtype="int64")),
None,
None,
),
(
(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
None,
None,
),
(
(pt.lvector(), np.arange(2, dtype="int64")),
(pt.lvector(), np.array([1, 1], dtype="int64")),
None,
(pt.lvector(), np.array([1, 3], dtype="int64")),
0,
None,
),
(
(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lvector(), np.array([1, 3], dtype="int64")),
0,
UserWarning,
),
Expand Down
Loading