Skip to content

Commit

Permalink
feat: Model accepts a wider range of data
Browse files Browse the repository at this point in the history
Values which resemble Numpy `ndarray` and pandas `Series` should be
accepted. They will be converted into (JSON-serializable) lists.

Closes #51
  • Loading branch information
riddell-stan committed Dec 16, 2018
1 parent b7929d3 commit 661d064
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
35 changes: 35 additions & 0 deletions stan/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import collections.abc
import json
import time
import typing

import requests

import httpstan.models
import httpstan.schemas
import httpstan.services.arguments as arguments
import httpstan.utils
import stan.common
import stan.fit

Expand All @@ -13,6 +17,36 @@
import numpy as np


def _ensure_json_serializable(data: dict) -> dict:
"""Convert `data` with numpy.ndarray-like values to JSON serializable form.
Arguments:
data (dict): A Python dictionary or mapping providing the data for the
model. Variable names are the keys and the values are their
associated values. Default is an empty dictionary.
Returns:
dict: Data dictionary with JSON-serializable values.
"""
for key, value in data.copy().items():
# first, see if the value is already JSON-serializable
try:
json.dumps(value)
except TypeError:
pass
else:
continue
# numpy scalar
if isinstance(value, np.ndarray) and value.ndim == 0:
data[key] = np.asarray(value).tolist()
# numpy.ndarray, pandas.Series, and anything similar
elif isinstance(value, collections.abc.Collection):
data[key] = np.asarray(value).tolist()
else:
raise TypeError(f"Value associated with variable `{key}` is not JSON serializable.")
return data


class Model:
"""Stores data associated with and proxies calls to a Stan model.
Expand Down Expand Up @@ -164,6 +198,7 @@ def build(program_code, data=None, random_seed=None):
"""
if data is None:
data = {}
data = _ensure_json_serializable(data)
with stan.common.httpstan_server() as server:
host, port = server.host, server.port

Expand Down
35 changes: 35 additions & 0 deletions tests/test_data_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json

import numpy as np
import pandas as pd
import pytest

import stan.model


def test_ensure_json():
data = {"K": 3}
assert stan.model._ensure_json_serializable(data) == data
data = {"K": 3, "x": [3, 4, 5]}
assert stan.model._ensure_json_serializable(data) == data

class DummyClass:
pass

data = {"K": DummyClass(), "x": [3, 4, 5]}
with pytest.raises(TypeError, match=r"Value associated with variable `K`"):
stan.model._ensure_json_serializable(data)


def test_ensure_json_numpy():
data = {"K": 3, "x": np.array([3, 4, 5])}
expected = {"K": 3, "x": [3, 4, 5]}
assert stan.model._ensure_json_serializable(data) == expected
assert json.dumps(stan.model._ensure_json_serializable(data))


def test_ensure_json_pandas():
data = {"K": 3, "x": pd.Series([3, 4, 5])}
expected = {"K": 3, "x": [3, 4, 5]}
assert stan.model._ensure_json_serializable(data) == expected
assert json.dumps(stan.model._ensure_json_serializable(data))
13 changes: 13 additions & 0 deletions tests/test_eight_schools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np
import pandas as pd
import pytest

import stan
Expand Down Expand Up @@ -41,6 +43,17 @@ def test_eight_schools_build(posterior):
assert posterior is not None


def test_eight_schools_build_numpy(posterior):
"""Verify eight schools compiles."""
schools_data_alt = {
"J": 8,
"y": np.array([28, 8, -3, 7, -1, 1, 18, 12]),
"sigma": pd.Series([15, 10, 16, 11, 9, 11, 10, 18], name="sigma"),
}
posterior_alt = stan.build(program_code, data=schools_data_alt)
assert posterior_alt is not None


def test_eight_schools_sample(posterior):
"""Sample from a simple model."""
num_chains, num_samples = 2, 200
Expand Down

0 comments on commit 661d064

Please sign in to comment.