From 4f0ba49e187b1ef3088fcae99a4ce0ff09eaeda0 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Mon, 14 Mar 2022 18:47:53 -0700 Subject: [PATCH 01/14] Allow coefficient to be parameterized --- cirq-core/cirq/ops/global_phase_op.py | 29 +++++++++++++++---- cirq-core/cirq/ops/pauli_string.py | 40 ++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/cirq-core/cirq/ops/global_phase_op.py b/cirq-core/cirq/ops/global_phase_op.py index 139f51ab177..e71e6a158e1 100644 --- a/cirq-core/cirq/ops/global_phase_op.py +++ b/cirq-core/cirq/ops/global_phase_op.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """A no-qubit global phase operation.""" -from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING +import numbers +from typing import AbstractSet, Any, Dict, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np +import sympy from cirq import value, protocols from cirq._compat import deprecated_class @@ -24,6 +26,9 @@ import cirq +ComplexParam = Union[value.Scalar, sympy.Basic] + + @value.value_equality(approximate=True) @deprecated_class(deadline='v0.16', fix='Use cirq.global_phase_operation') class GlobalPhaseOperation(gate_operation.GateOperation): @@ -57,20 +62,20 @@ def _json_dict_(self) -> Dict[str, Any]: @value.value_equality(approximate=True) class GlobalPhaseGate(raw_types.Gate): - def __init__(self, coefficient: value.Scalar, atol: float = 1e-8) -> None: - if abs(1 - abs(coefficient)) > atol: + def __init__(self, coefficient: ComplexParam, atol: float = 1e-8) -> None: + if not isinstance(coefficient, sympy.Basic) and abs(1 - abs(coefficient)) > atol: raise ValueError(f'Coefficient is not unitary: {coefficient!r}') self._coefficient = coefficient @property - def coefficient(self) -> value.Scalar: + def coefficient(self) -> ComplexParam: return self._coefficient def _value_equality_values_(self) -> Any: return self.coefficient def _has_unitary_(self) -> bool: - return True + return not self._is_parameterized_() def __pow__(self, power) -> 'cirq.GlobalPhaseGate': if isinstance(power, (int, float)): @@ -102,6 +107,18 @@ def _json_dict_(self) -> Dict[str, Any]: def _qid_shape_(self) -> Tuple[int, ...]: return tuple() + def _is_parameterized_(self) -> bool: + return protocols.is_parameterized(self.coefficient) + + def _parameter_names_(self) -> AbstractSet[str]: + return protocols.parameter_names(self.coefficient) + + def _resolve_parameters_( + self, resolver: 'cirq.ParamResolver', recursive: bool + ) -> 'cirq.GlobalPhaseGate': + coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive) + return GlobalPhaseGate(coefficient=coefficient) + -def global_phase_operation(coefficient: value.Scalar, atol: float = 1e-8) -> 'cirq.GateOperation': +def global_phase_operation(coefficient: ComplexParam, atol: float = 1e-8) -> 'cirq.GateOperation': return GlobalPhaseGate(coefficient, atol)() diff --git a/cirq-core/cirq/ops/pauli_string.py b/cirq-core/cirq/ops/pauli_string.py index 9a44cc78176..7f606c29ab9 100644 --- a/cirq-core/cirq/ops/pauli_string.py +++ b/cirq-core/cirq/ops/pauli_string.py @@ -39,6 +39,7 @@ ) import numpy as np +import sympy from cirq import value, protocols, linalg, qis from cirq._doc import document @@ -115,7 +116,7 @@ def __init__( self, *contents: 'cirq.PAULI_STRING_LIKE', qubit_pauli_map: Optional[Dict[TKey, 'cirq.Pauli']] = None, - coefficient: Union[int, float, complex] = 1, + coefficient: Union[sympy.Basic, int, float, complex] = 1, ): """Initializes a new PauliString. @@ -151,7 +152,7 @@ def __init__( argument specifies values that are logically *before* factors specified in `contents`; `contents` are *right* multiplied onto the values in this dictionary. - coefficient: Initial scalar coefficient. Defaults to 1. + coefficient: Initial scalar coefficient or symbol. Defaults to 1. Raises: TypeError: If the `qubit_pauli_map` has values that are not Paulis. @@ -162,14 +163,14 @@ def __init__( raise TypeError(f'{v} is not a Pauli') self._qubit_pauli_map: Dict[TKey, 'cirq.Pauli'] = qubit_pauli_map or {} - self._coefficient = complex(coefficient) + self._coefficient = complex(coefficient) if coefficient is numbers.Number else coefficient if contents: m = self.mutable_copy().inplace_left_multiply_by(contents).frozen() self._qubit_pauli_map = m._qubit_pauli_map self._coefficient = m._coefficient @property - def coefficient(self) -> complex: + def coefficient(self) -> Union[sympy.Basic, complex]: return self._coefficient def _value_equality_values_(self): @@ -340,8 +341,10 @@ def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs') -> List[st prefix = 'i' elif self.coefficient == -1j: prefix = '-i' - else: + elif isinstance(self.coefficient, numbers.Number): prefix = f'({args.format_complex(self.coefficient)})*' + else: + prefix = f'({self.coefficient})*' symbols[0] = f'PauliString({prefix}{symbols[0]})' return symbols @@ -441,6 +444,8 @@ def matrix(self, qubits: Optional[Iterable[TKey]] = None) -> np.ndarray: return linalg.kron(self.coefficient, *[protocols.unitary(f) for f in factors]) def _has_unitary_(self) -> bool: + if self._is_parameterized_(): + return False return abs(1 - abs(self.coefficient)) < 1e-6 def _unitary_(self) -> Optional[np.ndarray]: @@ -498,6 +503,9 @@ def expectation_from_state_vector( TypeError: If the input state is not complex. ValueError: If the input state does not have the correct shape. """ + if self._is_parameterized_(): + raise ValueError('Cannot get expectation value when parameterized') + if abs(self.coefficient.imag) > 0.0001: raise NotImplementedError( 'Cannot compute expectation value of a non-Hermitian ' @@ -602,6 +610,8 @@ def expectation_from_density_matrix( TypeError: If the input state is not complex. ValueError: If the input state does not have the correct shape. """ + if self._is_parameterized_(): + raise ValueError('Cannot get expectation value when parameterized') if abs(self.coefficient.imag) > 0.0001: raise NotImplementedError( 'Cannot compute expectation value of a non-Hermitian ' @@ -703,6 +713,8 @@ def __pow__(self, power): return PauliString( qubit_pauli_map=self._qubit_pauli_map, coefficient=self.coefficient ** -1 ) + if self._is_parameterized_(): + raise ValueError('Cannot raise to power when parameterized.') if isinstance(power, (int, float)): r, i = cmath.polar(self.coefficient) if abs(r - 1) > 0.0001: @@ -731,6 +743,8 @@ def __pow__(self, power): return NotImplemented def __rpow__(self, base): + if self._is_parameterized_(): + raise ValueError('Cannot raise to power when parameterized.') if isinstance(base, (int, float)) and base > 0: if abs(self.coefficient.real) > 0.0001: raise NotImplementedError( @@ -946,6 +960,18 @@ def pass_operations_over( coef = -self._coefficient if should_negate else self.coefficient return PauliString(qubit_pauli_map=pauli_map, coefficient=coef) + def _is_parameterized_(self) -> bool: + return protocols.is_parameterized(self.coefficient) + + def _parameter_names_(self) -> AbstractSet[str]: + return protocols.parameter_names(self.coefficient) + + def _resolve_parameters_( + self, resolver: 'cirq.ParamResolver', recursive: bool + ) -> 'cirq.PauliString': + coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive) + return PauliString(qubit_pauli_map=self._qubit_pauli_map, coefficient=coefficient) + def _validate_qubit_mapping( qubit_map: Mapping[TKey, int], pauli_qubits: Tuple[TKey, ...], num_state_qubits: int @@ -1053,10 +1079,10 @@ class MutablePauliString(Generic[TKey]): def __init__( self, *contents: 'cirq.PAULI_STRING_LIKE', - coefficient: Union[int, float, complex] = 1, + coefficient: Union[sympy.Basic, int, float, complex] = 1, pauli_int_dict: Optional[Dict[TKey, int]] = None, ): - self.coefficient = complex(coefficient) + self.coefficient = complex(coefficient) if coefficient is numbers.Number else coefficient self.pauli_int_dict: Dict[TKey, int] = {} if pauli_int_dict is None else pauli_int_dict if contents: self.inplace_left_multiply_by(contents) From 025f054076f23be71716a5bc2c81f605704db4b1 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 16 Mar 2022 17:57:26 -0700 Subject: [PATCH 02/14] add tests --- cirq-core/cirq/ops/dense_pauli_string_test.py | 7 +++- cirq-core/cirq/ops/global_phase_op.py | 1 - cirq-core/cirq/ops/global_phase_op_test.py | 18 +++++++++ cirq-core/cirq/ops/pauli_string.py | 11 +++--- cirq-core/cirq/ops/pauli_string_test.py | 37 +++++++++++++++++++ 5 files changed, 67 insertions(+), 7 deletions(-) diff --git a/cirq-core/cirq/ops/dense_pauli_string_test.py b/cirq-core/cirq/ops/dense_pauli_string_test.py index 99dbed21b8a..ed68dfe518e 100644 --- a/cirq-core/cirq/ops/dense_pauli_string_test.py +++ b/cirq-core/cirq/ops/dense_pauli_string_test.py @@ -397,11 +397,16 @@ def test_protocols(): def test_parameterizable(resolve_fn): t = sympy.Symbol('t') x = cirq.DensePauliString('X') + xt = x * t + x2 = x * 2 + q = cirq.LineQubit(0) assert not cirq.is_parameterized(x) assert not cirq.is_parameterized(x * 2) assert cirq.is_parameterized(x * t) - assert resolve_fn(x * t, {'t': 2}) == x * 2 + assert resolve_fn(xt, {'t': 2}) == x2 assert resolve_fn(x * 3, {'t': 2}) == x * 3 + assert resolve_fn(xt(q), {'t': 2}).gate == x2 + assert resolve_fn(xt(q).gate, {'t': 2}) == x2 def test_item_immutable(): diff --git a/cirq-core/cirq/ops/global_phase_op.py b/cirq-core/cirq/ops/global_phase_op.py index e71e6a158e1..04a51174f7e 100644 --- a/cirq-core/cirq/ops/global_phase_op.py +++ b/cirq-core/cirq/ops/global_phase_op.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """A no-qubit global phase operation.""" -import numbers from typing import AbstractSet, Any, Dict, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np diff --git a/cirq-core/cirq/ops/global_phase_op_test.py b/cirq-core/cirq/ops/global_phase_op_test.py index 1780aabb898..3197239b694 100644 --- a/cirq-core/cirq/ops/global_phase_op_test.py +++ b/cirq-core/cirq/ops/global_phase_op_test.py @@ -14,6 +14,7 @@ import numpy as np import pytest +import sympy import cirq @@ -345,3 +346,20 @@ def test_gate_global_phase_op_json_dict(): assert cirq.GlobalPhaseGate(-1j)._json_dict_() == { 'coefficient': -1j, } + + +def test_parameterization(): + t = sympy.Symbol('t') + gpt = cirq.GlobalPhaseGate(coefficient=t) + assert cirq.is_parameterized(gpt) + assert cirq.parameter_names(gpt) == {'t'} + assert not cirq.has_unitary(gpt) + assert gpt.coefficient == t + assert (gpt ** 2).coefficient == t ** 2 + + +@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) +def test_resolve(resolve_fn): + t = sympy.Symbol('t') + gpt = cirq.GlobalPhaseGate(coefficient=t) + assert resolve_fn(gpt, {'t': -1}) == cirq.GlobalPhaseGate(coefficient=-1) diff --git a/cirq-core/cirq/ops/pauli_string.py b/cirq-core/cirq/ops/pauli_string.py index 7f606c29ab9..8d33cb93210 100644 --- a/cirq-core/cirq/ops/pauli_string.py +++ b/cirq-core/cirq/ops/pauli_string.py @@ -67,6 +67,7 @@ TKey = TypeVar('TKey', bound=raw_types.Qid) TKeyNew = TypeVar('TKeyNew', bound=raw_types.Qid) TKeyOther = TypeVar('TKeyOther', bound=raw_types.Qid) +ComplexParam = Union[value.Scalar, sympy.Basic] # A value that can be unambiguously converted into a `cirq.PauliString`. @@ -116,7 +117,7 @@ def __init__( self, *contents: 'cirq.PAULI_STRING_LIKE', qubit_pauli_map: Optional[Dict[TKey, 'cirq.Pauli']] = None, - coefficient: Union[sympy.Basic, int, float, complex] = 1, + coefficient: ComplexParam = 1, ): """Initializes a new PauliString. @@ -163,7 +164,7 @@ def __init__( raise TypeError(f'{v} is not a Pauli') self._qubit_pauli_map: Dict[TKey, 'cirq.Pauli'] = qubit_pauli_map or {} - self._coefficient = complex(coefficient) if coefficient is numbers.Number else coefficient + self._coefficient = coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient) if contents: m = self.mutable_copy().inplace_left_multiply_by(contents).frozen() self._qubit_pauli_map = m._qubit_pauli_map @@ -354,7 +355,7 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'PauliString': coefficient=self._coefficient, ) - def with_coefficient(self, new_coefficient: Union[int, float, complex]) -> 'PauliString': + def with_coefficient(self, new_coefficient: ComplexParam) -> 'PauliString': return PauliString(qubit_pauli_map=dict(self._qubit_pauli_map), coefficient=new_coefficient) def values(self) -> ValuesView[pauli_gates.Pauli]: @@ -1079,10 +1080,10 @@ class MutablePauliString(Generic[TKey]): def __init__( self, *contents: 'cirq.PAULI_STRING_LIKE', - coefficient: Union[sympy.Basic, int, float, complex] = 1, + coefficient: ComplexParam = 1, pauli_int_dict: Optional[Dict[TKey, int]] = None, ): - self.coefficient = complex(coefficient) if coefficient is numbers.Number else coefficient + self.coefficient = coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient) self.pauli_int_dict: Dict[TKey, int] = {} if pauli_int_dict is None else pauli_int_dict if contents: self.inplace_left_multiply_by(contents) diff --git a/cirq-core/cirq/ops/pauli_string_test.py b/cirq-core/cirq/ops/pauli_string_test.py index 04f4fd8785f..a6fb6ce6c75 100644 --- a/cirq-core/cirq/ops/pauli_string_test.py +++ b/cirq-core/cirq/ops/pauli_string_test.py @@ -1944,3 +1944,40 @@ def test_transform_qubits(): assert m is m2 assert m == p2 assert m2 == p2 + + +def test_parameterization(): + t = sympy.Symbol('t') + q = cirq.LineQubit(0) + pst = cirq.PauliString({q: 'x'}, coefficient=t) + assert cirq.is_parameterized(pst) + assert cirq.parameter_names(pst) == {'t'} + assert pst.coefficient == 1.0 * t + assert not cirq.has_unitary(pst) + assert not cirq.is_parameterized(pst.with_coefficient(2)) + with pytest.raises(TypeError): + cirq.decompose_once(pst) + with pytest.raises(ValueError, match='parameterized'): + pst.expectation_from_state_vector(np.array([]), {}) + with pytest.raises(ValueError, match='parameterized'): + pst.expectation_from_density_matrix(np.array([]), {}) + assert pst ** 1 == pst + assert pst ** -1 == pst.with_coefficient(1.0/t) + assert (-pst) ** 1 == -pst + assert (-pst) ** -1 == -pst.with_coefficient(1.0/t) + assert (1j * pst) ** 1 == 1j * pst + assert (1j * pst) ** -1 == -1j * pst.with_coefficient(1.0/t) + with pytest.raises(ValueError, match='parameterized'): + pst ** 2 + with pytest.raises(ValueError, match='parameterized'): + 1 ** pst + cirq.testing.assert_has_diagram(cirq.Circuit(pst), '0: ───PauliString((1.0*t)*X)───') + + +@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) +def test_resolve(resolve_fn): + t = sympy.Symbol('t') + q = cirq.LineQubit(0) + pst = cirq.PauliString({q: 'x'}, coefficient=t) + ps1 = cirq.PauliString({q: 'x'}, coefficient=1) + assert resolve_fn(pst, {'t': 1}) == ps1 From b65b4a6d9edefe9dfd312cb563b92a8370aad851 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 16 Mar 2022 18:03:30 -0700 Subject: [PATCH 03/14] format --- cirq-core/cirq/ops/pauli_string.py | 10 +++++++--- cirq-core/cirq/ops/pauli_string_test.py | 6 +++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/ops/pauli_string.py b/cirq-core/cirq/ops/pauli_string.py index 8d33cb93210..ec3faef2f99 100644 --- a/cirq-core/cirq/ops/pauli_string.py +++ b/cirq-core/cirq/ops/pauli_string.py @@ -164,14 +164,16 @@ def __init__( raise TypeError(f'{v} is not a Pauli') self._qubit_pauli_map: Dict[TKey, 'cirq.Pauli'] = qubit_pauli_map or {} - self._coefficient = coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient) + self._coefficient = ( + coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient) + ) if contents: m = self.mutable_copy().inplace_left_multiply_by(contents).frozen() self._qubit_pauli_map = m._qubit_pauli_map self._coefficient = m._coefficient @property - def coefficient(self) -> Union[sympy.Basic, complex]: + def coefficient(self) -> ComplexParam: return self._coefficient def _value_equality_values_(self): @@ -1083,7 +1085,9 @@ def __init__( coefficient: ComplexParam = 1, pauli_int_dict: Optional[Dict[TKey, int]] = None, ): - self.coefficient = coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient) + self.coefficient = ( + coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient) + ) self.pauli_int_dict: Dict[TKey, int] = {} if pauli_int_dict is None else pauli_int_dict if contents: self.inplace_left_multiply_by(contents) diff --git a/cirq-core/cirq/ops/pauli_string_test.py b/cirq-core/cirq/ops/pauli_string_test.py index a6fb6ce6c75..d5ff7947f6c 100644 --- a/cirq-core/cirq/ops/pauli_string_test.py +++ b/cirq-core/cirq/ops/pauli_string_test.py @@ -1962,11 +1962,11 @@ def test_parameterization(): with pytest.raises(ValueError, match='parameterized'): pst.expectation_from_density_matrix(np.array([]), {}) assert pst ** 1 == pst - assert pst ** -1 == pst.with_coefficient(1.0/t) + assert pst ** -1 == pst.with_coefficient(1.0 / t) assert (-pst) ** 1 == -pst - assert (-pst) ** -1 == -pst.with_coefficient(1.0/t) + assert (-pst) ** -1 == -pst.with_coefficient(1.0 / t) assert (1j * pst) ** 1 == 1j * pst - assert (1j * pst) ** -1 == -1j * pst.with_coefficient(1.0/t) + assert (1j * pst) ** -1 == -1j * pst.with_coefficient(1.0 / t) with pytest.raises(ValueError, match='parameterized'): pst ** 2 with pytest.raises(ValueError, match='parameterized'): From 6a5a552e1ecf4ad46e1c223639794fb694813e6e Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 16 Mar 2022 18:04:53 -0700 Subject: [PATCH 04/14] lint --- cirq-core/cirq/ops/pauli_string_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/pauli_string_test.py b/cirq-core/cirq/ops/pauli_string_test.py index d5ff7947f6c..ba6dd8ff2ac 100644 --- a/cirq-core/cirq/ops/pauli_string_test.py +++ b/cirq-core/cirq/ops/pauli_string_test.py @@ -1968,9 +1968,9 @@ def test_parameterization(): assert (1j * pst) ** 1 == 1j * pst assert (1j * pst) ** -1 == -1j * pst.with_coefficient(1.0 / t) with pytest.raises(ValueError, match='parameterized'): - pst ** 2 + _ = pst ** 2 with pytest.raises(ValueError, match='parameterized'): - 1 ** pst + _ = 1 ** pst cirq.testing.assert_has_diagram(cirq.Circuit(pst), '0: ───PauliString((1.0*t)*X)───') From eee670c760d8a4ae2fbd1db1cce49a842824ec44 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 16 Mar 2022 20:49:36 -0700 Subject: [PATCH 05/14] diagonal --- cirq-core/cirq/ops/diagonal_gate.py | 8 +++----- cirq-core/cirq/ops/diagonal_gate_test.py | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cirq-core/cirq/ops/diagonal_gate.py b/cirq-core/cirq/ops/diagonal_gate.py index 670a683f675..88cf33cea3e 100644 --- a/cirq-core/cirq/ops/diagonal_gate.py +++ b/cirq-core/cirq/ops/diagonal_gate.py @@ -185,11 +185,9 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': # we add global phase. # Global phase is ignored for parameterized gates as `cirq.GlobalPhaseGate` expects a # scalar value. - decomposed_circ: List[Any] = ( - [global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))] - if not protocols.is_parameterized(hat_angles[0]) - else [] - ) + decomposed_circ = [ + global_phase_op.global_phase_operation(1j ** (2 * hat_angles[0] / np.pi)) + ] for i, bit_flip in _gen_gray_code(n): decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits)) return decomposed_circ diff --git a/cirq-core/cirq/ops/diagonal_gate_test.py b/cirq-core/cirq/ops/diagonal_gate_test.py index c4529766628..e171e338b38 100644 --- a/cirq-core/cirq/ops/diagonal_gate_test.py +++ b/cirq-core/cirq/ops/diagonal_gate_test.py @@ -92,7 +92,7 @@ def test_decomposition_with_parameterization(n): ) resolved_op = cirq.resolve_parameters(parameterized_op, resolver) resolved_circuit = cirq.resolve_parameters(decomposed_circuit, resolver) - cirq.testing.assert_allclose_up_to_global_phase( + np.testing.assert_allclose( cirq.unitary(resolved_op), cirq.unitary(resolved_circuit), atol=1e-8 ) From a32346df7cc46232d6bdcdffeac27269d70ded94 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 16 Mar 2022 20:50:55 -0700 Subject: [PATCH 06/14] lint --- cirq-core/cirq/ops/diagonal_gate.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/ops/diagonal_gate.py b/cirq-core/cirq/ops/diagonal_gate.py index 88cf33cea3e..ccf958de1d9 100644 --- a/cirq-core/cirq/ops/diagonal_gate.py +++ b/cirq-core/cirq/ops/diagonal_gate.py @@ -18,7 +18,7 @@ passed as a list. """ -from typing import AbstractSet, Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import AbstractSet, Any, Iterator, Optional, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np import sympy @@ -66,7 +66,7 @@ def _gen_gray_code(n: int) -> Iterator[Tuple[int, int]]: class DiagonalGate(raw_types.Gate): """A gate given by a diagonal (2^n)\\times(2^n) matrix.""" - def __init__(self, diag_angles_radians: Sequence[value.TParamVal]) -> None: + def __init__(self, diag_angles_radians: Sequence['cirq.TParamVal']) -> None: r"""A n-qubit gate with only diagonal elements. This gate's off-diagonal elements are zero and it's on diagonal @@ -77,7 +77,7 @@ def __init__(self, diag_angles_radians: Sequence[value.TParamVal]) -> None: If these values are $(x_0, x_1, \ldots , x_N)$ then the unitary has diagonal values $(e^{i x_0}, e^{i x_1}, \ldots, e^{i x_N})$. """ - self._diag_angles_radians: Tuple[value.TParamVal, ...] = tuple(diag_angles_radians) + self._diag_angles_radians: Tuple['cirq.TParamVal', ...] = tuple(diag_angles_radians) def _num_qubits_(self): return int(np.log2(len(self._diag_angles_radians))) @@ -144,7 +144,7 @@ def _value_equality_values_(self) -> Any: return tuple(self._diag_angles_radians) def _decompose_for_basis( - self, index: int, bit_flip: int, theta: value.TParamVal, qubits: Sequence['cirq.Qid'] + self, index: int, bit_flip: int, theta: 'cirq.TParamVal', qubits: Sequence['cirq.Qid'] ) -> Iterator[Union['cirq.ZPowGate', 'cirq.CXPowGate']]: if index == 0: return [] From 39990e122b7bcd42df89868a77e65e0b5c4aa716 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 16 Mar 2022 20:59:52 -0700 Subject: [PATCH 07/14] format --- cirq-core/cirq/ops/diagonal_gate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/diagonal_gate.py b/cirq-core/cirq/ops/diagonal_gate.py index ccf958de1d9..562e560e365 100644 --- a/cirq-core/cirq/ops/diagonal_gate.py +++ b/cirq-core/cirq/ops/diagonal_gate.py @@ -18,7 +18,7 @@ passed as a list. """ -from typing import AbstractSet, Any, Iterator, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import AbstractSet, Any, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np import sympy @@ -185,7 +185,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': # we add global phase. # Global phase is ignored for parameterized gates as `cirq.GlobalPhaseGate` expects a # scalar value. - decomposed_circ = [ + decomposed_circ: List[Any] = [ global_phase_op.global_phase_operation(1j ** (2 * hat_angles[0] / np.pi)) ] for i, bit_flip in _gen_gray_code(n): From dd756b2afd92fae7753ac3e773f132fbd02055c2 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 29 Mar 2022 12:06:49 -0700 Subject: [PATCH 08/14] format --- cirq-core/cirq/ops/global_phase_op_test.py | 2 +- cirq-core/cirq/ops/pauli_string_test.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/ops/global_phase_op_test.py b/cirq-core/cirq/ops/global_phase_op_test.py index f175758f4ee..018be1a1c02 100644 --- a/cirq-core/cirq/ops/global_phase_op_test.py +++ b/cirq-core/cirq/ops/global_phase_op_test.py @@ -353,7 +353,7 @@ def test_parameterization(): assert cirq.parameter_names(gpt) == {'t'} assert not cirq.has_unitary(gpt) assert gpt.coefficient == t - assert (gpt ** 2).coefficient == t ** 2 + assert (gpt**2).coefficient == t**2 @pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) diff --git a/cirq-core/cirq/ops/pauli_string_test.py b/cirq-core/cirq/ops/pauli_string_test.py index f94dd36893f..2df27ae1043 100644 --- a/cirq-core/cirq/ops/pauli_string_test.py +++ b/cirq-core/cirq/ops/pauli_string_test.py @@ -1949,16 +1949,16 @@ def test_parameterization(): pst.expectation_from_state_vector(np.array([]), {}) with pytest.raises(ValueError, match='parameterized'): pst.expectation_from_density_matrix(np.array([]), {}) - assert pst ** 1 == pst - assert pst ** -1 == pst.with_coefficient(1.0 / t) + assert pst**1 == pst + assert pst**-1 == pst.with_coefficient(1.0 / t) assert (-pst) ** 1 == -pst assert (-pst) ** -1 == -pst.with_coefficient(1.0 / t) assert (1j * pst) ** 1 == 1j * pst assert (1j * pst) ** -1 == -1j * pst.with_coefficient(1.0 / t) with pytest.raises(ValueError, match='parameterized'): - _ = pst ** 2 + _ = pst**2 with pytest.raises(ValueError, match='parameterized'): - _ = 1 ** pst + _ = 1**pst cirq.testing.assert_has_diagram(cirq.Circuit(pst), '0: ───PauliString((1.0*t)*X)───') From 017aca192226542ee2ef8285f8a8b4ef899d3beb Mon Sep 17 00:00:00 2001 From: daxfohl Date: Tue, 29 Mar 2022 13:24:07 -0700 Subject: [PATCH 09/14] Fix global phase unitary when parameterized --- cirq-core/cirq/ops/global_phase_op.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/global_phase_op.py b/cirq-core/cirq/ops/global_phase_op.py index f8d3b2fd05a..6797f264a7a 100644 --- a/cirq-core/cirq/ops/global_phase_op.py +++ b/cirq-core/cirq/ops/global_phase_op.py @@ -19,6 +19,7 @@ from cirq import value, protocols from cirq._compat import deprecated_class +from cirq.type_workarounds import NotImplementedType from cirq.ops import gate_operation, raw_types if TYPE_CHECKING: @@ -81,10 +82,16 @@ def __pow__(self, power) -> 'cirq.GlobalPhaseGate': return GlobalPhaseGate(self.coefficient**power) return NotImplemented - def _unitary_(self) -> np.ndarray: + def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: + if not self._has_unitary_(): + return NotImplemented return np.array([[self.coefficient]]) - def _apply_unitary_(self, args) -> np.ndarray: + def _apply_unitary_( + self, args: 'cirq.ApplyUnitaryArgs' + ) -> Union[np.ndarray, NotImplementedType]: + if not self._has_unitary_(): + return NotImplemented args.target_tensor *= self.coefficient return args.target_tensor From 78e400245f8c03641781a4ed1e2e718772150ae1 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Wed, 20 Apr 2022 21:09:49 -0700 Subject: [PATCH 10/14] Code review cleanup --- cirq-core/cirq/__init__.py | 1 + cirq-core/cirq/ops/diagonal_gate.py | 2 -- cirq-core/cirq/ops/global_phase_op.py | 11 ++++----- cirq-core/cirq/ops/global_phase_op_test.py | 8 +++++++ cirq-core/cirq/ops/pauli_string.py | 23 ++++++++++--------- cirq-core/cirq/ops/pauli_string_test.py | 8 +++---- .../cirq/protocols/json_test_data/spec.py | 1 + cirq-core/cirq/value/__init__.py | 2 +- cirq-core/cirq/value/type_alias.py | 7 ++++++ 9 files changed, 39 insertions(+), 24 deletions(-) diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 040eee77296..fb9e817ef23 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -544,6 +544,7 @@ Timestamp, TParamKey, TParamVal, + TParamValComplex, validate_probability, value_equality, KET_PLUS, diff --git a/cirq-core/cirq/ops/diagonal_gate.py b/cirq-core/cirq/ops/diagonal_gate.py index 8eb53defa93..b01bf84c754 100644 --- a/cirq-core/cirq/ops/diagonal_gate.py +++ b/cirq-core/cirq/ops/diagonal_gate.py @@ -183,8 +183,6 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': # decomposed gates. On its own it is not physically observable. However, if using this # diagonal gate for sub-system like controlled gate, it is no longer equivalent. Hence, # we add global phase. - # Global phase is ignored for parameterized gates as `cirq.GlobalPhaseGate` expects a - # scalar value. decomposed_circ: List[Any] = [ global_phase_op.global_phase_operation(1j ** (2 * hat_angles[0] / np.pi)) ] diff --git a/cirq-core/cirq/ops/global_phase_op.py b/cirq-core/cirq/ops/global_phase_op.py index 6797f264a7a..5c4f0b9538d 100644 --- a/cirq-core/cirq/ops/global_phase_op.py +++ b/cirq-core/cirq/ops/global_phase_op.py @@ -26,9 +26,6 @@ import cirq -ComplexParam = Union[value.Scalar, sympy.Basic] - - @value.value_equality(approximate=True) @deprecated_class(deadline='v0.16', fix='Use cirq.global_phase_operation') class GlobalPhaseOperation(gate_operation.GateOperation): @@ -62,13 +59,13 @@ def _json_dict_(self) -> Dict[str, Any]: @value.value_equality(approximate=True) class GlobalPhaseGate(raw_types.Gate): - def __init__(self, coefficient: ComplexParam, atol: float = 1e-8) -> None: + def __init__(self, coefficient: 'cirq.TParamValComplex', atol: float = 1e-8) -> None: if not isinstance(coefficient, sympy.Basic) and abs(1 - abs(coefficient)) > atol: raise ValueError(f'Coefficient is not unitary: {coefficient!r}') self._coefficient = coefficient @property - def coefficient(self) -> ComplexParam: + def coefficient(self) -> 'cirq.TParamValComplex': return self._coefficient def _value_equality_values_(self) -> Any: @@ -126,5 +123,7 @@ def _resolve_parameters_( return GlobalPhaseGate(coefficient=coefficient) -def global_phase_operation(coefficient: ComplexParam, atol: float = 1e-8) -> 'cirq.GateOperation': +def global_phase_operation( + coefficient: 'cirq.TParamValComplex', atol: float = 1e-8 +) -> 'cirq.GateOperation': return GlobalPhaseGate(coefficient, atol)() diff --git a/cirq-core/cirq/ops/global_phase_op_test.py b/cirq-core/cirq/ops/global_phase_op_test.py index 4585158b697..5c55b5354e1 100644 --- a/cirq-core/cirq/ops/global_phase_op_test.py +++ b/cirq-core/cirq/ops/global_phase_op_test.py @@ -289,3 +289,11 @@ def test_resolve(resolve_fn): t = sympy.Symbol('t') gpt = cirq.GlobalPhaseGate(coefficient=t) assert resolve_fn(gpt, {'t': -1}) == cirq.GlobalPhaseGate(coefficient=-1) + + +@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) +def test_resolve_error(resolve_fn): + t = sympy.Symbol('t') + gpt = cirq.GlobalPhaseGate(coefficient=t) + with pytest.raises(ValueError, match='Coefficient is not unitary'): + resolve_fn(gpt, {'t': -2}) diff --git a/cirq-core/cirq/ops/pauli_string.py b/cirq-core/cirq/ops/pauli_string.py index 855a76c05f0..0990b5d40dd 100644 --- a/cirq-core/cirq/ops/pauli_string.py +++ b/cirq-core/cirq/ops/pauli_string.py @@ -67,7 +67,6 @@ TKey = TypeVar('TKey', bound=raw_types.Qid) TKeyNew = TypeVar('TKeyNew', bound=raw_types.Qid) TKeyOther = TypeVar('TKeyOther', bound=raw_types.Qid) -ComplexParam = Union[value.Scalar, sympy.Basic] # A value that can be unambiguously converted into a `cirq.PauliString`. @@ -112,7 +111,7 @@ def __init__( self, *contents: 'cirq.PAULI_STRING_LIKE', qubit_pauli_map: Optional[Dict[TKey, 'cirq.Pauli']] = None, - coefficient: ComplexParam = 1, + coefficient: 'cirq.TParamValComplex' = 1, ): """Initializes a new PauliString. @@ -168,7 +167,7 @@ def __init__( self._coefficient = m._coefficient @property - def coefficient(self) -> ComplexParam: + def coefficient(self) -> 'cirq.TParamValComplex': return self._coefficient def _value_equality_values_(self): @@ -352,7 +351,7 @@ def with_qubits(self, *new_qubits: 'cirq.Qid') -> 'PauliString': coefficient=self._coefficient, ) - def with_coefficient(self, new_coefficient: ComplexParam) -> 'PauliString': + def with_coefficient(self, new_coefficient: 'cirq.TParamValComplex') -> 'PauliString': return PauliString(qubit_pauli_map=dict(self._qubit_pauli_map), coefficient=new_coefficient) def values(self) -> ValuesView[pauli_gates.Pauli]: @@ -497,12 +496,13 @@ def expectation_from_state_vector( The expectation value of the input state. Raises: - NotImplementedError: If this PauliString is non-Hermitian. + NotImplementedError: If this PauliString is non-Hermitian or + parameterized. TypeError: If the input state is not complex. ValueError: If the input state does not have the correct shape. """ if self._is_parameterized_(): - raise ValueError('Cannot get expectation value when parameterized') + raise NotImplementedError('Cannot get expectation value when parameterized') if abs(self.coefficient.imag) > 0.0001: raise NotImplementedError( @@ -604,12 +604,13 @@ def expectation_from_density_matrix( The expectation value of the input state. Raises: - NotImplementedError: If this PauliString is non-Hermitian. + NotImplementedError: If this PauliString is non-Hermitian or + parameterized. TypeError: If the input state is not complex. ValueError: If the input state does not have the correct shape. """ if self._is_parameterized_(): - raise ValueError('Cannot get expectation value when parameterized') + raise NotImplementedError('Cannot get expectation value when parameterized') if abs(self.coefficient.imag) > 0.0001: raise NotImplementedError( 'Cannot compute expectation value of a non-Hermitian ' @@ -712,7 +713,7 @@ def __pow__(self, power): qubit_pauli_map=self._qubit_pauli_map, coefficient=self.coefficient**-1 ) if self._is_parameterized_(): - raise ValueError('Cannot raise to power when parameterized.') + return NotImplemented if isinstance(power, (int, float)): r, i = cmath.polar(self.coefficient) if abs(r - 1) > 0.0001: @@ -742,7 +743,7 @@ def __pow__(self, power): def __rpow__(self, base): if self._is_parameterized_(): - raise ValueError('Cannot raise to power when parameterized.') + return NotImplemented if isinstance(base, (int, float)) and base > 0: if abs(self.coefficient.real) > 0.0001: raise NotImplementedError( @@ -1077,7 +1078,7 @@ class MutablePauliString(Generic[TKey]): def __init__( self, *contents: 'cirq.PAULI_STRING_LIKE', - coefficient: ComplexParam = 1, + coefficient: 'cirq.TParamValComplex' = 1, pauli_int_dict: Optional[Dict[TKey, int]] = None, ): self.coefficient = ( diff --git a/cirq-core/cirq/ops/pauli_string_test.py b/cirq-core/cirq/ops/pauli_string_test.py index 94b56c1c45c..9d4685c1819 100644 --- a/cirq-core/cirq/ops/pauli_string_test.py +++ b/cirq-core/cirq/ops/pauli_string_test.py @@ -1931,9 +1931,9 @@ def test_parameterization(): assert not cirq.is_parameterized(pst.with_coefficient(2)) with pytest.raises(TypeError): cirq.decompose_once(pst) - with pytest.raises(ValueError, match='parameterized'): + with pytest.raises(NotImplementedError, match='parameterized'): pst.expectation_from_state_vector(np.array([]), {}) - with pytest.raises(ValueError, match='parameterized'): + with pytest.raises(NotImplementedError, match='parameterized'): pst.expectation_from_density_matrix(np.array([]), {}) assert pst**1 == pst assert pst**-1 == pst.with_coefficient(1.0 / t) @@ -1941,9 +1941,9 @@ def test_parameterization(): assert (-pst) ** -1 == -pst.with_coefficient(1.0 / t) assert (1j * pst) ** 1 == 1j * pst assert (1j * pst) ** -1 == -1j * pst.with_coefficient(1.0 / t) - with pytest.raises(ValueError, match='parameterized'): + with pytest.raises(TypeError): _ = pst**2 - with pytest.raises(ValueError, match='parameterized'): + with pytest.raises(TypeError): _ = 1**pst cirq.testing.assert_has_diagram(cirq.Circuit(pst), '0: ───PauliString((1.0*t)*X)───') diff --git a/cirq-core/cirq/protocols/json_test_data/spec.py b/cirq-core/cirq/protocols/json_test_data/spec.py index 8fc2e49f64d..4fa53bad7c5 100644 --- a/cirq-core/cirq/protocols/json_test_data/spec.py +++ b/cirq-core/cirq/protocols/json_test_data/spec.py @@ -178,6 +178,7 @@ 'Sweepable', 'TParamKey', 'TParamVal', + 'TParamValComplex', 'TRANSFORMER', 'ParamDictType', # utility: diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 728c564404e..423acee7967 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -62,6 +62,6 @@ from cirq.value.timestamp import Timestamp -from cirq.value.type_alias import TParamKey, TParamVal +from cirq.value.type_alias import TParamKey, TParamVal, TParamValComplex from cirq.value.value_equality_attr import value_equality diff --git a/cirq-core/cirq/value/type_alias.py b/cirq-core/cirq/value/type_alias.py index 68458b07b7f..86bf934b603 100644 --- a/cirq-core/cirq/value/type_alias.py +++ b/cirq-core/cirq/value/type_alias.py @@ -16,6 +16,7 @@ import sympy from cirq._doc import document +from cirq.value import linear_dict """Supply aliases for commonly used types. """ @@ -27,3 +28,9 @@ document( TParamVal, """A value that a parameter resolver may return for a parameter.""" # type: ignore ) + +TParamValComplex = Union[linear_dict.Scalar, sympy.Basic] +document( + TParamValComplex, + """A complex value that parameter resolvers may use for parameters.""", # type: ignore +) From fed0fc51591f68b2845b7fe74c82da1cf992263d Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 16:07:43 -0700 Subject: [PATCH 11/14] Change test to use complex numbers --- cirq-core/cirq/ops/pauli_string_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/ops/pauli_string_test.py b/cirq-core/cirq/ops/pauli_string_test.py index 9d4685c1819..457e42dc4eb 100644 --- a/cirq-core/cirq/ops/pauli_string_test.py +++ b/cirq-core/cirq/ops/pauli_string_test.py @@ -1953,5 +1953,5 @@ def test_resolve(resolve_fn): t = sympy.Symbol('t') q = cirq.LineQubit(0) pst = cirq.PauliString({q: 'x'}, coefficient=t) - ps1 = cirq.PauliString({q: 'x'}, coefficient=1) - assert resolve_fn(pst, {'t': 1}) == ps1 + ps1 = cirq.PauliString({q: 'x'}, coefficient=1j) + assert resolve_fn(pst, {'t': 1j}) == ps1 From 8da82329cf9a140443b5cd0ecd9c5dd1cd25c0af Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 16:39:23 -0700 Subject: [PATCH 12/14] modify the return type of ParamResolver to allow complex --- cirq-core/cirq/study/resolver.py | 11 ++++++----- cirq-core/cirq/work/observable_measurement.py | 16 ++++++++++++++-- cirq-core/cirq/work/sampler.py | 6 ++++-- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/cirq-core/cirq/study/resolver.py b/cirq-core/cirq/study/resolver.py index 499d1859337..9370f7359fb 100644 --- a/cirq-core/cirq/study/resolver.py +++ b/cirq-core/cirq/study/resolver.py @@ -26,7 +26,8 @@ import cirq -ParamDictType = Dict['cirq.TParamKey', 'cirq.TParamVal'] +TParamValAny = Union['cirq.TParamVal', 'cirq.TParamValComplex'] +ParamDictType = Dict['cirq.TParamKey', TParamValAny] document(ParamDictType, """Dictionary from symbols to values.""") # type: ignore ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamDictType, None] @@ -71,11 +72,11 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None self._deep_eval_map: ParamDictType = {} def value_of( - self, value: Union['cirq.TParamKey', float], recursive: bool = True - ) -> 'cirq.TParamVal': + self, value: Union['cirq.TParamKey', TParamValAny], recursive: bool = True + ) -> TParamValAny: """Attempt to resolve a parameter to its assigned value. - Floats are returned without modification. Strings are resolved via + Scalars are returned without modification. Strings are resolved via the parameter dictionary with exact match only. Otherwise, strings are considered to be sympy.Symbols with the name as the input string. @@ -207,7 +208,7 @@ def __iter__(self) -> Iterator[Union[str, sympy.Symbol]]: def __bool__(self) -> bool: return bool(self.param_dict) - def __getitem__(self, key: Union[sympy.Basic, float, str]) -> 'cirq.TParamVal': + def __getitem__(self, key: Union['cirq.TParamKey', TParamValAny]) -> TParamValAny: return self.value_of(key) def __hash__(self) -> int: diff --git a/cirq-core/cirq/work/observable_measurement.py b/cirq-core/cirq/work/observable_measurement.py index 6fcb442d7af..b19487ab14b 100644 --- a/cirq-core/cirq/work/observable_measurement.py +++ b/cirq-core/cirq/work/observable_measurement.py @@ -18,7 +18,19 @@ import os import tempfile import warnings -from typing import Optional, Union, Iterable, Dict, List, Tuple, TYPE_CHECKING, Set, Sequence, Any +from typing import ( + Any, + cast, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) import numpy as np import pandas as pd @@ -530,7 +542,7 @@ def measure_grouped_settings( for max_setting, param_resolver in itertools.product( grouped_settings.keys(), study.to_resolvers(circuit_sweep) ): - circuit_params = param_resolver.param_dict + circuit_params = cast(Dict[str, float], param_resolver.param_dict) meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params) accumulator = BitstringAccumulator( meas_spec=meas_spec, diff --git a/cirq-core/cirq/work/sampler.py b/cirq-core/cirq/work/sampler.py index ac03d43a76a..264802329a1 100644 --- a/cirq-core/cirq/work/sampler.py +++ b/cirq-core/cirq/work/sampler.py @@ -15,7 +15,7 @@ import abc import collections -from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import cast, Dict, FrozenSet, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import pandas as pd @@ -329,7 +329,9 @@ def sample_expectation_values( # Flatten Circuit Sweep into one big list of Params. # Keep track of their indices so we can map back. - flat_params: List[Dict[str, float]] = [pr.param_dict for pr in study.to_resolvers(params)] + flat_params = cast( + List[Dict[str, float]], [pr.param_dict for pr in study.to_resolvers(params)] + ) circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, float]], int] = { _hashable_param(param.items()): i for i, param in enumerate(flat_params) } From 83c37a85be404200ed186d7eb141c20183f8ae81 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Thu, 21 Apr 2022 16:51:06 -0700 Subject: [PATCH 13/14] simplify cast --- cirq-core/cirq/work/sampler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cirq-core/cirq/work/sampler.py b/cirq-core/cirq/work/sampler.py index 264802329a1..1d2d976028a 100644 --- a/cirq-core/cirq/work/sampler.py +++ b/cirq-core/cirq/work/sampler.py @@ -329,9 +329,7 @@ def sample_expectation_values( # Flatten Circuit Sweep into one big list of Params. # Keep track of their indices so we can map back. - flat_params = cast( - List[Dict[str, float]], [pr.param_dict for pr in study.to_resolvers(params)] - ) + flat_params = [cast(Dict[str, float], pr.param_dict) for pr in study.to_resolvers(params)] circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, float]], int] = { _hashable_param(param.items()): i for i, param in enumerate(flat_params) } From e464b5c4f52f2f100aee6c68a4b98916783c8d18 Mon Sep 17 00:00:00 2001 From: daxfohl Date: Fri, 22 Apr 2022 15:50:28 -0700 Subject: [PATCH 14/14] Remove TParamValAny --- cirq-core/cirq/study/resolver.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cirq-core/cirq/study/resolver.py b/cirq-core/cirq/study/resolver.py index 9370f7359fb..872cbb3d1e9 100644 --- a/cirq-core/cirq/study/resolver.py +++ b/cirq-core/cirq/study/resolver.py @@ -26,8 +26,7 @@ import cirq -TParamValAny = Union['cirq.TParamVal', 'cirq.TParamValComplex'] -ParamDictType = Dict['cirq.TParamKey', TParamValAny] +ParamDictType = Dict['cirq.TParamKey', 'cirq.TParamValComplex'] document(ParamDictType, """Dictionary from symbols to values.""") # type: ignore ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamDictType, None] @@ -72,8 +71,8 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None self._deep_eval_map: ParamDictType = {} def value_of( - self, value: Union['cirq.TParamKey', TParamValAny], recursive: bool = True - ) -> TParamValAny: + self, value: Union['cirq.TParamKey', 'cirq.TParamValComplex'], recursive: bool = True + ) -> 'cirq.TParamValComplex': """Attempt to resolve a parameter to its assigned value. Scalars are returned without modification. Strings are resolved via @@ -208,7 +207,9 @@ def __iter__(self) -> Iterator[Union[str, sympy.Symbol]]: def __bool__(self) -> bool: return bool(self.param_dict) - def __getitem__(self, key: Union['cirq.TParamKey', TParamValAny]) -> TParamValAny: + def __getitem__( + self, key: Union['cirq.TParamKey', 'cirq.TParamValComplex'] + ) -> 'cirq.TParamValComplex': return self.value_of(key) def __hash__(self) -> int: