Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions cirq-core/cirq/ops/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,20 @@
"""Functionality for grouping and validating Cirq Gates"""

import warnings
from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
cast,
Dict,
FrozenSet,
Hashable,
List,
Optional,
Sequence,
Type,
TYPE_CHECKING,
Union,
)

from cirq import _compat, protocols, value
from cirq.ops import global_phase_op, op_tree, raw_types
Expand Down Expand Up @@ -56,6 +69,37 @@ class GateFamily:
>>> assert cirq.Rx(rads=np.pi) in gate_family
>>> assert cirq.X ** sympy.Symbol("theta") in gate_family

As seen in the examples above, GateFamily supports containment checks for instances of both
`cirq.Operation` and `cirq.Gate`. By default, a `cirq.Operation` instance `op` is accepted if
the underlying `op.gate` is accepted.

Further constraints can be added on containment checks for `cirq.Operation` objects by setting
`tags_to_accept` and/or `tags_to_ignore` in the GateFamily constructor. For a tagged
operation, the underlying gate `op.gate` will be checked for containment only if:

* `op.tags` has no intersection with `tags_to_ignore`, and
* if `tags_to_accept` is not empty, then `op.tags` should have a non-empty intersection with
`tags_to_accept`.

If a `cirq.Operation` contains tags from both `tags_to_accept` and `tags_to_ignore`, it is
rejected. Furthermore, tags cannot appear in both `tags_to_accept` and `tags_to_ignore`.

For the purpose of tag comparisons, a `Gate` is considered as an `Operation` without tags.

For example:
>>> q = cirq.NamedQubit('q')
>>> gate_family = cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['accepted_tag'])
>>> assert cirq.Z(q).with_tags('accepted_tag') in gate_family
>>> assert cirq.Z(q).with_tags('other_tag') not in gate_family
>>> assert cirq.Z(q) not in gate_family
>>> assert cirq.Z not in gate_family
...
>>> gate_family = cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=['ignored_tag'])
>>> assert cirq.Z(q).with_tags('ignored_tag') not in gate_family
>>> assert cirq.Z(q).with_tags('other_tag') in gate_family
>>> assert cirq.Z(q) in gate_family
>>> assert cirq.Z in gate_family

In order to create gate families with constraints on parameters of a gate
type, users should derive from the `cirq.GateFamily` class and override the
`_predicate` method used to check for gate containment.
Expand All @@ -68,6 +112,8 @@ def __init__(
name: Optional[str] = None,
description: Optional[str] = None,
ignore_global_phase: bool = True,
tags_to_accept: Sequence[Hashable] = (),
tags_to_ignore: Sequence[Hashable] = (),
) -> None:
"""Init GateFamily.

Expand All @@ -78,10 +124,16 @@ def __init__(
description: Human readable description of the gate family.
ignore_global_phase: If True, value equality is checked via
`cirq.equal_up_to_global_phase`.
tags_to_accept: If non-empty, only `cirq.Operations` containing at least one tag in this
sequence can be accepted.
tags_to_ignore: Any `cirq.Operation` containing at least one tag in this sequence is
rejected. Note that this takes precedence over `tags_to_accept`, so an operation
which contains tags from both `tags_to_accept` and `tags_to_ignore` is rejected.

Raises:
ValueError: if `gate` is not a `cirq.Gate` instance or subclass.
ValueError: if `gate` is a parameterized instance of `cirq.Gate`.
ValueError: if `tags_to_accept` and `tags_to_ignore` contain common tags.
"""
if not (
isinstance(gate, raw_types.Gate)
Expand All @@ -95,6 +147,14 @@ def __init__(
self._name = name if name else self._default_name()
self._description = description if description else self._default_description()
self._ignore_global_phase = ignore_global_phase
self._tags_to_accept = frozenset(tags_to_accept)
self._tags_to_ignore = frozenset(tags_to_ignore)

common_tags = self._tags_to_accept & self._tags_to_ignore
if common_tags:
raise ValueError(
f"Tag(s) '{list(common_tags)}' cannot be in both tags_to_accept and tags_to_ignore."
)

def _gate_str(self, gettr: Callable[[Any], str] = str) -> str:
return _gate_str(self.gate, gettr)
Expand Down Expand Up @@ -142,6 +202,13 @@ def _predicate(self, gate: raw_types.Gate) -> bool:
return isinstance(gate, self.gate)

def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
if self._tags_to_accept and (
not isinstance(item, raw_types.Operation) or self._tags_to_accept.isdisjoint(item.tags)
):
return False
if isinstance(item, raw_types.Operation) and not self._tags_to_ignore.isdisjoint(item.tags):
return False

if isinstance(item, raw_types.Operation):
if item.gate is None:
return False
Expand All @@ -159,7 +226,9 @@ def __repr__(self) -> str:
f'cirq.GateFamily('
f'gate={self._gate_str(repr)}, '
f'{name_and_description}'
f'ignore_global_phase={self._ignore_global_phase})'
f'ignore_global_phase={self._ignore_global_phase}, '
f'tags_to_accept={self._tags_to_accept}, '
f'tags_to_ignore={self._tags_to_ignore})'
)

def _value_equality_values_(self) -> Any:
Expand All @@ -170,24 +239,41 @@ def _value_equality_values_(self) -> Any:
self.name,
self.description,
self._ignore_global_phase,
self._tags_to_accept,
self._tags_to_ignore,
)

