Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions cirq/google/engine/engine_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,37 @@ def run_sweep(
gate_set=self._gate_set)
return job.results()

def run_batch(
self,
programs: List['cirq.Circuit'],
params_list: Optional[List['cirq.Sweepable']] = None,
repetitions: Union[int, List[int]] = 1,
) -> List[List['cirq.TrialResult']]:
"""Runs the supplied circuits.

In order to gain a speedup from using this method instead of other run
methods, the following conditions must be satisfied:
1. All circuits must measure the same set of qubits.
2. The number of circuit repetitions must be the same for all
circuits. That is, the `repetitions` argument must be an integer,
or else a list with identical values.
"""
if isinstance(repetitions, List) and len(programs) != len(repetitions):
raise ValueError('len(programs) and len(repetitions) must match. '
f'Got {len(programs)} and {len(repetitions)}.')
if isinstance(repetitions, int) or len(set(repetitions)) == 1:
# All repetitions are the same so batching can be done efficiently
if isinstance(repetitions, List):
repetitions = repetitions[0]
job = self._engine.run_batch(programs=programs,
params_list=params_list,
repetitions=repetitions,
processor_ids=self._processor_ids,
gate_set=self._gate_set)
return job.batched_results()
# Varying number of repetitions so no speedup
return super().run_batch(programs, params_list, repetitions)

@property
def engine(self) -> 'cirq.google.Engine':
return self._engine
Expand Down
81 changes: 81 additions & 0 deletions cirq/google/engine/engine_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,87 @@ def test_run_engine_program():
engine.run_sweep.assert_not_called()


def test_run_batch():
engine = mock.Mock()
sampler = cg.QuantumEngineSampler(engine=engine,
processor_id='tmp',
gate_set=cg.XMON)
a = cirq.LineQubit(0)
circuit1 = cirq.Circuit(cirq.X(a))
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
sampler.run_batch(circuits, params_list, 5)
engine.run_batch.assert_called_with(gate_set=cg.XMON,
params_list=params_list,
processor_ids=['tmp'],
programs=circuits,
repetitions=5)


def test_run_batch_identical_repetitions():
engine = mock.Mock()
sampler = cg.QuantumEngineSampler(engine=engine,
processor_id='tmp',
gate_set=cg.XMON)
a = cirq.LineQubit(0)
circuit1 = cirq.Circuit(cirq.X(a))
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
sampler.run_batch(circuits, params_list, [5, 5])
engine.run_batch.assert_called_with(gate_set=cg.XMON,
params_list=params_list,
processor_ids=['tmp'],
programs=circuits,
repetitions=5)


def test_run_batch_bad_number_of_repetitions():
engine = mock.Mock()
sampler = cg.QuantumEngineSampler(engine=engine,
processor_id='tmp',
gate_set=cg.XMON)
a = cirq.LineQubit(0)
circuit1 = cirq.Circuit(cirq.X(a))
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
with pytest.raises(ValueError, match='2 and 3'):
sampler.run_batch(circuits, params_list, [5, 5, 5])


def test_run_batch_differing_repetitions():
engine = mock.Mock()
job = mock.Mock()
job.results.return_value = []
engine.run_sweep.return_value = job
sampler = cg.QuantumEngineSampler(engine=engine,
processor_id='tmp',
gate_set=cg.XMON)
a = cirq.LineQubit(0)
circuit1 = cirq.Circuit(cirq.X(a))
circuit2 = cirq.Circuit(cirq.Y(a))
params1 = [cirq.ParamResolver({'t': 1})]
params2 = [cirq.ParamResolver({'t': 2})]
circuits = [circuit1, circuit2]
params_list = [params1, params2]
repetitions = [1, 2]
sampler.run_batch(circuits, params_list, repetitions)
engine.run_sweep.assert_called_with(gate_set=cg.XMON,
params=params2,
processor_ids=['tmp'],
program=circuit2,
repetitions=2)
engine.run_batch.assert_not_called()


