From bde310e9f6e2034d5310c2b8de50ef9f2987ee68 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Mon, 14 Mar 2022 23:18:12 +0000 Subject: [PATCH] GateFamily: tag validation --- cirq-core/cirq/ops/gateset.py | 96 ++++++++++++++++++- cirq-core/cirq/ops/gateset_test.py | 60 +++++++++++- .../protocols/json_test_data/GateFamily.json | 30 +++++- .../protocols/json_test_data/GateFamily.repr | 6 +- .../protocols/json_test_data/Gateset.json | 16 +++- .../json_test_data/GridDeviceMetadata.json | 26 +++-- 6 files changed, 213 insertions(+), 21 deletions(-) diff --git a/cirq-core/cirq/ops/gateset.py b/cirq-core/cirq/ops/gateset.py index 1e36e67f635..53a31b650cb 100644 --- a/cirq-core/cirq/ops/gateset.py +++ b/cirq-core/cirq/ops/gateset.py @@ -15,7 +15,20 @@ """Functionality for grouping and validating Cirq Gates""" import warnings -from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + cast, + Dict, + FrozenSet, + Hashable, + List, + Optional, + Sequence, + Type, + TYPE_CHECKING, + Union, +) from cirq import _compat, protocols, value from cirq.ops import global_phase_op, op_tree, raw_types @@ -56,6 +69,37 @@ class GateFamily: >>> assert cirq.Rx(rads=np.pi) in gate_family >>> assert cirq.X ** sympy.Symbol("theta") in gate_family + As seen in the examples above, GateFamily supports containment checks for instances of both + `cirq.Operation` and `cirq.Gate`. By default, a `cirq.Operation` instance `op` is accepted if + the underlying `op.gate` is accepted. + + Further constraints can be added on containment checks for `cirq.Operation` objects by setting + `tags_to_accept` and/or `tags_to_ignore` in the GateFamily constructor. For a tagged + operation, the underlying gate `op.gate` will be checked for containment only if: + + * `op.tags` has no intersection with `tags_to_ignore`, and + * if `tags_to_accept` is not empty, then `op.tags` should have a non-empty intersection with + `tags_to_accept`. + + If a `cirq.Operation` contains tags from both `tags_to_accept` and `tags_to_ignore`, it is + rejected. Furthermore, tags cannot appear in both `tags_to_accept` and `tags_to_ignore`. + + For the purpose of tag comparisons, a `Gate` is considered as an `Operation` without tags. + + For example: + >>> q = cirq.NamedQubit('q') + >>> gate_family = cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['accepted_tag']) + >>> assert cirq.Z(q).with_tags('accepted_tag') in gate_family + >>> assert cirq.Z(q).with_tags('other_tag') not in gate_family + >>> assert cirq.Z(q) not in gate_family + >>> assert cirq.Z not in gate_family + ... + >>> gate_family = cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=['ignored_tag']) + >>> assert cirq.Z(q).with_tags('ignored_tag') not in gate_family + >>> assert cirq.Z(q).with_tags('other_tag') in gate_family + >>> assert cirq.Z(q) in gate_family + >>> assert cirq.Z in gate_family + In order to create gate families with constraints on parameters of a gate type, users should derive from the `cirq.GateFamily` class and override the `_predicate` method used to check for gate containment. @@ -68,6 +112,8 @@ def __init__( name: Optional[str] = None, description: Optional[str] = None, ignore_global_phase: bool = True, + tags_to_accept: Sequence[Hashable] = (), + tags_to_ignore: Sequence[Hashable] = (), ) -> None: """Init GateFamily. @@ -78,10 +124,16 @@ def __init__( description: Human readable description of the gate family. ignore_global_phase: If True, value equality is checked via `cirq.equal_up_to_global_phase`. + tags_to_accept: If non-empty, only `cirq.Operations` containing at least one tag in this + sequence can be accepted. + tags_to_ignore: Any `cirq.Operation` containing at least one tag in this sequence is + rejected. Note that this takes precedence over `tags_to_accept`, so an operation + which contains tags from both `tags_to_accept` and `tags_to_ignore` is rejected. Raises: ValueError: if `gate` is not a `cirq.Gate` instance or subclass. ValueError: if `gate` is a parameterized instance of `cirq.Gate`. + ValueError: if `tags_to_accept` and `tags_to_ignore` contain common tags. """ if not ( isinstance(gate, raw_types.Gate) @@ -95,6 +147,14 @@ def __init__( self._name = name if name else self._default_name() self._description = description if description else self._default_description() self._ignore_global_phase = ignore_global_phase + self._tags_to_accept = frozenset(tags_to_accept) + self._tags_to_ignore = frozenset(tags_to_ignore) + + common_tags = self._tags_to_accept & self._tags_to_ignore + if common_tags: + raise ValueError( + f"Tag(s) '{list(common_tags)}' cannot be in both tags_to_accept and tags_to_ignore." + ) def _gate_str(self, gettr: Callable[[Any], str] = str) -> str: return _gate_str(self.gate, gettr) @@ -142,6 +202,13 @@ def _predicate(self, gate: raw_types.Gate) -> bool: return isinstance(gate, self.gate) def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool: + if self._tags_to_accept and ( + not isinstance(item, raw_types.Operation) or self._tags_to_accept.isdisjoint(item.tags) + ): + return False + if isinstance(item, raw_types.Operation) and not self._tags_to_ignore.isdisjoint(item.tags): + return False + if isinstance(item, raw_types.Operation): if item.gate is None: return False @@ -159,7 +226,9 @@ def __repr__(self) -> str: f'cirq.GateFamily(' f'gate={self._gate_str(repr)}, ' f'{name_and_description}' - f'ignore_global_phase={self._ignore_global_phase})' + f'ignore_global_phase={self._ignore_global_phase}, ' + f'tags_to_accept={self._tags_to_accept}, ' + f'tags_to_ignore={self._tags_to_ignore})' ) def _value_equality_values_(self) -> Any: @@ -170,24 +239,41 @@ def _value_equality_values_(self) -> Any: self.name, self.description, self._ignore_global_phase, + self._tags_to_accept, + self._tags_to_ignore, ) def _json_dict_(self) -> Dict[str, Any]: - return { + d: Dict[str, Any] = { 'gate': self._gate_json(), 'name': self.name, 'description': self.description, 'ignore_global_phase': self._ignore_global_phase, + 'tags_to_accept': list(self._tags_to_accept), + 'tags_to_ignore': list(self._tags_to_ignore), } + return d @classmethod def _from_json_dict_( - cls, gate, name, description, ignore_global_phase, **kwargs + cls, + gate, + name, + description, + ignore_global_phase, + tags_to_accept=(), + tags_to_ignore=(), + **kwargs, ) -> 'GateFamily': if isinstance(gate, str): gate = protocols.cirq_type_from_json(gate) return cls( - gate, name=name, description=description, ignore_global_phase=ignore_global_phase + gate, + name=name, + description=description, + ignore_global_phase=ignore_global_phase, + tags_to_accept=tags_to_accept, + tags_to_ignore=tags_to_ignore, ) diff --git a/cirq-core/cirq/ops/gateset_test.py b/cirq-core/cirq/ops/gateset_test.py index 7427ed35ae5..02a827d02d1 100644 --- a/cirq-core/cirq/ops/gateset_test.py +++ b/cirq-core/cirq/ops/gateset_test.py @@ -46,6 +46,7 @@ def _num_qubits_(self) -> int: CustomX = CustomXPowGate() +q = cirq.NamedQubit("q") @pytest.mark.parametrize('gate', [CustomX, CustomXPowGate]) @@ -72,6 +73,9 @@ def test_invalid_gate_family(): with pytest.raises(ValueError, match='non-parameterized instance of `cirq.Gate`'): _ = cirq.GateFamily(gate=CustomX ** sympy.Symbol('theta')) + with pytest.raises(ValueError, match='cannot be in both'): + _ = cirq.GateFamily(gate=cirq.H, tags_to_accept={'a', 'b'}, tags_to_ignore={'b', 'c'}) + def test_gate_family_immutable(): g = cirq.GateFamily(CustomX) @@ -151,7 +155,6 @@ def test_gate_family_eq(): ], ) def test_gate_family_predicate_and_containment(gate_family, gates_to_check): - q = cirq.NamedQubit("q") for gate, result in gates_to_check: assert gate_family._predicate(gate) == result assert (gate in gate_family) == result @@ -160,6 +163,61 @@ def test_gate_family_predicate_and_containment(gate_family, gates_to_check): assert (gate(q).with_tags('tags') in gate_family) == result +@pytest.mark.parametrize( + 'gate_family, gates_to_check', + [ + ( + # Accept only if the input operation contains at least one of the accepted tags. + cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['a', 'b']), + [ + (cirq.Z(q).with_tags('a', 'b'), True), + (cirq.Z(q).with_tags('a'), True), + (cirq.Z(q).with_tags('b'), True), + (cirq.Z(q).with_tags('c'), False), + (cirq.Z(q).with_tags('a', 'c'), True), + (cirq.Z(q).with_tags(), False), + (cirq.Z(q), False), + (cirq.Z, False), + (cirq.X(q).with_tags('a'), False), + (cirq.X(q).with_tags('c'), False), + ], + ), + ( + # Reject if input operation contains at least one of the rejected tags. + cirq.GateFamily(cirq.ZPowGate, tags_to_ignore=['a', 'b']), + [ + (cirq.Z(q).with_tags('a', 'b'), False), + (cirq.Z(q).with_tags('a'), False), + (cirq.Z(q).with_tags('b'), False), + (cirq.Z(q).with_tags('c'), True), + (cirq.Z(q).with_tags('a', 'c'), False), + (cirq.Z(q).with_tags(), True), + (cirq.Z(q), True), + (cirq.Z, True), + (cirq.X(q).with_tags('a'), False), + (cirq.X(q).with_tags('c'), False), + ], + ), + ( + cirq.GateFamily(cirq.ZPowGate, tags_to_accept=['a'], tags_to_ignore=['c']), + [ + (cirq.Z(q).with_tags('a', 'c'), False), # should prioritize tags_to_ignore + (cirq.Z(q).with_tags('a'), True), + (cirq.Z(q).with_tags('c'), False), + (cirq.Z(q).with_tags(), False), + (cirq.Z(q), False), + (cirq.Z, False), + (cirq.X(q).with_tags('a'), False), + (cirq.X(q).with_tags('c'), False), + ], + ), + ], +) +def test_gate_family_tagged_operations(gate_family, gates_to_check): + for gate, result in gates_to_check: + assert (gate in gate_family) == result + + class CustomXGateFamily(cirq.GateFamily): """Accepts all integer powers of CustomXPowGate""" diff --git a/cirq-core/cirq/protocols/json_test_data/GateFamily.json b/cirq-core/cirq/protocols/json_test_data/GateFamily.json index 8d8a02379d7..cb151f2bd51 100644 --- a/cirq-core/cirq/protocols/json_test_data/GateFamily.json +++ b/cirq-core/cirq/protocols/json_test_data/GateFamily.json @@ -4,7 +4,9 @@ "gate": "XPowGate", "name": "Type GateFamily: cirq.ops.common_gates.XPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "GateFamily", @@ -15,6 +17,30 @@ }, "name": "XFamily", "description": "Just the X gate.", - "ignore_global_phase": false + "ignore_global_phase": false, + "tags_to_accept": [], + "tags_to_ignore": [] + }, + { + "cirq_type": "GateFamily", + "gate": "ZPowGate", + "name": "Type GateFamily: cirq.ops.common_gates.ZPowGate", + "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`", + "ignore_global_phase": true, + "tags_to_accept": [ + "physical_z" + ], + "tags_to_ignore": [] + }, + { + "cirq_type": "GateFamily", + "gate": "ZPowGate", + "name": "Type GateFamily: cirq.ops.common_gates.ZPowGate", + "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`", + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [ + "physical_z" + ] } ] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/GateFamily.repr b/cirq-core/cirq/protocols/json_test_data/GateFamily.repr index ed9e506cffd..d4e98f335de 100644 --- a/cirq-core/cirq/protocols/json_test_data/GateFamily.repr +++ b/cirq-core/cirq/protocols/json_test_data/GateFamily.repr @@ -1,4 +1,6 @@ [ - cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True), - cirq.GateFamily(gate=cirq.X, name="XFamily", description="Just the X gate.", ignore_global_phase=False) + cirq.GateFamily(gate=cirq.ops.common_gates.XPowGate, ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset()), + cirq.GateFamily(gate=cirq.X, name="XFamily", description="Just the X gate.", ignore_global_phase=False, tags_to_accept=frozenset(), tags_to_ignore=frozenset()), + cirq.GateFamily(gate=cirq.ops.common_gates.ZPowGate, ignore_global_phase=True, tags_to_accept=frozenset({'physical_z'}), tags_to_ignore=frozenset()), + cirq.GateFamily(gate=cirq.ops.common_gates.ZPowGate, ignore_global_phase=True, tags_to_accept=frozenset(), tags_to_ignore=frozenset({'physical_z'})), ] \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/Gateset.json b/cirq-core/cirq/protocols/json_test_data/Gateset.json index b6c4dbfd6c5..51ffe208475 100644 --- a/cirq-core/cirq/protocols/json_test_data/Gateset.json +++ b/cirq-core/cirq/protocols/json_test_data/Gateset.json @@ -7,7 +7,9 @@ "gate": "YPowGate", "name": "Type GateFamily: cirq.ops.common_gates.YPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.YPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "AnyUnitaryGateFamily", @@ -22,7 +24,9 @@ }, "name": "Instance GateFamily: X", "description": "Accepts `cirq.Gate` instances `g` s.t. `g == X`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] } ], "name": null, @@ -40,14 +44,18 @@ }, "name": "Instance GateFamily: X", "description": "Accepts `cirq.Gate` instances `g` s.t. `g == X`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "GateFamily", "gate": "YPowGate", "name": "Type GateFamily: cirq.ops.common_gates.YPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.YPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "AnyUnitaryGateFamily", diff --git a/cirq-core/cirq/protocols/json_test_data/GridDeviceMetadata.json b/cirq-core/cirq/protocols/json_test_data/GridDeviceMetadata.json index 2f17bc263e4..90e32478175 100644 --- a/cirq-core/cirq/protocols/json_test_data/GridDeviceMetadata.json +++ b/cirq-core/cirq/protocols/json_test_data/GridDeviceMetadata.json @@ -94,21 +94,27 @@ "gate": "XPowGate", "name": "Type GateFamily: cirq.ops.common_gates.XPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "GateFamily", "gate": "YPowGate", "name": "Type GateFamily: cirq.ops.common_gates.YPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.YPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "GateFamily", "gate": "ZPowGate", "name": "Type GateFamily: cirq.ops.common_gates.ZPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] } ], "name": null, @@ -121,7 +127,9 @@ "gate": "XPowGate", "name": "Type GateFamily: cirq.ops.common_gates.XPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.XPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "Duration", @@ -134,7 +142,9 @@ "gate": "YPowGate", "name": "Type GateFamily: cirq.ops.common_gates.YPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.YPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "Duration", @@ -147,7 +157,9 @@ "gate": "ZPowGate", "name": "Type GateFamily: cirq.ops.common_gates.ZPowGate", "description": "Accepts `cirq.Gate` instances `g` s.t. `isinstance(g, cirq.ops.common_gates.ZPowGate)`", - "ignore_global_phase": true + "ignore_global_phase": true, + "tags_to_accept": [], + "tags_to_ignore": [] }, { "cirq_type": "Duration", @@ -197,4 +209,4 @@ "col": 10 } ] -} +} \ No newline at end of file