diff --git a/cirq-core/cirq/ops/dense_pauli_string.py b/cirq-core/cirq/ops/dense_pauli_string.py index ccbbe361a3c..75560b8ff85 100644 --- a/cirq-core/cirq/ops/dense_pauli_string.py +++ b/cirq-core/cirq/ops/dense_pauli_string.py @@ -614,13 +614,13 @@ def _as_pauli_mask(val: Iterable[cirq.PAULI_GATE_LIKE] | np.ndarray) -> np.ndarr def _attempt_value_to_pauli_index(v: cirq.Operation) -> tuple[int, int] | None: - if not isinstance(v, raw_types.Operation): + if (ps := pauli_string._try_interpret_as_pauli_string(v)) is None: return None - if not isinstance(v.gate, pauli_gates.Pauli): + if len(ps.qubits) != 1: return None # pragma: no cover - q = v.qubits[0] + q = ps.qubits[0] from cirq import devices if not isinstance(q, devices.LineQubit): @@ -629,7 +629,7 @@ def _attempt_value_to_pauli_index(v: cirq.Operation) -> tuple[int, int] | None: 'other than `cirq.LineQubit` so its dense index is ambiguous.\n' f'v={repr(v)}.' ) - return pauli_string.PAULI_GATE_LIKE_TO_INDEX_MAP[v.gate], q.x + return pauli_string.PAULI_GATE_LIKE_TO_INDEX_MAP[ps[q]], q.x def _vectorized_pauli_mul_phase(lhs: int | np.ndarray, rhs: int | np.ndarray) -> complex: diff --git a/cirq-core/cirq/ops/dense_pauli_string_test.py b/cirq-core/cirq/ops/dense_pauli_string_test.py index da58141acde..f8c092a25bc 100644 --- a/cirq-core/cirq/ops/dense_pauli_string_test.py +++ b/cirq-core/cirq/ops/dense_pauli_string_test.py @@ -503,6 +503,10 @@ def test_commutes(): assert cirq.commutes(f('IIIXII'), cirq.X(cirq.LineQubit(2))) assert not cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(3))) assert cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(2))) + assert cirq.commutes(f('IIIXII'), cirq.X(cirq.LineQubit(3)) ** 3) + assert cirq.commutes(f('IIIXII'), cirq.X(cirq.LineQubit(2)) ** 3) + assert not cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(3)) ** 3) + assert cirq.commutes(f('IIIXII'), cirq.Z(cirq.LineQubit(2)) ** 3) assert cirq.commutes(f('XX'), "test", default=NotImplemented) is NotImplemented diff --git a/cirq-core/cirq/ops/parity_gates_test.py b/cirq-core/cirq/ops/parity_gates_test.py index be5495ef398..47de0bbc979 100644 --- a/cirq-core/cirq/ops/parity_gates_test.py +++ b/cirq-core/cirq/ops/parity_gates_test.py @@ -352,3 +352,38 @@ def test_clifford_protocols( else: assert not cirq.has_stabilizer_effect(gate) assert gate._decompose_into_clifford_with_qubits_(cirq.LineQubit.range(2)) is NotImplemented + + +def test_parity_gate_multiplication(): + q1, q2, q3 = cirq.LineQubit.range(3) + + # XX gate + xx_12 = cirq.XX(q1, q2) + xx_23 = cirq.XX(q2, q3) + result = xx_12 * xx_23 + expected = cirq.PauliString({q1: cirq.X, q3: cirq.X}) + assert result == expected + + # YY gate + yy_12 = cirq.YY(q1, q2) + yy_23 = cirq.YY(q2, q3) + result_yy = yy_12 * yy_23 + expected_yy = cirq.PauliString({q1: cirq.Y, q3: cirq.Y}) + assert result_yy == expected_yy + + # ZZ gate + zz_12 = cirq.ZZ(q1, q2) + zz_23 = cirq.ZZ(q2, q3) + result_zz = zz_12 * zz_23 + expected_zz = cirq.PauliString({q1: cirq.Z, q3: cirq.Z}) + assert result_zz == expected_zz + + +def test_parity_gate_multiplication_same_qubits(): + q1, q2 = cirq.LineQubit.range(2) + + # XX * XX should be identity + xx = cirq.XX(q1, q2) + result = xx * xx + expected = cirq.PauliString({q1: cirq.I, q2: cirq.I}) + assert result == expected diff --git a/cirq-core/cirq/ops/pauli_string.py b/cirq-core/cirq/ops/pauli_string.py index c86753a11ee..1fb90a8ce15 100644 --- a/cirq-core/cirq/ops/pauli_string.py +++ b/cirq-core/cirq/ops/pauli_string.py @@ -266,16 +266,9 @@ def __mul__(self, other: complex) -> cirq.PauliString[TKey]: pass def __mul__(self, other): - known = False - if isinstance(other, raw_types.Operation) and isinstance(other.gate, identity.IdentityGate): - known = True - elif isinstance(other, (PauliString, numbers.Number)): - known = True - if known: + if isinstance(other, (PauliString, numbers.Number)): return PauliString( - cast(PAULI_STRING_LIKE, other), - qubit_pauli_map=self._qubit_pauli_map, - coefficient=self.coefficient, + other, qubit_pauli_map=self._qubit_pauli_map, coefficient=self.coefficient ) return NotImplemented @@ -295,9 +288,6 @@ def __rmul__(self, other) -> PauliString: qubit_pauli_map=self._qubit_pauli_map, coefficient=self._coefficient * other ) - if isinstance(other, raw_types.Operation) and isinstance(other.gate, identity.IdentityGate): - return self # pragma: no cover - # Note: PauliString case handled by __mul__. return NotImplemented @@ -1100,21 +1090,16 @@ def _validate_qubit_mapping( ) -def _try_interpret_as_pauli_string(op: Any): +def _try_interpret_as_pauli_string(op: Any) -> PauliString | None: """Return a reprepresentation of an operation as a pauli string, if it is possible.""" - if isinstance(op, gate_operation.GateOperation): - gates = { - common_gates.XPowGate: pauli_gates.X, - common_gates.YPowGate: pauli_gates.Y, - common_gates.ZPowGate: pauli_gates.Z, - } - if (pauli := gates.get(type(op.gate), None)) is not None: - exponent = op.gate.exponent # type: ignore - if exponent % 2 == 0: - return PauliString() - if exponent % 2 == 1: - return pauli.on(op.qubits[0]) - return None + if not isinstance(op, raw_types.Operation): + return None + + pauli_expansion_op = protocols.pauli_expansion(op, default=None) + if pauli_expansion_op is None or len(pauli_expansion_op) != 1: + return None + gates, coef = next(iter(pauli_expansion_op.items())) + return PauliString(dict(zip(op.qubits, gates)), coefficient=coef) # Ignoring type because mypy believes `with_qubits` methods are incompatible. @@ -1148,28 +1133,6 @@ def qubit(self) -> raw_types.Qid: assert len(self.qubits) == 1 return self.qubits[0] - def _as_pauli_string(self) -> PauliString: - return PauliString(qubit_pauli_map={self.qubit: self.pauli}) - - def __mul__(self, other): - if isinstance(other, SingleQubitPauliStringGateOperation): - return self._as_pauli_string() * other._as_pauli_string() - if isinstance(other, (PauliString, numbers.Complex)): - return self._as_pauli_string() * other - if (as_pauli_string := _try_interpret_as_pauli_string(other)) is not None: - return self * as_pauli_string - return NotImplemented - - def __rmul__(self, other): - if isinstance(other, (PauliString, numbers.Complex)): - return other * self._as_pauli_string() - if (as_pauli_string := _try_interpret_as_pauli_string(other)) is not None: - return as_pauli_string * self - return NotImplemented - - def __neg__(self): - return -self._as_pauli_string() - def _json_dict_(self) -> dict[str, Any]: return protocols.obj_to_dict_helper(self, ['pauli', 'qubit']) diff --git a/cirq-core/cirq/ops/pauli_string_test.py b/cirq-core/cirq/ops/pauli_string_test.py index 99cea2c20dc..cd1d26fb63d 100644 --- a/cirq-core/cirq/ops/pauli_string_test.py +++ b/cirq-core/cirq/ops/pauli_string_test.py @@ -2025,3 +2025,44 @@ def test_pauli_ops_identity_gate_operation(gate1: cirq.Pauli, gate2: cirq.Pauli) subtraction = pauli1 - pauli2 assert isinstance(subtraction, cirq.PauliSum) assert np.array_equal(subtraction.matrix(), unitary1 - unitary2) + + +def test_pauli_gate_multiplication_with_power(): + q = cirq.LineQubit(0) + + # Test all Pauli gates (X, Y, Z) + pauli_gates = [cirq.X, cirq.Y, cirq.Z] + for pauli_gate in pauli_gates: + gate = pauli_gate(q) + + # Test multiplication + assert gate**2 * gate * gate * gate == gate**5 + assert gate * gate**2 * gate * gate == gate**5 + assert gate * gate * gate**2 * gate == gate**5 + assert gate * gate * gate * gate**2 == gate**5 + + # Test with different powers + assert gate**0 * gate**5 == gate**5 + assert gate**1 * gate**4 == gate**5 + assert gate**2 * gate**3 == gate**5 + assert gate**3 * gate**2 == gate**5 + assert gate**4 * gate**1 == gate**5 + assert gate**5 * gate**0 == gate**5 + + +def test_try_interpret_as_pauli_string(): + from cirq.ops.pauli_string import _try_interpret_as_pauli_string + + q = cirq.LineQubit(0) + + # Pauli gate operation + x_gate = cirq.X(q) + assert _try_interpret_as_pauli_string(x_gate) == cirq.PauliString({q: cirq.X}) + + # powered gates + x_squared = x_gate**2 + assert _try_interpret_as_pauli_string(x_squared) == cirq.PauliString({q: cirq.I}) + + # non-Pauli operation + h_gate = cirq.H(q) + assert _try_interpret_as_pauli_string(h_gate) is None diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index a4548b1a303..efd4f27f897 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -483,10 +483,18 @@ def _commutes_(self, other: Any, *, atol: float = 1e-8) -> None | NotImplemented def _mul_with_qubits(self, qubits: tuple[cirq.Qid, ...], other): """cirq.GateOperation.__mul__ delegates to this method.""" + from cirq.ops.pauli_string import _try_interpret_as_pauli_string + + if (as_pauli_string := _try_interpret_as_pauli_string(self.on(*qubits))) is not None: + return as_pauli_string * other return NotImplemented def _rmul_with_qubits(self, qubits: tuple[cirq.Qid, ...], other): """cirq.GateOperation.__rmul__ delegates to this method.""" + from cirq.ops.pauli_string import _try_interpret_as_pauli_string + + if (as_pauli_string := _try_interpret_as_pauli_string(self.on(*qubits))) is not None: + return other * as_pauli_string return NotImplemented def _json_dict_(self) -> dict[str, Any]: