Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cirq-core/cirq/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@

from __future__ import annotations

from typing import Any
from typing import TypeVar

T = TypeVar('T')
RECORDED_CONST_DOCS: dict[int, str] = {}


def document(value: Any, doc_string: str = ''):
def document(value: T, doc_string: str = '') -> T:
"""Stores documentation details about the given value.

This method is used to associate a docstring with global constants. It is
Expand Down Expand Up @@ -64,7 +65,7 @@ def document(value: Any, doc_string: str = ''):
_DOC_PRIVATE = "_tf_docs_doc_private"


def doc_private(obj):
def doc_private(obj: T) -> T:
"""A decorator: Generates docs for private methods/functions.

For example:
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ def __repr__(self) -> str:
f' global_phase={self.global_phase!r})'
)

def _has_unitary_(self) -> bool:
return True

def _unitary_(self) -> np.ndarray:
"""Returns the decomposition's two-qubit unitary matrix.

Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/linalg/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def test_kak_plot_empty() -> None:
)
def test_kak_decomposition(target) -> None:
kak = cirq.kak_decomposition(target)
assert cirq.has_unitary(kak)
np.testing.assert_allclose(cirq.unitary(kak), target, atol=1e-8)


Expand Down
6 changes: 2 additions & 4 deletions cirq-core/cirq/ops/fourier_transform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def test_qft() -> None:

arr = np.array([[1, 1, 1, 1], [1, -1j, -1, 1j], [1, -1, 1, -1], [1, 1j, -1, -1j]]) / 2
np.testing.assert_allclose(
cirq.unitary(cirq.qft(*cirq.LineQubit.range(2)) ** -1), # type: ignore[operator]
arr, # type: ignore[arg-type]
atol=1e-8,
cirq.unitary(cirq.qft(*cirq.LineQubit.range(2)) ** -1), arr, atol=1e-8
)

for k in range(4):
Expand All @@ -121,7 +119,7 @@ def test_qft() -> None:

def test_inverse() -> None:
a, b, c = cirq.LineQubit.range(3)
assert cirq.qft(a, b, c, inverse=True) == cirq.qft(a, b, c) ** -1 # type: ignore[operator]
assert cirq.qft(a, b, c, inverse=True) == cirq.qft(a, b, c) ** -1
assert cirq.qft(a, b, c, inverse=True, without_reverse=True) == cirq.inverse(
cirq.qft(a, b, c, without_reverse=True)
)
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,9 @@ def _num_qubits_(self) -> int:
def _qid_shape_(self) -> tuple[int, ...]:
return protocols.qid_shape(self.qubits)

def __pow__(self, exponent: Any) -> Operation:
return NotImplemented # pragma: no cover

@abc.abstractmethod
def with_qubits(self, *new_qubits: cirq.Qid) -> cirq.Operation:
"""Returns the same operation, but applied to different qubits.
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/protocols/apply_unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def assert_is_swap_simple(val: cirq.SupportsConsistentApplyUnitary) -> None:
op_indices, tuple(qid_shape[i] for i in op_indices)
)
sub_result = val._apply_unitary_(sub_args)
assert isinstance(sub_result, np.ndarray)
result = _incorporate_result_into_target(args, sub_args, sub_result)
np.testing.assert_allclose(result, expected, atol=1e-8)

Expand All @@ -258,6 +259,7 @@ def assert_is_swap(val: cirq.SupportsConsistentApplyUnitary) -> None:
op_indices, tuple(qid_shape[i] for i in op_indices)
)
sub_result = val._apply_unitary_(sub_args)
assert isinstance(sub_result, np.ndarray)
result = _incorporate_result_into_target(args, sub_args, sub_result)
np.testing.assert_allclose(result, expected, atol=1e-8, verbose=True)

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/protocols/kraus_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import warnings
from types import NotImplementedType
from typing import Any, Protocol, Sequence, TypeVar
from typing import Any, Iterable, Protocol, TypeVar

import numpy as np

Expand All @@ -31,7 +31,7 @@

# This is a special indicator value used by the channel method to determine
# whether or not the caller provided a 'default' argument. It must be of type
# Sequence[np.ndarray] to ensure the method has the correct type signature in
# Iterable[np.ndarray] to ensure the method has the correct type signature in
# that case. It is checked for using `is`, so it won't have a false positive
# if the user provides a different (np.array([]),) value.
RaiseTypeErrorIfNotProvided: tuple[np.ndarray] = (np.array([]),)
Expand All @@ -44,7 +44,7 @@ class SupportsKraus(Protocol):
"""An object that may be describable as a quantum channel."""