def _json_dict_(self) -> Dict[str, Any]:
return {
d: Dict[str, Any] = {
'gate': self._gate_json(),
'name': self.name,
'description': self.description,
'ignore_global_phase': self._ignore_global_phase,
'tags_to_accept': list(self._tags_to_accept),
'tags_to_ignore': list(self._tags_to_ignore),
}
return d

@classmethod
def _from_json_dict_(
cls, gate, name, description, ignore_global_phase, **kwargs
cls,
gate,
name,
description,
ignore_global_phase,
tags_to_accept=(),
tags_to_ignore=(),
**kwargs,
) -> 'GateFamily':
if isinstance(gate, str):
gate = protocols.cirq_type_from_json(gate)
return cls(
gate, name=name, description=description, ignore_global_phase=ignore_global_phase
gate,
name=name,
description=description,
ignore_global_phase=ignore_global_phase,
tags_to_accept=tags_to_accept,
tags_to_ignore=tags_to_ignore,
)


Expand Down
60 changes: 59 additions & 1 deletion cirq-core/cirq/ops/gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _num_qubits_(self) -> int:


CustomX = CustomXPowGate()
q = cirq.NamedQubit("q")


@pytest.mark.parametrize('gate', [CustomX, CustomXPowGate])
Expand All @@ -72,6 +73,9 @@ def test_invalid_gate_family():
with pytest.raises(ValueError, match='non-parameterized instance of `cirq.Gate`'):
_ = cirq.GateFamily(gate=CustomX ** sympy.Symbol('theta'))

with pytest.raises(ValueError, match='cannot be in both'):
_ = cirq.GateFamily(gate=cirq.H, tags_to_accept={'a', 'b'}, tags_to_ignore={'b', 'c'})


def test_gate_family_immutable():
g = cirq.GateFamily(CustomX)
Expand Down Expand Up @@ -151,7 +155,6 @@ def test_gate_family_eq():
],
)
def test_gate_family_predicate_and_containment(gate_family, gates_to_check):
q = cirq.NamedQubit("q")
for gate, result in gates_to_check:
assert gate_family._predicate(gate) == result
assert (gate in gate_family) == result
Expand All @@ -160,6 +163,61 @@ def test_gate_family_predicate_and_containment(gate_family, gates_to_check):
assert (gate(q).with_tags('tags') in gate_family) == result


