Skip to content

Commit

Permalink
Extending types for symbol resolution fast pass-through (#3366)
Browse files Browse the repository at this point in the history
Extends ParamResolver's logic to circumvent sympy's slowness to members of numbers.Number. 
It also generalizes sympy constants instead of only handling pi and NegativeOne.

Fixes #3359.
  • Loading branch information
balopat committed Sep 30, 2020
1 parent fa80d4c commit 46f8df7
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 25 deletions.
38 changes: 25 additions & 13 deletions cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Resolves ParameterValues to assigned values."""

import numbers
from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Union, cast
import numpy as np
import sympy
Expand Down Expand Up @@ -89,18 +89,21 @@ def value_of(self,
Returns:
The value of the parameter as resolved by this resolver.
"""
# Input is a float, no resolution needed: return early
if isinstance(value, float):
return value

# Input is a pass through type, no resolution needed: return early
v = _sympy_pass_through(value)
if v is not None:
return v

# Handles 2 cases:
# Input is a string and maps to a number in the dictionary
# Input is a symbol and maps to a number in the dictionary
# In both cases, return it directly.
if value in self.param_dict:
param_value = self.param_dict[value]
if isinstance(param_value, (float, int)):
return param_value
v = _sympy_pass_through(param_value)
if v is not None:
return v

# Input is a string and is not in the dictionary.
# Treat it as a symbol instead.
Expand All @@ -111,10 +114,11 @@ def value_of(self,

# Input is a symbol (sympy.Symbol('a')) and its string maps to a number
# in the dictionary ({'a': 1.0}). Return it.
if (isinstance(value, sympy.Symbol) and value.name in self.param_dict):
if isinstance(value, sympy.Symbol) and value.name in self.param_dict:
param_value = self.param_dict[value.name]
if isinstance(param_value, (float, int)):
return param_value
v = _sympy_pass_through(param_value)
if v is not None:
return v

# The following resolves common sympy expressions
# If sympy did its job and wasn't slower than molasses,
Expand All @@ -132,10 +136,6 @@ def value_of(self,
if isinstance(value, sympy.Pow) and len(value.args) == 2:
return np.power(self.value_of(value.args[0]),
self.value_of(value.args[1]))
if value == sympy.pi:
return np.pi
if value == sympy.S.NegativeOne:
return -1

# Input is either a sympy formula or the dictionary maps to a
# formula. Use sympy to resolve the value.
Expand Down Expand Up @@ -193,3 +193,15 @@ def _json_dict_(self) -> Dict[str, Any]:
@classmethod
def _from_json_dict_(cls, param_dict, **kwargs):
return cls(dict(param_dict))


def _sympy_pass_through(val: Any) -> Optional[Any]:
if isinstance(val, numbers.Number) and not isinstance(val, sympy.Basic):
return val
if isinstance(val, sympy.core.numbers.IntegerConstant):
return val.p
if isinstance(val, sympy.core.numbers.RationalConstant):
return val.p / val.q
if val == sympy.pi:
return np.pi
return None
111 changes: 99 additions & 12 deletions cirq/study/resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,120 @@
# limitations under the License.

"""Tests for parameter resolvers."""
import fractions

import numpy as np
import pytest
import sympy

import cirq


def test_value_of():
@pytest.mark.parametrize('val', [
3.2,
np.float32(3.2),
int(1),
np.int(3),
np.int32(45),
np.float64(6.3),
np.int32(2),
np.complex64(1j),
np.complex128(2j),
np.complex(1j),
fractions.Fraction(3, 2),
])
def test_value_of_pass_through_types(val):
_assert_consistent_resolution(val, val)


@pytest.mark.parametrize('val,resolved', [(sympy.pi, np.pi),
(sympy.S.NegativeOne, -1),
(sympy.S.Half, 0.5),
(sympy.S.One, 1)])
def test_value_of_transformed_types(val, resolved):
_assert_consistent_resolution(val, resolved)


@pytest.mark.parametrize('val,resolved', [(sympy.I, 1j)])
def test_value_of_substituted_types(val, resolved):
_assert_consistent_resolution(val, resolved, True)


def _assert_consistent_resolution(v, resolved, subs_called=False):
"""Asserts that parameter resolution works consistently.
The ParamResolver.value_of method can resolve any Sympy expression -
subclasses of sympy.Basic. In the generic case, it calls `sympy.Basic.subs`
to substitute symbols with values specified in a dict, which is known to be
very slow. Instead value_of defines a pass-through shortcut for known
numeric types. For a given value `v` it is asserted that value_of resolves
it to `resolved`, with the exact type of `resolved`.`subs_called` indicates
whether it is expected to have `subs` called or not during the resolution.
Args:
v: the value to resolve
resolved: the expected resolution result
subs_called: if True, it is expected that the slow subs method is called
Raises:
AssertionError in case resolution assertion fail.
"""

class SubsAwareSymbol(sympy.Symbol):
"""A Symbol that registers a call to its `subs` method."""

def __init__(self, sym: str):
self.called = False
self.symbol = sympy.Symbol(sym)

# note: super().subs() doesn't resolve based on the param_dict properly
# for some reason, that's why a delegate (self.symbol) is used instead
def subs(self, *args, **kwargs):
self.called = True
return self.symbol.subs(*args, **kwargs)

r = cirq.ParamResolver({'a': v})

# symbol based resolution
s = SubsAwareSymbol('a')
assert r.value_of(s) == resolved, (f"expected {resolved}, "
f"got {r.value_of(s)}")
assert subs_called == s.called, (
f"For pass-through type "
f"{type(v)} sympy.subs shouldn't have been called.")
assert isinstance(r.value_of(s),
type(resolved)), (f"expected {type(resolved)} "
f"got {type(r.value_of(s))}")

# string based resolution (which in turn uses symbol based resolution)
assert r.value_of('a') == resolved, (f"expected {resolved}, "
f"got {r.value_of('a')}")
assert isinstance(r.value_of('a'),
type(resolved)), (f"expected {type(resolved)} "
f"got {type(r.value_of('a'))}")

# value based resolution
assert r.value_of(v) == resolved, (f"expected {resolved}, "
f"got {r.value_of(v)}")
assert isinstance(r.value_of(v),
type(resolved)), (f"expected {type(resolved)} "
f"got {type(r.value_of(v))}")


def test_value_of_strings():
assert cirq.ParamResolver().value_of('x') == sympy.Symbol('x')


def test_value_of_calculations():
assert not bool(cirq.ParamResolver())

r = cirq.ParamResolver({'a': 0.5, 'b': 0.1, 'c': 1 + 1j})
assert bool(r)

assert r.value_of('x') == sympy.Symbol('x')
assert r.value_of('a') == 0.5
assert r.value_of(sympy.Symbol('a')) == 0.5
assert r.value_of(0.5) == 0.5
assert r.value_of(sympy.Symbol('b')) == 0.1
assert r.value_of(0.3) == 0.3
assert r.value_of(sympy.Symbol('a') * 3) == 1.5
assert r.value_of(sympy.Symbol('b') / 0.1 - sympy.Symbol('a')) == 0.5

assert r.value_of(sympy.pi) == np.pi
assert r.value_of(2 * sympy.pi) == 2 * np.pi
assert r.value_of(4**sympy.Symbol('a') + sympy.Symbol('b') * 10) == 3
assert r.value_of('c') == 1 + 1j
assert r.value_of(sympy.I * sympy.pi) == np.pi * 1j
assert r.value_of(sympy.Symbol('a') * 3) == 1.5
assert r.value_of(sympy.Symbol('b') / 0.1 - sympy.Symbol('a')) == 0.5


def test_param_dict():
Expand Down

0 comments on commit 46f8df7

Please sign in to comment.