@doc_private
def _kraus_(self) -> Sequence[np.ndarray] | NotImplementedType:
def _kraus_(self) -> Iterable[np.ndarray] | NotImplementedType:
r"""A list of Kraus matrices describing the quantum channel.

These matrices are the terms in the operator sum representation of a
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/protocols/kraus_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import Iterable, Sequence
from typing import Iterable

import numpy as np
import pytest
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_explicit_kraus() -> None:
c = (a0, a1)

class ReturnsKraus:
def _kraus_(self) -> Sequence[np.ndarray]:
def _kraus_(self) -> Iterable[np.ndarray]:
return c

assert cirq.kraus(ReturnsKraus()) is c
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SupportsUnitary(Protocol):
"""An object that may be describable by a unitary matrix."""

@doc_private
def _unitary_(self) -> np.ndarray | NotImplementedType:
def _unitary_(self) -> np.ndarray | NotImplementedType | None:
"""A unitary matrix describing this value, e.g. the matrix of a gate.

This method is used by the global `cirq.unitary` method. If this method
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/testing/consistent_resolve_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import cirq


def assert_consistent_resolve_parameters(val: Any):
def assert_consistent_resolve_parameters(val: Any) -> None:
names = cirq.parameter_names(val)
symbols = cirq.parameter_symbols(val)

Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/testing/consistent_unitary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import cirq


def assert_unitary_is_consistent(val: Any, ignoring_global_phase: bool = False):
def assert_unitary_is_consistent(val: Any, ignoring_global_phase: bool = False) -> None:
if not isinstance(val, (cirq.Operation, cirq.Gate)):
return

Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/testing/consistent_unitary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ def test_failed_decomposition() -> None:
with pytest.raises(ValueError):
cirq.testing.assert_unitary_is_consistent(FailsOnDecompostion())

_ = cirq.testing.assert_unitary_is_consistent(cirq.Circuit())
cirq.testing.assert_unitary_is_consistent(cirq.Circuit())
4 changes: 2 additions & 2 deletions cirq-core/cirq/testing/equals_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _verify_equality_group(self, *group_items: Any):
"Common problem: returning NotImplementedError instead of NotImplemented. "
)

def add_equality_group(self, *group_items: Any):
def add_equality_group(self, *group_items: Any) -> None:
"""Tries to add a disjoint equivalence group to the equality tester.

This methods asserts that items within the group must all be equal to
Expand All @@ -114,7 +114,7 @@ def add_equality_group(self, *group_items: Any):
# Remember this group, to enable disjoint checks vs later groups.
self._groups.append(group_items)

def make_equality_group(self, *factories: Callable[[], Any]):
def make_equality_group(self, *factories: Callable[[], Any]) -> None:
"""Tries to add a disjoint equivalence group to the equality tester.

Uses the factory methods to produce two different objects with the same
Expand Down
4 changes: 3 additions & 1 deletion cirq-core/cirq/testing/equivalent_basis_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from cirq import circuits


def assert_equivalent_computational_basis_map(maps: dict[int, int], circuit: circuits.Circuit):
def assert_equivalent_computational_basis_map(
maps: dict[int, int], circuit: circuits.Circuit
) -> None:
"""Ensure equivalence of basis state mapping.

Args:
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/testing/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def spec_for(module_name: str) -> ModuleJsonTestSpec:
return getattr(test_module, "TestSpec")


def assert_json_roundtrip_works(obj, text_should_be=None, resolvers=None):
def assert_json_roundtrip_works(obj, text_should_be=None, resolvers=None) -> None:
"""Tests that the given object can serialized and de-serialized

Args:
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/testing/op_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cirq import ops


def assert_equivalent_op_tree(x: ops.OP_TREE, y: ops.OP_TREE):
def assert_equivalent_op_tree(x: ops.OP_TREE, y: ops.OP_TREE) -> None:
"""Ensures that the two OP_TREEs are equivalent.

Args:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/testing/order_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _verify_not_implemented_vs_unknown(self, item: Any):
f"That rule is being violated by this value: {item!r}"
) from ex

def add_ascending(self, *items: Any):
def add_ascending(self, *items: Any) -> None:
"""Tries to add a sequence of ascending items to the order tester.

This methods asserts that items must all be ascending
Expand All @@ -98,7 +98,7 @@ def add_ascending(self, *items: Any):
for item in items:
self.add_ascending_equivalence_group(item)

