diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 36618ceb26..ba714acbc2 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from copy import copy, deepcopy +from copy import copy from functools import singledispatch from textwrap import dedent @@ -25,6 +25,7 @@ ) from pytensor.tensor import get_vector_length from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape +from pytensor.tensor.random.utils import custom_rng_deepcopy from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.utils import _parse_gufunc_signature @@ -34,7 +35,7 @@ def copy_NumPyRandomGenerator(rng): def impl(rng): # TODO: Open issue on Numba? with numba.objmode(new_rng=types.npy_rng): - new_rng = deepcopy(rng) + new_rng = custom_rng_deepcopy(rng) return new_rng diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 6891823576..ab15ca5649 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,7 +1,6 @@ import abc import warnings from collections.abc import Sequence -from copy import deepcopy from typing import Any, cast import numpy as np @@ -23,6 +22,7 @@ from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.utils import ( compute_batch_shape, + custom_rng_deepcopy, explicit_expand_dims, normalize_size_param, ) @@ -421,7 +421,7 @@ def perform(self, node, inputs, outputs): # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. if not self.inplace: - rng = deepcopy(rng) + rng = custom_rng_deepcopy(rng) outputs[0][0] = rng outputs[1][0] = np.asarray( diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 86628a81cb..c34530eb0d 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -1,10 +1,12 @@ from collections.abc import Callable, Sequence +from copy import deepcopy from functools import wraps from itertools import zip_longest from types import ModuleType from typing import TYPE_CHECKING import numpy as np +from numpy.random import Generator from pytensor.compile.sharedvalue import shared from pytensor.graph.basic import Constant, Variable @@ -201,6 +203,17 @@ def normalize_size_param( return shape +# NOTE: +# This helper exists because copying numpy.random.Generator via deepcopy is slow. +# NumPy may implement a faster clone/copy API in the future: +# https://github.com/numpy/numpy/issues/24086 +def custom_rng_deepcopy(rng): + old_bitgen = rng.bit_generator + new_bitgen = type(old_bitgen)(deepcopy(old_bitgen._seed_seq)) + new_bitgen.state = old_bitgen.state + return Generator(new_bitgen) + + class RandomStream: """Module component with similar interface to `numpy.random.Generator`. diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index aa761d2922..d6ecec1d16 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import numpy as np import pytest @@ -7,6 +9,7 @@ from pytensor.tensor.random.utils import ( RandomStream, broadcast_params, + custom_rng_deepcopy, supp_shape_from_ref_param_shape, ) from pytensor.tensor.type import matrix, tensor @@ -327,3 +330,28 @@ def test_supp_shape_from_ref_param_shape(): ref_param_idx=1, ) assert res == (3, 4) + + +def test_custom_rng_deepcopy_matches_deepcopy(): + rng = np.random.default_rng(123) + + dp = deepcopy(rng).bit_generator + fc = custom_rng_deepcopy(rng).bit_generator + + # Same state + assert dp.state == fc.state + # Same seed sequence + assert dp.seed_seq.state == fc.seed_seq.state + + +def test_custom_rng_deepcopy_output_identical(): + rng = np.random.default_rng(123) + + rng1 = deepcopy(rng) + rng2 = custom_rng_deepcopy(rng) + + # Generate numbers from each + x1 = rng1.normal(size=10) + x2 = rng2.normal(size=10) + + assert np.allclose(x1, x2)