Skip to content

Commit

Permalink
Make Instance GateFamily check for equality ignoring global phase (#4542
Browse files Browse the repository at this point in the history
)

* GateFamily value equality ignoring global phase

* Rename accept_global_phase to accept_global_phase_op
  • Loading branch information
tanujkhattar committed Oct 13, 2021
1 parent 66c694a commit 888aeb7
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 70 deletions.
2 changes: 1 addition & 1 deletion cirq-core/cirq/ion/ion_device.py
Expand Up @@ -30,7 +30,7 @@ def get_ion_gateset() -> ops.Gateset:
ops.ZPowGate,
ops.PhasedXPowGate,
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)


Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/neutral_atoms/neutral_atom_devices.py
Expand Up @@ -44,7 +44,7 @@ def neutral_atom_gateset(max_parallel_z=None, max_parallel_xy=None):
ops.MeasurementGate,
ops.IdentityGate,
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)


Expand Down Expand Up @@ -100,15 +100,15 @@ def __init__(
ops.ParallelGateFamily(ops.YPowGate),
ops.ParallelGateFamily(ops.PhasedXPowGate),
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)
self.controlled_gateset = ops.Gateset(
ops.AnyIntegerPowerGateFamily(ops.CNotPowGate),
ops.AnyIntegerPowerGateFamily(ops.CCNotPowGate),
ops.AnyIntegerPowerGateFamily(ops.CZPowGate),
ops.AnyIntegerPowerGateFamily(ops.CCZPowGate),
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
)
self.gateset = neutral_atom_gateset(max_parallel_z, max_parallel_xy)
for q in qubits:
Expand Down
72 changes: 44 additions & 28 deletions cirq-core/cirq/ops/gateset.py
Expand Up @@ -14,7 +14,7 @@

"""Functionality for grouping and validating Cirq Gates"""

from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Dict, FrozenSet, Optional, Type, TYPE_CHECKING, Union
from cirq.ops import global_phase_op, op_tree, raw_types
from cirq import protocols, value

