-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
copy random states when combining random variables (#207)
- Loading branch information
1 parent
080a380
commit 832fb25
Showing
3 changed files
with
53 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from tests.testing import NumpyAssertions | ||
from probnum.utils import randomutils | ||
|
||
|
||
class RandomUtilsTestCase(unittest.TestCase, NumpyAssertions): | ||
"""Test case for utility functions handling objects dealing with randomness.""" | ||
|
||
def setUp(self) -> None: | ||
self.seed = 42 | ||
self.random_state = np.random.RandomState(seed=self.seed) | ||
self.random_generator = np.random.default_rng(seed=self.seed) | ||
self.random_generator_list = [np.random.default_rng(seed=s) for s in range(5)] | ||
|
||
def test_derive_random_seed_invariant_random_state(self): | ||
""" | ||
Test whether deriving a random seed leaves the original random states | ||
untouched. | ||
""" | ||
# Original random states | ||
rs_state = self.random_state.get_state()[1] | ||
rng_state = self.random_generator.bit_generator.state["state"]["state"] | ||
rng_list_states = [ | ||
rng.bit_generator.state["state"]["state"] | ||
for rng in self.random_generator_list | ||
] | ||
|
||
# Combine RandomState and Generator object | ||
_ = randomutils.derive_random_seed(self.random_state, self.random_generator) | ||
self.assertArrayEqual(rs_state, self.random_state.get_state()[1]) | ||
self.assertEqual( | ||
rng_state, self.random_generator.bit_generator.state["state"]["state"] | ||
) | ||
|
||
# Combine list of generators | ||
_ = randomutils.derive_random_seed(*self.random_generator_list) | ||
self.assertArrayEqual( | ||
rng_list_states, | ||
[ | ||
rng.bit_generator.state["state"]["state"] | ||
for rng in self.random_generator_list | ||
], | ||
) |