def add_ascending_equivalence_group(self, *group_items: Any):
def add_ascending_equivalence_group(self, *group_items: Any) -> None:
"""Tries to add an ascending equivalence group to the order tester.

Asserts that the group items are equal to each other, but strictly
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/testing/repr_pretty_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class FakePrinter:
def __init__(self):
self.text_pretty = ""

def text(self, to_print):
def text(self, to_print) -> None:
self.text_pretty += to_print


def assert_repr_pretty(val: Any, text: str, cycle: bool = False):
def assert_repr_pretty(val: Any, text: str, cycle: bool = False) -> None:
"""Assert that the given object has a `_repr_pretty_` method that produces the given text.

Args:
Expand All @@ -57,7 +57,7 @@ def assert_repr_pretty(val: Any, text: str, cycle: bool = False):
assert p.text_pretty == text, f"{p.text_pretty} != {text}"


def assert_repr_pretty_contains(val: Any, substr: str, cycle: bool = False):
def assert_repr_pretty_contains(val: Any, substr: str, cycle: bool = False) -> None:
"""Assert that the given object has a `_repr_pretty_` output that contains the given text.

Args:
Expand Down
26 changes: 13 additions & 13 deletions cirq-core/cirq/transformers/align_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import cirq


def test_align_basic_no_context():
def test_align_basic_no_context() -> None:
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
c = cirq.Circuit(
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_align_basic_no_context():
)


def test_align_left_no_compile_context():
def test_align_left_no_compile_context() -> None:
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
cirq.testing.assert_same_circuits(
Expand All @@ -59,7 +59,7 @@ def test_align_left_no_compile_context():
cirq.measure(*[q1, q2], key='a'),
]
),
context=cirq.TransformerContext(tags_to_ignore=["nocompile"]),
context=cirq.TransformerContext(tags_to_ignore=("nocompile",)),
),
cirq.Circuit(
[
Expand All @@ -73,7 +73,7 @@ def test_align_left_no_compile_context():
)


def test_align_left_deep():
def test_align_left_deep() -> None:
q1, q2 = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(
[
Expand Down Expand Up @@ -104,11 +104,11 @@ def test_align_left_deep():
c_nested_aligned,
cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"),
)
context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True)
context = cirq.TransformerContext(tags_to_ignore=("nocompile",), deep=True)
cirq.testing.assert_same_circuits(cirq.align_left(c_orig, context=context), c_expected)


def test_align_left_subset_of_operations():
def test_align_left_subset_of_operations() -> None:
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
tag = "op_to_align"
Expand All @@ -134,15 +134,15 @@ def test_align_left_subset_of_operations():
cirq.toggle_tags(
cirq.align_left(
cirq.toggle_tags(c_orig, [tag]),
context=cirq.TransformerContext(tags_to_ignore=[tag]),
context=cirq.TransformerContext(tags_to_ignore=(tag,)),
),
[tag],
),
c_exp,
)


def test_align_right_no_compile_context():
def test_align_right_no_compile_context() -> None:
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
cirq.testing.assert_same_circuits(
Expand All @@ -156,7 +156,7 @@ def test_align_right_no_compile_context():
cirq.measure(*[q1, q2], key='a'),
]
),
context=cirq.TransformerContext(tags_to_ignore=["nocompile"]),
context=cirq.TransformerContext(tags_to_ignore=("nocompile",)),
),
cirq.Circuit(
[
Expand All @@ -170,7 +170,7 @@ def test_align_right_no_compile_context():
)


def test_align_right_deep():
def test_align_right_deep() -> None:
q1, q2 = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(
cirq.Moment([cirq.X(q1)]),
Expand Down Expand Up @@ -199,11 +199,11 @@ def test_align_right_deep():
c_nested_aligned,
cirq.CircuitOperation(c_nested_aligned).repeat(5).with_tags("preserve_tag"),
)
context = cirq.TransformerContext(tags_to_ignore=["nocompile"], deep=True)
context = cirq.TransformerContext(tags_to_ignore=("nocompile",), deep=True)
cirq.testing.assert_same_circuits(cirq.align_right(c_orig, context=context), c_expected)


def test_classical_control():
def test_classical_control() -> None:
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.H(q0), cirq.measure(q0, key='m'), cirq.X(q1).with_classical_controls('m')
Expand All @@ -212,7 +212,7 @@ def test_classical_control():
cirq.testing.assert_same_circuits(cirq.align_right(circuit), circuit)


def test_measurement_and_classical_control_same_moment_preserve_order():
def test_measurement_and_classical_control_same_moment_preserve_order() -> None:
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit()
op_measure = cirq.measure(q0, key='m')
Expand Down
Loading