Skip to content

Commit

Permalink
Add ZipLongest to cirq_google (#6074)
Browse files Browse the repository at this point in the history
* Add ZipLongest to cirq_google

- This class is similar to cirq.Zip but repeats the
last value if the combined sweeps are not the same
length.
  • Loading branch information
dstrain115 committed Apr 25, 2023
1 parent 2a57894 commit c1d6451
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 2 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Expand Up @@ -502,6 +502,7 @@
Result,
UnitSweep,
Zip,
ZipLongest,
)

from cirq.value import (
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Expand Up @@ -241,6 +241,7 @@ def _symmetricalqidpair(qids):
'YYPowGate': cirq.YYPowGate,
'_ZEigenState': cirq.value.product_state._ZEigenState,
'Zip': cirq.Zip,
'ZipLongest': cirq.ZipLongest,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,
# Old types, only supported for backwards-compatibility
Expand Down
19 changes: 19 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/ZipLongest.json
@@ -0,0 +1,19 @@
{
"cirq_type": "ZipLongest",
"sweeps": [
{
"cirq_type": "Linspace",
"key": "a",
"start": 0,
"stop": 1,
"length": 2
},
{
"cirq_type": "Linspace",
"key": "b",
"start": 0,
"stop": 2,
"length": 4
}
]
}
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/ZipLongest.repr
@@ -0,0 +1 @@
cirq.ZipLongest(cirq.Linspace('a', start=0, stop=1, length=2), cirq.Linspace('b', start=0, stop=2, length=4))
1 change: 1 addition & 0 deletions cirq-core/cirq/study/__init__.py
Expand Up @@ -38,6 +38,7 @@
Sweep,
UnitSweep,
Zip,
ZipLongest,
dict_to_product_sweep,
dict_to_zip_sweep,
)
Expand Down
59 changes: 57 additions & 2 deletions cirq-core/cirq/study/sweeps.py
Expand Up @@ -292,7 +292,7 @@ def __init__(self, *sweeps: Sweep) -> None:
self.sweeps = sweeps

def __eq__(self, other):
if not isinstance(other, Zip):
if type(other) is not Zip:
return NotImplemented
return self.sweeps == other.sweeps

Expand Down Expand Up @@ -327,7 +327,62 @@ def _json_dict_(self) -> Dict[str, Any]:

@classmethod
def _from_json_dict_(cls, sweeps, **kwargs):
return Zip(*sweeps)
return cls(*sweeps)


class ZipLongest(Zip):
"""Iterate over constituent sweeps in parallel
Analogous to itertools.zip_longest.
Note that we iterate until all sweeps terminate,
so if the sweeps are different lengths, the
shorter sweeps will be filled by repeating their last value
until all sweeps have equal length.
Note that this is different from itertools.zip_longest,
which uses a fixed fill value.
Raises:
ValueError if an input sweep if completely empty.
"""

def __init__(self, *sweeps: Sweep) -> None:
super().__init__(*sweeps)
if any(len(sweep) == 0 for sweep in self.sweeps):
raise ValueError('All sweeps must be non-empty for ZipLongest')

def __eq__(self, other):
if not isinstance(other, ZipLongest):
return NotImplemented
return self.sweeps == other.sweeps

def __len__(self) -> int:
if not self.sweeps:
return 0
return max(len(sweep) for sweep in self.sweeps)

def __hash__(self) -> int:
return hash((self.__class__.__name__, tuple(self.sweeps)))

def __repr__(self) -> str:
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
return f'cirq_google.ZipLongest({sweeps_repr})'

def __str__(self) -> str:
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
return f'ZipLongest({sweeps_repr})'

def param_tuples(self) -> Iterator[Params]:
def _iter_and_repeat_last(one_iter: Iterator[Params]):
last = None
for last in one_iter:
yield last
while True:
yield last

iters = [_iter_and_repeat_last(sweep.param_tuples()) for sweep in self.sweeps]
for values in itertools.islice(zip(*iters), len(self)):
yield tuple(item for value in values for item in value)


class SingleSweep(Sweep):
Expand Down
58 changes: 58 additions & 0 deletions cirq-core/cirq/study/sweeps_test.py
Expand Up @@ -77,6 +77,64 @@ def test_zip():
assert _values(sweep, 'b') == [4, 5, 6]


def test_zip_longest():
sweep = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7]))
assert tuple(sweep.param_tuples()) == (
(('a', 1), ('b', 4)),
(('a', 2), ('b', 5)),
(('a', 3), ('b', 6)),
(('a', 3), ('b', 7)),
)
assert sweep.keys == ['a', 'b']
assert (
str(sweep) == 'ZipLongest(cirq.Points(\'a\', [1, 2, 3]), cirq.Points(\'b\', [4, 5, 6, 7]))'
)
assert (
repr(sweep)
== 'cirq_google.ZipLongest(cirq.Points(\'a\', [1, 2, 3]), cirq.Points(\'b\', [4, 5, 6, 7]))'
)


def test_zip_longest_compatibility():
sweep = cirq.Zip(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6]))
sweep_longest = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6]))
assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples())

sweep = cirq.Zip(
(cirq.Points('a', [1, 3]) * cirq.Points('b', [2, 4])), cirq.Points('c', [4, 5, 6, 7])
)
sweep_longest = cirq.ZipLongest(
(cirq.Points('a', [1, 3]) * cirq.Points('b', [2, 4])), cirq.Points('c', [4, 5, 6, 7])
)
assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples())


def test_empty_zip():
assert len(cirq.ZipLongest()) == 0
with pytest.raises(ValueError, match='non-empty'):
_ = cirq.ZipLongest(cirq.Points('e', []), cirq.Points('a', [1, 2, 3]))


def test_zip_eq():
et = cirq.testing.EqualsTester()
point_sweep1 = cirq.Points('a', [1, 2, 3])
point_sweep2 = cirq.Points('b', [4, 5, 6, 7])
point_sweep3 = cirq.Points('c', [1, 2])

et.add_equality_group(cirq.ZipLongest(), cirq.ZipLongest())

et.add_equality_group(
cirq.ZipLongest(point_sweep1, point_sweep2), cirq.ZipLongest(point_sweep1, point_sweep2)
)

et.add_equality_group(cirq.ZipLongest(point_sweep3, point_sweep2))
et.add_equality_group(cirq.ZipLongest(point_sweep2, point_sweep1))
et.add_equality_group(cirq.ZipLongest(point_sweep1, point_sweep2, point_sweep3))

et.add_equality_group(cirq.Zip(point_sweep1, point_sweep2, point_sweep3))
et.add_equality_group(cirq.Zip(point_sweep1, point_sweep2))


def test_product():
sweep = cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6, 7])
assert len(sweep) == 12
Expand Down

0 comments on commit c1d6451

Please sign in to comment.