Skip to content

Commit

Permalink
copy random states when combining random variables (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanWenger committed Sep 7, 2020
1 parent 080a380 commit 832fb25
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/probnum/utils/randomutils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Union
import copy

import numpy as np


def derive_random_seed(*args: Union[np.random.RandomState, np.random.Generator]) -> int:
def _sample(rng: Union[np.random.RandomState, np.random.Generator]) -> int:
if isinstance(rng, np.random.RandomState):
return rng.randint(0, 2 ** 32, size=None, dtype=int)
return copy.copy(rng).randint(0, 2 ** 32, size=None, dtype=int)
elif isinstance(rng, np.random.Generator):
return rng.integers(0, 2 ** 32, size=None, dtype=int, endpoint=False)
return copy.copy(rng).integers(
0, 2 ** 32, size=None, dtype=int, endpoint=False
)
else:
raise ValueError("Unsupported type of random number generator")

Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils/test_fctutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class TestAssertsEvaluatesToScalar(unittest.TestCase):
"""Test case for utility functions dealing with functions."""

def test_assert_evaluates_to_scalar_pass(self):
def fct(x):
return np.linalg.norm(x)
Expand Down
46 changes: 46 additions & 0 deletions tests/test_utils/test_randomutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import unittest

import numpy as np

from tests.testing import NumpyAssertions
from probnum.utils import randomutils


class RandomUtilsTestCase(unittest.TestCase, NumpyAssertions):
"""Test case for utility functions handling objects dealing with randomness."""

def setUp(self) -> None:
self.seed = 42
self.random_state = np.random.RandomState(seed=self.seed)
self.random_generator = np.random.default_rng(seed=self.seed)
self.random_generator_list = [np.random.default_rng(seed=s) for s in range(5)]

def test_derive_random_seed_invariant_random_state(self):
"""
Test whether deriving a random seed leaves the original random states
untouched.
"""
# Original random states
rs_state = self.random_state.get_state()[1]
rng_state = self.random_generator.bit_generator.state["state"]["state"]
rng_list_states = [
rng.bit_generator.state["state"]["state"]
for rng in self.random_generator_list
]

# Combine RandomState and Generator object
_ = randomutils.derive_random_seed(self.random_state, self.random_generator)
self.assertArrayEqual(rs_state, self.random_state.get_state()[1])
self.assertEqual(
rng_state, self.random_generator.bit_generator.state["state"]["state"]
)

# Combine list of generators
_ = randomutils.derive_random_seed(*self.random_generator_list)
self.assertArrayEqual(
rng_list_states,
[
rng.bit_generator.state["state"]["state"]
for rng in self.random_generator_list
],
)

0 comments on commit 832fb25

Please sign in to comment.