Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZipLongest to cirq_google #6074

Merged
merged 4 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@
Result,
UnitSweep,
Zip,
ZipLongest,
)

from cirq.value import (
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def _symmetricalqidpair(qids):
'YYPowGate': cirq.YYPowGate,
'_ZEigenState': cirq.value.product_state._ZEigenState,
'Zip': cirq.Zip,
'ZipLongest': cirq.ZipLongest,
'ZPowGate': cirq.ZPowGate,
'ZZPowGate': cirq.ZZPowGate,
# Old types, only supported for backwards-compatibility
Expand Down
19 changes: 19 additions & 0 deletions cirq-core/cirq/protocols/json_test_data/ZipLongest.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"cirq_type": "ZipLongest",
"sweeps": [
{
"cirq_type": "Linspace",
"key": "a",
"start": 0,
"stop": 1,
"length": 2
},
{
"cirq_type": "Linspace",
"key": "b",
"start": 0,
"stop": 2,
"length": 4
}
]
}
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/ZipLongest.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cirq.ZipLongest(cirq.Linspace('a', start=0, stop=1, length=2), cirq.Linspace('b', start=0, stop=2, length=4))
1 change: 1 addition & 0 deletions cirq-core/cirq/study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Sweep,
UnitSweep,
Zip,
ZipLongest,
dict_to_product_sweep,
dict_to_zip_sweep,
)
Expand Down
57 changes: 56 additions & 1 deletion cirq-core/cirq/study/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,62 @@ def _json_dict_(self) -> Dict[str, Any]:

@classmethod
def _from_json_dict_(cls, sweeps, **kwargs):
return Zip(*sweeps)
return cls(*sweeps)


class ZipLongest(Zip):
"""Iterate over constituent sweeps in parallel

Analogous to itertools.zip_longest.
Note that we iterate until all sweeps terminate,
so if the sweeps are different lengths, the
shorter sweeps will be filled by repeating their last value
until all sweeps have equal length.

Note that this is different from itertools.zip_longest,
which uses a fixed fill value.

Raises:
ValueError if an input sweep if completely empty.
"""

def __init__(self, *sweeps: Sweep) -> None:
super().__init__(*sweeps)
if any(len(sweep) == 0 for sweep in self.sweeps):
raise ValueError('All sweeps must be non-empty for ZipLongest')

def __eq__(self, other):
if not isinstance(other, ZipLongest):
return NotImplemented
return self.sweeps == other.sweeps

def __len__(self) -> int:
if not self.sweeps:
return 0
return max(len(sweep) for sweep in self.sweeps)

def __hash__(self) -> int:
return hash(tuple(self.sweeps))

def __repr__(self) -> str:
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
return f'cirq_google.ZipLongest({sweeps_repr})'

def __str__(self) -> str:
sweeps_repr = ', '.join(repr(s) for s in self.sweeps)
return f'ZipLongest({sweeps_repr})'

def param_tuples(self) -> Iterator[Params]:
def _iter_and_repeat_last(one_iter: Iterator[Params]):
last = None
for last in one_iter:
yield last
while True:
yield last

iters = [_iter_and_repeat_last(sweep.param_tuples()) for sweep in self.sweeps]
for values in itertools.islice(zip(*iters), len(self)):
yield tuple(item for value in values for item in value)


class SingleSweep(Sweep):
Expand Down
52 changes: 52 additions & 0 deletions cirq-core/cirq/study/sweeps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,58 @@ def test_zip():
assert _values(sweep, 'b') == [4, 5, 6]


def test_zip_longest():
sweep = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7]))
assert tuple(sweep.param_tuples()) == (
(('a', 1), ('b', 4)),
(('a', 2), ('b', 5)),
(('a', 3), ('b', 6)),
(('a', 3), ('b', 7)),
)
assert sweep.keys == ['a', 'b']
assert (
str(sweep) == 'ZipLongest(cirq.Points(\'a\', [1, 2, 3]), cirq.Points(\'b\', [4, 5, 6, 7]))'
)
assert (
repr(sweep)
== 'cirq_google.ZipLongest(cirq.Points(\'a\', [1, 2, 3]), cirq.Points(\'b\', [4, 5, 6, 7]))'
)


def test_zip_longest_compatibility():
sweep = cirq.Zip(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6]))
sweep_longest = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6]))
assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples())

sweep = cirq.Zip(
(cirq.Points('a', [1, 3]) * cirq.Points('b', [2, 4])), cirq.Points('c', [4, 5, 6, 7])
)
sweep_longest = cirq.ZipLongest(
(cirq.Points('a', [1, 3]) * cirq.Points('b', [2, 4])), cirq.Points('c', [4, 5, 6, 7])
)
assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples())


def test_empty_zip():
assert len(cirq.ZipLongest()) == 0
with pytest.raises(ValueError, match='non-empty'):
_ = cirq.ZipLongest(cirq.Points('e', []), cirq.Points('a', [1, 2, 3]))


def test_zip_eq():
sweep1 = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7]))
sweep2 = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7]))
sweep3 = cirq.ZipLongest(cirq.Points('a', [1, 2]), cirq.Points('b', [4, 5, 6, 7]))
sweep4 = cirq.Zip(cirq.Points('a', [1, 2]), cirq.Points('b', [4, 5, 6, 7]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it is, both sweep3 == sweep4 and hash(sweep3) == hash(sweep4) evaluate to True.
Perhaps we need to use a strict type(other) is Zip in the Zip.__eq__ function

def __eq__(self, other):
if not isinstance(other, Zip):
return NotImplemented
return self.sweeps == other.sweeps

and insert ZipLongest or a similar type-marker to the hashed tuple at

def __hash__(self) -> int:
return hash(tuple(self.sweeps))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to use 'is not Zip' and also changed the test to use Cirq's equality tester.


assert sweep1 == sweep2
assert hash(sweep1) == hash(sweep2)
assert sweep2 != sweep3
assert hash(sweep2) != hash(sweep3)
assert sweep1 != sweep4
assert hash(sweep1) != hash(sweep4)


def test_product():
sweep = cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6, 7])
assert len(sweep) == 12
Expand Down