Expand All @@ -31,18 +31,20 @@ class GateFamily:
b) Python types inheriting from `cirq.Gate` (Type Family).
By default, the containment checks depend on the initialization type:
a) Instance Family: Containment check is done by object equality.
a) Instance Family: Containment check is done via `cirq.equal_up_to_global_phase`.
b) Type Family: Containment check is done by type comparison.
For example:
a) Instance Family:
>>> gate_family = cirq.GateFamily(cirq.X)
>>> assert cirq.X in gate_family
>>> assert cirq.Rx(rads=np.pi) in gate_family
>>> assert cirq.X ** sympy.Symbol("theta") not in gate_family
b) Type Family:
>>> gate_family = cirq.GateFamily(cirq.XPowGate)
>>> assert cirq.X in gate_family
>>> assert cirq.Rx(rads=np.pi) in gate_family
>>> assert cirq.X ** sympy.Symbol("theta") in gate_family
In order to create gate families with constraints on parameters of a gate
Expand All @@ -56,6 +58,7 @@ def __init__(
*,
name: Optional[str] = None,
description: Optional[str] = None,
ignore_global_phase: bool = True,
) -> None:
"""Init GateFamily.
Expand All @@ -64,6 +67,8 @@ def __init__(
a non-parameterized instance of a `cirq.Gate` for equality based membership checks.
name: The name of the gate family.
description: Human readable description of the gate family.
ignore_global_phase: If True, value equality is checked via
`cirq.equal_up_to_global_phase`.
Raises:
ValueError: if `gate` is not a `cirq.Gate` instance or subclass.
Expand All @@ -80,6 +85,7 @@ def __init__(
self._gate = gate
self._name = name if name else self._default_name()
self._description = description if description else self._default_description()
self._ignore_global_phase = ignore_global_phase

def _gate_str(self, gettr: Callable[[Any], str] = str) -> str:
return (
Expand Down Expand Up @@ -112,17 +118,20 @@ def _predicate(self, gate: raw_types.Gate) -> bool:
"""Checks whether `cirq.Gate` instance `gate` belongs to this GateFamily.
The default predicate depends on the gate family initialization type:
a) Instance Family: `gate == self.gate`.
a) Instance Family: `cirq.equal_up_to_global_phase(gate, self.gate)`
if self._ignore_global_phase else `gate == self.gate`.
b) Type Family: `isinstance(gate, self.gate)`.
Args:
gate: `cirq.Gate` instance which should be checked for containment.
"""
return (
gate == self.gate
if isinstance(self.gate, raw_types.Gate)
else isinstance(gate, self.gate)
)
if isinstance(self.gate, raw_types.Gate):
return (
protocols.equal_up_to_global_phase(gate, self.gate)
if self._ignore_global_phase
else gate == self._gate
)
return isinstance(gate, self.gate)

def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
if isinstance(item, raw_types.Operation):
Expand All @@ -138,12 +147,19 @@ def __repr__(self) -> str:
return (
f'cirq.GateFamily(gate={self._gate_str(repr)},'
f'name="{self.name}", '
f'description="{self.description}")'
f'description="{self.description}",'
f'ignore_global_phase={self._ignore_global_phase})'
)

def _value_equality_values_(self) -> Any:
# `isinstance` is used to ensure the a gate type and gate instance is not compared.
return isinstance(self.gate, raw_types.Gate), self.gate, self.name, self.description
return (
isinstance(self.gate, raw_types.Gate),
self.gate,
self.name,
self.description,
self._ignore_global_phase,
)


@value.value_equality()
Expand All @@ -163,7 +179,7 @@ def __init__(
*gates: Union[Type[raw_types.Gate], raw_types.Gate, GateFamily],
name: Optional[str] = None,
unroll_circuit_op: bool = True,
accept_global_phase: bool = True,
accept_global_phase_op: bool = True,
) -> None:
"""Init Gateset.
Expand All @@ -182,25 +198,22 @@ def __init__(
name: (Optional) Name for the Gateset. Useful for description.
unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively
validated by validating the underlying `cirq.Circuit`.
accept_global_phase: If True, `cirq.GlobalPhaseOperation` is accepted.
accept_global_phase_op: If True, `cirq.GlobalPhaseOperation` is accepted.
"""
self._name = name
self._gates = frozenset(
g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates
)
self._unroll_circuit_op = unroll_circuit_op
self._accept_global_phase = accept_global_phase
self._accept_global_phase_op = accept_global_phase_op
self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {}
self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {}
self._custom_gate_families: List[GateFamily] = []
for g in self._gates:
if type(g) == GateFamily:
if isinstance(g.gate, raw_types.Gate):
self._instance_gate_families[g.gate] = g
else:
self._type_gate_families[g.gate] = g
else:
self._custom_gate_families.append(g)

@property
def name(self) -> Optional[str]:
Expand All @@ -215,7 +228,7 @@ def with_params(
*,
name: Optional[str] = None,
unroll_circuit_op: Optional[bool] = None,
accept_global_phase: Optional[bool] = None,
accept_global_phase_op: Optional[bool] = None,
) -> 'Gateset':
"""Returns a copy of this Gateset with identical gates and new values for named arguments.
Expand All @@ -225,7 +238,7 @@ def with_params(
name: New name for the Gateset.
unroll_circuit_op: If True, new Gateset will recursively validate
`cirq.CircuitOperation` by validating the underlying `cirq.Circuit`.
accept_global_phase: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.
accept_global_phase_op: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.
Returns:
`self` if all new values are None or identical to the values of current Gateset.
Expand All @@ -237,26 +250,26 @@ def val_if_none(var: Any, val: Any) -> Any:

name = val_if_none(name, self._name)
unroll_circuit_op = val_if_none(unroll_circuit_op, self._unroll_circuit_op)
accept_global_phase = val_if_none(accept_global_phase, self._accept_global_phase)
accept_global_phase_op = val_if_none(accept_global_phase_op, self._accept_global_phase_op)
if (
name == self._name
and unroll_circuit_op == self._unroll_circuit_op
and accept_global_phase == self._accept_global_phase
and accept_global_phase_op == self._accept_global_phase_op
):
return self
return Gateset(
*self.gates,
name=name,
unroll_circuit_op=cast(bool, unroll_circuit_op),
accept_global_phase=cast(bool, accept_global_phase),
accept_global_phase_op=cast(bool, accept_global_phase_op),
)

def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
"""Check for containment of a given Gate/Operation in this Gateset.
Containment checks are handled as follows:
a) For Gates or Operations that have an underlying gate (i.e. op.gate is not None):
- Forwards the containment check to the underlying GateFamily's
- Forwards the containment check to the underlying `cirq.GateFamily` objects.
- Examples of such operations include `cirq.GateOperations` and their controlled
and tagged variants (i.e. instances of `cirq.TaggedOperation`,
`cirq.ControlledOperation` where `op.gate` is not None) etc.
Expand All @@ -268,8 +281,11 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
`cirq.ControlledOperation` where `op.gate` is None) etc.
The complexity of the method is:
a) O(1) for checking containment in the default `cirq.GateFamily` instances.
b) O(n) for checking containment in custom GateFamily instances.
a) O(1) when any default `cirq.GateFamily` instance accepts the given item, except
for an Instance GateFamily trying to match an item with a different global phase.
b) O(n) for all other cases: matching against custom gate families, matching across
global phase for the default Instance GateFamily, no match against any underlying
gate family.
Args:
item: The `cirq.Gate` or `cirq.Operation` instance to check containment for.
Expand All @@ -295,7 +311,7 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
)
return True

return any(item in gate_family for gate_family in self._custom_gate_families)
return any(item in gate_family for gate_family in self._gates)

def validate(
self,
Expand Down Expand Up @@ -351,7 +367,7 @@ def _validate_operation(self, op: raw_types.Operation) -> bool:
)
return self.validate(op_circuit)
elif isinstance(op, global_phase_op.GlobalPhaseOperation):
return self._accept_global_phase
return self._accept_global_phase_op
else:
return False

Expand All @@ -360,7 +376,7 @@ def _value_equality_values_(self) -> Any:
frozenset(self.gates),
self.name,
self._unroll_circuit_op,
self._accept_global_phase,
self._accept_global_phase_op,
)

def __repr__(self) -> str:
Expand All @@ -369,7 +385,7 @@ def __repr__(self) -> str:
f'{",".join([repr(g) for g in self.gates])},'
f'name = "{self.name}",'
f'unroll_circuit_op = {self._unroll_circuit_op},'
f'accept_global_phase = {self._accept_global_phase})'
f'accept_global_phase_op = {self._accept_global_phase_op})'
)

def __str__(self) -> str:
Expand Down
24 changes: 17 additions & 7 deletions cirq-core/cirq/ops/gateset_test.py
Expand Up @@ -131,13 +131,21 @@ def test_gate_family_eq():
cirq.GateFamily(CustomX),
[
(CustomX, True),
(CustomXPowGate(exponent=1, global_shift=0.15), True),
(CustomX ** 2, False),
(CustomX ** 3, True),
(CustomX ** sympy.Symbol('theta'), False),
(None, False),
(cirq.GlobalPhaseOperation(1j), False),
],
),
(
cirq.GateFamily(CustomX, ignore_global_phase=False),
[
(CustomX, True),
(CustomXPowGate(exponent=1, global_shift=0.15), False),
],
),
],
)
def test_gate_family_predicate_and_containment(gate_family, gates_to_check):
Expand Down Expand Up @@ -201,7 +209,7 @@ def test_gateset_repr_and_str():
(CustomX ** 2, True),
(CustomXPowGate(exponent=3, global_shift=0.5), True),
(CustomX ** 0.5, True),
(CustomXPowGate(exponent=0.5, global_shift=0.5), False),
(CustomXPowGate(exponent=0.5, global_shift=0.5), True),
(CustomX ** 0.25, False),
(CustomX ** sympy.Symbol('theta'), False),
(cirq.testing.TwoQubitGate(), True),
Expand Down Expand Up @@ -250,7 +258,7 @@ def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert_validate_and_contains_consistent(
gateset.with_params(
unroll_circuit_op=use_circuit_op,
accept_global_phase=use_global_phase,
accept_global_phase_op=use_global_phase,
),
op_tree,
True,
Expand All @@ -259,7 +267,7 @@ def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert_validate_and_contains_consistent(
gateset.with_params(
unroll_circuit_op=False,
accept_global_phase=False,
accept_global_phase_op=False,
),
op_tree,
False,
Expand All @@ -272,16 +280,16 @@ def test_with_params():
gateset.with_params(
name=gateset.name,
unroll_circuit_op=gateset._unroll_circuit_op,
accept_global_phase=gateset._accept_global_phase,
accept_global_phase_op=gateset._accept_global_phase_op,
)
is gateset
)
gateset_with_params = gateset.with_params(
name='new name', unroll_circuit_op=False, accept_global_phase=False
name='new name', unroll_circuit_op=False, accept_global_phase_op=False
)
assert gateset_with_params.name == 'new name'
assert gateset_with_params._unroll_circuit_op is False
assert gateset_with_params._accept_global_phase is False
assert gateset_with_params._accept_global_phase_op is False


def test_gateset_eq():
Expand All @@ -290,7 +298,9 @@ def test_gateset_eq():
eq.add_equality_group(cirq.Gateset(CustomX ** 3))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset'))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', unroll_circuit_op=False))
eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase=False))
eq.add_equality_group(
cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase_op=False)
)
eq.add_equality_group(
cirq.Gateset(
cirq.GateFamily(CustomX, name='custom_name', description='custom_description')
Expand Down
Expand Up @@ -133,6 +133,7 @@ def test_allow_partial_czs():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.CZ(q0, q1) ** 0.5,
cirq.CZPowGate(exponent=0.5, global_shift=-0.5).on(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates(allow_partial_czs=True).optimize_circuit(circuit)
Expand All @@ -155,6 +156,14 @@ def test_allow_partial_czs():

def test_dont_allow_partial_czs():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.CZ(q0, q1),
cirq.CZPowGate(exponent=1, global_shift=-0.5).on(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
assert circuit == c_orig

circuit = cirq.Circuit(
cirq.CZ(q0, q1) ** 0.5,
)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/optimizers/merge_interactions.py
Expand Up @@ -232,7 +232,7 @@ def __init__(
self.gateset = ops.Gateset(
ops.CZPowGate if allow_partial_czs else ops.CZ,
unroll_circuit_op=False,
accept_global_phase=True,
accept_global_phase_op=True,
)

def _may_keep_old_op(self, old_op: 'cirq.Operation') -> bool:
Expand Down

0 comments on commit 888aeb7

Please sign in to comment.