Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Instance GateFamily check for equality ignoring global phase #4542

Merged
merged 8 commits into from Oct 13, 2021
43 changes: 29 additions & 14 deletions cirq-core/cirq/ops/gateset.py
Expand Up @@ -14,7 +14,7 @@

"""Functionality for grouping and validating Cirq Gates"""

from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Dict, FrozenSet, Optional, Type, TYPE_CHECKING, Union
from cirq.ops import global_phase_op, op_tree, raw_types
from cirq import protocols, value

Expand All @@ -31,18 +31,20 @@ class GateFamily:
b) Python types inheriting from `cirq.Gate` (Type Family).

By default, the containment checks depend on the initialization type:
a) Instance Family: Containment check is done by object equality.
a) Instance Family: Containment check is done via `cirq.equal_up_to_global_phase`.
b) Type Family: Containment check is done by type comparison.

For example:
a) Instance Family:
>>> gate_family = cirq.GateFamily(cirq.X)
>>> assert cirq.X in gate_family
>>> assert cirq.Rx(rads=np.pi) in gate_family
>>> assert cirq.X ** sympy.Symbol("theta") not in gate_family

b) Type Family:
>>> gate_family = cirq.GateFamily(cirq.XPowGate)
>>> assert cirq.X in gate_family
>>> assert cirq.Rx(rads=np.pi) in gate_family
>>> assert cirq.X ** sympy.Symbol("theta") in gate_family

In order to create gate families with constraints on parameters of a gate
Expand All @@ -56,6 +58,7 @@ def __init__(
*,
name: Optional[str] = None,
description: Optional[str] = None,
ignore_global_phase: bool = True,
) -> None:
"""Init GateFamily.

Expand All @@ -64,6 +67,8 @@ def __init__(
a non-parameterized instance of a `cirq.Gate` for equality based membership checks.
name: The name of the gate family.
description: Human readable description of the gate family.
ignore_global_phase: If True, value equality is checked via
`cirq.equal_up_to_global_phase`.

Raises:
ValueError: if `gate` is not a `cirq.Gate` instance or subclass.
Expand All @@ -80,6 +85,7 @@ def __init__(
self._gate = gate
self._name = name if name else self._default_name()
self._description = description if description else self._default_description()
self._ignore_global_phase = ignore_global_phase

def _gate_str(self, gettr: Callable[[Any], str] = str) -> str:
return (
Expand Down Expand Up @@ -112,17 +118,22 @@ def _predicate(self, gate: raw_types.Gate) -> bool:
"""Checks whether `cirq.Gate` instance `gate` belongs to this GateFamily.

The default predicate depends on the gate family initialization type:
a) Instance Family: `gate == self.gate`.
a) Instance Family: `cirq.equal_up_to_global_phase(gate, self.gate)`
if self._ignore_global_phase else `gate == self.gate`.
b) Type Family: `isinstance(gate, self.gate)`.

