Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions cmdstanpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,22 @@ def cxx_toolchain_path(
return compiler_path, tool_path


def rewrite_inf_nan(
Comment thread
WardBrian marked this conversation as resolved.
data: Union[float, int, List[Any]]
) -> Union[str, int, float, List[Any]]:
"""Replaces NaN and Infinity with string representations"""
if isinstance(data, float):
if math.isnan(data):
return 'NaN'
if math.isinf(data):
return ('+' if data > 0 else '-') + 'inf'
return data
elif isinstance(data, list):
return [rewrite_inf_nan(item) for item in data]
else:
return data


def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
"""
Dump a mapping of strings to data to a JSON file.
Expand All @@ -430,6 +446,7 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
"""
data_out = {}
for key, val in data.items():
handle_nan_inf = False
if val is not None:
if isinstance(val, (str, bytes)) or (
type(val).__module__ != 'numpy'
Expand All @@ -440,18 +457,14 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
+ f"write_stan_json for key '{key}'"
)
try:
if not np.all(np.isfinite(val)):
raise ValueError(
"Input to write_stan_json has nan or infinite "
+ f"values for key '{key}'"
)
handle_nan_inf = not np.all(np.isfinite(val))
except TypeError:
# handles cases like val == ['hello']
# pylint: disable=raise-missing-from
raise ValueError(
"Invalid type provided to "
+ f"write_stan_json for key '{key}' "
+ f"as part of collection {type(val)}"
f"write_stan_json for key '{key}' "
f"as part of collection {type(val)}"
)

if type(val).__module__ == 'numpy':
Expand All @@ -463,6 +476,9 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None:
else:
data_out[key] = val

if handle_nan_inf:
data_out[key] = rewrite_inf_nan(data_out[key])

with open(path, 'w') as fd:
json.dump(data_out, fd)

Expand Down
29 changes: 21 additions & 8 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,27 @@ def cmp(d1, d2):
with open(file_scalr) as fd:
cmp(json.load(fd), dict_scalr)

# custom Stan serialization
dict_inf_nan = {
'a': np.array(
[
[-np.inf, np.inf, np.NaN],
[-float('inf'), float('inf'), float('NaN')],
[
np.float32(-np.inf),
np.float32(np.inf),
np.float32(np.NaN),
],
[1e200 * -1e200, 1e220 * 1e200, -np.nan],
]
)
}
dict_inf_nan_exp = {'a': [["-inf", "+inf", "NaN"]] * 4}
file_fin = os.path.join(_TMPDIR, 'inf.json')
write_stan_json(file_fin, dict_inf_nan)
with open(file_fin) as fd:
cmp(json.load(fd), dict_inf_nan_exp)

def test_write_stan_json_bad(self):
file_bad = os.path.join(_TMPDIR, 'bad.json')

Expand All @@ -349,14 +370,6 @@ def test_write_stan_json_bad(self):
with self.assertRaises(ValueError):
write_stan_json(file_bad, dict_badtype_nested)

dict_inf = {'a': [np.inf]}
with self.assertRaises(ValueError):
write_stan_json(file_bad, dict_inf)

dict_nan = {'a': np.nan}
with self.assertRaises(ValueError):
write_stan_json(file_bad, dict_nan)


class ReadStanCsvTest(unittest.TestCase):
def test_check_sampler_csv_1(self):
Expand Down