From c1d64518ce298d22bdf04fc202e9fa0902717137 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Tue, 25 Apr 2023 06:58:14 -0700 Subject: [PATCH] Add ZipLongest to cirq_google (#6074) * Add ZipLongest to cirq_google - This class is similar to cirq.Zip but repeats the last value if the combined sweeps are not the same length. --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/json_resolver_cache.py | 1 + .../protocols/json_test_data/ZipLongest.json | 19 ++++++ .../protocols/json_test_data/ZipLongest.repr | 1 + cirq-core/cirq/study/__init__.py | 1 + cirq-core/cirq/study/sweeps.py | 59 ++++++++++++++++++- cirq-core/cirq/study/sweeps_test.py | 58 ++++++++++++++++++ 7 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/ZipLongest.json create mode 100644 cirq-core/cirq/protocols/json_test_data/ZipLongest.repr diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index b246ccdbf38..73b21c1cd9d 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -502,6 +502,7 @@ Result, UnitSweep, Zip, + ZipLongest, ) from cirq.value import ( diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 6a2d4f69c7c..20a7377b294 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -241,6 +241,7 @@ def _symmetricalqidpair(qids): 'YYPowGate': cirq.YYPowGate, '_ZEigenState': cirq.value.product_state._ZEigenState, 'Zip': cirq.Zip, + 'ZipLongest': cirq.ZipLongest, 'ZPowGate': cirq.ZPowGate, 'ZZPowGate': cirq.ZZPowGate, # Old types, only supported for backwards-compatibility diff --git a/cirq-core/cirq/protocols/json_test_data/ZipLongest.json b/cirq-core/cirq/protocols/json_test_data/ZipLongest.json new file mode 100644 index 00000000000..d3f5c5cc3fd --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ZipLongest.json @@ -0,0 +1,19 @@ +{ + "cirq_type": "ZipLongest", + "sweeps": [ + { + "cirq_type": "Linspace", + "key": "a", + "start": 0, + "stop": 1, + "length": 2 + }, + { + "cirq_type": "Linspace", + "key": "b", + "start": 0, + "stop": 2, + "length": 4 + } + ] +} diff --git a/cirq-core/cirq/protocols/json_test_data/ZipLongest.repr b/cirq-core/cirq/protocols/json_test_data/ZipLongest.repr new file mode 100644 index 00000000000..c35c3e52f0f --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/ZipLongest.repr @@ -0,0 +1 @@ +cirq.ZipLongest(cirq.Linspace('a', start=0, stop=1, length=2), cirq.Linspace('b', start=0, stop=2, length=4)) diff --git a/cirq-core/cirq/study/__init__.py b/cirq-core/cirq/study/__init__.py index 61b88fc83cc..0cc5b3b46cb 100644 --- a/cirq-core/cirq/study/__init__.py +++ b/cirq-core/cirq/study/__init__.py @@ -38,6 +38,7 @@ Sweep, UnitSweep, Zip, + ZipLongest, dict_to_product_sweep, dict_to_zip_sweep, ) diff --git a/cirq-core/cirq/study/sweeps.py b/cirq-core/cirq/study/sweeps.py index ecce76284ed..fd364525d85 100644 --- a/cirq-core/cirq/study/sweeps.py +++ b/cirq-core/cirq/study/sweeps.py @@ -292,7 +292,7 @@ def __init__(self, *sweeps: Sweep) -> None: self.sweeps = sweeps def __eq__(self, other): - if not isinstance(other, Zip): + if type(other) is not Zip: return NotImplemented return self.sweeps == other.sweeps @@ -327,7 +327,62 @@ def _json_dict_(self) -> Dict[str, Any]: @classmethod def _from_json_dict_(cls, sweeps, **kwargs): - return Zip(*sweeps) + return cls(*sweeps) + + +class ZipLongest(Zip): + """Iterate over constituent sweeps in parallel + + Analogous to itertools.zip_longest. + Note that we iterate until all sweeps terminate, + so if the sweeps are different lengths, the + shorter sweeps will be filled by repeating their last value + until all sweeps have equal length. + + Note that this is different from itertools.zip_longest, + which uses a fixed fill value. + + Raises: + ValueError if an input sweep if completely empty. + """ + + def __init__(self, *sweeps: Sweep) -> None: + super().__init__(*sweeps) + if any(len(sweep) == 0 for sweep in self.sweeps): + raise ValueError('All sweeps must be non-empty for ZipLongest') + + def __eq__(self, other): + if not isinstance(other, ZipLongest): + return NotImplemented + return self.sweeps == other.sweeps + + def __len__(self) -> int: + if not self.sweeps: + return 0 + return max(len(sweep) for sweep in self.sweeps) + + def __hash__(self) -> int: + return hash((self.__class__.__name__, tuple(self.sweeps))) + + def __repr__(self) -> str: + sweeps_repr = ', '.join(repr(s) for s in self.sweeps) + return f'cirq_google.ZipLongest({sweeps_repr})' + + def __str__(self) -> str: + sweeps_repr = ', '.join(repr(s) for s in self.sweeps) + return f'ZipLongest({sweeps_repr})' + + def param_tuples(self) -> Iterator[Params]: + def _iter_and_repeat_last(one_iter: Iterator[Params]): + last = None + for last in one_iter: + yield last + while True: + yield last + + iters = [_iter_and_repeat_last(sweep.param_tuples()) for sweep in self.sweeps] + for values in itertools.islice(zip(*iters), len(self)): + yield tuple(item for value in values for item in value) class SingleSweep(Sweep): diff --git a/cirq-core/cirq/study/sweeps_test.py b/cirq-core/cirq/study/sweeps_test.py index 4da1008a9df..09c4d645af2 100644 --- a/cirq-core/cirq/study/sweeps_test.py +++ b/cirq-core/cirq/study/sweeps_test.py @@ -77,6 +77,64 @@ def test_zip(): assert _values(sweep, 'b') == [4, 5, 6] +def test_zip_longest(): + sweep = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6, 7])) + assert tuple(sweep.param_tuples()) == ( + (('a', 1), ('b', 4)), + (('a', 2), ('b', 5)), + (('a', 3), ('b', 6)), + (('a', 3), ('b', 7)), + ) + assert sweep.keys == ['a', 'b'] + assert ( + str(sweep) == 'ZipLongest(cirq.Points(\'a\', [1, 2, 3]), cirq.Points(\'b\', [4, 5, 6, 7]))' + ) + assert ( + repr(sweep) + == 'cirq_google.ZipLongest(cirq.Points(\'a\', [1, 2, 3]), cirq.Points(\'b\', [4, 5, 6, 7]))' + ) + + +def test_zip_longest_compatibility(): + sweep = cirq.Zip(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6])) + sweep_longest = cirq.ZipLongest(cirq.Points('a', [1, 2, 3]), cirq.Points('b', [4, 5, 6])) + assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples()) + + sweep = cirq.Zip( + (cirq.Points('a', [1, 3]) * cirq.Points('b', [2, 4])), cirq.Points('c', [4, 5, 6, 7]) + ) + sweep_longest = cirq.ZipLongest( + (cirq.Points('a', [1, 3]) * cirq.Points('b', [2, 4])), cirq.Points('c', [4, 5, 6, 7]) + ) + assert tuple(sweep.param_tuples()) == tuple(sweep_longest.param_tuples()) + + +def test_empty_zip(): + assert len(cirq.ZipLongest()) == 0 + with pytest.raises(ValueError, match='non-empty'): + _ = cirq.ZipLongest(cirq.Points('e', []), cirq.Points('a', [1, 2, 3])) + + +def test_zip_eq(): + et = cirq.testing.EqualsTester() + point_sweep1 = cirq.Points('a', [1, 2, 3]) + point_sweep2 = cirq.Points('b', [4, 5, 6, 7]) + point_sweep3 = cirq.Points('c', [1, 2]) + + et.add_equality_group(cirq.ZipLongest(), cirq.ZipLongest()) + + et.add_equality_group( + cirq.ZipLongest(point_sweep1, point_sweep2), cirq.ZipLongest(point_sweep1, point_sweep2) + ) + + et.add_equality_group(cirq.ZipLongest(point_sweep3, point_sweep2)) + et.add_equality_group(cirq.ZipLongest(point_sweep2, point_sweep1)) + et.add_equality_group(cirq.ZipLongest(point_sweep1, point_sweep2, point_sweep3)) + + et.add_equality_group(cirq.Zip(point_sweep1, point_sweep2, point_sweep3)) + et.add_equality_group(cirq.Zip(point_sweep1, point_sweep2)) + + def test_product(): sweep = cirq.Points('a', [1, 2, 3]) * cirq.Points('b', [4, 5, 6, 7]) assert len(sweep) == 12