Skip to content

Commit

Permalink
Add __cirq_debug__ flag and conditionally disable qid validations i…
Browse files Browse the repository at this point in the history
…n gates and operations (#6000)

* Add __cirq_debug__ flag and conditionally disable qid validations in gates and operations

* fix mypy errors

* Fix typo

* Address comments and add a context manager

* Address nit
  • Loading branch information
tanujkhattar committed Feb 16, 2023
1 parent fd491e0 commit 27dd607
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 14 deletions.
1 change: 1 addition & 0 deletions asv.conf.json
Expand Up @@ -9,6 +9,7 @@
"environment_type": "virtualenv",
"show_commit_url": "https://github.com/quantumlib/Cirq/commit/",
"pythons": ["3.8"],
"matrix": {"env_nobuild": {"PYTHONOPTIMIZE": ["-O", ""]}},
"benchmark_dir": "benchmarks",
"env_dir": ".asv/env",
"results_dir": ".asv/results",
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Expand Up @@ -16,6 +16,8 @@

from cirq import _import

from cirq._compat import __cirq_debug__, with_debug

# A module can only depend on modules imported earlier in this list of modules
# at import time. Pytest will fail otherwise (enforced by
# dev_tools/import_test.py).
Expand Down
29 changes: 28 additions & 1 deletion cirq-core/cirq/_compat.py
Expand Up @@ -14,6 +14,7 @@

"""Workarounds for compatibility issues between versions and libraries."""
import contextlib
import contextvars
import dataclasses
import functools
import importlib
Expand All @@ -24,15 +25,41 @@
import traceback
import warnings
from types import ModuleType
from typing import Any, Callable, Dict, Optional, overload, Set, Tuple, Type, TypeVar
from typing import Any, Callable, Dict, Iterator, Optional, overload, Set, Tuple, Type, TypeVar

import numpy as np
import pandas as pd
import sympy
import sympy.printing.repr

from cirq._doc import document

ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'

__cirq_debug__ = contextvars.ContextVar('__cirq_debug__', default=__debug__)
document(
__cirq_debug__,
"A cirq specific flag which can be used to conditionally turn off all validations across Cirq "
"to boost performance in production mode. Defaults to python's built-in constant __debug__. "
"The flag is implemented as a `ContextVar` and is thread safe.",
)


@contextlib.contextmanager
def with_debug(value: bool) -> Iterator[None]:
"""Sets the value of global constant `cirq.__cirq_debug__` within the context.
If `__cirq_debug__` is set to False, all validations in Cirq are disabled to optimize
performance. Users should use the `cirq.with_debug` context manager instead of manually
mutating the value of `__cirq_debug__` flag. On exit, the context manager resets the
value of `__cirq_debug__` flag to what it was before entering the context manager.
"""
token = __cirq_debug__.set(value)
try:
yield
finally:
__cirq_debug__.reset(token)


try:
from functools import cached_property # pylint: disable=unused-import
Expand Down
10 changes: 10 additions & 0 deletions cirq-core/cirq/_compat_test.py
Expand Up @@ -51,6 +51,16 @@
)


def test_with_debug():
assert cirq.__cirq_debug__.get()
with cirq.with_debug(False):
assert not cirq.__cirq_debug__.get()
with cirq.with_debug(True):
assert cirq.__cirq_debug__.get()
assert not cirq.__cirq_debug__.get()
assert cirq.__cirq_debug__.get()


