Skip to content

Commit

Permalink
Approximately equal circuits with symbols not detected by cirq.approx…
Browse files Browse the repository at this point in the history
…_eq (#3195)

Approximately equal circuits with symbols not detected by cirq.approx_eq (Fixes #3192.)

We should use sympy to make comparisons when available.
  • Loading branch information
tonybruguier-google committed Aug 10, 2020
1 parent 3c7fc5d commit 4bab2d2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 11 deletions.
26 changes: 15 additions & 11 deletions cirq/ops/eigen_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import numpy as np
import pytest
import sympy
Expand Down Expand Up @@ -132,22 +134,24 @@ def test_approx_eq():
ZGateDef(exponent=1.5),
atol=0.1
)
assert not cirq.approx_eq(
ZGateDef(exponent=1.5),
ZGateDef(exponent=sympy.Symbol('a')),
atol=0.1
)

with pytest.raises(TypeError,
match=re.escape("unsupported operand type(s) for"
" -: 'Symbol' and 'PeriodicValue'")):
cirq.approx_eq(ZGateDef(exponent=1.5),
ZGateDef(exponent=sympy.Symbol('a')),
atol=0.1)
assert cirq.approx_eq(
CExpZinGate(sympy.Symbol('a')),
CExpZinGate(sympy.Symbol('a')),
atol=0.1
)
assert not cirq.approx_eq(
CExpZinGate(sympy.Symbol('a')),
CExpZinGate(sympy.Symbol('b')),
atol=0.1
)
with pytest.raises(
AttributeError,
match="Insufficient information to decide whether expressions are "
"approximately equal .* vs .*"):
assert not cirq.approx_eq(CExpZinGate(sympy.Symbol('a')),
CExpZinGate(sympy.Symbol('b')),
atol=0.1)


def test_approx_eq_periodic():
Expand Down
9 changes: 9 additions & 0 deletions cirq/protocols/approximate_equality_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numbers
import numpy as np
import sympy

from typing_extensions import Protocol

Expand Down Expand Up @@ -94,6 +95,14 @@ def approx_eq(val: Any, other: Any, *, atol: Union[int, float] = 1e-8) -> bool:
if isinstance(val, str):
return val == other

if isinstance(val, sympy.Basic) or isinstance(other, sympy.Basic):
delta = sympy.Abs(other - val).simplify()
if not delta.is_number:
raise AttributeError('Insufficient information to decide whether '
'expressions are approximately equal '
f'[{val}] vs [{other}]')
return sympy.LessThan(delta, atol) == sympy.true

# If the values are iterable, try comparing recursively on items.
if isinstance(val, Iterable) and isinstance(other, Iterable):
return _approx_eq_iterables(val, other, atol=atol)
Expand Down
28 changes: 28 additions & 0 deletions cirq/protocols/approximate_equality_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from decimal import Decimal
from numbers import Number
import numpy as np
import pytest
import sympy
import cirq


Expand Down Expand Up @@ -152,6 +154,32 @@ def test_approx_eq_list():
assert not cirq.approx_eq([1.1, 1.2, 1.3], [1, 1, 1], atol=0.2)


def test_approx_eq_symbol():
q = cirq.GridQubit(0, 0)
s = sympy.Symbol("s")
t = sympy.Symbol("t")

assert not cirq.approx_eq(t + 1.51 + s, t + 1.50 + s, atol=0.005)
assert cirq.approx_eq(t + 1.51 + s, t + 1.50 + s, atol=0.020)

with pytest.raises(
AttributeError,
match="Insufficient information to decide whether expressions are "
"approximately equal .* vs .*"):
cirq.approx_eq(t, 0.0, atol=0.005)

symbol_1 = cirq.Circuit(cirq.rz(1.515 + s)(q))
symbol_2 = cirq.Circuit(cirq.rz(1.510 + s)(q))
assert cirq.approx_eq(symbol_1, symbol_2, atol=0.2)

symbol_3 = cirq.Circuit(cirq.rz(1.510 + t)(q))
with pytest.raises(
AttributeError,
match="Insufficient information to decide whether expressions are "
"approximately equal .* vs .*"):
cirq.approx_eq(symbol_1, symbol_3, atol=0.2)


def test_approx_eq_default():
assert cirq.approx_eq(1.0, 1.0 + 1e-9)
assert cirq.approx_eq(1.0, 1.0 - 1e-9)
Expand Down

0 comments on commit 4bab2d2

Please sign in to comment.