From b205926a3a88307f46b0979f4979436a805db7cd Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 1 Nov 2021 12:42:18 -0400 Subject: [PATCH 1/3] Handle inf/nan in json --- cmdstanpy/utils.py | 32 +++++++++++++++++++++++++------- test/test_utils.py | 19 +++++++++++-------- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/cmdstanpy/utils.py b/cmdstanpy/utils.py index 2699ca80..6784c350 100644 --- a/cmdstanpy/utils.py +++ b/cmdstanpy/utils.py @@ -409,6 +409,21 @@ def cxx_toolchain_path( return compiler_path, tool_path +def rewrite_inf_nan( + data: Union[float, int, List[Any]] +) -> Union[str, int, float, List[Any]]: + if isinstance(data, float): + if math.isnan(data): + return 'NaN' + if math.isinf(data): + return ('+' if data > 0 else '-') + 'inf' + return data + elif isinstance(data, list): + return [rewrite_inf_nan(item) for item in data] + else: + return data + + def write_stan_json(path: str, data: Mapping[str, Any]) -> None: """ Dump a mapping of strings to data to a JSON file. @@ -427,9 +442,13 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: :param data: A mapping from strings to values. This can be a dictionary or something more exotic like an :class:`xarray.Dataset`. This will be copied before type conversion, not modified + + :param handle_nan_inf: If enabled, perform the (Slow!) checks necessary to + output NaN and inf as required for Stan """ data_out = {} for key, val in data.items(): + handle_nan_inf = False if val is not None: if isinstance(val, (str, bytes)) or ( type(val).__module__ != 'numpy' @@ -440,18 +459,14 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: + f"write_stan_json for key '{key}'" ) try: - if not np.all(np.isfinite(val)): - raise ValueError( - "Input to write_stan_json has nan or infinite " - + f"values for key '{key}'" - ) + handle_nan_inf = not np.all(np.isfinite(val)) except TypeError: # handles cases like val == ['hello'] # pylint: disable=raise-missing-from raise ValueError( "Invalid type provided to " - + f"write_stan_json for key '{key}' " - + f"as part of collection {type(val)}" + f"write_stan_json for key '{key}' " + f"as part of collection {type(val)}" ) if type(val).__module__ == 'numpy': @@ -463,6 +478,9 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: else: data_out[key] = val + if handle_nan_inf: + data_out[key] = rewrite_inf_nan(data_out[key]) + with open(path, 'w') as fd: json.dump(data_out, fd) diff --git a/test/test_utils.py b/test/test_utils.py index 5d54bee2..00fad5c6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -338,6 +338,17 @@ def cmp(d1, d2): with open(file_scalr) as fd: cmp(json.load(fd), dict_scalr) + # custom Stan serialization + + dict_inf_nan = { + 'a': np.array([[-np.inf, np.inf, np.NaN, float('NaN')]]) + } + dict_inf_nan_exp = {'a': np.array([["-inf", "+inf", "NaN", "NaN"]])} + file_fin = os.path.join(_TMPDIR, 'inf.json') + write_stan_json(file_fin, dict_inf_nan) + with open(file_fin) as fd: + cmp(json.load(fd), dict_inf_nan_exp) + def test_write_stan_json_bad(self): file_bad = os.path.join(_TMPDIR, 'bad.json') @@ -349,14 +360,6 @@ def test_write_stan_json_bad(self): with self.assertRaises(ValueError): write_stan_json(file_bad, dict_badtype_nested) - dict_inf = {'a': [np.inf]} - with self.assertRaises(ValueError): - write_stan_json(file_bad, dict_inf) - - dict_nan = {'a': np.nan} - with self.assertRaises(ValueError): - write_stan_json(file_bad, dict_nan) - class ReadStanCsvTest(unittest.TestCase): def test_check_sampler_csv_1(self): From e58832c4dd79737337a4d75e2582cdd31d8641dc Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 1 Nov 2021 12:49:20 -0400 Subject: [PATCH 2/3] Docstrings --- cmdstanpy/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cmdstanpy/utils.py b/cmdstanpy/utils.py index 6784c350..3247d175 100644 --- a/cmdstanpy/utils.py +++ b/cmdstanpy/utils.py @@ -412,6 +412,7 @@ def cxx_toolchain_path( def rewrite_inf_nan( data: Union[float, int, List[Any]] ) -> Union[str, int, float, List[Any]]: + """Replaces NaN and Infinity with string representations""" if isinstance(data, float): if math.isnan(data): return 'NaN' @@ -442,9 +443,6 @@ def write_stan_json(path: str, data: Mapping[str, Any]) -> None: :param data: A mapping from strings to values. This can be a dictionary or something more exotic like an :class:`xarray.Dataset`. This will be copied before type conversion, not modified - - :param handle_nan_inf: If enabled, perform the (Slow!) checks necessary to - output NaN and inf as required for Stan """ data_out = {} for key, val in data.items(): From 86e7a81f5ba2087f04f0a354dc99a0e80826e16c Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Mon, 1 Nov 2021 14:12:04 -0400 Subject: [PATCH 3/3] Expand scalar json testing --- test/test_utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 00fad5c6..7260d65d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -339,11 +339,21 @@ def cmp(d1, d2): cmp(json.load(fd), dict_scalr) # custom Stan serialization - dict_inf_nan = { - 'a': np.array([[-np.inf, np.inf, np.NaN, float('NaN')]]) + 'a': np.array( + [ + [-np.inf, np.inf, np.NaN], + [-float('inf'), float('inf'), float('NaN')], + [ + np.float32(-np.inf), + np.float32(np.inf), + np.float32(np.NaN), + ], + [1e200 * -1e200, 1e220 * 1e200, -np.nan], + ] + ) } - dict_inf_nan_exp = {'a': np.array([["-inf", "+inf", "NaN", "NaN"]])} + dict_inf_nan_exp = {'a': [["-inf", "+inf", "NaN"]] * 4} file_fin = os.path.join(_TMPDIR, 'inf.json') write_stan_json(file_fin, dict_inf_nan) with open(file_fin) as fd: