diff --git a/cirq-pasqal/cirq_pasqal/pasqal_device.py b/cirq-pasqal/cirq_pasqal/pasqal_device.py index 28375eb9856..ee7dd2dbd43 100644 --- a/cirq-pasqal/cirq_pasqal/pasqal_device.py +++ b/cirq-pasqal/cirq_pasqal/pasqal_device.py @@ -63,6 +63,21 @@ def __init__(self, qubits: Sequence[cirq.ops.Qid]) -> None: 'qubits.'.format(type(self), self.maximum_qubit_number) ) + self.gateset = cirq.Gateset( + cirq.ParallelGateFamily(cirq.H), + cirq.ParallelGateFamily(cirq.PhasedXPowGate), + cirq.ParallelGateFamily(cirq.XPowGate), + cirq.ParallelGateFamily(cirq.YPowGate), + cirq.ParallelGateFamily(cirq.ZPowGate), + cirq.AnyIntegerPowerGateFamily(cirq.CNotPowGate), + cirq.AnyIntegerPowerGateFamily(cirq.CCNotPowGate), + cirq.AnyIntegerPowerGateFamily(cirq.CZPowGate), + cirq.AnyIntegerPowerGateFamily(cirq.CCZPowGate), + cirq.IdentityGate, + cirq.MeasurementGate, + unroll_circuit_op=False, + accept_global_phase=False, + ) self.qubits = qubits # pylint: enable=missing-raises-doc @@ -96,43 +111,9 @@ def decompose_operation(self, operation: cirq.ops.Operation) -> 'cirq.OP_TREE': return decomposition def is_pasqal_device_op(self, op: cirq.ops.Operation) -> bool: - if not isinstance(op, cirq.ops.Operation): raise ValueError('Got unknown operation:', op) - - if isinstance(op.gate, cirq.ops.MeasurementGate): - return True - - op_gate = op.gate.sub_gate if isinstance(op.gate, cirq.ops.ParallelGate) else op.gate - - if isinstance( - op_gate, - ( - cirq.ops.IdentityGate, - cirq.ops.PhasedXPowGate, - cirq.ops.XPowGate, - cirq.ops.YPowGate, - cirq.ops.ZPowGate, - ), - ): - return True - - if ( - isinstance( - op_gate, - ( - cirq.ops.HPowGate, - cirq.ops.CNotPowGate, - cirq.ops.CZPowGate, - cirq.ops.CCZPowGate, - cirq.ops.CCXPowGate, - ), - ) - and not cirq.is_parameterized(op) - ): - expo = op_gate.exponent - return np.isclose(expo, np.around(expo, decimals=0)) - return False + return op in self.gateset # TODO(#3388) Add documentation for Raises. # pylint: disable=missing-raises-doc @@ -267,6 +248,15 @@ def __init__( ) self.control_radius = control_radius + self.exclude_gateset = cirq.Gateset( + cirq.AnyIntegerPowerGateFamily(cirq.CNotPowGate), + cirq.AnyIntegerPowerGateFamily(cirq.CCNotPowGate), + cirq.AnyIntegerPowerGateFamily(cirq.CCZPowGate), + ) + self.controlled_gateset = cirq.Gateset( + *self.exclude_gateset.gates, + cirq.AnyIntegerPowerGateFamily(cirq.CZPowGate), + ) @property def supported_qubit_type(self): @@ -278,9 +268,7 @@ def supported_qubit_type(self): ) def is_pasqal_device_op(self, op: cirq.ops.Operation) -> bool: - return super().is_pasqal_device_op(op) and not isinstance( - op.gate, (cirq.ops.CNotPowGate, cirq.ops.CCZPowGate, cirq.ops.CCXPowGate) - ) + return super().is_pasqal_device_op(op) and op not in self.exclude_gateset def validate_operation(self, operation: cirq.ops.Operation): """Raises an error if the given operation is invalid on this device. @@ -293,14 +281,11 @@ def validate_operation(self, operation: cirq.ops.Operation): super().validate_operation(operation) # Verify that a controlled gate operation is valid - if isinstance(operation, cirq.ops.GateOperation): - if len(operation.qubits) > 1 and not isinstance( - operation.gate, (cirq.ops.MeasurementGate, cirq.ops.ParallelGate) - ): - for p in operation.qubits: - for q in operation.qubits: - if self.distance(p, q) > self.control_radius: - raise ValueError(f"Qubits {p!r}, {q!r} are too far away") + if operation in self.controlled_gateset: + for p in operation.qubits: + for q in operation.qubits: + if self.distance(p, q) > self.control_radius: + raise ValueError(f"Qubits {p!r}, {q!r} are too far away") def validate_moment(self, moment: cirq.ops.Moment): """Raises an error if the given moment is invalid on this device.