Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@
)

from cirq.study import (
dict_to_product_sweep,
dict_to_zip_sweep,
ExpressionMap,
flatten,
flatten_with_params,
Expand Down
2 changes: 2 additions & 0 deletions cirq/study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
Sweep,
UnitSweep,
Zip,
dict_to_product_sweep,
dict_to_zip_sweep,
)

from cirq.study.trial_result import (
Expand Down
33 changes: 14 additions & 19 deletions cirq/study/sweepable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

"""Defines which types are Sweepable."""

from typing import Dict, Iterable, Iterator, List, Union, cast
import itertools
from typing import Dict, Iterable, Iterator, List, Sequence, Union, cast
import warnings

from cirq._doc import document
from cirq.study.resolver import ParamResolver, ParamResolverOrSimilarType
from cirq.study.sweeps import ListSweep, Points, Sweep, UnitSweep, Zip
from cirq.study.sweeps import (ListSweep, Points, Sweep, UnitSweep, Zip,
dict_to_product_sweep)

SweepLike = Union[ParamResolverOrSimilarType, Sweep]
document(
Expand Down Expand Up @@ -59,22 +60,16 @@ def to_sweeps(sweepable: Sweepable) -> List[Sweep]:
if isinstance(sweepable, Sweep):
return [sweepable]
if isinstance(sweepable, dict):
# change dictionary of lists to list of dictionaries
# of single values using Cartesian product.
newsweepable = {}
for key, value in sweepable.items():
if isinstance(value, Iterable):
newsweepable[key] = value
else:
newsweepable[key] = [value]
expandsweepable = [
dict(zip(newsweepable.keys(), v))
for v in itertools.product(*newsweepable.values())
]
return [
_resolver_to_sweep(ParamResolver(cast(Dict, dictitem)))
for dictitem in expandsweepable
]
if any(isinstance(val, Sequence) for val in sweepable.values()):
warnings.warn(
'Implicit expansion of a dictionary into a Cartesian product '
'of sweeps is deprecated and will be removed in cirq 0.10. '
'Instead, expand the sweep explicitly using '
'`cirq.dict_to_product_sweep`.',
DeprecationWarning,
stacklevel=2)
product_sweep = dict_to_product_sweep(sweepable)
return [_resolver_to_sweep(resolver) for resolver in product_sweep]
if isinstance(sweepable, Iterable) and not isinstance(sweepable, str):
return [
sweep for item in sweepable for sweep in to_sweeps(
Expand Down
19 changes: 10 additions & 9 deletions cirq/study/sweepable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,16 @@ def test_to_sweeps_iterable_sweeps():


def test_to_sweeps_dictionary_of_list():
assert cirq.study.to_sweeps({'t': [0, 2, 3]}) == \
cirq.study.to_sweeps([{'t': 0}, {'t': 2}, {'t': 3}])
assert cirq.study.to_sweeps({'t': [0, 1], 's': [2, 3], 'r': 4}) == \
cirq.study.to_sweeps([
{'t': 0, 's': 2, 'r': 4},
{'t': 0, 's': 3, 'r': 4},
{'t': 1, 's': 2, 'r': 4},
{'t': 1, 's': 3, 'r': 4},
])
with pytest.warns(DeprecationWarning, match='dict_to_product_sweep'):
assert cirq.study.to_sweeps({'t': [0, 2, 3]}) == \
cirq.study.to_sweeps([{'t': 0}, {'t': 2}, {'t': 3}])
assert cirq.study.to_sweeps({'t': [0, 1], 's': [2, 3], 'r': 4}) == \
cirq.study.to_sweeps([
{'t': 0, 's': 2, 'r': 4},
{'t': 0, 's': 3, 'r': 4},
{'t': 1, 's': 2, 'r': 4},
{'t': 1, 's': 3, 'r': 4},
])


def test_to_sweeps_invalid():
Expand Down
62 changes: 47 additions & 15 deletions cirq/study/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import cirq

Params = Iterable[Tuple['cirq.TParamKey', 'cirq.TParamVal']]
ProductOrZipSweepLike = Dict[
'cirq.TParamKey', Union['cirq.TParamVal', Sequence['cirq.TParamVal']]]


def _check_duplicate_keys(sweeps):
Expand Down Expand Up @@ -94,7 +96,7 @@ def __ne__(self, other):

@property
@abc.abstractmethod
def keys(self) -> List[str]:
def keys(self) -> List['cirq.TParamKey']:
"""The keys for the all of the sympy.Symbols that are resolved."""

@abc.abstractmethod
Expand Down Expand Up @@ -175,7 +177,7 @@ def __eq__(self, other):
return True

@property
def keys(self) -> List[str]:
def keys(self) -> List['cirq.TParamKey']:
return []

def __len__(self) -> int:
Expand Down Expand Up @@ -212,7 +214,7 @@ def __hash__(self):
return hash(tuple(self.factors))

@property
def keys(self) -> List[str]:
def keys(self) -> List['cirq.TParamKey']:
return sum((factor.keys for factor in self.factors), [])

def __len__(self) -> int:
Expand Down Expand Up @@ -276,7 +278,7 @@ def __hash__(self) -> int:
return hash(tuple(self.sweeps))

@property
def keys(self) -> List[str]:
def keys(self) -> List['cirq.TParamKey']:
return sum((sweep.keys for sweep in self.sweeps), [])

def __len__(self) -> int:
Expand All @@ -303,7 +305,7 @@ def __str__(self) -> str:
class SingleSweep(Sweep):
"""A simple sweep over one parameter with values from an iterator."""

def __init__(self, key: Union[str, sympy.Symbol]) -> None:
def __init__(self, key: 'cirq.TParamKey') -> None:
if isinstance(key, sympy.Symbol):
key = str(key)
self.key = key
Expand All @@ -321,7 +323,7 @@ def _tuple(self) -> Tuple[Any, ...]:
pass

@property
def keys(self) -> List[str]:
def keys(self) -> List['cirq.TParamKey']:
return [self.key]

def param_tuples(self) -> Iterator[Params]:
Expand All @@ -336,9 +338,8 @@ def _values(self) -> Iterator[float]:
class Points(SingleSweep):
"""A simple sweep with explicitly supplied values."""

def __init__(
self, key: Union[str, sympy.Symbol],
points: Sequence[float]) -> None:
def __init__(self, key: 'cirq.TParamKey',
points: Sequence['cirq.TParamVal']) -> None:
super(Points, self).__init__(key)
self.points = points

Expand All @@ -358,11 +359,8 @@ def __repr__(self) -> str:
class Linspace(SingleSweep):
"""A simple sweep over linearly-spaced values."""

def __init__(
self, key: Union[str, sympy.Symbol],
start: float,
stop: float,
length: int) -> None:
def __init__(self, key: 'cirq.TParamKey', start: float, stop: float,
length: int) -> None:
"""Creates a linear-spaced sweep for a given key.

For the given args, assigns to the list of values
Expand Down Expand Up @@ -418,7 +416,7 @@ def __ne__(self, other):
return not self == other

@property
def keys(self) -> List[str]:
def keys(self) -> List['cirq.TParamKey']:
if not self.resolver_list:
return []
return list(map(str, self.resolver_list[0].param_dict))
Expand All @@ -439,3 +437,37 @@ def _params_without_symbols(resolver: resolver.ParamResolver) -> Params:
if isinstance(sym, sympy.Symbol):
sym = sym.name
yield cast(str, sym), cast(float, val)


def dict_to_product_sweep(factor_dict: ProductOrZipSweepLike) -> Product:
"""Cartesian product of sweeps from a dictionary.

Each entry in the dictionary specifies a sweep as a mapping from the
parameter to a value or sequence of values. The Cartesian product of these
sweeps is returned.

Args:
factor_dict: The dictionary containing the sweeps.

Returns:
Cartesian product of the sweeps.
"""
return Product(*(Points(k, v if isinstance(v, Sequence) else [v])
for k, v in factor_dict.items()))


def dict_to_zip_sweep(factor_dict: ProductOrZipSweepLike) -> Zip:
"""Zip product of sweeps from a dictionary.

Each entry in the dictionary specifies a sweep as a mapping from the
parameter to a value or sequence of values. The zip product of these
sweeps is returned.

Args:
factor_dict: The dictionary containing the sweeps.

Returns:
Zip product of the sweeps.
"""
return Zip(*(Points(k, v if isinstance(v, Sequence) else [v])
for k, v in factor_dict.items()))
24 changes: 24 additions & 0 deletions cirq/study/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,27 @@ def test_list_sweep_str():
{'a': 2.0, 'b': 2.0}
{'a': 3.0, 'b': 1.0}
{'a': 3.0, 'b': 2.0}'''


def test_dict_to_product_sweep():
assert cirq.dict_to_product_sweep({'t': [0, 2, 3]}) == (cirq.Product(
cirq.Points('t', [0, 2, 3])))

assert cirq.dict_to_product_sweep({
't': [0, 1],
's': [2, 3],
'r': 4
}) == (cirq.Product(cirq.Points('t', [0, 1]), cirq.Points('s', [2, 3]),
cirq.Points('r', [4])))


def test_dict_to_zip_sweep():
assert cirq.dict_to_zip_sweep({'t': [0, 2, 3]
}) == (cirq.Zip(cirq.Points('t', [0, 2, 3])))

assert cirq.dict_to_zip_sweep({
't': [0, 1],
's': [2, 3],
'r': 4
}) == (cirq.Zip(cirq.Points('t', [0, 1]), cirq.Points('s', [2, 3]),
cirq.Points('r', [4])))
2 changes: 2 additions & 0 deletions rtd_docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ results.
cirq.big_endian_digits_to_int
cirq.big_endian_int_to_bits
cirq.big_endian_int_to_digits
cirq.dict_to_product_sweep
cirq.dict_to_zip_sweep
cirq.final_density_matrix
cirq.final_state_vector
cirq.flatten
Expand Down