@pytest.mark.parametrize(
'gate_family, gates_to_check',
[
(
# Accept only if the input operation contains at least one of the accepted tags.
cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['a', 'b']),
[
(cirq.Z(q).with_tags('a', 'b'), True),
(cirq.Z(q).with_tags('a'), True),
(cirq.Z(q).with_tags('b'), True),
(cirq.Z(q).with_tags('c'), False),
(cirq.Z(q).with_tags('a', 'c'), True),
(cirq.Z(q).with_tags(), False),
(cirq.Z(q), False),
(cirq.Z, False),
(cirq.X(q).with_tags('a'), False),
(cirq.X(q).with_tags('c'), False),
],
),
(
# Reject if input operation contains at least one of the rejected tags.
cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=['a', 'b']),
[
(cirq.Z(q).with_tags('a', 'b'), False),
(cirq.Z(q).with_tags('a'), False),
(cirq.Z(q).with_tags('b'), False),
(cirq.Z(q).with_tags('c'), True),
(cirq.Z(q).with_tags('a', 'c'), False),
(cirq.Z(q).with_tags(), True),
(cirq.Z(q), True),
(cirq.Z, True),
(cirq.X(q).with_tags('a'), False),
(cirq.X(q).with_tags('c'), False),
],
),
(
cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['a'], tags_to_ignore=['c']),
[
(cirq.Z(q).with_tags('a', 'c'), False), # should prioritize tags_to_ignore
(cirq.Z(q).with_tags('a'), True),
(cirq.Z(q).with_tags('c'), False),
(cirq.Z(q).with_tags(), False),
(cirq.Z(q), False),
(cirq.Z, False),
(cirq.X(q).with_tags('a'), False),
(cirq.X(q).with_tags('c'), False),
],
),
],
)
def test_gate_family_tagged_operations(gate_family, gates_to_check):
for gate, result in gates_to_check:
assert (gate in gate_family) == result


class CustomXGateFamily(cirq.GateFamily):
"""Accepts all integer powers of CustomXPowGate"""

Expand Down
30 changes: 28 additions & 2 deletions cirq-core/cirq/protocols/json_test_data/GateFamily.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"gate": "XPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.XPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`",
"ignore_global_phase": true
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
Expand All @@ -15,6 +17,30 @@
},
"name": "XFamily",
"description": "Just the X gate.",
"ignore_global_phase": false
"ignore_global_phase": false,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "ZPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.ZPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [
"physical_z"
],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "ZPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.ZPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`",
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": [
"physical_z"
]
}
]
6 changes: 4 additions & 2 deletions cirq-core/cirq/protocols/json_test_data/GateFamily.repr
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[
cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True),
cirq.GateFamily(gate=cirq.X, name="XFamily", description="Just the X gate.", ignore_global_phase=False)
cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset()),
cirq.GateFamily(gate=cirq.X, name="XFamily", description="Just the X gate.", ignore_global_phase=False, tags_to_accept=frozenset(), tags_to_ignore=frozenset()),
cirq.GateFamily(gate=cirq.ops.common_gates.ZPowGate, ignore_global_phase=True, tags_to_accept=frozenset({'physical_z'}), tags_to_ignore=frozenset()),
cirq.GateFamily(gate=cirq.ops.common_gates.ZPowGate, ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset({'physical_z'})),
]
16 changes: 12 additions & 4 deletions cirq-core/cirq/protocols/json_test_data/Gateset.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
"gate": "YPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.YPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.YPowGate)`",
"ignore_global_phase": true
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "AnyUnitaryGateFamily",
Expand All @@ -22,7 +24,9 @@
},
"name": "Instance GateFamily: X",
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == X`",
"ignore_global_phase": true
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
}
],
"name": null,
Expand All @@ -40,14 +44,18 @@
},
"name": "Instance GateFamily: X",
"description": "Accepts `cirq.Gate` instances `g` s.t. `g == X`",
"ignore_global_phase": true
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "GateFamily",
"gate": "YPowGate",
"name": "Type GateFamily: cirq.ops.common_gates.YPowGate",
"description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.YPowGate)`",
"ignore_global_phase": true
"ignore_global_phase": true,
"tags_to_accept": [],
"tags_to_ignore": []
},
{
"cirq_type": "AnyUnitaryGateFamily",
Expand Down
Loading