Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Instance GateFamily check for equality ignoring global phase #4542

Merged
merged 8 commits into from
Oct 13, 2021
2 changes: 1 addition & 1 deletion cirq-core/cirq/ion/ion_device.py
Expand Up @@ -30,7 +30,7 @@ def get_ion_gateset() -> ops.Gateset:
ops.ZPowGate,
ops.PhasedXPowGate,
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)


Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices.py
Expand Up @@ -44,7 +44,7 @@ def neutral_atom_gateset(max_parallel_z=None, max_parallel_xy=None):
ops.MeasurementGate,
ops.IdentityGate,
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)


Expand Down Expand Up @@ -100,15 +100,15 @@ def __init__(
ops.ParallelGateFamily(ops.YPowGate),
ops.ParallelGateFamily(ops.PhasedXPowGate),
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)
self.controlled_gateset = ops.Gateset(
ops.AnyIntegerPowerGateFamily(ops.CNotPowGate),
ops.AnyIntegerPowerGateFamily(ops.CCNotPowGate),
ops.AnyIntegerPowerGateFamily(ops.CZPowGate),
ops.AnyIntegerPowerGateFamily(ops.CCZPowGate),
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)
self.gateset = neutral_atom_gateset(max_parallel_z, max_parallel_xy)
for q in qubits:
Expand Down
72 changes: 44 additions & 28 deletions cirq-core/cirq/ops/gateset.py
Expand Up @@ -14,7 +14,7 @@

"""Functionality for grouping and validating Cirq Gates"""

from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Dict, FrozenSet, Optional, Type, TYPE_CHECKING, Union
from cirq.ops import global_phase_op, op_tree, raw_types
from cirq import protocols, value

Expand All @@ -31,18 +31,20 @@ class GateFamily:
b) Python types inheriting from `cirq.Gate` (Type Family).

By default, the containment checks depend on the initialization type:
a) Instance Family: Containment check is done by object equality.
a) Instance Family: Containment check is done via `cirq.equal_up_to_global_phase`.
b) Type Family: Containment check is done by type comparison.

For example:
a) Instance Family:
>>> gate_family = cirq.GateFamily(cirq.X)
>>> assert cirq.X in gate_family
>>> assert cirq.Rx(rads=np.pi) in gate_family
>>> assert cirq.X ** sympy.Symbol("theta") not in gate_family

b) Type Family:
>>> gate_family = cirq.GateFamily(cirq.XPowGate)
>>> assert cirq.X in gate_family
>>> assert cirq.Rx(rads=np.pi) in gate_family
>>> assert cirq.X ** sympy.Symbol("theta") in gate_family

In order to create gate families with constraints on parameters of a gate
Expand All @@ -56,6 +58,7 @@ def __init__(
*,
name: Optional[str] = None,
description: Optional[str] = None,
ignore_global_phase: bool = True,
) -> None:
"""Init GateFamily.

Expand All @@ -64,6 +67,8 @@ def __init__(
a non-parameterized instance of a `cirq.Gate` for equality based membership checks.
name: The name of the gate family.
description: Human readable description of the gate family.
ignore_global_phase: If True, value equality is checked via
`cirq.equal_up_to_global_phase`.

Raises:
ValueError: if `gate` is not a `cirq.Gate` instance or subclass.
Expand All @@ -80,6 +85,7 @@ def __init__(
self._gate = gate
self._name = name if name else self._default_name()
self._description = description if description else self._default_description()
self._ignore_global_phase = ignore_global_phase

def _gate_str(self, gettr: Callable[[Any], str] = str) -> str:
return (
Expand Down Expand Up @@ -112,17 +118,20 @@ def _predicate(self, gate: raw_types.Gate) -> bool:
"""Checks whether `cirq.Gate` instance `gate` belongs to this GateFamily.

The default predicate depends on the gate family initialization type:
a) Instance Family: `gate == self.gate`.
a) Instance Family: `cirq.equal_up_to_global_phase(gate, self.gate)`
if self._ignore_global_phase else `gate == self.gate`.
b) Type Family: `isinstance(gate, self.gate)`.

Args:
gate: `cirq.Gate` instance which should be checked for containment.
"""
return (
gate == self.gate
if isinstance(self.gate, raw_types.Gate)
else isinstance(gate, self.gate)
)
if isinstance(self.gate, raw_types.Gate):
return (
protocols.equal_up_to_global_phase(gate, self.gate)
if self._ignore_global_phase
else gate == self._gate
)
return isinstance(gate, self.gate)

def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
if isinstance(item, raw_types.Operation):
Expand All @@ -138,12 +147,19 @@ def __repr__(self) -> str:
return (
f'cirq.GateFamily(gate={self._gate_str(repr)},'
f'name="{self.name}", '
f'description="{self.description}")'
f'description="{self.description}",'
f'ignore_global_phase={self._ignore_global_phase})'
)

def _value_equality_values_(self) -> Any:
# `isinstance` is used to ensure the a gate type and gate instance is not compared.
return isinstance(self.gate, raw_types.Gate), self.gate, self.name, self.description
return (
isinstance(self.gate, raw_types.Gate),
self.gate,
self.name,
self.description,
self._ignore_global_phase,
)


@value.value_equality()
Expand All @@ -163,7 +179,7 @@ def __init__(
*gates: Union[Type[raw_types.Gate], raw_types.Gate, GateFamily],
name: Optional[str] = None,
unroll_circuit_op: bool = True,
accept_global_phase: bool = True,
accept_global_phase_op: bool = True,
) -> None:
"""Init Gateset.

