diff --git a/cmdstanpy/utils.py b/cmdstanpy/utils.py index 2699ca80..3247d175 100644 --- a/cmdstanpy/utils.py +++ b/cmdstanpy/utils.py @@ -409,6 +409,22 @@ def cxx_toolchain_path( return compiler_path, tool_path +def rewrite_inf_nan( + data: Union[float, int, List[Any]] +) -> Union[str, int, float, List[Any]]: + """Replaces NaN and Infinity with string representations""" + if isinstance(data, float): + if math.isnan(data): + return 'NaN' + if math.isinf(data): + return ('+' if data > 0 else '-') + 'inf' + return data + elif isinstance(data, list): + return [rewrite_inf_nan(item) for item in data] + else: + return data + + def write_stan_json(path: str, data: Mapping[str, Any]) -> None: """ Dump a mapping of strings to data to a JSON file. @@ -430,6 +446,7 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: """ data_out = {} for key, val in data.items(): + handle_nan_inf = False if val is not None: if isinstance(val, (str, bytes)) or ( type(val).__module__ != 'numpy' @@ -440,18 +457,14 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: + f"write_stan_json for key '{key}'" ) try: - if not np.all(np.isfinite(val)): - raise ValueError( - "Input to write_stan_json has nan or infinite " - + f"values for key '{key}'" - ) + handle_nan_inf = not np.all(np.isfinite(val)) except TypeError: # handles cases like val == ['hello'] # pylint: disable=raise-missing-from raise ValueError( "Invalid type provided to " - + f"write_stan_json for key '{key}' " - + f"as part of collection {type(val)}" + f"write_stan_json for key '{key}' " + f"as part of collection {type(val)}" ) if type(val).__module__ == 'numpy': @@ -463,6 +476,9 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: else: data_out[key] = val + if handle_nan_inf: + data_out[key] = rewrite_inf_nan(data_out[key]) + with open(path, 'w') as fd: json.dump(data_out, fd) diff --git a/test/test_utils.py b/test/test_utils.py index 5d54bee2..7260d65d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -338,6 +338,27 @@ def cmp(d1, d2): with open(file_scalr) as fd: cmp(json.load(fd), dict_scalr) + # custom Stan serialization + dict_inf_nan = { + 'a': np.array( + [ + [-np.inf, np.inf, np.NaN], + [-float('inf'), float('inf'), float('NaN')], + [ + np.float32(-np.inf), + np.float32(np.inf), + np.float32(np.NaN), + ], + [1e200 * -1e200, 1e220 * 1e200, -np.nan], + ] + ) + } + dict_inf_nan_exp = {'a': [["-inf", "+inf", "NaN"]] * 4} + file_fin = os.path.join(_TMPDIR, 'inf.json') + write_stan_json(file_fin, dict_inf_nan) + with open(file_fin) as fd: + cmp(json.load(fd), dict_inf_nan_exp) + def test_write_stan_json_bad(self): file_bad = os.path.join(_TMPDIR, 'bad.json') @@ -349,14 +370,6 @@ def test_write_stan_json_bad(self): with self.assertRaises(ValueError): write_stan_json(file_bad, dict_badtype_nested) - dict_inf = {'a': [np.inf]} - with self.assertRaises(ValueError): - write_stan_json(file_bad, dict_inf) - - dict_nan = {'a': np.nan} - with self.assertRaises(ValueError): - write_stan_json(file_bad, dict_nan) - class ReadStanCsvTest(unittest.TestCase): def test_check_sampler_csv_1(self):