Args:
gate: `cirq.Gate` instance which should be checked for containment.
"""
return (
gate == self.gate
if isinstance(self.gate, raw_types.Gate)
else isinstance(gate, self.gate)
)
if isinstance(self.gate, raw_types.Gate):
return (
protocols.equal_up_to_global_phase(gate, self.gate)
if self._ignore_global_phase
else gate == self._gate
)

else:
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
return isinstance(gate, self.gate)
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved

def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
if isinstance(item, raw_types.Operation):
Expand All @@ -138,12 +149,19 @@ def __repr__(self) -> str:
return (
f'cirq.GateFamily(gate={self._gate_str(repr)},'
f'name="{self.name}", '
f'description="{self.description}")'
f'description="{self.description}",'
f'ignore_global_phase={self._ignore_global_phase})'
)

def _value_equality_values_(self) -> Any:
# `isinstance` is used to ensure the a gate type and gate instance is not compared.
return isinstance(self.gate, raw_types.Gate), self.gate, self.name, self.description
return (
isinstance(self.gate, raw_types.Gate),
self.gate,
self.name,
self.description,
self._ignore_global_phase,
)


@value.value_equality()
Expand Down Expand Up @@ -192,15 +210,12 @@ def __init__(
self._accept_global_phase = accept_global_phase
self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {}
self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {}
self._custom_gate_families: List[GateFamily] = []
for g in self._gates:
if type(g) == GateFamily:
if isinstance(g.gate, raw_types.Gate):
self._instance_gate_families[g.gate] = g
else:
self._type_gate_families[g.gate] = g
else:
self._custom_gate_families.append(g)

@property
def name(self) -> Optional[str]:
Expand Down Expand Up @@ -295,7 +310,7 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
)
return True

return any(item in gate_family for gate_family in self._custom_gate_families)
return any(item in gate_family for gate_family in self._gates)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment that the above checks will still catch GateFamilys that are typed based and GateFamilys that have ignore_global_phase=False ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanded the docstring comment to explain the new code path.


def validate(
self,
Expand Down
10 changes: 9 additions & 1 deletion cirq-core/cirq/ops/gateset_test.py
Expand Up @@ -131,13 +131,21 @@ def test_gate_family_eq():
cirq.GateFamily(CustomX),
[
(CustomX, True),
(CustomXPowGate(exponent=1, global_shift=0.15), True),
(CustomX ** 2, False),
(CustomX ** 3, True),
(CustomX ** sympy.Symbol('theta'), False),
(None, False),
(cirq.GlobalPhaseOperation(1j), False),
],
),
(
cirq.GateFamily(CustomX, ignore_global_phase=False),
[
(CustomX, True),
(CustomXPowGate(exponent=1, global_shift=0.15), False),
],
),
],
)
def test_gate_family_predicate_and_containment(gate_family, gates_to_check):
Expand Down Expand Up @@ -201,7 +209,7 @@ def test_gateset_repr_and_str():
(CustomX ** 2, True),
(CustomXPowGate(exponent=3, global_shift=0.5), True),
(CustomX ** 0.5, True),
(CustomXPowGate(exponent=0.5, global_shift=0.5), False),
(CustomXPowGate(exponent=0.5, global_shift=0.5), True),
(CustomX ** 0.25, False),
(CustomX ** sympy.Symbol('theta'), False),
(cirq.testing.TwoQubitGate(), True),
Expand Down
Expand Up @@ -133,6 +133,7 @@ def test_allow_partial_czs():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.CZ(q0, q1) ** 0.5,
cirq.CZPowGate(exponent=0.5, global_shift=-0.5).on(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates(allow_partial_czs=True).optimize_circuit(circuit)
Expand All @@ -155,6 +156,14 @@ def test_allow_partial_czs():

def test_dont_allow_partial_czs():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.CZ(q0, q1),
cirq.CZPowGate(exponent=1, global_shift=-0.5).on(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
assert circuit == c_orig

circuit = cirq.Circuit(
cirq.CZ(q0, q1) ** 0.5,
)
Expand Down
13 changes: 11 additions & 2 deletions cirq-core/cirq/optimizers/merge_interactions_test.py
Expand Up @@ -173,6 +173,15 @@ def test_optimizes_tagged_partial_cz():
), 'It should take 2 CZ gates to decompose a CZ**0.5 gate'


def test_not_decompose_czs():
circuit = cirq.Circuit(
cirq.CZPowGate(exponent=1, global_shift=-0.5).on(*cirq.LineQubit.range(2))
)
circ_orig = circuit.copy()
cirq.MergeInteractions(allow_partial_czs=False).optimize_circuit(circuit)
assert circ_orig == circuit


@pytest.mark.parametrize(
'circuit',
(
Expand All @@ -181,7 +190,7 @@ def test_optimizes_tagged_partial_cz():
),
cirq.Circuit(
cirq.CZPowGate(exponent=0.2)(*cirq.LineQubit.range(2)),
cirq.CZPowGate(exponent=0.3)(*cirq.LineQubit.range(2)),
cirq.CZPowGate(exponent=0.3, global_shift=-0.5)(*cirq.LineQubit.range(2)),
),
),
)
Expand All @@ -202,7 +211,7 @@ def test_decompose_partial_czs(circuit):

def test_not_decompose_partial_czs():
circuit = cirq.Circuit(
cirq.CZPowGate(exponent=0.1)(*cirq.LineQubit.range(2)),
cirq.CZPowGate(exponent=0.1, global_shift=-0.5)(*cirq.LineQubit.range(2)),
)

optimizer = cirq.MergeInteractions(allow_partial_czs=True)
Expand Down
37 changes: 14 additions & 23 deletions cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py
Expand Up @@ -155,35 +155,26 @@ def test_works_with_tags():

def test_no_touch_single_sqrt_iswap():
a, b = cirq.LineQubit.range(2)
assert_optimizes(
before=cirq.Circuit(
[
cirq.Moment([cirq.SQRT_ISWAP(a, b).with_tags('mytag')]),
]
),
expected=cirq.Circuit(
[
cirq.Moment([cirq.SQRT_ISWAP(a, b).with_tags('mytag')]),
]
),
circuit = cirq.Circuit(
[
cirq.Moment(
[cirq.ISwapPowGate(exponent=0.5, global_shift=-0.5).on(a, b).with_tags('mytag')]
),
]
)
assert_optimizes(before=circuit, expected=circuit)


def test_no_touch_single_sqrt_iswap_inv():
a, b = cirq.LineQubit.range(2)
assert_optimizes(
use_sqrt_iswap_inv=True,
before=cirq.Circuit(
[
cirq.Moment([cirq.SQRT_ISWAP_INV(a, b).with_tags('mytag')]),
]
),
expected=cirq.Circuit(
[
cirq.Moment([cirq.SQRT_ISWAP_INV(a, b).with_tags('mytag')]),
]
),
circuit = cirq.Circuit(
[
cirq.Moment(
[cirq.ISwapPowGate(exponent=-0.5, global_shift=-0.5).on(a, b).with_tags('mytag')]
),
]
)
assert_optimizes(before=circuit, expected=circuit, use_sqrt_iswap_inv=True)


def test_cnots_separated_by_single_gates_correct():
Expand Down
3 changes: 3 additions & 0 deletions cirq-ionq/cirq_ionq/ionq_devices_test.py
Expand Up @@ -28,16 +28,19 @@
cirq.ry(0.1),
cirq.rz(0.1),
cirq.H,
cirq.HPowGate(exponent=1, global_shift=-0.5),
cirq.T,
cirq.S,
cirq.CNOT,
cirq.CXPowGate(exponent=1, global_shift=-0.5),
cirq.XX,
cirq.YY,
cirq.ZZ,
cirq.XX ** 0.5,
cirq.YY ** 0.5,
cirq.ZZ ** 0.5,
cirq.SWAP,
cirq.SwapPowGate(exponent=1, global_shift=-0.5),
cirq.MeasurementGate(num_qubits=1, key='a'),
cirq.MeasurementGate(num_qubits=2, key='b'),
cirq.MeasurementGate(num_qubits=10, key='c'),
Expand Down
4 changes: 2 additions & 2 deletions cirq-pasqal/cirq_pasqal/pasqal_device_test.py
Expand Up @@ -112,8 +112,8 @@ def test_is_pasqal_device_op():
cirq.ops.CCX(cirq.NamedQubit('q0'), cirq.NamedQubit('q1'), cirq.NamedQubit('q2')) ** 0.2
)
assert not d.is_pasqal_device_op(bad_op(cirq.NamedQubit('q0'), cirq.NamedQubit('q1')))
op1 = cirq.ops.CNotPowGate(exponent=1.0)
assert d.is_pasqal_device_op(op1(cirq.NamedQubit('q0'), cirq.NamedQubit('q1')))
for op1 in [cirq.CNotPowGate(exponent=1.0), cirq.CNotPowGate(exponent=1.0, global_shift=-0.5)]:
assert d.is_pasqal_device_op(op1(cirq.NamedQubit('q0'), cirq.NamedQubit('q1')))

op2 = (cirq.ops.H ** sympy.Symbol('exp')).on(d.qubit_list()[0])
assert not d.is_pasqal_device_op(op2)
Expand Down