def test_engine_sampler_engine_property():
engine = mock.Mock()
sampler = cg.QuantumEngineSampler(engine=engine,
Expand Down
58 changes: 57 additions & 1 deletion cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Abstract base class for things sampling quantum circuits."""

from typing import List, TYPE_CHECKING
from typing import List, Optional, TYPE_CHECKING, Union
import abc

import pandas as pd
Expand Down Expand Up @@ -199,3 +199,59 @@ async def run_sweep_async(
An awaitable TrialResult.
"""
return self.run_sweep(program, params=params, repetitions=repetitions)

def run_batch(
self,
programs: List['cirq.Circuit'],
params_list: Optional[List['cirq.Sweepable']] = None,
repetitions: Union[int, List[int]] = 1,
) -> List[List['cirq.TrialResult']]:
"""Runs the supplied circuits.

Each circuit provided in `programs` will pair with the optional
associated parameter sweep provided in the `params_list`, and be run
with the associated repetitions provided in `repetitions` (if
`repetitions` is an integer, then all runs will have that number of
repetitions). If `params_list` is specified, then the number of
circuits is required to match the number of sweeps. Similarly, when
`repetitions` is a list, the number of circuits is required to match
the length of this list.

By default, this method simply invokes `run_sweep` sequentially for
each (circuit, parameter sweep, repetitions) tuple. Child classes that
are capable of sampling batches more efficiently should override it to
use other strategies. Note that child classes may have certain
requirements that must be met in order for a speedup to be possible,
such as a constant number of repetitions being used for all circuits.
Refer to the documentation of the child class for any such requirements.

Args:
programs: The circuits to execute as a batch.
params_list: Parameter sweeps to use with the circuits. The number
of sweeps should match the number of circuits and will be
paired in order with the circuits.
repetitions: Number of circuit repetitions to run. Can be specified
as a single value to use for all runs, or as a list of values,
one for each circuit.

Returns:
A list of lists of TrialResults. The outer list corresponds to
the circuits, while each inner list contains the TrialResults
for the corresponding circuit, in the order imposed by the
associated parameter sweep.
"""
if params_list is None:
params_list = [None] * len(programs)
if len(programs) != len(params_list):
raise ValueError('len(programs) and len(params_list) must match. '
f'Got {len(programs)} and {len(params_list)}.')
if isinstance(repetitions, int):
repetitions = [repetitions] * len(programs)
if len(programs) != len(repetitions):
raise ValueError('len(programs) and len(repetitions) must match. '
f'Got {len(programs)} and {len(repetitions)}.')
return [
self.run_sweep(circuit, params=params, repetitions=repetitions)
for circuit, params, repetitions in zip(programs, params_list,
repetitions)
]
58 changes: 58 additions & 0 deletions cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for cirq.Sampler."""
import pytest

import numpy as np
import pandas as pd
import sympy

Expand Down Expand Up @@ -155,3 +156,60 @@ def run_sweep(self, *args, **kwargs):
assert not ran
assert await a == []
assert ran


def test_sampler_run_batch():
sampler = cirq.ZerosSampler()
a = cirq.LineQubit(0)
circuit1 = cirq.Circuit(
cirq.X(a)**sympy.Symbol('t'), cirq.measure(a, key='m'))
circuit2 = cirq.Circuit(
cirq.Y(a)**sympy.Symbol('t'), cirq.measure(a, key='m'))
params1 = cirq.Points('t', [0.3, 0.7])
params2 = cirq.Points('t', [0.4, 0.6])
results = sampler.run_batch([circuit1, circuit2],
params_list=[params1, params2],
repetitions=[1, 2])
assert len(results) == 2
for result, param in zip(results[0], [0.3, 0.7]):
assert result.repetitions == 1
assert result.params.param_dict == {'t': param}
assert result.measurements == {'m': np.array([[0]], dtype='uint8')}
for result, param in zip(results[1], [0.4, 0.6]):
assert result.repetitions == 2
assert result.params.param_dict == {'t': param}
assert len(result.measurements) == 1
assert np.array_equal(result.measurements['m'],
np.array([[0], [0]], dtype='uint8'))


def test_sampler_run_batch_default_params_and_repetitions():
sampler = cirq.ZerosSampler()
a = cirq.LineQubit(0)
circuit1 = cirq.Circuit(cirq.X(a), cirq.measure(a, key='m'))
circuit2 = cirq.Circuit(cirq.Y(a), cirq.measure(a, key='m'))
results = sampler.run_batch([circuit1, circuit2])
assert len(results) == 2
for result_list in results:
assert len(result_list) == 1
result = result_list[0]
assert result.repetitions == 1
assert result.params.param_dict == {}
assert result.measurements == {'m': np.array([[0]], dtype='uint8')}


def test_sampler_run_batch_bad_input_lengths():
sampler = cirq.ZerosSampler()
a = cirq.LineQubit(0)
circuit1 = cirq.Circuit(
cirq.X(a)**sympy.Symbol('t'), cirq.measure(a, key='m'))
circuit2 = cirq.Circuit(
cirq.Y(a)**sympy.Symbol('t'), cirq.measure(a, key='m'))
params1 = cirq.Points('t', [0.3, 0.7])
params2 = cirq.Points('t', [0.4, 0.6])
with pytest.raises(ValueError, match='2 and 1'):
_ = sampler.run_batch([circuit1, circuit2], params_list=[params1])
with pytest.raises(ValueError, match='2 and 3'):
_ = sampler.run_batch([circuit1, circuit2],
params_list=[params1, params2],
repetitions=[1, 2, 3])