Skip to content

Commit

Permalink
remove cirq.google dependency from three qubit gates test (#3928)
Browse files Browse the repository at this point in the history
Moves ValidatingTestDevice to cirq.testing and it is used in `three_qubit_gates_test`.
This helps removing cirq.google references from cirq. 

This is related to #3737.
  • Loading branch information
balopat committed Mar 18, 2021
1 parent fa55cbb commit 2128257
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 82 deletions.
121 changes: 43 additions & 78 deletions cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from collections import defaultdict
from random import randint, random, sample, randrange
from typing import Tuple, cast, AbstractSet
from typing import Tuple

import numpy as np
import pytest
Expand All @@ -23,6 +23,46 @@
import cirq
import cirq.testing
from cirq import ops
from cirq.testing.devices import ValidatingTestDevice


class _Foxy(ValidatingTestDevice):
def can_add_operation_into_moment(
self, operation: 'ops.Operation', moment: 'ops.Moment'
) -> bool:
if not super().can_add_operation_into_moment(operation, moment):
return False
# a fake rule for ensuring that no two CZs are executed at the same moment.
# this will ensure that CZs are always in separate moments in this device
return not (
isinstance(operation.gate, ops.CZPowGate)
and any(isinstance(op.gate, ops.CZPowGate) for op in moment.operations)
)


FOXY = _Foxy(
allowed_qubit_types=(cirq.GridQubit,),
allowed_gates=(
ops.CZPowGate,
ops.XPowGate,
ops.YPowGate,
ops.ZPowGate,
),
qubits=set(cirq.GridQubit.rect(2, 7)),
name=f'{__name__}.FOXY',
auto_decompose_gates=(ops.CCXPowGate,),
validate_locality=True,
)


BCONE = ValidatingTestDevice(
allowed_qubit_types=(cirq.GridQubit,),
allowed_gates=(ops.XPowGate,),
qubits={
cirq.GridQubit(0, 6),
},
name=f'{__name__}.BCONE',
)


class _MomentAndOpTypeValidatingDeviceType(cirq.Device):
Expand Down Expand Up @@ -341,7 +381,7 @@ def test_repr(circuit_cls):

c = circuit_cls(device=FOXY)
cirq.testing.assert_equivalent_repr(c)
assert repr(c) == f'cirq.{circuit_cls.__name__}(device=cirq.circuits.circuit_test.FOXY)'
assert repr(c) == f'cirq.{circuit_cls.__name__}(device={repr(FOXY)})'

c = circuit_cls(cirq.Z(cirq.GridQubit(0, 0)), device=FOXY)
cirq.testing.assert_equivalent_repr(c)
Expand All @@ -351,7 +391,7 @@ def test_repr(circuit_cls):
cirq.Moment(
cirq.Z(cirq.GridQubit(0, 0)),
),
], device=cirq.circuits.circuit_test.FOXY)"""
], device={repr(FOXY)})"""
)


Expand Down Expand Up @@ -584,81 +624,6 @@ def test_concatenate_with_device():
assert len(cone) == 0


class ValidatingTestDevice(cirq.Device):
"""A fake device that was created to ensure certain Device validation features are
leveraged in Circuit functions. It contains the minimum set of features that tests
require. Feel free to extend the features here as needed."""

def __init__(
self,
allowed_qubit_types: Tuple[type, ...],
allowed_gates: Tuple[type, ...],
qubits: AbstractSet[cirq.Qid],
name: str,
):
self.allowed_qubit_types = allowed_qubit_types
self.allowed_gates = allowed_gates
self.qubits = qubits
self._repr = name

def validate_operation(self, operation: cirq.Operation) -> None:
# This is pretty close to what the cirq.google.XmonDevice has for validation
for q in operation.qubits:
if not isinstance(q, self.allowed_qubit_types):
raise ValueError(f"Unsupported qubit type: {type(q)!r}")
if q not in self.qubits:
raise ValueError(f'Qubit not on device: {q!r}')
if not isinstance(operation.gate, self.allowed_gates):
raise ValueError(f"Unsupported gate type: {operation.gate!r}")
if len(operation.qubits) == 2 and not isinstance(operation.gate, ops.MeasurementGate):
p, q = operation.qubits
if not cast(cirq.GridQubit, p).is_adjacent(q):
raise ValueError(f'Non-local interaction: {operation!r}.')

def decompose_operation(self, operation: 'cirq.Operation') -> 'cirq.OP_TREE':
# a fake decomposer for only TOFFOLI gates
if isinstance(operation.gate, cirq.CCXPowGate):
return cirq.decompose(operation)
return operation

