diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index fed48e90254..a3daa6a0443 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -212,7 +212,9 @@ freeze_op_tree, FSimGate, Gate, + GateFamily, GateOperation, + Gateset, generalized_amplitude_damp, GeneralizedAmplitudeDampingChannel, givens, diff --git a/cirq-core/cirq/ops/__init__.py b/cirq-core/cirq/ops/__init__.py index 65bef2e2d50..af3bd520799 100644 --- a/cirq-core/cirq/ops/__init__.py +++ b/cirq-core/cirq/ops/__init__.py @@ -110,6 +110,8 @@ GateOperation, ) +from cirq.ops.gateset import GateFamily, Gateset + from cirq.ops.identity import ( I, identity_each, diff --git a/cirq-core/cirq/ops/gateset.py b/cirq-core/cirq/ops/gateset.py new file mode 100644 index 00000000000..490073799a2 --- /dev/null +++ b/cirq-core/cirq/ops/gateset.py @@ -0,0 +1,379 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functionality for grouping and validating Cirq Gates""" + +from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union +from cirq.ops import global_phase_op, op_tree, raw_types +from cirq import protocols, value + +if TYPE_CHECKING: + import cirq + + +@value.value_equality(distinct_child_types=True) +class GateFamily: + """Wrapper around gate instances/types describing a set of accepted gates. + + GateFamily supports initialization via + a) Non-parameterized instances of `cirq.Gate` (Instance Family). + b) Python types inheriting from `cirq.Gate` (Type Family). + + By default, the containment checks depend on the initialization type: + a) Instance Family: Containment check is done by object equality. + b) Type Family: Containment check is done by type comparison. + + For example: + a) Instance Family: + >>> gate_family = cirq.GateFamily(cirq.X) + >>> assert cirq.X in gate_family + >>> assert cirq.X ** sympy.Symbol("theta") not in gate_family + + b) Type Family: + >>> gate_family = cirq.GateFamily(cirq.XPowGate) + >>> assert cirq.X in gate_family + >>> assert cirq.X ** sympy.Symbol("theta") in gate_family + + In order to create gate families with constraints on parameters of a gate + type, users should derive from the `cirq.GateFamily` class and override the + `_predicate` method used to check for gate containment. + """ + + def __init__( + self, + gate: Union[Type[raw_types.Gate], raw_types.Gate], + *, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> None: + """Init GateFamily. + + Args: + gate: A python `type` inheriting from `cirq.Gate` for type based membership checks, or + a non-parameterized instance of a `cirq.Gate` for equality based membership checks. + name: The name of the gate family. + description: Human readable description of the gate family. + + Raises: + ValueError: if `gate` is not a `cirq.Gate` instance or subclass. + ValueError: if `gate` is a parameterized instance of `cirq.Gate`. + """ + if not ( + isinstance(gate, raw_types.Gate) + or (isinstance(gate, type) and issubclass(gate, raw_types.Gate)) + ): + raise ValueError(f'Gate {gate} must be an instance or subclass of `cirq.Gate`.') + if isinstance(gate, raw_types.Gate) and protocols.is_parameterized(gate): + raise ValueError(f'Gate {gate} must be a non-parameterized instance of `cirq.Gate`.') + + self._gate = gate + self._name = name if name else self._default_name() + self._description = description if description else self._default_description() + + def _gate_str(self, gettr: Callable[[Any], str] = str) -> str: + return ( + gettr(self.gate) + if isinstance(self.gate, raw_types.Gate) + else f'{self.gate.__module__}.{self.gate.__name__}' + ) + + def _default_name(self) -> str: + family_type = 'Instance' if isinstance(self.gate, raw_types.Gate) else 'Type' + return f'{family_type} GateFamily: {self._gate_str()}' + + def _default_description(self) -> str: + check_type = r'g == {}' if isinstance(self.gate, raw_types.Gate) else r'isinstance(g, {})' + return f'Accepts `cirq.Gate` instances `g` s.t. `{check_type.format(self._gate_str())}`' + + @property + def gate(self) -> Union[Type[raw_types.Gate], raw_types.Gate]: + return self._gate + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._description + + def _predicate(self, gate: raw_types.Gate) -> bool: + """Checks whether `cirq.Gate` instance `gate` belongs to this GateFamily. + + The default predicate depends on the gate family initialization type: + a) Instance Family: `gate == self.gate`. + b) Type Family: `isinstance(gate, self.gate)`. + + Args: + gate: `cirq.Gate` instance which should be checked for containment. + """ + return ( + gate == self.gate + if isinstance(self.gate, raw_types.Gate) + else isinstance(gate, self.gate) + ) + + def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool: + if isinstance(item, raw_types.Operation): + if item.gate is None: + return False + item = item.gate + return self._predicate(item) + + def __str__(self) -> str: + return f'{self.name}\n{self.description}' + + def __repr__(self) -> str: + return ( + f'cirq.GateFamily(gate={self._gate_str(repr)},' + f'name="{self.name}", ' + f'description="{self.description}")' + ) + + def _value_equality_values_(self) -> Any: + # `isinstance` is used to ensure the a gate type and gate instance is not compared. + return isinstance(self.gate, raw_types.Gate), self.gate, self.name, self.description + + +@value.value_equality() +class Gateset: + """Gatesets represent a collection of `cirq.GateFamily` objects. + + Gatesets are useful for + a) Describing the set of allowed gates in a human readable format + b) Validating a given gate / optree against the set of allowed gates + + Gatesets rely on the underlying `cirq.GateFamily` for both description and + validation purposes. + """ + + def __init__( + self, + *gates: Union[Type[raw_types.Gate], raw_types.Gate, GateFamily], + name: Optional[str] = None, + unroll_circuit_op: bool = True, + accept_global_phase: bool = True, + ) -> None: + """Init Gateset. + + Accepts a list of gates, each of which should be either + a) `cirq.Gate` subclass + b) `cirq.Gate` instance + c) `cirq.GateFamily` instance + + `cirq.Gate` subclasses and instances are converted to the default + `cirq.GateFamily(gate=g)` instance and thus a default name and + description is populated. + + Args: + *gates: A list of `cirq.Gate` subclasses / `cirq.Gate` instances / + `cirq.GateFamily` instances to initialize the Gateset. + name: (Optional) Name for the Gateset. Useful for description. + unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively + validated by validating the underlying `cirq.Circuit`. + accept_global_phase: If True, `cirq.GlobalPhaseOperation` is accepted. + """ + self._name = name + self._gates = frozenset( + g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates + ) + self._unroll_circuit_op = unroll_circuit_op + self._accept_global_phase = accept_global_phase + self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {} + self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {} + self._custom_gate_families: List[GateFamily] = [] + for g in self._gates: + if type(g) == GateFamily: + if isinstance(g.gate, raw_types.Gate): + self._instance_gate_families[g.gate] = g + else: + self._type_gate_families[g.gate] = g + else: + self._custom_gate_families.append(g) + + @property + def name(self) -> Optional[str]: + return self._name + + @property + def gates(self) -> FrozenSet[GateFamily]: + return self._gates + + def with_params( + self, + *, + name: Optional[str] = None, + unroll_circuit_op: Optional[bool] = None, + accept_global_phase: Optional[bool] = None, + ) -> 'Gateset': + """Returns a copy of this Gateset with identical gates and new values for named arguments. + + If a named argument is None then corresponding value of this Gateset is used instead. + + Args: + name: New name for the Gateset. + unroll_circuit_op: If True, new Gateset will recursively validate + `cirq.CircuitOperation` by validating the underlying `cirq.Circuit`. + accept_global_phase: If True, new Gateset will accept `cirq.GlobalPhaseOperation`. + + Returns: + `self` if all new values are None or identical to the values of current Gateset. + else a new Gateset with identical gates and new values for named arguments. + """ + + def val_if_none(var: Any, val: Any) -> Any: + return var if var is not None else val + + name = val_if_none(name, self._name) + unroll_circuit_op = val_if_none(unroll_circuit_op, self._unroll_circuit_op) + accept_global_phase = val_if_none(accept_global_phase, self._accept_global_phase) + if ( + name == self._name + and unroll_circuit_op == self._unroll_circuit_op + and accept_global_phase == self._accept_global_phase + ): + return self + return Gateset( + *self.gates, + name=name, + unroll_circuit_op=cast(bool, unroll_circuit_op), + accept_global_phase=cast(bool, accept_global_phase), + ) + + def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool: + """Check for containment of a given Gate/Operation in this Gateset. + + Containment checks are handled as follows: + a) For Gates or Operations that have an underlying gate (i.e. op.gate is not None): + - Forwards the containment check to the underlying GateFamily's + - Examples of such operations include `cirq.GateOperations` and their controlled + and tagged variants (i.e. instances of `cirq.TaggedOperation`, + `cirq.ControlledOperation` where `op.gate` is not None) etc. + + b) For Operations that do not have an underlying gate: + - Forwards the containment check to `self._validate_operation(item)`. + - Examples of such operations include `cirq.CircuitOperations` and their controlled + and tagged variants (i.e. instances of `cirq.TaggedOperation`, + `cirq.ControlledOperation` where `op.gate` is None) etc. + + The complexity of the method is: + a) O(1) for checking containment in the default `cirq.GateFamily` instances. + b) O(n) for checking containment in custom GateFamily instances. + + Args: + item: The `cirq.Gate` or `cirq.Operation` instance to check containment for. + """ + if isinstance(item, raw_types.Operation) and item.gate is None: + return self._validate_operation(item) + + g = item if isinstance(item, raw_types.Gate) else item.gate + assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`' + + if g in self._instance_gate_families: + assert item in self._instance_gate_families[g], ( + f"{item} instance matches {self._instance_gate_families[g]} but " + f"is not accepted by it." + ) + return True + + for gate_mro_type in type(g).mro(): + if gate_mro_type in self._type_gate_families: + assert item in self._type_gate_families[gate_mro_type], ( + f"{g} type {gate_mro_type} matches Type GateFamily:" + f"{self._type_gate_families[gate_mro_type]} but is not accepted by it." + ) + return True + + return any(item in gate_family for gate_family in self._custom_gate_families) + + def validate( + self, + circuit_or_optree: Union['cirq.AbstractCircuit', op_tree.OP_TREE], + ) -> bool: + """Validates gates forming `circuit_or_optree` should be contained in Gateset. + + Args: + circuit_or_optree: The `cirq.Circuit` or `cirq.OP_TREE` to validate. + """ + # To avoid circular import. + from cirq.circuits import circuit + + optree = circuit_or_optree + if isinstance(circuit_or_optree, circuit.AbstractCircuit): + optree = circuit_or_optree.all_operations() + return all(self._validate_operation(op) for op in op_tree.flatten_to_ops(optree)) + + def _validate_operation(self, op: raw_types.Operation) -> bool: + """Validates whether the given `cirq.Operation` is contained in this Gateset. + + The containment checks are handled as follows: + + a) For any operation which has an underlying gate (i.e. `op.gate` is not None): + - Containment is checked via `self.__contains__` which further checks for containment + in any of the underlying gate families. + + b) For all other types of operations (eg: `cirq.CircuitOperation`, + `cirq.GlobalPhaseOperation` etc): + - The behavior is controlled via flags passed to the constructor. + + Users should override this method to define custom behavior for operations that do not + have an underlying `cirq.Gate`. + + Args: + op: The `cirq.Operation` instance to check containment for. + """ + + # To avoid circular import. + from cirq.circuits import circuit_operation + + if op.gate is not None: + return op in self + + if isinstance(op, raw_types.TaggedOperation): + return self._validate_operation(op.sub_operation) + elif isinstance(op, circuit_operation.CircuitOperation) and self._unroll_circuit_op: + op_circuit = protocols.resolve_parameters( + op.circuit.unfreeze(), op.param_resolver, recursive=False + ) + op_circuit = op_circuit.transform_qubits( + lambda q: cast(circuit_operation.CircuitOperation, op).qubit_map.get(q, q) + ) + return self.validate(op_circuit) + elif isinstance(op, global_phase_op.GlobalPhaseOperation): + return self._accept_global_phase + else: + return False + + def _value_equality_values_(self) -> Any: + return ( + frozenset(self.gates), + self.name, + self._unroll_circuit_op, + self._accept_global_phase, + ) + + def __repr__(self) -> str: + return ( + f'cirq.Gateset(' + f'{",".join([repr(g) for g in self.gates])},' + f'name = "{self.name}",' + f'unroll_circuit_op = {self._unroll_circuit_op},' + f'accept_global_phase = {self._accept_global_phase})' + ) + + def __str__(self) -> str: + header = 'Gateset: ' + if self.name: + header += self.name + return f'{header}\n' + "\n\n".join([str(g) for g in self.gates]) diff --git a/cirq-core/cirq/ops/gateset_test.py b/cirq-core/cirq/ops/gateset_test.py new file mode 100644 index 00000000000..12215a0f69a --- /dev/null +++ b/cirq-core/cirq/ops/gateset_test.py @@ -0,0 +1,314 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, List, cast +import re +import pytest +import sympy +import cirq +from cirq._compat import proper_repr +import numpy as np + + +class CustomXPowGate(cirq.EigenGate): + def _eigen_components(self) -> List[Tuple[float, np.ndarray]]: + return [ + (0, np.array([[0.5, 0.5], [0.5, 0.5]])), + (1, np.array([[0.5, -0.5], [-0.5, 0.5]])), + ] + + def __str__(self) -> str: + if self._global_shift == 0: + if self._exponent == 1: + return 'CustomX' + return f'CustomX**{self._exponent}' + return f'CustomXPowGate(exponent={self._exponent}, global_shift={self._global_shift!r})' + + def __repr__(self) -> str: + if self._global_shift == 0: + if self._exponent == 1: + return 'cirq.ops.gateset_test.CustomX' + return f'(cirq.ops.gateset_test.CustomX**{proper_repr(self._exponent)})' + return 'cirq.ops.gateset_test.CustomXPowGate(exponent={}, global_shift={!r})'.format( + proper_repr(self._exponent), self._global_shift + ) + + def _num_qubits_(self) -> int: + return 1 + + +CustomX = CustomXPowGate() + + +@pytest.mark.parametrize('gate', [CustomX, CustomXPowGate]) +def test_gate_family_init(gate): + name = 'test_name' + description = 'test_description' + g = cirq.GateFamily(gate=gate, name=name, description=description) + assert g.gate == gate + assert g.name == name + assert g.description == description + + +@pytest.mark.parametrize('gate', [CustomX, CustomXPowGate]) +def test_gate_family_default_name_and_description(gate): + g = cirq.GateFamily(gate) + assert re.match('.*GateFamily.*CustomX.*', g.name) + assert re.match('Accepts.*instances.*CustomX.*', g.description) + + +def test_invalid_gate_family(): + with pytest.raises(ValueError, match='instance or subclass of `cirq.Gate`'): + _ = cirq.GateFamily(gate=cirq.Operation) + + with pytest.raises(ValueError, match='non-parameterized instance of `cirq.Gate`'): + _ = cirq.GateFamily(gate=CustomX ** sympy.Symbol('theta')) + + +def test_gate_family_immutable(): + g = cirq.GateFamily(CustomX) + with pytest.raises(AttributeError, match="can't set attribute"): + g.gate = CustomXPowGate + with pytest.raises(AttributeError, match="can't set attribute"): + g.name = 'new name' + with pytest.raises(AttributeError, match="can't set attribute"): + g.description = 'new description' + + +@pytest.mark.parametrize( + 'gate', [CustomX, CustomXPowGate(exponent=0.5, global_shift=0.1), CustomXPowGate] +) +@pytest.mark.parametrize('name, description', [(None, None), ('custom_name', 'custom_description')]) +def test_gate_family_repr_and_str(gate, name, description): + g = cirq.GateFamily(gate, name=name, description=description) + cirq.testing.assert_equivalent_repr(g) + assert g.name in str(g) + assert g.description in str(g) + + +def test_gate_family_eq(): + eq = cirq.testing.EqualsTester() + eq.add_equality_group(cirq.GateFamily(CustomX)) + eq.add_equality_group(cirq.GateFamily(CustomX ** 3)) + eq.add_equality_group( + cirq.GateFamily(CustomX, name='custom_name', description='custom_description'), + cirq.GateFamily(CustomX ** 3, name='custom_name', description='custom_description'), + ) + eq.add_equality_group(cirq.GateFamily(CustomXPowGate)) + eq.add_equality_group( + cirq.GateFamily(CustomXPowGate, name='custom_name', description='custom_description') + ) + + +@pytest.mark.parametrize( + 'gate_family, gates_to_check', + [ + ( + cirq.GateFamily(CustomXPowGate), + [ + (CustomX, True), + (CustomX ** 0.5, True), + (CustomX ** sympy.Symbol('theta'), True), + (CustomXPowGate(exponent=0.25, global_shift=0.15), True), + (cirq.SingleQubitGate(), False), + (cirq.X ** 0.5, False), + (None, False), + (cirq.GlobalPhaseOperation(1j), False), + ], + ), + ( + cirq.GateFamily(CustomX), + [ + (CustomX, True), + (CustomX ** 2, False), + (CustomX ** 3, True), + (CustomX ** sympy.Symbol('theta'), False), + (None, False), + (cirq.GlobalPhaseOperation(1j), False), + ], + ), + ], +) +def test_gate_family_predicate_and_containment(gate_family, gates_to_check): + q = cirq.NamedQubit("q") + for gate, result in gates_to_check: + assert gate_family._predicate(gate) == result + assert (gate in gate_family) == result + if isinstance(gate, cirq.Gate): + assert (gate(q) in gate_family) == result + assert (gate(q).with_tags('tags') in gate_family) == result + + +class CustomXGateFamily(cirq.GateFamily): + """Accepts all integer powers of CustomXPowGate""" + + def __init__(self) -> None: + super().__init__( + gate=CustomXPowGate, + name='CustomXGateFamily', + description='Accepts all integer powers of CustomXPowGate', + ) + + def _predicate(self, g: cirq.Gate) -> bool: + """Checks whether gate instance `g` belongs to this GateFamily.""" + if not super()._predicate(g) or cirq.is_parameterized(g): + return False + exp = cast(CustomXPowGate, g).exponent + return int(exp) == exp + + def __repr__(self): + return 'cirq.ops.gateset_test.CustomXGateFamily()' + + +gateset = cirq.Gateset( + CustomX ** 0.5, cirq.testing.TwoQubitGate, CustomXGateFamily(), name='custom gateset' +) + + +def test_gateset_init(): + assert gateset.name == 'custom gateset' + assert gateset.gates == frozenset( + [ + cirq.GateFamily(CustomX ** 0.5), + cirq.GateFamily(cirq.testing.TwoQubitGate), + CustomXGateFamily(), + ] + ) + + +def test_gateset_repr_and_str(): + cirq.testing.assert_equivalent_repr(gateset) + assert gateset.name in str(gateset) + for gate_family in gateset.gates: + assert str(gate_family) in str(gateset) + + +@pytest.mark.parametrize( + 'gate, result', + [ + (CustomX, True), + (CustomX ** 2, True), + (CustomXPowGate(exponent=3, global_shift=0.5), True), + (CustomX ** 0.5, True), + (CustomXPowGate(exponent=0.5, global_shift=0.5), False), + (CustomX ** 0.25, False), + (CustomX ** sympy.Symbol('theta'), False), + (cirq.testing.TwoQubitGate(), True), + ], +) +def test_gateset_contains(gate, result): + assert (gate in gateset) is result + op = gate(*cirq.LineQubit.range(gate.num_qubits())) + assert (op in gateset) is result + assert (op.with_tags('tags') in gateset) is result + circuit_op = cirq.CircuitOperation(cirq.FrozenCircuit([op] * 5), repetitions=5) + assert (circuit_op in gateset) is result + assert circuit_op not in gateset.with_params(unroll_circuit_op=False) + + +@pytest.mark.parametrize('use_circuit_op', [True, False]) +@pytest.mark.parametrize('use_global_phase', [True, False]) +def test_gateset_validate(use_circuit_op, use_global_phase): + def optree_and_circuit(optree): + yield optree + yield cirq.Circuit(optree) + + def get_ops(use_circuit_op, use_global_phase): + q = cirq.LineQubit.range(3) + yield [CustomX(q[0]).with_tags('custom tags'), CustomX(q[1]) ** 2, CustomX(q[2]) ** 3] + yield [CustomX(q[0]) ** 0.5, cirq.testing.TwoQubitGate()(*q[:2])] + if use_circuit_op: + circuit_op = cirq.CircuitOperation( + cirq.FrozenCircuit(get_ops(False, False)), repetitions=10 + ).with_tags('circuit op tags') + recursive_circuit_op = cirq.CircuitOperation( + cirq.FrozenCircuit([circuit_op, CustomX(q[2]) ** 0.5]), + repetitions=10, + qubit_map={q[0]: q[1], q[1]: q[2], q[2]: q[0]}, + ) + yield [circuit_op, recursive_circuit_op] + if use_global_phase: + yield cirq.GlobalPhaseOperation(1j) + + def assert_validate_and_contains_consistent(gateset, op_tree, result): + assert all(op in gateset for op in cirq.flatten_to_ops(op_tree)) is result + for item in optree_and_circuit(op_tree): + assert gateset.validate(item) is result + + op_tree = [*get_ops(use_circuit_op, use_global_phase)] + assert_validate_and_contains_consistent( + gateset.with_params( + unroll_circuit_op=use_circuit_op, + accept_global_phase=use_global_phase, + ), + op_tree, + True, + ) + if use_circuit_op or use_global_phase: + assert_validate_and_contains_consistent( + gateset.with_params( + unroll_circuit_op=False, + accept_global_phase=False, + ), + op_tree, + False, + ) + + +def test_with_params(): + assert gateset.with_params() is gateset + assert ( + gateset.with_params( + name=gateset.name, + unroll_circuit_op=gateset._unroll_circuit_op, + accept_global_phase=gateset._accept_global_phase, + ) + is gateset + ) + gateset_with_params = gateset.with_params( + name='new name', unroll_circuit_op=False, accept_global_phase=False + ) + assert gateset_with_params.name == 'new name' + assert gateset_with_params._unroll_circuit_op is False + assert gateset_with_params._accept_global_phase is False + + +def test_gateset_eq(): + eq = cirq.testing.EqualsTester() + eq.add_equality_group(cirq.Gateset(CustomX)) + eq.add_equality_group(cirq.Gateset(CustomX ** 3)) + eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset')) + eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', unroll_circuit_op=False)) + eq.add_equality_group(cirq.Gateset(CustomX, name='Custom Gateset', accept_global_phase=False)) + eq.add_equality_group( + cirq.Gateset( + cirq.GateFamily(CustomX, name='custom_name', description='custom_description') + ), + cirq.Gateset( + cirq.GateFamily(CustomX ** 3, name='custom_name', description='custom_description') + ), + ) + eq.add_equality_group( + cirq.Gateset(CustomX, CustomXPowGate), cirq.Gateset(CustomXPowGate, CustomX) + ) + eq.add_equality_group(cirq.Gateset(CustomXGateFamily())) + eq.add_equality_group( + cirq.Gateset( + cirq.GateFamily( + gate=CustomXPowGate, + name='CustomXGateFamily', + description='Accepts all integer powers of CustomXPowGate', + ) + ) + ) diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index b75bdb1dc68..e0c155493e3 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -37,6 +37,8 @@ 'DensityMatrixStepResult', 'DensityMatrixTrialResult', 'ExpressionMap', + 'GateFamily', + 'Gateset', 'InsertStrategy', 'IonDevice', 'KakDecomposition',