diff --git a/cirq/__init__.py b/cirq/__init__.py index b71ca7cfdb6..c39209b4b42 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -374,6 +374,8 @@ ) from cirq.study import ( + dict_to_product_sweep, + dict_to_zip_sweep, ExpressionMap, flatten, flatten_with_params, diff --git a/cirq/study/__init__.py b/cirq/study/__init__.py index 265c3267321..a9ddb6ca3da 100644 --- a/cirq/study/__init__.py +++ b/cirq/study/__init__.py @@ -42,6 +42,8 @@ Sweep, UnitSweep, Zip, + dict_to_product_sweep, + dict_to_zip_sweep, ) from cirq.study.trial_result import ( diff --git a/cirq/study/sweepable.py b/cirq/study/sweepable.py index b26ff4b577c..a40416b88ba 100644 --- a/cirq/study/sweepable.py +++ b/cirq/study/sweepable.py @@ -14,12 +14,13 @@ """Defines which types are Sweepable.""" -from typing import Dict, Iterable, Iterator, List, Union, cast -import itertools +from typing import Dict, Iterable, Iterator, List, Sequence, Union, cast +import warnings from cirq._doc import document from cirq.study.resolver import ParamResolver, ParamResolverOrSimilarType -from cirq.study.sweeps import ListSweep, Points, Sweep, UnitSweep, Zip +from cirq.study.sweeps import (ListSweep, Points, Sweep, UnitSweep, Zip, + dict_to_product_sweep) SweepLike = Union[ParamResolverOrSimilarType, Sweep] document( @@ -59,22 +60,16 @@ def to_sweeps(sweepable: Sweepable) -> List[Sweep]: if isinstance(sweepable, Sweep): return [sweepable] if isinstance(sweepable, dict): - # change dictionary of lists to list of dictionaries - # of single values using Cartesian product. - newsweepable = {} - for key, value in sweepable.items(): - if isinstance(value, Iterable): - newsweepable[key] = value - else: - newsweepable[key] = [value] - expandsweepable = [ - dict(zip(newsweepable.keys(), v)) - for v in itertools.product(*newsweepable.values()) - ] - return [ - _resolver_to_sweep(ParamResolver(cast(Dict, dictitem))) - for dictitem in expandsweepable - ] + if any(isinstance(val, Sequence) for val in sweepable.values()): + warnings.warn( + 'Implicit expansion of a dictionary into a Cartesian product ' + 'of sweeps is deprecated and will be removed in cirq 0.10. ' + 'Instead, expand the sweep explicitly using ' + '`cirq.dict_to_product_sweep`.', + DeprecationWarning, + stacklevel=2) + product_sweep = dict_to_product_sweep(sweepable) + return [_resolver_to_sweep(resolver) for resolver in product_sweep] if isinstance(sweepable, Iterable) and not isinstance(sweepable, str): return [ sweep for item in sweepable for sweep in to_sweeps( diff --git a/cirq/study/sweepable_test.py b/cirq/study/sweepable_test.py index 63429877264..43d941a1d60 100644 --- a/cirq/study/sweepable_test.py +++ b/cirq/study/sweepable_test.py @@ -83,15 +83,16 @@ def test_to_sweeps_iterable_sweeps(): def test_to_sweeps_dictionary_of_list(): - assert cirq.study.to_sweeps({'t': [0, 2, 3]}) == \ - cirq.study.to_sweeps([{'t': 0}, {'t': 2}, {'t': 3}]) - assert cirq.study.to_sweeps({'t': [0, 1], 's': [2, 3], 'r': 4}) == \ - cirq.study.to_sweeps([ - {'t': 0, 's': 2, 'r': 4}, - {'t': 0, 's': 3, 'r': 4}, - {'t': 1, 's': 2, 'r': 4}, - {'t': 1, 's': 3, 'r': 4}, - ]) + with pytest.warns(DeprecationWarning, match='dict_to_product_sweep'): + assert cirq.study.to_sweeps({'t': [0, 2, 3]}) == \ + cirq.study.to_sweeps([{'t': 0}, {'t': 2}, {'t': 3}]) + assert cirq.study.to_sweeps({'t': [0, 1], 's': [2, 3], 'r': 4}) == \ + cirq.study.to_sweeps([ + {'t': 0, 's': 2, 'r': 4}, + {'t': 0, 's': 3, 'r': 4}, + {'t': 1, 's': 2, 'r': 4}, + {'t': 1, 's': 3, 'r': 4}, + ]) def test_to_sweeps_invalid(): diff --git a/cirq/study/sweeps.py b/cirq/study/sweeps.py index 7ef05bce9d6..08abb0cf90e 100644 --- a/cirq/study/sweeps.py +++ b/cirq/study/sweeps.py @@ -26,6 +26,8 @@ import cirq Params = Iterable[Tuple['cirq.TParamKey', 'cirq.TParamVal']] +ProductOrZipSweepLike = Dict[ + 'cirq.TParamKey', Union['cirq.TParamVal', Sequence['cirq.TParamVal']]] def _check_duplicate_keys(sweeps): @@ -94,7 +96,7 @@ def __ne__(self, other): @property @abc.abstractmethod - def keys(self) -> List[str]: + def keys(self) -> List['cirq.TParamKey']: """The keys for the all of the sympy.Symbols that are resolved.""" @abc.abstractmethod @@ -175,7 +177,7 @@ def __eq__(self, other): return True @property - def keys(self) -> List[str]: + def keys(self) -> List['cirq.TParamKey']: return [] def __len__(self) -> int: @@ -212,7 +214,7 @@ def __hash__(self): return hash(tuple(self.factors)) @property - def keys(self) -> List[str]: + def keys(self) -> List['cirq.TParamKey']: return sum((factor.keys for factor in self.factors), []) def __len__(self) -> int: @@ -276,7 +278,7 @@ def __hash__(self) -> int: return hash(tuple(self.sweeps)) @property - def keys(self) -> List[str]: + def keys(self) -> List['cirq.TParamKey']: return sum((sweep.keys for sweep in self.sweeps), []) def __len__(self) -> int: @@ -303,7 +305,7 @@ def __str__(self) -> str: class SingleSweep(Sweep): """A simple sweep over one parameter with values from an iterator.""" - def __init__(self, key: Union[str, sympy.Symbol]) -> None: + def __init__(self, key: 'cirq.TParamKey') -> None: if isinstance(key, sympy.Symbol): key = str(key) self.key = key @@ -321,7 +323,7 @@ def _tuple(self) -> Tuple[Any, ...]: pass @property - def keys(self) -> List[str]: + def keys(self) -> List['cirq.TParamKey']: return [self.key] def param_tuples(self) -> Iterator[Params]: @@ -336,9 +338,8 @@ def _values(self) -> Iterator[float]: class Points(SingleSweep): """A simple sweep with explicitly supplied values.""" - def __init__( - self, key: Union[str, sympy.Symbol], - points: Sequence[float]) -> None: + def __init__(self, key: 'cirq.TParamKey', + points: Sequence['cirq.TParamVal']) -> None: super(Points, self).__init__(key) self.points = points @@ -358,11 +359,8 @@ def __repr__(self) -> str: class Linspace(SingleSweep): """A simple sweep over linearly-spaced values.""" - def __init__( - self, key: Union[str, sympy.Symbol], - start: float, - stop: float, - length: int) -> None: + def __init__(self, key: 'cirq.TParamKey', start: float, stop: float, + length: int) -> None: """Creates a linear-spaced sweep for a given key. For the given args, assigns to the list of values @@ -418,7 +416,7 @@ def __ne__(self, other): return not self == other @property - def keys(self) -> List[str]: + def keys(self) -> List['cirq.TParamKey']: if not self.resolver_list: return [] return list(map(str, self.resolver_list[0].param_dict)) @@ -439,3 +437,37 @@ def _params_without_symbols(resolver: resolver.ParamResolver) -> Params: if isinstance(sym, sympy.Symbol): sym = sym.name yield cast(str, sym), cast(float, val) + + +def dict_to_product_sweep(factor_dict: ProductOrZipSweepLike) -> Product: + """Cartesian product of sweeps from a dictionary. + + Each entry in the dictionary specifies a sweep as a mapping from the + parameter to a value or sequence of values. The Cartesian product of these + sweeps is returned. + + Args: + factor_dict: The dictionary containing the sweeps. + + Returns: + Cartesian product of the sweeps. + """ + return Product(*(Points(k, v if isinstance(v, Sequence) else [v]) + for k, v in factor_dict.items())) + + +def dict_to_zip_sweep(factor_dict: ProductOrZipSweepLike) -> Zip: + """Zip product of sweeps from a dictionary. + + Each entry in the dictionary specifies a sweep as a mapping from the + parameter to a value or sequence of values. The zip product of these + sweeps is returned. + + Args: + factor_dict: The dictionary containing the sweeps. + + Returns: + Zip product of the sweeps. + """ + return Zip(*(Points(k, v if isinstance(v, Sequence) else [v]) + for k, v in factor_dict.items())) diff --git a/cirq/study/sweeps_test.py b/cirq/study/sweeps_test.py index 0835101450a..a1732fa3a9c 100644 --- a/cirq/study/sweeps_test.py +++ b/cirq/study/sweeps_test.py @@ -280,3 +280,27 @@ def test_list_sweep_str(): {'a': 2.0, 'b': 2.0} {'a': 3.0, 'b': 1.0} {'a': 3.0, 'b': 2.0}''' + + +def test_dict_to_product_sweep(): + assert cirq.dict_to_product_sweep({'t': [0, 2, 3]}) == (cirq.Product( + cirq.Points('t', [0, 2, 3]))) + + assert cirq.dict_to_product_sweep({ + 't': [0, 1], + 's': [2, 3], + 'r': 4 + }) == (cirq.Product(cirq.Points('t', [0, 1]), cirq.Points('s', [2, 3]), + cirq.Points('r', [4]))) + + +def test_dict_to_zip_sweep(): + assert cirq.dict_to_zip_sweep({'t': [0, 2, 3] + }) == (cirq.Zip(cirq.Points('t', [0, 2, 3]))) + + assert cirq.dict_to_zip_sweep({ + 't': [0, 1], + 's': [2, 3], + 'r': 4 + }) == (cirq.Zip(cirq.Points('t', [0, 1]), cirq.Points('s', [2, 3]), + cirq.Points('r', [4]))) diff --git a/rtd_docs/api.rst b/rtd_docs/api.rst index e13cff99d00..4f6740da0c5 100644 --- a/rtd_docs/api.rst +++ b/rtd_docs/api.rst @@ -227,6 +227,8 @@ results. cirq.big_endian_digits_to_int cirq.big_endian_int_to_bits cirq.big_endian_int_to_digits + cirq.dict_to_product_sweep + cirq.dict_to_zip_sweep cirq.final_density_matrix cirq.final_state_vector cirq.flatten