def can_add_operation_into_moment(
self, operation: 'cirq.Operation', moment: 'cirq.Moment'
) -> bool:
if not super().can_add_operation_into_moment(operation, moment):
return False
# a fake rule for ensuring that no two CZs are executed at the same moment.
# this will ensure that CZs are always in separate moments in this device11
return not (
isinstance(operation.gate, cirq.CZPowGate)
and any(isinstance(op.gate, cirq.CZPowGate) for op in moment.operations)
)

def __repr__(self):
return self._repr


FOXY = ValidatingTestDevice(
allowed_qubit_types=(cirq.GridQubit,),
allowed_gates=(
ops.CZPowGate,
ops.XPowGate,
ops.YPowGate,
ops.ZPowGate,
),
qubits=set(cirq.GridQubit.rect(2, 7)),
name='cirq.circuits.circuit_test.FOXY',
)

BCONE = ValidatingTestDevice(
allowed_qubit_types=(cirq.GridQubit,),
allowed_gates=(cirq.XPowGate,),
qubits={
cirq.GridQubit(0, 6),
},
name='cirq.circuits.circuit_test.BCONE',
)


@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_with_device(circuit_cls):

Expand Down
8 changes: 4 additions & 4 deletions cirq/ops/three_qubit_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_identity_multiplication():
],
)
def test_decomposition_cost(op: cirq.Operation, max_two_cost: int):
ops = tuple(cirq.flatten_op_tree(cirq.google.ConvertToXmonGates().convert(op)))
ops = tuple(cirq.flatten_op_tree(cirq.decompose(op)))
two_cost = len([e for e in ops if len(e.qubits) == 2])
over_cost = len([e for e in ops if len(e.qubits) > 2])
assert over_cost == 0
Expand All @@ -222,11 +222,11 @@ def test_decomposition_respects_locality(gate):
a = cirq.GridQubit(0, 0)
b = cirq.GridQubit(1, 0)
c = cirq.GridQubit(0, 1)

dev = cirq.testing.ValidatingTestDevice(qubits={a, b, c}, validate_locality=True)
for x, y, z in itertools.permutations([a, b, c]):
circuit = cirq.Circuit(gate(x, y, z))
cirq.google.ConvertToXmonGates().optimize_circuit(circuit)
cirq.google.Foxtail.validate_circuit(circuit)
circuit = cirq.Circuit(cirq.decompose(circuit))
dev.validate_circuit(circuit)


def test_diagram():
Expand Down
4 changes: 4 additions & 0 deletions cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
assert_specifies_has_unitary_if_unitary,
)

from cirq.testing.devices import (
ValidatingTestDevice,
)

