Skip to content

Commit

Permalink
Fix serialization breakage due to sympy.Rational now being a numbers.…
Browse files Browse the repository at this point in the history
…Number (#2655)

- Reorder the sympy checks to come before generic number checks
- Rename json.py to json_serialization.py to avoid collisions with the built-in json library
- Detect integral values when deserializing symbolic protos, so that x - y does not become x - 1.0*y.
- Pin to 1.4 until sympy/sympy#18056 is fixed

Fixes #2650
Fixes #2646
  • Loading branch information
Strilanc authored and CirqBot committed Dec 17, 2019
1 parent c78d345 commit 9a5a09e
Show file tree
Hide file tree
Showing 13 changed files with 49 additions and 41 deletions.
24 changes: 12 additions & 12 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,18 @@ matrix:
- python -m pip install -r dev_tools/conf/pip-list-dev-tools.txt
- sudo apt-get install pandoc
script: dev_tools/build-docs.sh
- stage: test
os: osx
env: NAME=pytest (macOS)
python: '3.6'
language: generic
install:
- python3.6 -m venv venv
- source venv/bin/activate
- python -m pip install -r requirements.txt
- python -m pip install -r cirq/contrib/contrib-requirements.txt
- python -m pip install -r dev_tools/conf/pip-list-dev-tools.txt
script: check/pytest --benchmark-skip
#- stage: test
# os: osx
# env: NAME=pytest (macOS)
# python: '3.6'
# language: generic
# install:
# - python3.6 -m venv venv
# - source venv/bin/activate
# - python -m pip install -r requirements.txt
# - python -m pip install -r cirq/contrib/contrib-requirements.txt
# - python -m pip install -r dev_tools/conf/pip-list-dev-tools.txt
# script: check/pytest --benchmark-skip
- stage: test
os: linux
env: NAME=pytest (without contrib)
Expand Down
2 changes: 1 addition & 1 deletion check/pytest-changed-files
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ changed=$(git diff --name-only ${rev} -- \
if git diff --name-only "${rev}" -- | grep "__init__\.py$" > /dev/null; then
# Include global API tests when an __init__ file is touched.
changed+=('docs/docs_coverage_test.py')
changed+=('cirq/protocols/json_test.py')
changed+=('cirq/protocols/json_serialization_test.py')
fi
num_changed=$(echo -e "${changed[@]}" | wc -w)

Expand Down
2 changes: 1 addition & 1 deletion check/pytest-changed-files-and-incremental-coverage
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ changed_python_tests=$(git diff --name-only "${rev}" -- \
if git diff --name-only "${rev}" -- | grep "__init__\.py$" > /dev/null; then
# Include global API tests when an __init__ file is touched.
changed_python_tests+=('docs/docs_coverage_test.py')
changed_python_tests+=('cirq/protocols/json_test.py')
changed_python_tests+=('cirq/protocols/json_serialization_test.py')
fi
if [ "${#changed_python_tests[@]}" -eq 0 ]; then
echo -e "\033[33mNo changed files with associated python tests.\033[0m" >&2
Expand Down
2 changes: 1 addition & 1 deletion cirq/contrib/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from cirq.protocols.json import DEFAULT_RESOLVERS
from cirq.protocols.json_serialization import DEFAULT_RESOLVERS


def contrib_class_resolver(cirq_type: str):
Expand Down
7 changes: 5 additions & 2 deletions cirq/google/arg_func_langs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import (
List,
Union,
Expand Down Expand Up @@ -180,7 +180,10 @@ def _arg_from_proto(
arg_value = arg_proto.arg_value
which_val = arg_value.WhichOneof('arg_value')
if which_val == 'float_value':
return float(arg_value.float_value)
result = float(arg_value.float_value)
if math.ceil(result) == math.floor(result):
result = int(result)
return result
if which_val == 'bool_values':
return list(arg_value.bool_values.values)
if which_val == 'string_value':
Expand Down
2 changes: 1 addition & 1 deletion cirq/interop/quirk/cells/parse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_parse_real_formula():
assert parse_formula('1/2') == 0.5
assert parse_formula('t*t + ln(t)') == t * t + sympy.ln(t)
assert parse_formula('cos(pi*t)') == sympy.cos(sympy.pi * t)
assert parse_formula('5t') == 5 * t
assert parse_formula('5t') == 5.0 * t
np.testing.assert_allclose(parse_formula('cos(pi)'), -1, atol=1e-8)
assert type(parse_formula('cos(pi)')) is float

Expand Down
2 changes: 1 addition & 1 deletion cirq/optimizers/eject_phased_paulis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_absorbs_z():
[cirq.T(q)**-x],
),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0.125 + x / 8).on(q)],
[cirq.PhasedXPowGate(phase_exponent=0.125 + x * 0.125).on(q)],
[],
[],
),
Expand Down
2 changes: 1 addition & 1 deletion cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)
from cirq.protocols.inverse_protocol import (
inverse,)
from cirq.protocols.json import (
from cirq.protocols.json_serialization import (
to_json,
read_json,
obj_to_dict_helper,
Expand Down
37 changes: 21 additions & 16 deletions cirq/protocols/json.py → cirq/protocols/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,23 +250,11 @@ class CirqEncoder(json.JSONEncoder):
"""

def default(self, o):
# Object with custom method?
if hasattr(o, '_json_dict_'):
return o._json_dict_()
if isinstance(o, np.bool_):
return bool(o)
if isinstance(o, numbers.Integral):
return int(o)
if isinstance(o, numbers.Real):
return float(o)
if isinstance(o, numbers.Complex):
return {
'cirq_type': 'complex',
'real': o.real,
'imag': o.imag,
}
if isinstance(o, np.ndarray):
return o.tolist()

# Sympy object? (Must come before general number checks.)
# TODO: More support for sympy
# https://github.com/quantumlib/Cirq/issues/2014
if isinstance(o, sympy.Symbol):
Expand All @@ -288,20 +276,37 @@ def default(self, o):
'q': o.q,
}

# A basic number object?
if isinstance(o, numbers.Integral):
return int(o)
if isinstance(o, numbers.Real):
return float(o)
if isinstance(o, numbers.Complex):
return {
'cirq_type': 'complex',
'real': o.real,
'imag': o.imag,
}

# Numpy object?
if isinstance(o, np.bool_):
return bool(o)
if isinstance(o, np.ndarray):
return o.tolist()

# Pandas object?
if isinstance(o, pd.MultiIndex):
return {
'cirq_type': 'pandas.MultiIndex',
'tuples': list(o),
'names': list(o.names),
}

if isinstance(o, pd.Index):
return {
'cirq_type': 'pandas.Index',
'data': list(o),
'name': o.name,
}

if isinstance(o, pd.DataFrame):
cols = [o[col].tolist() for col in o.columns]
rows = list(zip(*cols))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import cirq
from cirq._compat import proper_repr, proper_eq
from cirq.testing import assert_json_roundtrip_works
from cirq.protocols.json_serialization import RESOLVER_CACHE

TEST_DATA_PATH = pathlib.Path(__file__).parent / 'json_test_data'
TEST_DATA_REL = 'cirq/protocols/json_test_data'
Expand Down Expand Up @@ -324,7 +325,6 @@ def _find_classes_that_should_serialize() -> Set[Tuple[str, Type]]:
result.update(_get_all_public_classes(cirq))
result.update(_get_all_public_classes(cirq.google))

from cirq.protocols.json import RESOLVER_CACHE
for k, v in RESOLVER_CACHE.cirq_class_resolver_dictionary.items():
t = v if isinstance(v, type) else None
result.add((k, t))
Expand Down Expand Up @@ -464,7 +464,7 @@ def test_json_test_data_coverage(cirq_obj_name: str, cls):
f"docstring for protocols.SupportsJSON. If this object or "
f"class is not appropriate for serialization, add its name to "
f"the SHOULDNT_BE_SERIALIZED list in the "
f"cirq/protocols/json_test.py source file."))
f"cirq/protocols/json_serialization_test.py source file."))

repr_file = TEST_DATA_PATH / f'{cirq_obj_name}.repr'
if repr_file.exists() and cls is not None:
Expand Down
2 changes: 1 addition & 1 deletion dev_tools/bash_scripts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_pytest_changed_files_file_selection(tmpdir_factory):
'echo x > __init__.py\n')
assert result.exit_code == 0
assert result.out == ('INTERCEPTED pytest docs/docs_coverage_test.py '
'cirq/protocols/json_test.py\n')
'cirq/protocols/json_serialization_test.py\n')
assert result.err.split() == (
"Comparing against revision 'HEAD'.\n"
"Found 2 test files associated with changes.\n").split()
Expand Down
2 changes: 1 addition & 1 deletion docs/dev/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ If the returned object has a `_from_json_dict_` attribute, it is called instead.

All of Cirq's public classes should be serializable.
This is enforced by the `test_json_test_data_coverage` test in
`cirq/protocols/json_test.py`, which iterates over cirq's API looking for types
`cirq/protocols/json_serialization_test.py`, which iterates over cirq's API looking for types
with no associated json test data.

There are several steps needed to get a object serializing, deserializing, and
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ protobuf==3.8.0
requests~=2.18
sortedcontainers~=2.0
scipy
sympy
sympy==1.4 # Can't use 1.5 until https://github.com/sympy/sympy/issues/18056 is fixed
typing_extensions

0 comments on commit 9a5a09e

Please sign in to comment.