Skip to content

Commit

Permalink
fix: simply code, make one less copy of data
Browse files Browse the repository at this point in the history
Makes one less copy of the values of input `data`. Adds tests.

Related issue #60.
  • Loading branch information
riddell-stan committed Mar 13, 2019
1 parent d5fe9a9 commit fa6f6a0
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 23 deletions.
25 changes: 12 additions & 13 deletions stan/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import collections.abc
from copy import deepcopy
import json
import time
import typing
Expand All @@ -18,18 +17,22 @@
import numpy as np


def _ensure_json_serializable(data: dict) -> dict:
"""Convert `data` with numpy.ndarray-like values to JSON serializable form.
def _make_json_serializable(data: dict) -> dict:
"""Convert `data` with numpy.ndarray-like values to JSON-serializable form.
Returns a new dictionary.
Arguments:
data (dict): A Python dictionary or mapping providing the data for the
model. Variable names are the keys and the values are their
associated values. Default is an empty dictionary.
Returns:
dict: Data dictionary with JSON-serializable values.
dict: Copy of `data` dict with JSON-serializable values.
"""
for key, value in data.copy().items():
# no need for deep copy, we do not modify mutable items
data = data.copy()
for key, value in data.items():
# first, see if the value is already JSON-serializable
try:
json.dumps(value)
Expand Down Expand Up @@ -199,9 +202,10 @@ def build(program_code, data=None, random_seed=None):
"""
if data is None:
data = {}
else:
data = deepcopy(data)
data = _ensure_json_serializable(data)
# _make_json_serializable returns a new dict, original `data` unchanged
data = _make_json_serializable(data)
assert all(not isinstance(value, np.ndarray) for value in data.values())

with stan.common.httpstan_server() as server:
host, port = server.host, server.port

Expand All @@ -215,11 +219,6 @@ def build(program_code, data=None, random_seed=None):
response_payload = response.json()
model_name = response_payload["name"]

# in `data`: convert numpy arrays to normal lists
for key, value in data.items():
if isinstance(value, np.ndarray):
data[key] = value.tolist()

path, payload = f"/v1/{model_name}/params", {"data": data}
response = requests.post(f"http://{host}:{port}{path}", json=payload)
response_payload = response.json()
Expand Down
20 changes: 10 additions & 10 deletions tests/test_data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@
import stan.model


def test_ensure_json():
def test_make_json_serializable():
data = {"K": 3}
assert stan.model._ensure_json_serializable(data) == data
assert stan.model._make_json_serializable(data) == data
data = {"K": 3, "x": [3, 4, 5]}
assert stan.model._ensure_json_serializable(data) == data
assert stan.model._make_json_serializable(data) == data

class DummyClass:
pass

data = {"K": DummyClass(), "x": [3, 4, 5]}
with pytest.raises(TypeError, match=r"Value associated with variable `K`"):
stan.model._ensure_json_serializable(data)
stan.model._make_json_serializable(data)


def test_ensure_json_numpy():
def test_make_json_serializable_numpy():
data = {"K": 3, "x": np.array([3, 4, 5])}
expected = {"K": 3, "x": [3, 4, 5]}
assert stan.model._ensure_json_serializable(data) == expected
assert json.dumps(stan.model._ensure_json_serializable(data))
assert stan.model._make_json_serializable(data) == expected
assert json.dumps(stan.model._make_json_serializable(data))


def test_ensure_json_pandas():
def test_make_json_serializable_pandas():
data = {"K": 3, "x": pd.Series([3, 4, 5])}
expected = {"K": 3, "x": [3, 4, 5]}
assert stan.model._ensure_json_serializable(data) == expected
assert json.dumps(stan.model._ensure_json_serializable(data))
assert stan.model._make_json_serializable(data) == expected
assert json.dumps(stan.model._make_json_serializable(data))
51 changes: 51 additions & 0 deletions tests/test_model_build_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Test handling of `data` dictionary."""
import copy

import numpy as np
import pytest

import stan


program_code = """
data {
int<lower=0> N;
int<lower=0,upper=1> y[N];
}
parameters {
real<lower=0,upper=1> theta;
}
model {
for (n in 1:N)
y[n] ~ bernoulli(theta);
}
"""

data = {"N": 10, "y": [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]}


@pytest.fixture(scope="module")
def posterior():
return stan.build(program_code, data=data)


def test_data_wrong_dtype(posterior):
# pull in posterior to cache compilation
bad_data = copy.deepcopy(data)
# float is wrong dtype
bad_data["y"] = np.array(bad_data["y"], dtype=float)
assert bad_data["y"].dtype == float
with pytest.raises(RuntimeError, match=r"int variable contained non-int values"):
stan.build(program_code, data=bad_data)


def test_data_unmodified(posterior):
# pull in posterior to cache compilation
data_with_array = copy.deepcopy(data)
# `build` will convert data into a list, should not change original
data_with_array["y"] = np.array(data_with_array["y"], dtype=int)
assert data_with_array["y"].dtype == int
stan.build(program_code, data=data_with_array)
# `data_with_array` should be unchanged
assert not isinstance(data_with_array["y"], list)
assert data_with_array["y"].dtype == int

0 comments on commit fa6f6a0

Please sign in to comment.