from cirq.testing.equals_tester import (
EqualsTester,
)
Expand Down
75 changes: 75 additions & 0 deletions cirq/testing/devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides test devices that can validate circuits."""
from typing import Tuple, AbstractSet, cast

from cirq import devices, ops, protocols


class ValidatingTestDevice(devices.Device):
"""A fake device that was created to ensure certain Device validation features are
leveraged in Circuit functions. It contains the minimum set of features that tests
require. Feel free to extend the features here as needed.
Args:
qubits: set of qubits on this device
name: the name for repr
allowed_gates: tuple of allowed gate types
allowed_qubit_types: tuple of allowed qubit types
validate_locality: if True, device will validate 2 qubit operations
(except MeasurementGateOperations) whether the two qubits are adjacent. If True,
GridQubits are assumed to be part of the allowed_qubit_types
auto_decompose_gates: when set, for given gates it calls the cirq.decompose protocol
"""

def __init__(
self,
qubits: AbstractSet[ops.Qid],
name: str = "ValidatingTestDevice",
allowed_gates: Tuple[type, ...] = (ops.Gate,),
allowed_qubit_types: Tuple[type, ...] = (devices.GridQubit,),
validate_locality: bool = False,
auto_decompose_gates: Tuple[type, ...] = tuple(),
):
self.allowed_qubit_types = allowed_qubit_types
self.allowed_gates = allowed_gates
self.qubits = qubits
self._repr = name
self.validate_locality = validate_locality
self.auto_decompose_gates = auto_decompose_gates
if self.validate_locality and devices.GridQubit not in allowed_qubit_types:
raise ValueError("GridQubit must be an allowed qubit type with validate_locality=True")

def validate_operation(self, operation: ops.Operation) -> None:
# This is pretty close to what the cirq.google.XmonDevice has for validation
for q in operation.qubits:
if not isinstance(q, self.allowed_qubit_types):
raise ValueError(f"Unsupported qubit type: {type(q)!r}")
if q not in self.qubits:
raise ValueError(f'Qubit not on device: {q!r}')
if not isinstance(operation.gate, self.allowed_gates):
raise ValueError(f"Unsupported gate type: {operation.gate!r}")
if self.validate_locality:
if len(operation.qubits) == 2 and not isinstance(operation.gate, ops.MeasurementGate):
p, q = operation.qubits
if not cast(devices.GridQubit, p).is_adjacent(q):
raise ValueError(f'Non-local interaction: {operation!r}.')

def decompose_operation(self, operation: 'ops.Operation') -> 'ops.OP_TREE':
if isinstance(operation.gate, self.auto_decompose_gates):
return protocols.decompose(operation)
return operation

def __repr__(self):
return self._repr
122 changes: 122 additions & 0 deletions cirq/testing/devices_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import cirq
from cirq.testing.devices import ValidatingTestDevice


def test_validating_types_and_qubits():
dev = ValidatingTestDevice(
allowed_qubit_types=(cirq.GridQubit,),
allowed_gates=(cirq.XPowGate,),
qubits={cirq.GridQubit(0, 0)},
name='test',
)

dev.validate_operation(cirq.X(cirq.GridQubit(0, 0)))

with pytest.raises(ValueError, match="Unsupported qubit type"):
dev.validate_operation(cirq.X(cirq.NamedQubit("a")))

with pytest.raises(ValueError, match="Qubit not on device"):
dev.validate_operation(cirq.X(cirq.GridQubit(1, 0)))

with pytest.raises(ValueError, match="Unsupported gate type"):
dev.validate_operation(cirq.Y(cirq.GridQubit(0, 0)))


def test_validating_locality():
dev = ValidatingTestDevice(
allowed_qubit_types=(cirq.GridQubit,),
allowed_gates=(cirq.CZPowGate, cirq.MeasurementGate),
qubits=set(cirq.GridQubit.rect(3, 3)),
name='test',
validate_locality=True,
)

dev.validate_operation(cirq.CZ(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)))
dev.validate_operation(cirq.measure(cirq.GridQubit(0, 0), cirq.GridQubit(0, 2)))

with pytest.raises(ValueError, match="Non-local interaction"):
dev.validate_operation(cirq.CZ(cirq.GridQubit(0, 0), cirq.GridQubit(0, 2)))

with pytest.raises(ValueError, match="GridQubit must be an allowed qubit type"):
ValidatingTestDevice(
allowed_qubit_types=(cirq.NamedQubit,),
allowed_gates=(cirq.CZPowGate, cirq.MeasurementGate),
qubits=set(cirq.GridQubit.rect(3, 3)),
name='test',
validate_locality=True,
)


def test_autodecompose():
dev = ValidatingTestDevice(
allowed_qubit_types=(cirq.LineQubit,),
allowed_gates=(
cirq.XPowGate,
cirq.ZPowGate,
cirq.CZPowGate,
cirq.YPowGate,
cirq.MeasurementGate,
),
qubits=set(cirq.LineQubit.range(3)),
name='test',
validate_locality=False,
auto_decompose_gates=(cirq.CCXPowGate,),
)

a, b, c = cirq.LineQubit.range(3)
circuit = cirq.Circuit(cirq.CCX(a, b, c), device=dev)
decomposed = cirq.decompose(cirq.CCX(a, b, c))
assert circuit.moments == cirq.Circuit(decomposed).moments

with pytest.raises(ValueError, match="Unsupported gate type: cirq.TOFFOLI"):
dev = ValidatingTestDevice(
allowed_qubit_types=(cirq.LineQubit,),
allowed_gates=(
cirq.XPowGate,
cirq.ZPowGate,
cirq.CZPowGate,
cirq.YPowGate,
cirq.MeasurementGate,
),
qubits=set(cirq.LineQubit.range(3)),
name='test',
validate_locality=False,
auto_decompose_gates=tuple(),
)

a, b, c = cirq.LineQubit.range(3)
cirq.Circuit(cirq.CCX(a, b, c), device=dev)


def test_repr():
dev = ValidatingTestDevice(
allowed_qubit_types=(cirq.GridQubit,),
allowed_gates=(cirq.CZPowGate, cirq.MeasurementGate),
qubits=set(cirq.GridQubit.rect(3, 3)),
name='test',
validate_locality=True,
)
assert repr(dev) == 'test'


def test_defaults():
dev = ValidatingTestDevice(qubits={cirq.GridQubit(0, 0)})
assert repr(dev) == 'ValidatingTestDevice'
assert dev.allowed_qubit_types == (cirq.GridQubit,)
assert not dev.validate_locality
assert not dev.auto_decompose_gates

0 comments on commit 2128257

Please sign in to comment.