From 193ac6e16a2cb56a8fe90db6be236f23605d6381 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Thu, 20 Aug 2020 14:48:07 -0700 Subject: [PATCH 1/8] add run_batch to Sampler --- cirq/work/sampler.py | 37 +++++++++++++++++++++++++++++++++++++ cirq/work/sampler_test.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/cirq/work/sampler.py b/cirq/work/sampler.py index c730bd14679..c3589ec82f7 100644 --- a/cirq/work/sampler.py +++ b/cirq/work/sampler.py @@ -199,3 +199,40 @@ async def run_sweep_async( An awaitable TrialResult. """ return self.run_sweep(program, params=params, repetitions=repetitions) + + def run_batch( + self, + programs: List['cirq.Circuit'], + params_list: List['cirq.Sweepable'] = None, + repetitions: int = 1, + ) -> List['cirq.TrialResult']: + """Runs the supplied circuits. + + Each circuit provided in `programs` will pair with the associated + parameter sweep provided in the `params_list`. The number of circuits + is required to match the number of sweeps. + + By default, this method simply invokes `run_sweep` sequentially for + each (circuit, parameter sweep) pair. Child classes that are capable of + sampling batches more efficiently should override it to use other + strategies. + + Args: + programs: The circuits to execute as a batch. + params_list: Parameter sweeps to use with the circuits. The number + of sweeps should match the number of circuits and will be + paired in order with the circuits. + repetitions: Number of circuit repetitions to run. Each sweep value + of each circuit in the batch will run with the same repetitions. + + Returns: + A list of TrialResults. All TrialResults for the first circuit are + listed first, then the TrialResults for the second, etc. + The TrialResults for a circuit are listed in the order imposed by + the associated parameter sweep. + """ + if not params_list or len(programs) != len(params_list): + raise ValueError('Number of circuits and sweeps must match') + return [trial_result + for circuit, params in zip(programs, params_list) + for trial_result in self.run_sweep(circuit, params=params, repetitions=repetitions)] diff --git a/cirq/work/sampler_test.py b/cirq/work/sampler_test.py index 29e9baffd89..b870e9df860 100644 --- a/cirq/work/sampler_test.py +++ b/cirq/work/sampler_test.py @@ -14,6 +14,7 @@ """Tests for cirq.Sampler.""" import pytest +import numpy as np import pandas as pd import sympy @@ -155,3 +156,34 @@ def run_sweep(self, *args, **kwargs): assert not ran assert await a == [] assert ran + + +def test_sampler_run_batch(): + sampler = cirq.ZerosSampler() + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)**sympy.Symbol('t'), + cirq.measure(a, key='m')) + circuit2 = cirq.Circuit(cirq.Y(a)**sympy.Symbol('t'), + cirq.measure(a, key='m')) + params1 = cirq.Points('t', [0.3, 0.7]) + params2 = cirq.Points('t', [0.4, 0.6]) + results = sampler.run_batch([circuit1, circuit2], + params_list=[params1, params2]) + assert len(results) == 4 + for result, param in zip(results, [0.3, 0.7, 0.4, 0.6]): + assert result.repetitions == 1 + assert result.params.param_dict == {'t': param} + assert result.measurements == {'m': np.array([[0]], dtype='uint8')} + + +def test_sampler_run_batch_bad_input_lengths(): + sampler = cirq.ZerosSampler() + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)**sympy.Symbol('t'), + cirq.measure(a, key='m')) + circuit2 = cirq.Circuit(cirq.Y(a)**sympy.Symbol('t'), + cirq.measure(a, key='m')) + params = cirq.Points('t', [0.3, 0.7]) + with pytest.raises(ValueError): + _ = sampler.run_batch([circuit1, circuit2], + params_list=[params]) From 5f39284873560f65bd00644e80eaad8ca2cbd438 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Mon, 24 Aug 2020 17:41:09 -0700 Subject: [PATCH 2/8] add run_batch to QuantumEngineSampler --- cirq/google/engine/engine_sampler.py | 13 +++++++++++++ cirq/google/engine/engine_sampler_test.py | 20 ++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/cirq/google/engine/engine_sampler.py b/cirq/google/engine/engine_sampler.py index 4ec77075086..167401ca581 100644 --- a/cirq/google/engine/engine_sampler.py +++ b/cirq/google/engine/engine_sampler.py @@ -61,6 +61,19 @@ def run_sweep( gate_set=self._gate_set) return job.results() + def run_batch( + self, + programs: List['cirq.Circuit'], + params_list: List['cirq.Sweepable'] = None, + repetitions: int = 1, + ) -> List['cirq.TrialResult']: + job = self._engine.run_batch(programs=programs, + params_list=params_list, + repetitions=repetitions, + processor_ids=self._processor_ids, + gate_set=self._gate_set) + return job.results() + @property def engine(self) -> 'cirq.google.Engine': return self._engine diff --git a/cirq/google/engine/engine_sampler_test.py b/cirq/google/engine/engine_sampler_test.py index aa6b5e35849..da0cad4f5c9 100644 --- a/cirq/google/engine/engine_sampler_test.py +++ b/cirq/google/engine/engine_sampler_test.py @@ -50,6 +50,26 @@ def test_run_engine_program(): engine.run_sweep.assert_not_called() +def test_run_batch(): + engine = mock.Mock() + sampler = cg.QuantumEngineSampler(engine=engine, + processor_id='tmp', + gate_set=cg.XMON) + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)) + circuit2 = cirq.Circuit(cirq.Y(a)) + params1 = [cirq.ParamResolver({'t': 1})] + params2 = [cirq.ParamResolver({'t': 2})] + circuits = [circuit1, circuit2] + params_list = [params1, params2] + sampler.run_batch(circuits, params_list, 5) + engine.run_batch.assert_called_with(gate_set=cg.XMON, + params_list=params_list, + processor_ids=['tmp'], + programs=circuits, + repetitions=5) + + def test_engine_sampler_engine_property(): engine = mock.Mock() sampler = cg.QuantumEngineSampler(engine=engine, From e877b3e322b8aebd2375cc8591737905b31982f8 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Mon, 24 Aug 2020 17:50:04 -0700 Subject: [PATCH 3/8] format --- cirq/work/sampler.py | 8 +++++--- cirq/work/sampler_test.py | 19 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/cirq/work/sampler.py b/cirq/work/sampler.py index c3589ec82f7..5a495b46030 100644 --- a/cirq/work/sampler.py +++ b/cirq/work/sampler.py @@ -233,6 +233,8 @@ def run_batch( """ if not params_list or len(programs) != len(params_list): raise ValueError('Number of circuits and sweeps must match') - return [trial_result - for circuit, params in zip(programs, params_list) - for trial_result in self.run_sweep(circuit, params=params, repetitions=repetitions)] + return [ + trial_result for circuit, params in zip(programs, params_list) + for trial_result in self.run_sweep( + circuit, params=params, repetitions=repetitions) + ] diff --git a/cirq/work/sampler_test.py b/cirq/work/sampler_test.py index b870e9df860..2beea69e07e 100644 --- a/cirq/work/sampler_test.py +++ b/cirq/work/sampler_test.py @@ -161,10 +161,10 @@ def run_sweep(self, *args, **kwargs): def test_sampler_run_batch(): sampler = cirq.ZerosSampler() a = cirq.LineQubit(0) - circuit1 = cirq.Circuit(cirq.X(a)**sympy.Symbol('t'), - cirq.measure(a, key='m')) - circuit2 = cirq.Circuit(cirq.Y(a)**sympy.Symbol('t'), - cirq.measure(a, key='m')) + circuit1 = cirq.Circuit( + cirq.X(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) + circuit2 = cirq.Circuit( + cirq.Y(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) params1 = cirq.Points('t', [0.3, 0.7]) params2 = cirq.Points('t', [0.4, 0.6]) results = sampler.run_batch([circuit1, circuit2], @@ -179,11 +179,10 @@ def test_sampler_run_batch(): def test_sampler_run_batch_bad_input_lengths(): sampler = cirq.ZerosSampler() a = cirq.LineQubit(0) - circuit1 = cirq.Circuit(cirq.X(a)**sympy.Symbol('t'), - cirq.measure(a, key='m')) - circuit2 = cirq.Circuit(cirq.Y(a)**sympy.Symbol('t'), - cirq.measure(a, key='m')) + circuit1 = cirq.Circuit( + cirq.X(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) + circuit2 = cirq.Circuit( + cirq.Y(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) params = cirq.Points('t', [0.3, 0.7]) with pytest.raises(ValueError): - _ = sampler.run_batch([circuit1, circuit2], - params_list=[params]) + _ = sampler.run_batch([circuit1, circuit2], params_list=[params]) From 3a50ec31079e9ed9d2bf5548e49f7fa9277e6e10 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Thu, 3 Sep 2020 15:24:34 -0400 Subject: [PATCH 4/8] allow variable number of repetitions and make params_list optional --- cirq/google/engine/engine_sampler.py | 31 ++++++++++++---- cirq/google/engine/engine_sampler_test.py | 25 +++++++++++++ cirq/work/sampler.py | 45 +++++++++++++++-------- cirq/work/sampler_test.py | 24 +++++++++++- 4 files changed, 100 insertions(+), 25 deletions(-) diff --git a/cirq/google/engine/engine_sampler.py b/cirq/google/engine/engine_sampler.py index 167401ca581..7794d755844 100644 --- a/cirq/google/engine/engine_sampler.py +++ b/cirq/google/engine/engine_sampler.py @@ -64,15 +64,30 @@ def run_sweep( def run_batch( self, programs: List['cirq.Circuit'], - params_list: List['cirq.Sweepable'] = None, - repetitions: int = 1, + params_list: Optional[List['cirq.Sweepable']] = None, + repetitions: Union[int, List[int]] = 1, ) -> List['cirq.TrialResult']: - job = self._engine.run_batch(programs=programs, - params_list=params_list, - repetitions=repetitions, - processor_ids=self._processor_ids, - gate_set=self._gate_set) - return job.results() + """Runs the supplied circuits. + + In order to gain a speedup from using this method instead of other run + methods, the number of circuit repetitions must be the same for all + circuits. That is, the `repetitions` argument must be an integer, or + else a list with identical values. + """ + if isinstance(repetitions, List) and len(programs) != len(repetitions): + raise ValueError('Number of circuits and repetitions must match') + if isinstance(repetitions, int) or len(set(repetitions)) == 1: + # All repetitions are the same so batching can be done efficiently + if isinstance(repetitions, List): + repetitions = repetitions[0] + job = self._engine.run_batch(programs=programs, + params_list=params_list, + repetitions=repetitions, + processor_ids=self._processor_ids, + gate_set=self._gate_set) + return job.results() + # Varying number of repetitions so no speedup + return super().run_batch(programs, params_list, repetitions) @property def engine(self) -> 'cirq.google.Engine': diff --git a/cirq/google/engine/engine_sampler_test.py b/cirq/google/engine/engine_sampler_test.py index da0cad4f5c9..596934015b5 100644 --- a/cirq/google/engine/engine_sampler_test.py +++ b/cirq/google/engine/engine_sampler_test.py @@ -70,6 +70,31 @@ def test_run_batch(): repetitions=5) +def test_run_batch_differing_repetitions(): + engine = mock.Mock() + job = mock.Mock() + job.results.return_value = [] + engine.run_sweep.return_value = job + sampler = cg.QuantumEngineSampler(engine=engine, + processor_id='tmp', + gate_set=cg.XMON) + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)) + circuit2 = cirq.Circuit(cirq.Y(a)) + params1 = [cirq.ParamResolver({'t': 1})] + params2 = [cirq.ParamResolver({'t': 2})] + circuits = [circuit1, circuit2] + params_list = [params1, params2] + repetitions = [1, 2] + sampler.run_batch(circuits, params_list, repetitions) + engine.run_sweep.assert_called_with(gate_set=cg.XMON, + params=params2, + processor_ids=['tmp'], + program=circuit2, + repetitions=2) + engine.run_batch.assert_not_called() + + def test_engine_sampler_engine_property(): engine = mock.Mock() sampler = cg.QuantumEngineSampler(engine=engine, diff --git a/cirq/work/sampler.py b/cirq/work/sampler.py index 5a495b46030..7d98c59811a 100644 --- a/cirq/work/sampler.py +++ b/cirq/work/sampler.py @@ -13,7 +13,7 @@ # limitations under the License. """Abstract base class for things sampling quantum circuits.""" -from typing import List, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING, Union import abc import pandas as pd @@ -203,27 +203,36 @@ async def run_sweep_async( def run_batch( self, programs: List['cirq.Circuit'], - params_list: List['cirq.Sweepable'] = None, - repetitions: int = 1, + params_list: Optional[List['cirq.Sweepable']] = None, + repetitions: Union[int, List[int]] = 1, ) -> List['cirq.TrialResult']: """Runs the supplied circuits. - Each circuit provided in `programs` will pair with the associated - parameter sweep provided in the `params_list`. The number of circuits - is required to match the number of sweeps. + Each circuit provided in `programs` will pair with the optional + associated parameter sweep provided in the `params_list`, and be run + with the associated repetitions provided in `repetitions` (if + `repetitions` is an integer then all runs will have that number of + repetitions). If `params_list` is specified, then the number of + circuits is required to match the number of sweeps. Similarly, when + `repetitions` is a list, the number of circuits is required to match + the length of this list. By default, this method simply invokes `run_sweep` sequentially for - each (circuit, parameter sweep) pair. Child classes that are capable of - sampling batches more efficiently should override it to use other - strategies. + each (circuit, parameter sweep, repetitions) tuple. Child classes that + are capable of sampling batches more efficiently should override it to + use other strategies. Note that child classes may have certain + requirements that must be met in order for a speedup to be possible, + such as a constant number of repetitions being used for all circuits. + Refer to the documentation of the child class for any such requirements. Args: programs: The circuits to execute as a batch. params_list: Parameter sweeps to use with the circuits. The number of sweeps should match the number of circuits and will be paired in order with the circuits. - repetitions: Number of circuit repetitions to run. Each sweep value - of each circuit in the batch will run with the same repetitions. + repetitions: Number of circuit repetitions to run. Can be specified + as a single value to use for all runs, or as a list of values, + one for each circuit. Returns: A list of TrialResults. All TrialResults for the first circuit are @@ -231,10 +240,16 @@ def run_batch( The TrialResults for a circuit are listed in the order imposed by the associated parameter sweep. """ - if not params_list or len(programs) != len(params_list): + if params_list is None: + params_list = [None] * len(programs) + if len(programs) != len(params_list): raise ValueError('Number of circuits and sweeps must match') + if isinstance(repetitions, int): + repetitions = [repetitions] * len(programs) + if len(programs) != len(repetitions): + raise ValueError('Number of circuits and repetitions must match') return [ - trial_result for circuit, params in zip(programs, params_list) - for trial_result in self.run_sweep( - circuit, params=params, repetitions=repetitions) + trial_result for circuit, params, repetitions in zip( + programs, params_list, repetitions) for trial_result in + self.run_sweep(circuit, params=params, repetitions=repetitions) ] diff --git a/cirq/work/sampler_test.py b/cirq/work/sampler_test.py index 2beea69e07e..5588d8ad27a 100644 --- a/cirq/work/sampler_test.py +++ b/cirq/work/sampler_test.py @@ -168,12 +168,32 @@ def test_sampler_run_batch(): params1 = cirq.Points('t', [0.3, 0.7]) params2 = cirq.Points('t', [0.4, 0.6]) results = sampler.run_batch([circuit1, circuit2], - params_list=[params1, params2]) + params_list=[params1, params2], + repetitions=[1, 2]) assert len(results) == 4 - for result, param in zip(results, [0.3, 0.7, 0.4, 0.6]): + for result, param in zip(results[:2], [0.3, 0.7]): assert result.repetitions == 1 assert result.params.param_dict == {'t': param} assert result.measurements == {'m': np.array([[0]], dtype='uint8')} + for result, param in zip(results[2:], [0.4, 0.6]): + assert result.repetitions == 2 + assert result.params.param_dict == {'t': param} + assert len(result.measurements) == 1 + assert np.array_equal(result.measurements['m'], + np.array([[0], [0]], dtype='uint8')) + + +def test_sampler_run_batch_default_params_and_repetitions(): + sampler = cirq.ZerosSampler() + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a), cirq.measure(a, key='m')) + circuit2 = cirq.Circuit(cirq.Y(a), cirq.measure(a, key='m')) + results = sampler.run_batch([circuit1, circuit2]) + assert len(results) == 2 + for result in results: + assert result.repetitions == 1 + assert result.params.param_dict == {} + assert result.measurements == {'m': np.array([[0]], dtype='uint8')} def test_sampler_run_batch_bad_input_lengths(): From 4db6642dff28f794e2eafb99bd9854f9b175c961 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Thu, 3 Sep 2020 15:34:22 -0400 Subject: [PATCH 5/8] add tests --- cirq/google/engine/engine_sampler_test.py | 36 +++++++++++++++++++++++ cirq/work/sampler_test.py | 9 ++++-- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/cirq/google/engine/engine_sampler_test.py b/cirq/google/engine/engine_sampler_test.py index 596934015b5..1ad0d651146 100644 --- a/cirq/google/engine/engine_sampler_test.py +++ b/cirq/google/engine/engine_sampler_test.py @@ -70,6 +70,42 @@ def test_run_batch(): repetitions=5) +def test_run_batch_identical_repetitions(): + engine = mock.Mock() + sampler = cg.QuantumEngineSampler(engine=engine, + processor_id='tmp', + gate_set=cg.XMON) + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)) + circuit2 = cirq.Circuit(cirq.Y(a)) + params1 = [cirq.ParamResolver({'t': 1})] + params2 = [cirq.ParamResolver({'t': 2})] + circuits = [circuit1, circuit2] + params_list = [params1, params2] + sampler.run_batch(circuits, params_list, [5, 5]) + engine.run_batch.assert_called_with(gate_set=cg.XMON, + params_list=params_list, + processor_ids=['tmp'], + programs=circuits, + repetitions=5) + + +def test_run_batch_number_of_repetitions(): + engine = mock.Mock() + sampler = cg.QuantumEngineSampler(engine=engine, + processor_id='tmp', + gate_set=cg.XMON) + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)) + circuit2 = cirq.Circuit(cirq.Y(a)) + params1 = [cirq.ParamResolver({'t': 1})] + params2 = [cirq.ParamResolver({'t': 2})] + circuits = [circuit1, circuit2] + params_list = [params1, params2] + with pytest.raises(ValueError): + sampler.run_batch(circuits, params_list, [5, 5, 5]) + + def test_run_batch_differing_repetitions(): engine = mock.Mock() job = mock.Mock() diff --git a/cirq/work/sampler_test.py b/cirq/work/sampler_test.py index 5588d8ad27a..05d458ccdde 100644 --- a/cirq/work/sampler_test.py +++ b/cirq/work/sampler_test.py @@ -203,6 +203,11 @@ def test_sampler_run_batch_bad_input_lengths(): cirq.X(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) circuit2 = cirq.Circuit( cirq.Y(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) - params = cirq.Points('t', [0.3, 0.7]) + params1 = cirq.Points('t', [0.3, 0.7]) + params2 = cirq.Points('t', [0.4, 0.6]) + with pytest.raises(ValueError): + _ = sampler.run_batch([circuit1, circuit2], params_list=[params1]) with pytest.raises(ValueError): - _ = sampler.run_batch([circuit1, circuit2], params_list=[params]) + _ = sampler.run_batch([circuit1, circuit2], + params_list=[params1, params2], + repetitions=[1, 2, 3]) From eff62ed3ca5131bff0371148031bec45bb438836 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Thu, 10 Sep 2020 21:06:08 -0400 Subject: [PATCH 6/8] return list of lists --- cirq/google/engine/engine_sampler.py | 4 ++-- cirq/google/engine/engine_sampler_test.py | 2 +- cirq/work/sampler.py | 14 +++++++------- cirq/work/sampler_test.py | 10 ++++++---- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/cirq/google/engine/engine_sampler.py b/cirq/google/engine/engine_sampler.py index 7794d755844..2f44cccc03f 100644 --- a/cirq/google/engine/engine_sampler.py +++ b/cirq/google/engine/engine_sampler.py @@ -66,7 +66,7 @@ def run_batch( programs: List['cirq.Circuit'], params_list: Optional[List['cirq.Sweepable']] = None, repetitions: Union[int, List[int]] = 1, - ) -> List['cirq.TrialResult']: + ) -> List[List['cirq.TrialResult']]: """Runs the supplied circuits. In order to gain a speedup from using this method instead of other run @@ -85,7 +85,7 @@ def run_batch( repetitions=repetitions, processor_ids=self._processor_ids, gate_set=self._gate_set) - return job.results() + return job.batched_results() # Varying number of repetitions so no speedup return super().run_batch(programs, params_list, repetitions) diff --git a/cirq/google/engine/engine_sampler_test.py b/cirq/google/engine/engine_sampler_test.py index 1ad0d651146..5ba2c051dd9 100644 --- a/cirq/google/engine/engine_sampler_test.py +++ b/cirq/google/engine/engine_sampler_test.py @@ -90,7 +90,7 @@ def test_run_batch_identical_repetitions(): repetitions=5) -def test_run_batch_number_of_repetitions(): +def test_run_batch_bad_number_of_repetitions(): engine = mock.Mock() sampler = cg.QuantumEngineSampler(engine=engine, processor_id='tmp', diff --git a/cirq/work/sampler.py b/cirq/work/sampler.py index 7d98c59811a..fcf6f3b8e72 100644 --- a/cirq/work/sampler.py +++ b/cirq/work/sampler.py @@ -205,7 +205,7 @@ def run_batch( programs: List['cirq.Circuit'], params_list: Optional[List['cirq.Sweepable']] = None, repetitions: Union[int, List[int]] = 1, - ) -> List['cirq.TrialResult']: + ) -> List[List['cirq.TrialResult']]: """Runs the supplied circuits. Each circuit provided in `programs` will pair with the optional @@ -235,10 +235,10 @@ def run_batch( one for each circuit. Returns: - A list of TrialResults. All TrialResults for the first circuit are - listed first, then the TrialResults for the second, etc. - The TrialResults for a circuit are listed in the order imposed by - the associated parameter sweep. + A list of lists of TrialResults. The outer list corresponds to + the circuits, while each inner list contains the TrialResults + for the corresponding circuit, in the order imposed by the + associated parameter sweep. """ if params_list is None: params_list = [None] * len(programs) @@ -249,7 +249,7 @@ def run_batch( if len(programs) != len(repetitions): raise ValueError('Number of circuits and repetitions must match') return [ - trial_result for circuit, params, repetitions in zip( - programs, params_list, repetitions) for trial_result in self.run_sweep(circuit, params=params, repetitions=repetitions) + for circuit, params, repetitions in zip(programs, params_list, + repetitions) ] diff --git a/cirq/work/sampler_test.py b/cirq/work/sampler_test.py index 05d458ccdde..391de299d84 100644 --- a/cirq/work/sampler_test.py +++ b/cirq/work/sampler_test.py @@ -170,12 +170,12 @@ def test_sampler_run_batch(): results = sampler.run_batch([circuit1, circuit2], params_list=[params1, params2], repetitions=[1, 2]) - assert len(results) == 4 - for result, param in zip(results[:2], [0.3, 0.7]): + assert len(results) == 2 + for result, param in zip(results[0], [0.3, 0.7]): assert result.repetitions == 1 assert result.params.param_dict == {'t': param} assert result.measurements == {'m': np.array([[0]], dtype='uint8')} - for result, param in zip(results[2:], [0.4, 0.6]): + for result, param in zip(results[1], [0.4, 0.6]): assert result.repetitions == 2 assert result.params.param_dict == {'t': param} assert len(result.measurements) == 1 @@ -190,7 +190,9 @@ def test_sampler_run_batch_default_params_and_repetitions(): circuit2 = cirq.Circuit(cirq.Y(a), cirq.measure(a, key='m')) results = sampler.run_batch([circuit1, circuit2]) assert len(results) == 2 - for result in results: + for result_list in results: + assert len(result_list) == 1 + result = result_list[0] assert result.repetitions == 1 assert result.params.param_dict == {} assert result.measurements == {'m': np.array([[0]], dtype='uint8')} From d80a5e7cd35cd4dd4ab2b24cfd6b7c49b544b624 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Fri, 11 Sep 2020 14:43:57 -0400 Subject: [PATCH 7/8] update Engine batching requirements and improve error messages --- cirq/google/engine/engine_sampler.py | 11 +++++++---- cirq/google/engine/engine_sampler_test.py | 2 +- cirq/work/sampler.py | 6 ++++-- cirq/work/sampler_test.py | 4 ++-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/cirq/google/engine/engine_sampler.py b/cirq/google/engine/engine_sampler.py index 2f44cccc03f..2ebc3802533 100644 --- a/cirq/google/engine/engine_sampler.py +++ b/cirq/google/engine/engine_sampler.py @@ -70,12 +70,15 @@ def run_batch( """Runs the supplied circuits. In order to gain a speedup from using this method instead of other run - methods, the number of circuit repetitions must be the same for all - circuits. That is, the `repetitions` argument must be an integer, or - else a list with identical values. + methods, the following conditions must be satisfied: + 1. All circuits must measure the same set of qubits. + 2. The number of circuit repetitions must be the same for all + circuits. That is, the `repetitions` argument must be an integer, + or else a list with identical values. """ if isinstance(repetitions, List) and len(programs) != len(repetitions): - raise ValueError('Number of circuits and repetitions must match') + raise ValueError('len(programs) and len(repetitions) must match. ' + f'Got {len(programs)} and {len(repetitions)}.') if isinstance(repetitions, int) or len(set(repetitions)) == 1: # All repetitions are the same so batching can be done efficiently if isinstance(repetitions, List): diff --git a/cirq/google/engine/engine_sampler_test.py b/cirq/google/engine/engine_sampler_test.py index 5ba2c051dd9..4b8d6998f6d 100644 --- a/cirq/google/engine/engine_sampler_test.py +++ b/cirq/google/engine/engine_sampler_test.py @@ -102,7 +102,7 @@ def test_run_batch_bad_number_of_repetitions(): params2 = [cirq.ParamResolver({'t': 2})] circuits = [circuit1, circuit2] params_list = [params1, params2] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='2 and 3'): sampler.run_batch(circuits, params_list, [5, 5, 5]) diff --git a/cirq/work/sampler.py b/cirq/work/sampler.py index fcf6f3b8e72..4690f508703 100644 --- a/cirq/work/sampler.py +++ b/cirq/work/sampler.py @@ -243,11 +243,13 @@ def run_batch( if params_list is None: params_list = [None] * len(programs) if len(programs) != len(params_list): - raise ValueError('Number of circuits and sweeps must match') + raise ValueError('len(programs) and len(params_list) must match. ' + f'Got {len(programs)} and {len(params_list)}.') if isinstance(repetitions, int): repetitions = [repetitions] * len(programs) if len(programs) != len(repetitions): - raise ValueError('Number of circuits and repetitions must match') + raise ValueError('len(programs) and len(repetitions) must match. ' + f'Got {len(programs)} and {len(repetitions)}.') return [ self.run_sweep(circuit, params=params, repetitions=repetitions) for circuit, params, repetitions in zip(programs, params_list, diff --git a/cirq/work/sampler_test.py b/cirq/work/sampler_test.py index 391de299d84..1fc4a4aeb24 100644 --- a/cirq/work/sampler_test.py +++ b/cirq/work/sampler_test.py @@ -207,9 +207,9 @@ def test_sampler_run_batch_bad_input_lengths(): cirq.Y(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) params1 = cirq.Points('t', [0.3, 0.7]) params2 = cirq.Points('t', [0.4, 0.6]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='2 and 1'): _ = sampler.run_batch([circuit1, circuit2], params_list=[params1]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='2 and 3'): _ = sampler.run_batch([circuit1, circuit2], params_list=[params1, params2], repetitions=[1, 2, 3]) From 909fec3787641bca3b1695cd9d8969374627aa08 Mon Sep 17 00:00:00 2001 From: "Kevin J. Sung" Date: Fri, 11 Sep 2020 14:49:09 -0400 Subject: [PATCH 8/8] comma --- cirq/work/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cirq/work/sampler.py b/cirq/work/sampler.py index 4690f508703..98c960b0313 100644 --- a/cirq/work/sampler.py +++ b/cirq/work/sampler.py @@ -211,7 +211,7 @@ def run_batch( Each circuit provided in `programs` will pair with the optional associated parameter sweep provided in the `params_list`, and be run with the associated repetitions provided in `repetitions` (if - `repetitions` is an integer then all runs will have that number of + `repetitions` is an integer, then all runs will have that number of repetitions). If `params_list` is specified, then the number of circuits is required to match the number of sweeps. Similarly, when `repetitions` is a list, the number of circuits is required to match