Skip to content

Commit

Permalink
Move ZipLongest to cirq and address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
dstrain115 committed Apr 24, 2023
1 parent 560d9c5 commit 58d0205
Show file tree
Hide file tree
Showing 14 changed files with 113 additions and 150 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@
Result,
UnitSweep,
Zip,
ZipLongest,
)

from cirq.value import (
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def _symmetricalqidpair(qids):
'YYPowGate': cirq.YYPowGate,
'_ZEigenState': cirq.value.product_state._ZEigenState,
'Zip': cirq.Zip,
'ZipLongest': cirq.ZipLongest,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,
# Old types, only supported for backwards-compatibility
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cirq_type": "cirq.google.ZipLongest",
"cirq_type": "ZipLongest",
"sweeps": [
{
"cirq_type": "Linspace",
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/ZipLongest.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.ZipLongest(cirq.Linspace('a', start=0, stop=1, length=2), cirq.Linspace('b', start=0, stop=2, length=4))
1 change: 1 addition & 0 deletions cirq-core/cirq/study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Sweep,
UnitSweep,
Zip,
ZipLongest,
dict_to_product_sweep,
dict_to_zip_sweep,
)
Expand Down
57 changes: 56 additions & 1 deletion cirq-core/cirq/study/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,62 @@ def _json_dict_(self) -> Dict[str, Any]:

@classmethod
def _from_json_dict_(cls, sweeps, **kwargs):
return Zip(*sweeps)
return cls(*sweeps)


class ZipLongest(Zip):
"""Iterate over constituent sweeps in parallel
Analogous to itertools.zip_longest.
Note that we iterate until all sweeps terminate,
so if the sweeps are different lengths, the
shorter sweeps will be filled by repeating their last value
until all sweeps have equal length.
Note that this is different from itertools.zip_longest,
which uses a fixed fill value.
Raises:
ValueError if an input sweep if completely empty.
"""

def __init__(self, *sweeps: Sweep) -> None:
super().__init__(*sweeps)
if any(len(sweep) == 0 for sweep in self.sweeps):
raise ValueError('All sweeps must be non-empty for ZipLongest')

def __eq__(self, other):
if not isinstance(other, ZipLongest):
return NotImplemented
return self.sweeps == other.sweeps

def __len__(self) -> int:
if not self.sweeps:
return 0
return max(len(sweep) for sweep in self.sweeps)

def __hash__(self) -> int:
return hash(tuple(self.sweeps))

def __repr__(self) -> str:
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
return f'cirq_google.ZipLongest({sweeps_repr})'

def __str__(self) -> str:
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
return f'ZipLongest({sweeps_repr})'

def param_tuples(self) -> Iterator[Params]:
def _iter_and_repeat_last(one_iter: Iterator[Params]):
last = None
for last in one_iter:
yield last
while True:
yield last

iters = [_iter_and_repeat_last(sweep.param_tuples()) for sweep in self.sweeps]
for values in itertools.islice(zip(*iters), len(self)):
yield tuple(item for value in values for item in value)


class SingleSweep(Sweep):
Expand Down
52 changes: 52 additions & 0 deletions cirq-core/cirq/study/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,58 @@ def test_zip():
assert _values(sweep, 'b') == [4, 5, 6]


def test_zip_longest():
sweep = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7]))
assert tuple(sweep.param_tuples()) == (
(('a', 1), ('b', 4)),
(('a', 2), ('b', 5)),
(('a', 3), ('b', 6)),
(('a', 3), ('b', 7)),
)
assert sweep.keys == ['a', 'b']
assert (
str(sweep) == 'ZipLongest(cirq.Points(\'a\', [1, 2, 3]), cirq.Points(\'b\', [4, 5, 6, 7]))'
)
assert (
repr(sweep)
== 'cirq_google.ZipLongest(cirq.Points(\'a\', [1, 2, 3]), cirq.Points(\'b\', [4, 5, 6, 7]))'
)


def test_zip_longest_compatibility():
sweep = cirq.Zip(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6]))
sweep_longest = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6]))
assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples())

sweep = cirq.Zip(
(cirq.Points('a', [1, 3]) * cirq.Points('b', [2, 4])), cirq.Points('c', [4, 5, 6, 7])
)
sweep_longest = cirq.ZipLongest(
(cirq.Points('a', [1, 3]) * cirq.Points('b', [2, 4])), cirq.Points('c', [4, 5, 6, 7])
)
assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples())


def test_empty_zip():
assert len(cirq.ZipLongest()) == 0
with pytest.raises(ValueError, match='non-empty'):
_ = cirq.ZipLongest(cirq.Points('e', []), cirq.Points('a', [1, 2, 3]))


def test_zip_eq():
sweep1 = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7]))
sweep2 = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7]))
sweep3 = cirq.ZipLongest(cirq.Points('a', [1, 2]), cirq.Points('b', [4, 5, 6, 7]))
sweep4 = cirq.Zip(cirq.Points('a', [1, 2]), cirq.Points('b', [4, 5, 6, 7]))

assert sweep1 == sweep2
assert hash(sweep1) == hash(sweep2)
assert sweep2 != sweep3
assert hash(sweep2) != hash(sweep3)
assert sweep1 != sweep4
assert hash(sweep1) != hash(sweep4)


def test_product():
sweep = cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6, 7])
assert len(sweep) == 12
Expand Down
2 changes: 0 additions & 2 deletions cirq-google/cirq_google/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@
Serializer,
)

from cirq_google.study import ZipLongest

from cirq_google.workflow import (
ExecutableSpec,
KeyValueExecutableSpec,
Expand Down
1 change: 0 additions & 1 deletion cirq-google/cirq_google/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,4 @@ def _old_xmon(*args, **kwargs):
'cirq.google.EngineResult': cirq_google.EngineResult,
'cirq.google.GridDevice': cirq_google.GridDevice,
'cirq.google.GoogleCZTargetGateset': cirq_google.GoogleCZTargetGateset,
'cirq.google.ZipLongest': cirq_google.ZipLongest,
}

This file was deleted.

1 change: 0 additions & 1 deletion cirq-google/cirq_google/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
'EngineResult',
'GridDevice',
'GoogleCZTargetGateset',
'ZipLongest',
]
},
resolver_cache=_class_resolver_dictionary(),
Expand Down
15 changes: 0 additions & 15 deletions cirq-google/cirq_google/study/__init__.py

This file was deleted.

76 changes: 0 additions & 76 deletions cirq-google/cirq_google/study/zip_longest.py

This file was deleted.

52 changes: 0 additions & 52 deletions cirq-google/cirq_google/study/zip_longest_test.py

This file was deleted.

0 comments on commit 58d0205

Please sign in to comment.