Expand All @@ -182,25 +198,22 @@ def __init__(
name: (Optional) Name for the Gateset. Useful for description.
unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively
validated by validating the underlying `cirq.Circuit`.
accept_global_phase: If True, `cirq.GlobalPhaseOperation` is accepted.
accept_global_phase_op: If True, `cirq.GlobalPhaseOperation` is accepted.
"""
self._name = name
self._gates = frozenset(
g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates
)
self._unroll_circuit_op = unroll_circuit_op
self._accept_global_phase = accept_global_phase
self._accept_global_phase_op = accept_global_phase_op
self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {}
self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {}
self._custom_gate_families: List[GateFamily] = []
for g in self._gates:
if type(g) == GateFamily:
if isinstance(g.gate, raw_types.Gate):
self._instance_gate_families[g.gate] = g
else:
self._type_gate_families[g.gate] = g
else:
self._custom_gate_families.append(g)

@property
def name(self) -> Optional[str]:
Expand All @@ -215,7 +228,7 @@ def with_params(
*,
name: Optional[str] = None,
unroll_circuit_op: Optional[bool] = None,
accept_global_phase: Optional[bool] = None,
accept_global_phase_op: Optional[bool] = None,
) -> 'Gateset':
"""Returns a copy of this Gateset with identical gates and new values for named arguments.

Expand All @@ -225,7 +238,7 @@ def with_params(
name: New name for the Gateset.
unroll_circuit_op: If True, new Gateset will recursively validate
`cirq.CircuitOperation` by validating the underlying `cirq.Circuit`.
accept_global_phase: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.
accept_global_phase_op: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.

Returns:
`self` if all new values are None or identical to the values of current Gateset.
Expand All @@ -237,26 +250,26 @@ def val_if_none(var: Any, val: Any) -> Any:

name = val_if_none(name, self._name)
unroll_circuit_op = val_if_none(unroll_circuit_op, self._unroll_circuit_op)
accept_global_phase = val_if_none(accept_global_phase, self._accept_global_phase)
accept_global_phase_op = val_if_none(accept_global_phase_op, self._accept_global_phase_op)
if (
name == self._name
and unroll_circuit_op == self._unroll_circuit_op
and accept_global_phase == self._accept_global_phase
and accept_global_phase_op == self._accept_global_phase_op
):
return self
return Gateset(
*self.gates,
name=name,
unroll_circuit_op=cast(bool, unroll_circuit_op),
accept_global_phase=cast(bool, accept_global_phase),
accept_global_phase_op=cast(bool, accept_global_phase_op),
)

def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
"""Check for containment of a given Gate/Operation in this Gateset.

Containment checks are handled as follows:
a) For Gates or Operations that have an underlying gate (i.e. op.gate is not None):
- Forwards the containment check to the underlying GateFamily's
- Forwards the containment check to the underlying `cirq.GateFamily` objects.
- Examples of such operations include `cirq.GateOperations` and their controlled
and tagged variants (i.e. instances of `cirq.TaggedOperation`,
`cirq.ControlledOperation` where `op.gate` is not None) etc.
Expand All @@ -268,8 +281,11 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
`cirq.ControlledOperation` where `op.gate` is None) etc.

The complexity of the method is:
a) O(1) for checking containment in the default `cirq.GateFamily` instances.
b) O(n) for checking containment in custom GateFamily instances.
a) O(1) when any default `cirq.GateFamily` instance accepts the given item, except
for an Instance GateFamily trying to match an item with a different global phase.
b) O(n) for all other cases: matching against custom gate families, matching across
global phase for the default Instance GateFamily, no match against any underlying
gate family.

Args:
item: The `cirq.Gate` or `cirq.Operation` instance to check containment for.
Expand All @@ -295,7 +311,7 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
)
return True

return any(item in gate_family for gate_family in self._custom_gate_families)
return any(item in gate_family for gate_family in self._gates)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment that the above checks will still catch GateFamilys that are typed based and GateFamilys that have ignore_global_phase=False ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanded the docstring comment to explain the new code path.


def validate(
self,
Expand Down Expand Up @@ -351,7 +367,7 @@ def _validate_operation(self, op: raw_types.Operation) -> bool:
)
return self.validate(op_circuit)
elif isinstance(op, global_phase_op.GlobalPhaseOperation):
return self._accept_global_phase
return self._accept_global_phase_op
else:
return False

Expand All @@ -360,7 +376,7 @@ def _value_equality_values_(self) -> Any:
frozenset(self.gates),
self.name,
self._unroll_circuit_op,
self._accept_global_phase,
self._accept_global_phase_op,
)

def __repr__(self) -> str:
Expand All @@ -369,7 +385,7 @@ def __repr__(self) -> str:
f'{",".join([repr(g) for g in self.gates])},'
f'name = "{self.name}",'
f'unroll_circuit_op = {self._unroll_circuit_op},'
f'accept_global_phase = {self._accept_global_phase})'
f'accept_global_phase_op = {self._accept_global_phase_op})'
)

def __str__(self) -> str:
Expand Down
24 changes: 17 additions & 7 deletions cirq-core/cirq/ops/gateset_test.py
Expand Up @@ -131,13 +131,21 @@ def test_gate_family_eq():
cirq.GateFamily(CustomX),
[
(CustomX, True),
(CustomXPowGate(exponent=1, global_shift=0.15), True),
(CustomX ** 2, False),
(CustomX ** 3, True),
(CustomX ** sympy.Symbol('theta'), False),
(None, False),
(cirq.GlobalPhaseOperation(1j), False),
],
),
(
cirq.GateFamily(CustomX, ignore_global_phase=False),
[
(CustomX, True),
(CustomXPowGate(exponent=1, global_shift=0.15), False),
],
),
],
)
def test_gate_family_predicate_and_containment(gate_family, gates_to_check):
Expand Down Expand Up @@ -201,7 +209,7 @@ def test_gateset_repr_and_str():
(CustomX ** 2, True),
(CustomXPowGate(exponent=3, global_shift=0.5), True),
(CustomX ** 0.5, True),
(CustomXPowGate(exponent=0.5, global_shift=0.5), False),
(CustomXPowGate(exponent=0.5, global_shift=0.5), True),
(CustomX ** 0.25, False),
(CustomX ** sympy.Symbol('theta'), False),
(cirq.testing.TwoQubitGate(), True),
Expand Down Expand Up @@ -250,7 +258,7 @@ def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert_validate_and_contains_consistent(
gateset.with_params(
unroll_circuit_op=use_circuit_op,
accept_global_phase=use_global_phase,
accept_global_phase_op=use_global_phase,
),
op_tree,
True,
Expand All @@ -259,7 +267,7 @@ def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert_validate_and_contains_consistent(
gateset.with_params(
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
),
op_tree,
False,
Expand All @@ -272,16 +280,16 @@ def test_with_params():
gateset.with_params(
name=gateset.name,
unroll_circuit_op=gateset._unroll_circuit_op,
accept_global_phase=gateset._accept_global_phase,
accept_global_phase_op=gateset._accept_global_phase_op,
)
is gateset
)
gateset_with_params = gateset.with_params(
name='new name', unroll_circuit_op=False, accept_global_phase=False
name='new name', unroll_circuit_op=False, accept_global_phase_op=False
)
assert gateset_with_params.name == 'new name'
assert gateset_with_params._unroll_circuit_op is False
assert gateset_with_params._accept_global_phase is False
assert gateset_with_params._accept_global_phase_op is False


def test_gateset_eq():
Expand All @@ -290,7 +298,9 @@ def test_gateset_eq():
eq.add_equality_group(cirq.Gateset(CustomX ** 3))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset'))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', unroll_circuit_op=False))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase=False))
eq.add_equality_group(
cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase_op=False)
)
eq.add_equality_group(
cirq.Gateset(
cirq.GateFamily(CustomX, name='custom_name', description='custom_description')
Expand Down
Expand Up @@ -133,6 +133,7 @@ def test_allow_partial_czs():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.CZ(q0, q1) ** 0.5,
cirq.CZPowGate(exponent=0.5, global_shift=-0.5).on(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates(allow_partial_czs=True).optimize_circuit(circuit)
Expand All @@ -155,6 +156,14 @@ def test_allow_partial_czs():

def test_dont_allow_partial_czs():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.CZ(q0, q1),
cirq.CZPowGate(exponent=1, global_shift=-0.5).on(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
assert circuit == c_orig

circuit = cirq.Circuit(
cirq.CZ(q0, q1) ** 0.5,
)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/optimizers/merge_interactions.py
Expand Up @@ -232,7 +232,7 @@ def __init__(
self.gateset = ops.Gateset(
ops.CZPowGate if allow_partial_czs else ops.CZ,
unroll_circuit_op=False,
accept_global_phase=True,
accept_global_phase_op=True,
)

def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:
Expand Down