diff --git a/cirq/google/engine/engine_sampler.py b/cirq/google/engine/engine_sampler.py index 4ec77075086..2ebc3802533 100644 --- a/cirq/google/engine/engine_sampler.py +++ b/cirq/google/engine/engine_sampler.py @@ -61,6 +61,37 @@ def run_sweep( gate_set=self._gate_set) return job.results() + def run_batch( + self, + programs: List['cirq.Circuit'], + params_list: Optional[List['cirq.Sweepable']] = None, + repetitions: Union[int, List[int]] = 1, + ) -> List[List['cirq.TrialResult']]: + """Runs the supplied circuits. + + In order to gain a speedup from using this method instead of other run + methods, the following conditions must be satisfied: + 1. All circuits must measure the same set of qubits. + 2. The number of circuit repetitions must be the same for all + circuits. That is, the `repetitions` argument must be an integer, + or else a list with identical values. + """ + if isinstance(repetitions, List) and len(programs) != len(repetitions): + raise ValueError('len(programs) and len(repetitions) must match. ' + f'Got {len(programs)} and {len(repetitions)}.') + if isinstance(repetitions, int) or len(set(repetitions)) == 1: + # All repetitions are the same so batching can be done efficiently + if isinstance(repetitions, List): + repetitions = repetitions[0] + job = self._engine.run_batch(programs=programs, + params_list=params_list, + repetitions=repetitions, + processor_ids=self._processor_ids, + gate_set=self._gate_set) + return job.batched_results() + # Varying number of repetitions so no speedup + return super().run_batch(programs, params_list, repetitions) + @property def engine(self) -> 'cirq.google.Engine': return self._engine diff --git a/cirq/google/engine/engine_sampler_test.py b/cirq/google/engine/engine_sampler_test.py index aa6b5e35849..4b8d6998f6d 100644 --- a/cirq/google/engine/engine_sampler_test.py +++ b/cirq/google/engine/engine_sampler_test.py @@ -50,6 +50,87 @@ def test_run_engine_program(): engine.run_sweep.assert_not_called() +def test_run_batch(): + engine = mock.Mock() + sampler = cg.QuantumEngineSampler(engine=engine, + processor_id='tmp', + gate_set=cg.XMON) + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)) + circuit2 = cirq.Circuit(cirq.Y(a)) + params1 = [cirq.ParamResolver({'t': 1})] + params2 = [cirq.ParamResolver({'t': 2})] + circuits = [circuit1, circuit2] + params_list = [params1, params2] + sampler.run_batch(circuits, params_list, 5) + engine.run_batch.assert_called_with(gate_set=cg.XMON, + params_list=params_list, + processor_ids=['tmp'], + programs=circuits, + repetitions=5) + + +def test_run_batch_identical_repetitions(): + engine = mock.Mock() + sampler = cg.QuantumEngineSampler(engine=engine, + processor_id='tmp', + gate_set=cg.XMON) + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)) + circuit2 = cirq.Circuit(cirq.Y(a)) + params1 = [cirq.ParamResolver({'t': 1})] + params2 = [cirq.ParamResolver({'t': 2})] + circuits = [circuit1, circuit2] + params_list = [params1, params2] + sampler.run_batch(circuits, params_list, [5, 5]) + engine.run_batch.assert_called_with(gate_set=cg.XMON, + params_list=params_list, + processor_ids=['tmp'], + programs=circuits, + repetitions=5) + + +def test_run_batch_bad_number_of_repetitions(): + engine = mock.Mock() + sampler = cg.QuantumEngineSampler(engine=engine, + processor_id='tmp', + gate_set=cg.XMON) + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)) + circuit2 = cirq.Circuit(cirq.Y(a)) + params1 = [cirq.ParamResolver({'t': 1})] + params2 = [cirq.ParamResolver({'t': 2})] + circuits = [circuit1, circuit2] + params_list = [params1, params2] + with pytest.raises(ValueError, match='2 and 3'): + sampler.run_batch(circuits, params_list, [5, 5, 5]) + + +def test_run_batch_differing_repetitions(): + engine = mock.Mock() + job = mock.Mock() + job.results.return_value = [] + engine.run_sweep.return_value = job + sampler = cg.QuantumEngineSampler(engine=engine, + processor_id='tmp', + gate_set=cg.XMON) + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a)) + circuit2 = cirq.Circuit(cirq.Y(a)) + params1 = [cirq.ParamResolver({'t': 1})] + params2 = [cirq.ParamResolver({'t': 2})] + circuits = [circuit1, circuit2] + params_list = [params1, params2] + repetitions = [1, 2] + sampler.run_batch(circuits, params_list, repetitions) + engine.run_sweep.assert_called_with(gate_set=cg.XMON, + params=params2, + processor_ids=['tmp'], + program=circuit2, + repetitions=2) + engine.run_batch.assert_not_called() + + def test_engine_sampler_engine_property(): engine = mock.Mock() sampler = cg.QuantumEngineSampler(engine=engine, diff --git a/cirq/work/sampler.py b/cirq/work/sampler.py index c730bd14679..98c960b0313 100644 --- a/cirq/work/sampler.py +++ b/cirq/work/sampler.py @@ -13,7 +13,7 @@ # limitations under the License. """Abstract base class for things sampling quantum circuits.""" -from typing import List, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING, Union import abc import pandas as pd @@ -199,3 +199,59 @@ async def run_sweep_async( An awaitable TrialResult. """ return self.run_sweep(program, params=params, repetitions=repetitions) + + def run_batch( + self, + programs: List['cirq.Circuit'], + params_list: Optional[List['cirq.Sweepable']] = None, + repetitions: Union[int, List[int]] = 1, + ) -> List[List['cirq.TrialResult']]: + """Runs the supplied circuits. + + Each circuit provided in `programs` will pair with the optional + associated parameter sweep provided in the `params_list`, and be run + with the associated repetitions provided in `repetitions` (if + `repetitions` is an integer, then all runs will have that number of + repetitions). If `params_list` is specified, then the number of + circuits is required to match the number of sweeps. Similarly, when + `repetitions` is a list, the number of circuits is required to match + the length of this list. + + By default, this method simply invokes `run_sweep` sequentially for + each (circuit, parameter sweep, repetitions) tuple. Child classes that + are capable of sampling batches more efficiently should override it to + use other strategies. Note that child classes may have certain + requirements that must be met in order for a speedup to be possible, + such as a constant number of repetitions being used for all circuits. + Refer to the documentation of the child class for any such requirements. + + Args: + programs: The circuits to execute as a batch. + params_list: Parameter sweeps to use with the circuits. The number + of sweeps should match the number of circuits and will be + paired in order with the circuits. + repetitions: Number of circuit repetitions to run. Can be specified + as a single value to use for all runs, or as a list of values, + one for each circuit. + + Returns: + A list of lists of TrialResults. The outer list corresponds to + the circuits, while each inner list contains the TrialResults + for the corresponding circuit, in the order imposed by the + associated parameter sweep. + """ + if params_list is None: + params_list = [None] * len(programs) + if len(programs) != len(params_list): + raise ValueError('len(programs) and len(params_list) must match. ' + f'Got {len(programs)} and {len(params_list)}.') + if isinstance(repetitions, int): + repetitions = [repetitions] * len(programs) + if len(programs) != len(repetitions): + raise ValueError('len(programs) and len(repetitions) must match. ' + f'Got {len(programs)} and {len(repetitions)}.') + return [ + self.run_sweep(circuit, params=params, repetitions=repetitions) + for circuit, params, repetitions in zip(programs, params_list, + repetitions) + ] diff --git a/cirq/work/sampler_test.py b/cirq/work/sampler_test.py index 29e9baffd89..1fc4a4aeb24 100644 --- a/cirq/work/sampler_test.py +++ b/cirq/work/sampler_test.py @@ -14,6 +14,7 @@ """Tests for cirq.Sampler.""" import pytest +import numpy as np import pandas as pd import sympy @@ -155,3 +156,60 @@ def run_sweep(self, *args, **kwargs): assert not ran assert await a == [] assert ran + + +def test_sampler_run_batch(): + sampler = cirq.ZerosSampler() + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit( + cirq.X(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) + circuit2 = cirq.Circuit( + cirq.Y(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) + params1 = cirq.Points('t', [0.3, 0.7]) + params2 = cirq.Points('t', [0.4, 0.6]) + results = sampler.run_batch([circuit1, circuit2], + params_list=[params1, params2], + repetitions=[1, 2]) + assert len(results) == 2 + for result, param in zip(results[0], [0.3, 0.7]): + assert result.repetitions == 1 + assert result.params.param_dict == {'t': param} + assert result.measurements == {'m': np.array([[0]], dtype='uint8')} + for result, param in zip(results[1], [0.4, 0.6]): + assert result.repetitions == 2 + assert result.params.param_dict == {'t': param} + assert len(result.measurements) == 1 + assert np.array_equal(result.measurements['m'], + np.array([[0], [0]], dtype='uint8')) + + +def test_sampler_run_batch_default_params_and_repetitions(): + sampler = cirq.ZerosSampler() + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.X(a), cirq.measure(a, key='m')) + circuit2 = cirq.Circuit(cirq.Y(a), cirq.measure(a, key='m')) + results = sampler.run_batch([circuit1, circuit2]) + assert len(results) == 2 + for result_list in results: + assert len(result_list) == 1 + result = result_list[0] + assert result.repetitions == 1 + assert result.params.param_dict == {} + assert result.measurements == {'m': np.array([[0]], dtype='uint8')} + + +def test_sampler_run_batch_bad_input_lengths(): + sampler = cirq.ZerosSampler() + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit( + cirq.X(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) + circuit2 = cirq.Circuit( + cirq.Y(a)**sympy.Symbol('t'), cirq.measure(a, key='m')) + params1 = cirq.Points('t', [0.3, 0.7]) + params2 = cirq.Points('t', [0.4, 0.6]) + with pytest.raises(ValueError, match='2 and 1'): + _ = sampler.run_batch([circuit1, circuit2], params_list=[params1]) + with pytest.raises(ValueError, match='2 and 3'): + _ = sampler.run_batch([circuit1, circuit2], + params_list=[params1, params2], + repetitions=[1, 2, 3])