diff --git a/asv.conf.json b/asv.conf.json index 35813c17405..e5c94df2f88 100644 --- a/asv.conf.json +++ b/asv.conf.json @@ -9,6 +9,7 @@ "environment_type": "virtualenv", "show_commit_url": "https://github.com/quantumlib/Cirq/commit/", "pythons": ["3.8"], + "matrix": {"env_nobuild": {"PYTHONOPTIMIZE": ["-O", ""]}}, "benchmark_dir": "benchmarks", "env_dir": ".asv/env", "results_dir": ".asv/results", diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index d68fc645b25..b246ccdbf38 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -16,6 +16,8 @@ from cirq import _import +from cirq._compat import __cirq_debug__, with_debug + # A module can only depend on modules imported earlier in this list of modules # at import time. Pytest will fail otherwise (enforced by # dev_tools/import_test.py). diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index cf123d0408b..a39833abbfe 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -14,6 +14,7 @@ """Workarounds for compatibility issues between versions and libraries.""" import contextlib +import contextvars import dataclasses import functools import importlib @@ -24,15 +25,41 @@ import traceback import warnings from types import ModuleType -from typing import Any, Callable, Dict, Optional, overload, Set, Tuple, Type, TypeVar +from typing import Any, Callable, Dict, Iterator, Optional, overload, Set, Tuple, Type, TypeVar import numpy as np import pandas as pd import sympy import sympy.printing.repr +from cirq._doc import document + ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST' +__cirq_debug__ = contextvars.ContextVar('__cirq_debug__', default=__debug__) +document( + __cirq_debug__, + "A cirq specific flag which can be used to conditionally turn off all validations across Cirq " + "to boost performance in production mode. Defaults to python's built-in constant __debug__. " + "The flag is implemented as a `ContextVar` and is thread safe.", +) + + +@contextlib.contextmanager +def with_debug(value: bool) -> Iterator[None]: + """Sets the value of global constant `cirq.__cirq_debug__` within the context. + + If `__cirq_debug__` is set to False, all validations in Cirq are disabled to optimize + performance. Users should use the `cirq.with_debug` context manager instead of manually + mutating the value of `__cirq_debug__` flag. On exit, the context manager resets the + value of `__cirq_debug__` flag to what it was before entering the context manager. + """ + token = __cirq_debug__.set(value) + try: + yield + finally: + __cirq_debug__.reset(token) + try: from functools import cached_property # pylint: disable=unused-import diff --git a/cirq-core/cirq/_compat_test.py b/cirq-core/cirq/_compat_test.py index 7515ec9f2b4..94505535dc8 100644 --- a/cirq-core/cirq/_compat_test.py +++ b/cirq-core/cirq/_compat_test.py @@ -51,6 +51,16 @@ ) +def test_with_debug(): + assert cirq.__cirq_debug__.get() + with cirq.with_debug(False): + assert not cirq.__cirq_debug__.get() + with cirq.with_debug(True): + assert cirq.__cirq_debug__.get() + assert not cirq.__cirq_debug__.get() + assert cirq.__cirq_debug__.get() + + def test_proper_repr(): v = sympy.Symbol('t') * 3 v2 = eval(proper_repr(v)) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 3dcd6b2004a..993cf3cc9ea 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -17,6 +17,7 @@ import abc import functools from typing import ( + cast, AbstractSet, Any, Callable, @@ -40,6 +41,7 @@ from cirq import protocols, value from cirq._import import LazyLoader +from cirq._compat import __cirq_debug__ from cirq.type_workarounds import NotImplementedType from cirq.ops import control_values as cv @@ -215,7 +217,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']) -> None: Raises: ValueError: The gate can't be applied to the qubits. """ - _validate_qid_shape(self, qubits) + if __cirq_debug__.get(): + _validate_qid_shape(self, qubits) def on(self, *qubits: Qid) -> 'Operation': """Returns an application of this gate to the given qubits. @@ -254,19 +257,33 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation'] raise TypeError(f'{targets[0]} object is not iterable.') t0 = list(targets[0]) iterator = [t0] if t0 and isinstance(t0[0], Qid) else t0 - for target in iterator: - if not isinstance(target, Sequence): - raise ValueError( - f'Inputs to multi-qubit gates must be Sequence[Qid].' - f' Type: {type(target)}' - ) - if not all(isinstance(x, Qid) for x in target): - raise ValueError(f'All values in sequence should be Qids, but got {target}') - if len(target) != self._num_qubits_(): - raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}') - operations.append(self.on(*target)) + if __cirq_debug__.get(): + for target in iterator: + if not isinstance(target, Sequence): + raise ValueError( + f'Inputs to multi-qubit gates must be Sequence[Qid].' + f' Type: {type(target)}' + ) + if not all(isinstance(x, Qid) for x in target): + raise ValueError(f'All values in sequence should be Qids, but got {target}') + if len(target) != self._num_qubits_(): + raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}') + operations.append(self.on(*target)) + else: + operations = [self.on(*target) for target in iterator] return operations + if not __cirq_debug__.get(): + return [ + op + for q in targets + for op in ( + self.on_each(*q) + if isinstance(q, Iterable) and not isinstance(q, str) + else [self.on(cast('cirq.Qid', q))] + ) + ] + for target in targets: if isinstance(target, Qid): operations.append(self.on(target)) @@ -617,7 +634,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']): Raises: ValueError: The operation had qids that don't match it's qid shape. """ - _validate_qid_shape(self, qubits) + if __cirq_debug__.get(): + _validate_qid_shape(self, qubits) def _commutes_( self, other: Any, *, atol: float = 1e-8 diff --git a/cirq-core/cirq/ops/raw_types_test.py b/cirq-core/cirq/ops/raw_types_test.py index 2235f43742f..81ab8f61935 100644 --- a/cirq-core/cirq/ops/raw_types_test.py +++ b/cirq-core/cirq/ops/raw_types_test.py @@ -151,6 +151,29 @@ def test_op_validate(): op2.validate_args([cirq.LineQid(1, 2), cirq.LineQid(1, 2)]) +def test_disable_op_validation(): + q0, q1 = cirq.LineQubit.range(2) + h_op = cirq.H(q0) + + # Fails normally. + with pytest.raises(ValueError, match='Wrong number'): + _ = cirq.H(q0, q1) + with pytest.raises(ValueError, match='Wrong number'): + h_op.validate_args([q0, q1]) + + # Passes, skipping validation. + with cirq.with_debug(False): + op = cirq.H(q0, q1) + assert op.qubits == (q0, q1) + h_op.validate_args([q0, q1]) + + # Fails again when validation is re-enabled. + with pytest.raises(ValueError, match='Wrong number'): + _ = cirq.H(q0, q1) + with pytest.raises(ValueError, match='Wrong number'): + h_op.validate_args([q0, q1]) + + def test_default_validation_and_inverse(): class TestGate(cirq.Gate): def _num_qubits_(self): @@ -787,6 +810,10 @@ def matrix(self): test_non_qubits = [str(i) for i in range(3)] with pytest.raises(ValueError): _ = g.on_each(*test_non_qubits) + + with cirq.with_debug(False): + assert g.on_each(*test_non_qubits)[0].qubits == ('0',) + with pytest.raises(ValueError): _ = g.on_each(*test_non_qubits) @@ -853,6 +880,10 @@ def test_on_each_two_qubits(): g.on_each([(a,)]) with pytest.raises(ValueError, match='Expected 2 qubits'): g.on_each([(a, b, a)]) + + with cirq.with_debug(False): + assert g.on_each([(a, b, a)])[0].qubits == (a, b, a) + with pytest.raises(ValueError, match='Expected 2 qubits'): g.on_each(zip([a, a])) with pytest.raises(ValueError, match='Expected 2 qubits'):