Skip to content

Commit

Permalink
Return unrecognized random state unmodified (#2622)
Browse files Browse the repository at this point in the history
* return unrecognized random state unmodified

* lint

* document type and fix docstring

* edit doc

* doc
  • Loading branch information
kevinsung committed Feb 12, 2020
1 parent 6ba76b8 commit 1b37885
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
41 changes: 35 additions & 6 deletions cirq/value/random_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast, Optional, Union
from typing import cast, Any

import numpy as np

RANDOM_STATE_LIKE = Optional[Union[np.random.RandomState, int]]
from cirq._doc import document

RANDOM_STATE_LIKE = Any
document(
RANDOM_STATE_LIKE,
"""A pseudorandom number generator or object that can be converted to one.
If None, turns into the module `np.random`.
If an integer, turns into a `np.random.RandomState` seeded with that
integer.
If none of the above, it is used unmodified. In this case, it is assumed
that the object implements whatever methods are required for the use case
at hand. For example, it might be an existing instance of
`np.random.RandomState` or a custom pseudorandom number generator
implementation.
""")


def parse_random_state(random_state: RANDOM_STATE_LIKE
) -> np.random.RandomState:
"""Interpret an object as a pseudorandom number generator.
If `random_state` is None, returns the module `np.random`.
If `random_state` is an integer, returns
`np.random.RandomState(random_state)`.
Otherwise, returns `random_state` unmodified.
Args:
random_state: The object to be used as or converted to a pseudorandom
number generator.
Returns:
The pseudorandom number generator object.
"""
if random_state is None:
return cast(np.random.RandomState, np.random)
elif (isinstance(random_state, np.random.RandomState) or
random_state == np.random):
return cast(np.random.RandomState, random_state)
elif isinstance(random_state, int):
return np.random.RandomState(random_state)
raise TypeError(f'Argument must be of type cirq.value.RANDOM_STATE_LIKE.')
else:
return cast(np.random.RandomState, random_state)
4 changes: 0 additions & 4 deletions cirq/value/random_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import numpy as np
import pytest

import cirq

Expand Down Expand Up @@ -43,6 +42,3 @@ def rand(prng):
vals = [prng.rand() for prng in prngs]
eq = cirq.testing.EqualsTester()
eq.add_equality_group(*vals)

with pytest.raises(TypeError):
cirq.value.parse_random_state('random_state')

0 comments on commit 1b37885

Please sign in to comment.