Skip to content

Commit

Permalink
Use device instead of gateset validation in ZerosSampler (#3142)
Browse files Browse the repository at this point in the history
This introduces a dependency of cirq on cirq.google, which we want to avoid.  If we want a gate set validating version of this we should put it in cirq.google.
  • Loading branch information
dabacon committed Jul 22, 2020
1 parent 6f24531 commit 7bf16e1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 35 deletions.
33 changes: 17 additions & 16 deletions cirq/work/zeros_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,27 @@
# limitations under the License.

import abc
from typing import (Dict, List, TYPE_CHECKING)
from typing import Dict, List, TYPE_CHECKING

import numpy as np

from cirq import work, study, protocols
from cirq import devices, work, study, protocols

if TYPE_CHECKING:
import cirq.google
import cirq


class ZerosSampler(work.Sampler, metaclass=abc.ABCMeta):
"""A dummy sampler for testing. Immediately returns zeroes."""

def __init__(self, gate_set: 'cirq.google.SerializableGateSet' = None):
"""
def __init__(self, device: devices.Device = None):
"""Construct a sampler that returns 0 for all measurements.
Args:
gate_set: `SerializableGateSet`. If set, sampler will validate that
all gates in the circuit are from the given gate set.
device: A device against which to validate the circuit. If None,
no validation will be done.
"""
self.gate_set = gate_set
self.device = device

def run_sweep(
self,
Expand All @@ -50,20 +51,20 @@ def run_sweep(
Returns:
TrialResult list for this run; one for each possible parameter
resolver.
"""
if self.gate_set is not None:
for op in program.all_operations():
assert self.gate_set.is_supported_operation(op), (
"Unsupported operation: %s" % op)
Raises:
ValueError if this sampler has a device and the circuit is not
valid for the device.
"""
if self.device:
self.device.validate_circuit(program)
measurements = {} # type: Dict[str, np.ndarray]
for op in program.all_operations():
key = protocols.measurement_key(op, default=None)
if key is not None:
measurements[key] = np.zeros((repetitions, len(op.qubits)),
dtype=np.int8)
dtype=int)
return [
study.TrialResult.from_single_parameter_set(
params=param_resolver, measurements=measurements)
study.TrialResult(params=param_resolver, measurements=measurements)
for param_resolver in study.to_resolvers(params)
]
40 changes: 21 additions & 19 deletions cirq/work/zeros_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
import sympy

import cirq
from cirq import study
from cirq.google.common_serializers import (SINGLE_QUBIT_SERIALIZERS,
SINGLE_QUBIT_DESERIALIZERS)
from cirq.google.serializable_gate_set import SerializableGateSet


def test_run_sweep():
Expand Down Expand Up @@ -48,25 +44,31 @@ def test_sample():
c += [cirq.measure(q) for q in qs[0:3]]
c += cirq.measure(qs[4], qs[5])
# Z to even power is an identity.
params = study.Points(sympy.Symbol('p'), [0, 2, 4, 6])
params = cirq.Points(sympy.Symbol('p'), [0, 2, 4, 6])

result1 = cirq.ZerosSampler().sample(c, repetitions=10,
params=params).sort_index(axis=1)
result2 = cirq.Simulator().sample(c, repetitions=10,
params=params).sort_index(axis=1)
result1 = cirq.ZerosSampler().sample(c, repetitions=10, params=params)
result2 = cirq.Simulator().sample(c, repetitions=10, params=params)

assert np.all(result1 == result2)


def test_sample_with_gate_set():
gate_set = SerializableGateSet('test', SINGLE_QUBIT_SERIALIZERS,
SINGLE_QUBIT_DESERIALIZERS)
sampler = cirq.ZerosSampler(gate_set=gate_set)
a, b = cirq.LineQubit.range(2)
circuit1 = cirq.Circuit([cirq.X(a)])
circuit2 = cirq.Circuit([cirq.CX(a, b)])
class OnlyMeasurementsDevice(cirq.Device):

sampler.sample(circuit1)
def validate_operation(self, operation: 'cirq.Operation') -> None:
if not cirq.is_measurement(operation):
raise ValueError(f'{operation} is not a measurement and this '
f'device only measures!')

with pytest.raises(AssertionError, match='Unsupported operation'):
sampler.sample(circuit2)

def test_validate_device():
device = OnlyMeasurementsDevice()
sampler = cirq.ZerosSampler(device)

a, b, c = [cirq.NamedQubit(s) for s in ['a', 'b', 'c']]
circuit = cirq.Circuit(cirq.measure(a), cirq.measure(b, c))

_ = sampler.run_sweep(circuit, None, 3)

circuit = cirq.Circuit(cirq.measure(a), cirq.X(b))
with pytest.raises(ValueError, match=r'X\(b\) is not a measurement'):
_ = sampler.run_sweep(circuit, None, 3)

0 comments on commit 7bf16e1

Please sign in to comment.