Skip to content

Commit

Permalink
Change type signature of cirq.resolve_parameters to preserve types (#…
Browse files Browse the repository at this point in the history
…3922)

Fixes #3390

Review: @balopat
  • Loading branch information
maffoo committed Mar 17, 2021
1 parent 38cb2f2 commit aa6ec84
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cirq/ops/diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _gen_gray_code(n: int) -> Iterator[Tuple[int, int]]:
class DiagonalGate(raw_types.Gate):
"""A gate given by a diagonal (2^n)\\times(2^n) matrix."""

def __init__(self, diag_angles_radians: List[value.TParamVal]) -> None:
def __init__(self, diag_angles_radians: Sequence[value.TParamVal]) -> None:
r"""A n-qubit gate with only diagonal elements.
This gate's off-diagonal elements are zero and it's on diagonal
Expand Down
5 changes: 2 additions & 3 deletions cirq/ops/two_qubit_diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
passed as a list.
"""

from typing import AbstractSet, Any, Tuple, List, Optional, TYPE_CHECKING
from typing import AbstractSet, Any, Tuple, Optional, Sequence, TYPE_CHECKING
import numpy as np
import sympy

Expand All @@ -26,15 +26,14 @@
from cirq.ops import gate_features

if TYPE_CHECKING:
# pylint: disable=unused-import
import cirq


@value.value_equality()
class TwoQubitDiagonalGate(gate_features.TwoQubitGate):
"""A gate given by a diagonal 4\\times 4 matrix."""

def __init__(self, diag_angles_radians: List[value.TParamVal]) -> None:
def __init__(self, diag_angles_radians: Sequence[value.TParamVal]) -> None:
r"""A two qubit gate with only diagonal elements.
This gate's off-diagonal elements are zero and it's on diagonal
Expand Down
24 changes: 18 additions & 6 deletions cirq/protocols/resolve_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import numbers
from typing import AbstractSet, Any, TYPE_CHECKING
from typing import AbstractSet, Any, cast, TYPE_CHECKING, TypeVar

import sympy
from typing_extensions import Protocol
Expand All @@ -25,6 +25,9 @@
import cirq


T = TypeVar('T')


class SupportsParameterization(Protocol):
"""An object that can be parameterized by Symbols and resolved
via a ParamResolver"""
Expand All @@ -45,7 +48,7 @@ def _parameter_names_(self: Any) -> AbstractSet[str]:
"""

@doc_private
def _resolve_parameters_(self: Any, param_resolver: 'cirq.ParamResolver', recursive: bool):
def _resolve_parameters_(self: T, param_resolver: 'cirq.ParamResolver', recursive: bool) -> T:
"""Resolve the parameters in the effect."""


Expand Down Expand Up @@ -130,8 +133,8 @@ def parameter_symbols(val: Any) -> AbstractSet[sympy.Symbol]:


def resolve_parameters(
val: Any, param_resolver: 'cirq.ParamResolverOrSimilarType', recursive: bool = True
):
val: T, param_resolver: 'cirq.ParamResolverOrSimilarType', recursive: bool = True
) -> T:
"""Resolves symbol parameters in the effect using the param resolver.
This function will use the `_resolve_parameters_` magic method
Expand All @@ -149,6 +152,12 @@ def resolve_parameters(
replaced with floats or terminal symbols according to the
given ParamResolver. If `val` has no `_resolve_parameters_`
method or if it returns NotImplemented, `val` itself is returned.
Note that in some cases, such as when directly resolving a sympy
Symbol, the return type could differ from the input type; however,
for the much more common case of resolving parameters on cirq
objects (or if resolving a Union[Symbol, float] instead of just a
Symbol), the return type will be the same as val so we reflect
that in the type signature of this protocol function.
Raises:
RecursionError if the ParamResolver detects a loop in resolution.
Expand All @@ -160,10 +169,13 @@ def resolve_parameters(

# Ensure it is a dictionary wrapped in a ParamResolver.
param_resolver = study.ParamResolver(param_resolver)

# Handle special cases for sympy expressions and sequences.
# These may not in fact preserve types, but we pretend they do by casting.
if isinstance(val, sympy.Basic):
return param_resolver.value_of(val, recursive)
return cast(T, param_resolver.value_of(val, recursive))
if isinstance(val, (list, tuple)):
return type(val)(resolve_parameters(e, param_resolver, recursive) for e in val)
return cast(T, type(val)(resolve_parameters(e, param_resolver, recursive) for e in val))

getter = getattr(val, '_resolve_parameters_', None)
if getter is None:
Expand Down

0 comments on commit aa6ec84

Please sign in to comment.