def test_proper_repr():
v = sympy.Symbol('t') * 3
v2 = eval(proper_repr(v))
Expand Down
44 changes: 31 additions & 13 deletions cirq-core/cirq/ops/raw_types.py
Expand Up @@ -17,6 +17,7 @@
import abc
import functools
from typing import (
cast,
AbstractSet,
Any,
Callable,
Expand All @@ -40,6 +41,7 @@

from cirq import protocols, value
from cirq._import import LazyLoader
from cirq._compat import __cirq_debug__
from cirq.type_workarounds import NotImplementedType
from cirq.ops import control_values as cv

Expand Down Expand Up @@ -215,7 +217,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']) -> None:
Raises:
ValueError: The gate can't be applied to the qubits.
"""
_validate_qid_shape(self, qubits)
if __cirq_debug__.get():
_validate_qid_shape(self, qubits)

def on(self, *qubits: Qid) -> 'Operation':
"""Returns an application of this gate to the given qubits.
Expand Down Expand Up @@ -254,19 +257,33 @@ def on_each(self, *targets: Union[Qid, Iterable[Any]]) -> List['cirq.Operation']
raise TypeError(f'{targets[0]} object is not iterable.')
t0 = list(targets[0])
iterator = [t0] if t0 and isinstance(t0[0], Qid) else t0
for target in iterator:
if not isinstance(target, Sequence):
raise ValueError(
f'Inputs to multi-qubit gates must be Sequence[Qid].'
f' Type: {type(target)}'
)
if not all(isinstance(x, Qid) for x in target):
raise ValueError(f'All values in sequence should be Qids, but got {target}')
if len(target) != self._num_qubits_():
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
operations.append(self.on(*target))
if __cirq_debug__.get():
for target in iterator:
if not isinstance(target, Sequence):
raise ValueError(
f'Inputs to multi-qubit gates must be Sequence[Qid].'
f' Type: {type(target)}'
)
if not all(isinstance(x, Qid) for x in target):
raise ValueError(f'All values in sequence should be Qids, but got {target}')
if len(target) != self._num_qubits_():
raise ValueError(f'Expected {self._num_qubits_()} qubits, got {target}')
operations.append(self.on(*target))
else:
operations = [self.on(*target) for target in iterator]
return operations

if not __cirq_debug__.get():
return [
op
for q in targets
for op in (
self.on_each(*q)
if isinstance(q, Iterable) and not isinstance(q, str)
else [self.on(cast('cirq.Qid', q))]
)
]

for target in targets:
if isinstance(target, Qid):
operations.append(self.on(target))
Expand Down Expand Up @@ -617,7 +634,8 @@ def validate_args(self, qubits: Sequence['cirq.Qid']):
Raises:
ValueError: The operation had qids that don't match it's qid shape.
"""
_validate_qid_shape(self, qubits)
if __cirq_debug__.get():
_validate_qid_shape(self, qubits)

def _commutes_(
self, other: Any, *, atol: float = 1e-8
Expand Down
31 changes: 31 additions & 0 deletions cirq-core/cirq/ops/raw_types_test.py
Expand Up @@ -151,6 +151,29 @@ def test_op_validate():
op2.validate_args([cirq.LineQid(1, 2), cirq.LineQid(1, 2)])


def test_disable_op_validation():
q0, q1 = cirq.LineQubit.range(2)
h_op = cirq.H(q0)

# Fails normally.
with pytest.raises(ValueError, match='Wrong number'):
_ = cirq.H(q0, q1)
with pytest.raises(ValueError, match='Wrong number'):
h_op.validate_args([q0, q1])

# Passes, skipping validation.
with cirq.with_debug(False):
op = cirq.H(q0, q1)
assert op.qubits == (q0, q1)
h_op.validate_args([q0, q1])

# Fails again when validation is re-enabled.
with pytest.raises(ValueError, match='Wrong number'):
_ = cirq.H(q0, q1)
with pytest.raises(ValueError, match='Wrong number'):
h_op.validate_args([q0, q1])


def test_default_validation_and_inverse():
class TestGate(cirq.Gate):
def _num_qubits_(self):
Expand Down Expand Up @@ -787,6 +810,10 @@ def matrix(self):
test_non_qubits = [str(i) for i in range(3)]
with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)

with cirq.with_debug(False):
assert g.on_each(*test_non_qubits)[0].qubits == ('0',)

with pytest.raises(ValueError):
_ = g.on_each(*test_non_qubits)

Expand Down Expand Up @@ -853,6 +880,10 @@ def test_on_each_two_qubits():
g.on_each([(a,)])
with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each([(a, b, a)])

with cirq.with_debug(False):
assert g.on_each([(a, b, a)])[0].qubits == (a, b, a)

with pytest.raises(ValueError, match='Expected 2 qubits'):
g.on_each(zip([a, a]))
with pytest.raises(ValueError, match='Expected 2 qubits'):
Expand Down

0 comments on commit 27dd607

Please sign in to comment.