Skip to content

Commit

Permalink
Add commutes protocol (#1853)
Browse files Browse the repository at this point in the history
Minimal version so far. Wanted to get some feedback.

Some thoughts:
- A lot of issues seemed to be caused by `values_equality` being the name of both a module and a method, so we should probably change at least one of those for `commutes`.
- Maybe `_commutes_with_` for the magic method? (I just went with convention in the initial issue for now.)
- I changed `cirq.commutes` to be the protocol, which delegates to `cirq.linalg.commutes` in the relevant cases. Shouldn't be a problem, but may cause confusion.
- Some classes already define a `commute_with` method. Should I add some underscores to that or add a new method that delegates to the existing one?
- If passed a `Gate` and `GateOperation` whose gate commutes with the `Gate`, what should be returned?

Fixes #1125
  • Loading branch information
bryano authored and CirqBot committed Dec 17, 2019
1 parent 1017bae commit ed6188b
Show file tree
Hide file tree
Showing 29 changed files with 469 additions and 137 deletions.
3 changes: 2 additions & 1 deletion cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
bidiagonalize_real_matrix_pair_with_symmetric_products,
bidiagonalize_unitary_with_special_orthogonals,
block_diag,
commutes,
CONTROL_TAG,
diagonalize_real_symmetric_and_sorted_diagonal_matrices,
diagonalize_real_symmetric_matrix,
Expand Down Expand Up @@ -394,9 +393,11 @@
circuit_diagram_info,
CircuitDiagramInfo,
CircuitDiagramInfoArgs,
commutes,
decompose,
decompose_once,
decompose_once_with_qubits,
definitely_commutes,
equal_up_to_global_phase,
has_channel,
has_mixture,
Expand Down
13 changes: 11 additions & 2 deletions cirq/contrib/acquaintance/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.

import abc
from typing import (cast, Dict, Iterable, Sequence, Tuple, TYPE_CHECKING,
TypeVar, Union)
from typing import (Any, cast, Dict, Iterable, Sequence, Tuple, TYPE_CHECKING,
TypeVar, Union, TYPE_CHECKING)

from cirq import circuits, ops, optimizers, protocols, value
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -150,6 +151,14 @@ def __repr__(self):
def _value_equality_values_(self):
return (self.swap_gate,)

def _commutes_(self, other: Any, atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType]:
if (isinstance(other, ops.Gate) and
isinstance(other, ops.InterchangeableQubitsGate) and
protocols.num_qubits(other) == 2):
return True
return NotImplemented


def _canonicalize_permutation(permutation: Dict[int, int]) -> Dict[int, int]:
return {i: j for i, j in permutation.items() if i != j}
Expand Down
4 changes: 4 additions & 0 deletions cirq/contrib/acquaintance/permutation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def test_swap_permutation_gate():
expander(circuit)
assert tuple(circuit.all_operations()) == (cirq.CZ(a, b),)

assert cirq.commutes(gate, cirq.ZZ)
with pytest.raises(TypeError):
cirq.commutes(gate, cirq.CCZ)


def test_validate_permutation_errors():
validate_permutation = cca.PermutationGate.validate_permutation
Expand Down
4 changes: 2 additions & 2 deletions cirq/contrib/paulistring/pauli_string_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

from typing import cast

from cirq import ops, circuits
from cirq import circuits, ops, protocols


def pauli_string_reorder_pred(op1: ops.Operation,
op2: ops.Operation) -> bool:
ps1 = cast(ops.PauliStringGateOperation, op1).pauli_string
ps2 = cast(ops.PauliStringGateOperation, op2).pauli_string
return ps1.commutes_with(ps2)
return protocols.commutes(ps1, ps2)


def pauli_string_dag_from_circuit(circuit: circuits.Circuit
Expand Down
5 changes: 3 additions & 2 deletions cirq/contrib/paulistring/recombine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import (Any, Callable, Iterable, Sequence, Tuple, Union, cast, List)

from cirq import ops, circuits
from cirq import circuits, ops, protocols

from cirq.contrib.paulistring.pauli_string_dag import (
pauli_string_reorder_pred,
Expand All @@ -41,7 +41,8 @@ def _sorted_best_string_placements(
# Skip if operations don't share qubits
continue
if (isinstance(out_op, ops.PauliStringPhasor) and
out_op.pauli_string.commutes_with(string_op.pauli_string)):
protocols.commutes(out_op.pauli_string,
string_op.pauli_string)):
# Pass through another Pauli string if they commute
continue
if not (isinstance(out_op, ops.GateOperation) and
Expand Down
1 change: 1 addition & 0 deletions cirq/contrib/quantum_volume/quantum_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def prepare_circuits(
circuits.append((model_circuit, heavy_set))
return circuits


def execute_circuits(
*,
device: cirq.google.XmonDevice,
Expand Down
2 changes: 0 additions & 2 deletions cirq/google/common_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def _near_mod_2(e, t, atol=1e-8):
],
)


#
# Measurement Serializer and Deserializer
#
Expand Down Expand Up @@ -337,7 +336,6 @@ def _near_mod_2(e, t, atol=1e-8):
lambda e: _near_mod_2pi(cast(ops.FSimGate, e).theta, np.pi / 2) and
_near_mod_2pi(cast(ops.FSimGate, e).phi, np.pi / 6)))


SYC_DESERIALIZER = op_deserializer.GateOpDeserializer(
serialized_gate_id='syc',
gate_constructor=lambda: ops.FSimGate(theta=np.pi / 2, phi=np.pi / 6),
Expand Down
1 change: 0 additions & 1 deletion cirq/google/devices/known_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def _json_dict_(self):
proto=SYCAMORE_PROTO,
gate_sets=[gate_sets.SQRT_ISWAP_GATESET, gate_sets.SYC_GATESET])


# Subset of the Sycamore grid with a reduced layout.
_SYCAMORE23_GRID = """
----------
Expand Down
46 changes: 18 additions & 28 deletions cirq/linalg/predicates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,40 +294,30 @@ def test_is_special_unitary_tolerance():


def test_commutes():
assert cirq.commutes(
np.empty((0, 0)),
np.empty((0, 0)))
assert not cirq.commutes(
np.empty((1, 0)),
np.empty((0, 1)))
assert not cirq.commutes(
np.empty((0, 1)),
np.empty((1, 0)))
assert not cirq.commutes(
np.empty((1, 0)),
np.empty((1, 0)))
assert not cirq.commutes(
np.empty((0, 1)),
np.empty((0, 1)))

assert cirq.commutes(np.array([[1]]), np.array([[2]]))
assert cirq.commutes(np.array([[1]]), np.array([[0]]))
assert cirq.linalg.commutes(np.empty((0, 0)), np.empty((0, 0)))
assert not cirq.linalg.commutes(np.empty((1, 0)), np.empty((0, 1)))
assert not cirq.linalg.commutes(np.empty((0, 1)), np.empty((1, 0)))
assert not cirq.linalg.commutes(np.empty((1, 0)), np.empty((1, 0)))
assert not cirq.linalg.commutes(np.empty((0, 1)), np.empty((0, 1)))

assert cirq.linalg.commutes(np.array([[1]]), np.array([[2]]))
assert cirq.linalg.commutes(np.array([[1]]), np.array([[0]]))

x = np.array([[0, 1], [1, 0]])
y = np.array([[0, -1j], [1j, 0]])
z = np.array([[1, 0], [0, -1]])
xx = np.kron(x, x)
zz = np.kron(z, z)

assert cirq.commutes(x, x)
assert cirq.commutes(y, y)
assert cirq.commutes(z, z)
assert not cirq.commutes(x, y)
assert not cirq.commutes(x, z)
assert not cirq.commutes(y, z)
assert cirq.linalg.commutes(x, x)
assert cirq.linalg.commutes(y, y)
assert cirq.linalg.commutes(z, z)
assert not cirq.linalg.commutes(x, y)
assert not cirq.linalg.commutes(x, z)
assert not cirq.linalg.commutes(y, z)

assert cirq.commutes(xx, zz)
assert cirq.commutes(xx, np.diag([1, -1, -1, 1 + 1e-9]))
assert cirq.linalg.commutes(xx, zz)
assert cirq.linalg.commutes(xx, np.diag([1, -1, -1, 1 + 1e-9]))


def test_commutes_tolerance():
Expand All @@ -337,8 +327,8 @@ def test_commutes_tolerance():
z = np.array([[1, 0], [0, -1]])

# Pays attention to specified tolerance.
assert cirq.commutes(x, x + z * 0.1, atol=atol)
assert not cirq.commutes(x, x + z * 0.5, atol=atol)
assert cirq.linalg.commutes(x, x + z * 0.1, atol=atol)
assert not cirq.linalg.commutes(x, x + z * 0.5, atol=atol)


def test_allclose_up_to_global_phase():
Expand Down
26 changes: 15 additions & 11 deletions cirq/ops/clifford_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, NamedTuple, Optional, Sequence, Tuple, Union, cast, \
TYPE_CHECKING
from typing import (Any, cast, Dict, NamedTuple, Optional, Sequence, Tuple,
TYPE_CHECKING, Union)

import numpy as np

from cirq import protocols, value
from cirq._compat import deprecated
from cirq._doc import document
from cirq.ops import common_gates, gate_features, named_qubit, pauli_gates
from cirq.ops import (common_gates, gate_features, named_qubit, pauli_gates,
raw_types)
from cirq.ops.pauli_gates import Pauli
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -221,15 +224,16 @@ def __pow__(self, exponent) -> 'SingleQubitCliffordGate':
return SingleQubitCliffordGate(_rotation_map=self._inverse_map,
_inverse_map=self._rotation_map)

def commutes_with(self,
gate_or_pauli: Union['SingleQubitCliffordGate', Pauli]
) -> bool:
if isinstance(gate_or_pauli, SingleQubitCliffordGate):
gate = gate_or_pauli
return self.commutes_with_single_qubit_gate(gate)
def _commutes_(self, other: Any, *, atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType]:
if isinstance(other, SingleQubitCliffordGate):
return self.commutes_with_single_qubit_gate(other)
if isinstance(other, Pauli):
return self.commutes_with_pauli(other)
return NotImplemented

pauli = gate_or_pauli
return self.commutes_with_pauli(pauli)
commutes_with = deprecated(deadline='v0.7.0',
fix='Use `cirq.commutes()` instead.')(_commutes_)

def commutes_with_single_qubit_gate(self,
gate: 'SingleQubitCliffordGate') \
Expand Down
16 changes: 12 additions & 4 deletions cirq/ops/clifford_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,18 @@ def test_inverse_matrix(gate):
rtol=1e-7, atol=1e-7)


def test_commutes_notimplemented_type():
with pytest.raises(TypeError):
cirq.commutes(cirq.SingleQubitCliffordGate.X, 'X')
assert (cirq.commutes(cirq.SingleQubitCliffordGate.X,
'X',
default='default') == 'default')


@pytest.mark.parametrize('gate,other',
itertools.product(_all_clifford_gates(),
_all_clifford_gates()))
def test_commutes_with_single_qubit_gate(gate, other):
def test_commutes_single_qubit_gate(gate, other):
q0 = cirq.NamedQubit('q0')
mat = cirq.Circuit(
gate(q0),
Expand All @@ -410,7 +418,7 @@ def test_commutes_with_single_qubit_gate(gate, other):
other(q0),
gate(q0),
).unitary()
commutes = gate.commutes_with(other)
commutes = cirq.commutes(gate, other)
commutes_check = cirq.allclose_up_to_global_phase(mat, mat_swap)
assert commutes == commutes_check

Expand All @@ -419,7 +427,7 @@ def test_commutes_with_single_qubit_gate(gate, other):
itertools.product(_all_clifford_gates(),
_paulis,
(0.1, 0.25, 0.5, -0.5)))
def test_commutes_with_pauli(gate, pauli, half_turns):
def test_commutes_pauli(gate, pauli, half_turns):
pauli_gate = pauli ** half_turns
q0 = cirq.NamedQubit('q0')
mat = cirq.Circuit(
Expand All @@ -430,7 +438,7 @@ def test_commutes_with_pauli(gate, pauli, half_turns):
pauli_gate(q0),
gate(q0),
).unitary()
commutes = gate.commutes_with(pauli)
commutes = cirq.commutes(gate, pauli)
commutes_check = cirq.allclose_up_to_global_phase(mat, mat_swap)
assert commutes == commutes_check

Expand Down
15 changes: 14 additions & 1 deletion cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from cirq import protocols, value
from cirq._compat import deprecated, proper_repr
from cirq._doc import document
from cirq.ops import controlled_gate, gate_features, eigen_gate, raw_types
from cirq.ops import (controlled_gate, eigen_gate, gate_features,
gate_operation, raw_types)

from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -503,6 +504,18 @@ def __repr__(self) -> str:
'global_shift={!r})'
).format(proper_repr(self._exponent), self._global_shift)

def _commutes_on_qids_(self,
qids: 'Sequence[cirq.Qid]',
other: Any,
*,
atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType, None]:
if not isinstance(other, gate_operation.GateOperation):
return None
if isinstance(other.gate, (ZPowGate, CZPowGate)):
return True
return super()._commutes_on_qids_(qids, other, atol=atol)


class HPowGate(eigen_gate.EigenGate, gate_features.SingleQubitGate):
"""A Gate that performs a rotation around the X+Z axis of the Bloch sphere.
Expand Down
7 changes: 7 additions & 0 deletions cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,10 @@ def test_trace_distance():
assert cirq.approx_eq(cirq.trace_distance_bound(cirq.CX**2), 0.0)
assert cirq.approx_eq(cirq.trace_distance_bound(cirq.CZ**(1 / 9)),
np.sin(np.pi / 18))


def test_commutes():
assert cirq.commutes(cirq.ZPowGate(exponent=sympy.Symbol('t')), cirq.Z)
assert cirq.commutes(cirq.Z, cirq.Z(cirq.LineQubit(0)),
default=None) is None
assert cirq.commutes(cirq.Z**0.1, cirq.XPowGate(exponent=0))
2 changes: 1 addition & 1 deletion cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __repr__(self):
return (f'cirq.{type(self).__name__}({repr(paulis)}, '
f'coefficient={proper_repr(self.coefficient)})')

def _commutes_(self, other):
def _commutes_(self, other, *, atol: Union[int, float] = 1e-8):
if isinstance(other, BaseDensePauliString):
n = min(len(self.pauli_mask), len(other.pauli_mask))
phase = _vectorized_pauli_mul_phase(self.pauli_mask[:n],
Expand Down
27 changes: 12 additions & 15 deletions cirq/ops/dense_pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,21 +505,18 @@ def test_tensor_product():
def test_commutes():
f = cirq.DensePauliString
m = cirq.MutableDensePauliString
# TODO(craiggidney,bryano): use commutes protocol instead
commutes = lambda a, b, default=None: cirq.BaseDensePauliString._commutes_(
a, b)
assert cirq.commutes is cirq.linalg.commutes

assert commutes(f('XX'), m('ZZ'))
assert commutes(2 * f('XX'), m('ZZ', coefficient=3))
assert commutes(2 * f('IX'), 3 * f('IX'))
assert not commutes(f('IX'), f('IZ'))
assert commutes(f('IIIXII'), cirq.X(cirq.LineQubit(3)))
assert commutes(f('IIIXII'), cirq.X(cirq.LineQubit(2)))
assert not commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(3)))
assert commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(2)))

assert commutes(f('XX'), "test", default=NotImplemented) is NotImplemented

assert cirq.commutes(f('XX'), m('ZZ'))
assert cirq.commutes(2 * f('XX'), m('ZZ', coefficient=3))
assert cirq.commutes(2 * f('IX'), 3 * f('IX'))
assert not cirq.commutes(f('IX'), f('IZ'))
assert cirq.commutes(f('IIIXII'), cirq.X(cirq.LineQubit(3)))
assert cirq.commutes(f('IIIXII'), cirq.X(cirq.LineQubit(2)))
assert not cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(3)))
assert cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(2)))

assert cirq.commutes(f('XX'), "test",
default=NotImplemented) is NotImplemented


def test_copy():
Expand Down
4 changes: 4 additions & 0 deletions cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def _has_unitary_(self) -> bool:
def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
return protocols.unitary(self.gate, default=None)

def _commutes_(self, other: Any, *, atol: Union[int, float] = 1e-8
) -> Union[bool, NotImplementedType, None]:
return self.gate._commutes_on_qids_(self.qubits, other, atol=atol)

def _has_mixture_(self) -> bool:
return protocols.has_mixture(self.gate)

Expand Down

0 comments on commit ed6188b

Please sign in to comment.