Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from copy import copy, deepcopy
from copy import copy
from functools import singledispatch
from textwrap import dedent

Expand All @@ -25,6 +25,7 @@
)
from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.utils import custom_rng_deepcopy
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature

Expand All @@ -34,7 +35,7 @@ def copy_NumPyRandomGenerator(rng):
def impl(rng):
# TODO: Open issue on Numba?
with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(rng)
new_rng = custom_rng_deepcopy(rng)

return new_rng

Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, cast

import numpy as np
Expand All @@ -23,6 +22,7 @@
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
custom_rng_deepcopy,
explicit_expand_dims,
normalize_size_param,
)
Expand Down Expand Up @@ -421,7 +421,7 @@ def perform(self, node, inputs, outputs):

# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
if not self.inplace:
rng = deepcopy(rng)
rng = custom_rng_deepcopy(rng)

outputs[0][0] = rng
outputs[1][0] = np.asarray(
Expand Down
13 changes: 13 additions & 0 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Callable, Sequence
from copy import deepcopy
from functools import wraps
from itertools import zip_longest
from types import ModuleType
from typing import TYPE_CHECKING

import numpy as np
from numpy.random import Generator

from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable
Expand Down Expand Up @@ -201,6 +203,17 @@ def normalize_size_param(
return shape


# NOTE:
# This helper exists because copying numpy.random.Generator via deepcopy is slow.
# NumPy may implement a faster clone/copy API in the future:
# https://github.com/numpy/numpy/issues/24086
def custom_rng_deepcopy(rng):
old_bitgen = rng.bit_generator
new_bitgen = type(old_bitgen)(deepcopy(old_bitgen._seed_seq))
new_bitgen.state = old_bitgen.state
return Generator(new_bitgen)


class RandomStream:
"""Module component with similar interface to `numpy.random.Generator`.

Expand Down
40 changes: 40 additions & 0 deletions tests/tensor/random/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import timeit
from copy import deepcopy

import numpy as np
import pytest

Expand All @@ -7,6 +10,7 @@
from pytensor.tensor.random.utils import (
RandomStream,
broadcast_params,
custom_rng_deepcopy,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.type import matrix, tensor
Expand Down Expand Up @@ -327,3 +331,39 @@ def test_supp_shape_from_ref_param_shape():
ref_param_idx=1,
)
assert res == (3, 4)


def test_custom_rng_deepcopy_matches_deepcopy():
rng = np.random.default_rng(123)

dp = deepcopy(rng).bit_generator
fc = custom_rng_deepcopy(rng).bit_generator

# Same state
assert dp.state == fc.state
# Same seed sequence
assert dp.seed_seq.state == fc.seed_seq.state


def test_custom_rng_deepcopy_output_identical():
rng = np.random.default_rng(123)

rng1 = deepcopy(rng)
rng2 = custom_rng_deepcopy(rng)

# Generate numbers from each
x1 = rng1.normal(size=10)
x2 = rng2.normal(size=10)

assert np.allclose(x1, x2)


@pytest.mark.performance
def test_custom_rng_deepcopy_faster_than_deepcopy():
rng = np.random.default_rng()

t_dp = timeit.timeit(lambda: deepcopy(rng), number=2000)
t_fc = timeit.timeit(lambda: custom_rng_deepcopy(rng), number=2000)

# Fast copy should be at least 20% faster
assert t_fc < t_dp * 0.8
Comment on lines +361 to +369
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need, and we have a different framework for benchmarking

Loading