Skip to content

Commit

Permalink
Fix typing of the protocols.commutes() function (quantumlib#5651)
Browse files Browse the repository at this point in the history
Return bool when the `default` argument is not specified.
Otherwise allow return type to be the same as the `default` argument type.

Co-authored-by: Matthew Neeley <maffoo@google.com>
  • Loading branch information
2 people authored and rht committed May 1, 2023
1 parent e54d259 commit 0ffe54f
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions cirq-core/cirq/protocols/commutes_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Protocol for determining commutativity."""

from typing import Any, TypeVar, Union
from typing import Any, overload, TypeVar, Union

import numpy as np
from typing_extensions import Protocol
Expand All @@ -26,7 +26,7 @@
# whether or not the caller provided a 'default' argument.
# It is checked for using `is`, so it won't have a false positive if the user
# provides a different np.array([]) value.
RaiseTypeErrorIfNotProvided = np.array([])
RaiseTypeErrorIfNotProvided = object()

TDefault = TypeVar('TDefault')

Expand Down Expand Up @@ -73,13 +73,21 @@ def _commutes_(self, other: Any, *, atol: float) -> Union[None, bool, NotImpleme
"""


@overload
def commutes(v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8) -> bool:
...


@overload
def commutes(
v1: Any,
v2: Any,
*,
atol: Union[int, float] = 1e-8,
default: Union[bool, TDefault] = RaiseTypeErrorIfNotProvided,
v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8, default: TDefault
) -> Union[bool, TDefault]:
...


def commutes(
v1: Any, v2: Any, *, atol: Union[int, float] = 1e-8, default: Any = RaiseTypeErrorIfNotProvided
) -> Any:
"""Determines whether two values commute.
This is determined by any one of the following techniques:
Expand Down

0 comments on commit 0ffe54f

Please sign in to comment.