diff --git a/cirq-core/cirq/circuits/qasm_output.py b/cirq-core/cirq/circuits/qasm_output.py index 971d8c2e72d..35abb4c26a5 100644 --- a/cirq-core/cirq/circuits/qasm_output.py +++ b/cirq-core/cirq/circuits/qasm_output.py @@ -17,12 +17,14 @@ from __future__ import annotations import re -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence, Set from typing import TYPE_CHECKING import numpy as np +import sympy from cirq import linalg, ops, protocols, value +from cirq._compat import proper_repr if TYPE_CHECKING: import cirq @@ -30,7 +32,7 @@ @value.value_equality(approximate=True) class QasmUGate(ops.Gate): - def __init__(self, theta, phi, lmda) -> None: + def __init__(self, theta: cirq.TParamVal, phi: cirq.TParamVal, lmda: cirq.TParamVal) -> None: """A QASM gate representing any single qubit unitary with a series of three rotations, Z, Y, and Z. @@ -41,9 +43,9 @@ def __init__(self, theta, phi, lmda) -> None: phi: Half turns to rotate about Z (applied last). lmda: Half turns to rotate about Z (applied first). """ - self.lmda = lmda % 2 self.theta = theta % 2 self.phi = phi % 2 + self.lmda = lmda % 2 def _num_qubits_(self) -> int: return 1 @@ -54,7 +56,28 @@ def from_matrix(mat: np.ndarray) -> QasmUGate: return QasmUGate(rotation / np.pi, post_phase / np.pi, pre_phase / np.pi) def _has_unitary_(self): - return True + return not self._is_parameterized_() + + def _is_parameterized_(self) -> bool: + return ( + protocols.is_parameterized(self.theta) + or protocols.is_parameterized(self.phi) + or protocols.is_parameterized(self.lmda) + ) + + def _parameter_names_(self) -> Set[str]: + return ( + protocols.parameter_names(self.theta) + | protocols.parameter_names(self.phi) + | protocols.parameter_names(self.lmda) + ) + + def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> QasmUGate: + return QasmUGate( + protocols.resolve_parameters(self.theta, resolver, recursive), + protocols.resolve_parameters(self.phi, resolver, recursive), + protocols.resolve_parameters(self.lmda, resolver, recursive), + ) def _qasm_(self, qubits: tuple[cirq.Qid, ...], args: cirq.QasmArgs) -> str: args.validate_version('2.0', '3.0') @@ -69,18 +92,21 @@ def _qasm_(self, qubits: tuple[cirq.Qid, ...], args: cirq.QasmArgs) -> str: def __repr__(self) -> str: return ( f'cirq.circuits.qasm_output.QasmUGate(' - f'theta={self.theta!r}, ' - f'phi={self.phi!r}, ' - f'lmda={self.lmda})' + f'theta={proper_repr(self.theta)}, ' + f'phi={proper_repr(self.phi)}, ' + f'lmda={proper_repr(self.lmda)})' ) def _decompose_(self, qubits): + def mul_pi(x): + return x * (sympy.pi if protocols.is_parameterized(x) else np.pi) + q = qubits[0] phase_correction_half_turns = (self.phi + self.lmda) / 2 return [ - ops.rz(self.lmda * np.pi).on(q), - ops.ry(self.theta * np.pi).on(q), - ops.rz(self.phi * np.pi).on(q), + ops.rz(mul_pi(self.lmda)).on(q), + ops.ry(mul_pi(self.theta)).on(q), + ops.rz(mul_pi(self.phi)).on(q), ops.global_phase_operation(1j ** (2 * phase_correction_half_turns)), ] diff --git a/cirq-core/cirq/circuits/qasm_output_test.py b/cirq-core/cirq/circuits/qasm_output_test.py index 8285ffa7d5c..12f4e2cbf05 100644 --- a/cirq-core/cirq/circuits/qasm_output_test.py +++ b/cirq-core/cirq/circuits/qasm_output_test.py @@ -19,6 +19,7 @@ import numpy as np import pytest +import sympy import cirq from cirq.circuits.qasm_output import QasmTwoQubitGate, QasmUGate @@ -68,6 +69,28 @@ def test_u_gate_from_qiskit_ugate_unitary(_) -> None: np.testing.assert_allclose(cirq.unitary(g), u, atol=1e-7) +def test_u_gate_params() -> None: + q = cirq.LineQubit(0) + a, b, c = sympy.symbols('a b c') + u_gate = QasmUGate(a, b, c) + assert u_gate == QasmUGate(a, b + 2, c - 2) + assert u_gate != QasmUGate(a, b + 1, c - 1) + assert cirq.is_parameterized(u_gate) + assert cirq.parameter_names(u_gate) == {'a', 'b', 'c'} + assert not cirq.has_unitary(u_gate) + cirq.testing.assert_equivalent_repr(u_gate) + cirq.testing.assert_implements_consistent_protocols(u_gate) + u_gate_caps = cirq.resolve_parameters(u_gate, {'a': 'A', 'b': 'B', 'c': 'C'}) + assert u_gate_caps == QasmUGate(*sympy.symbols('A B C')) + resolver = {'A': 0.1, 'B': 2.2, 'C': -1.7} + resolved = cirq.resolve_parameters(u_gate_caps, resolver) + assert cirq.approx_eq(resolved, QasmUGate(0.1, 0.2, 0.3)) + resolved_then_decomposed = cirq.decompose_once_with_qubits(resolved, [q]) + decomposed = cirq.decompose_once_with_qubits(u_gate_caps, [q]) + decomposed_then_resolved = [cirq.resolve_parameters(g, resolver) for g in decomposed] + assert resolved_then_decomposed == decomposed_then_resolved + + def test_qasm_two_qubit_gate_repr() -> None: cirq.testing.assert_equivalent_repr( QasmTwoQubitGate.from_matrix(cirq.testing.random_unitary(4))