diff --git a/src/whylogs/core/types/typeddataconverter.py b/src/whylogs/core/types/typeddataconverter.py index 48d8d8c17a..4de1197cb1 100644 --- a/src/whylogs/core/types/typeddataconverter.py +++ b/src/whylogs/core/types/typeddataconverter.py @@ -13,7 +13,7 @@ # Dictionary mapping from type Number to type name TYPENUM_TO_NAME = {k: v for v, k in InferredType.Type.items()} INTEGRAL_TYPES = (int, np.integer) -FLOAT_TYPES = (float, np.float) +FLOAT_TYPES = (float, np.floating) class TypedDataConverter: diff --git a/tests/unit/core/statistics/datatypes/test_floattracker.py b/tests/unit/core/statistics/datatypes/test_floattracker.py index 33fad8e840..31ad4134d3 100644 --- a/tests/unit/core/statistics/datatypes/test_floattracker.py +++ b/tests/unit/core/statistics/datatypes/test_floattracker.py @@ -1,45 +1,75 @@ +import numpy as np + from whylogs.core.statistics.datatypes import FloatTracker +def _test_tracker_vs_array(tracker_of_array, array): + for val in array: + tracker_of_array.update(val) + assert tracker_of_array.count == len(array) + assert tracker_of_array.max == max(array) + assert tracker_of_array.min == min(array) + assert tracker_of_array.sum == sum(array) + + def test_values_are_min_max(): - first = FloatTracker() + tracker = FloatTracker() vals1 = [1.0, 2.0, 3.0] - for val in vals1: - first.update(val) + _test_tracker_vs_array(tracker, vals1) + + +def test_np_float(): + float32_tracker = FloatTracker() + float32_array = np.array([1.0, 2.0, 3.0], dtype=np.float32) + _test_tracker_vs_array(float32_tracker, float32_array) - assert first.count == len(vals1) - assert first.max == max(vals1) - assert first.min == min(vals1) - assert first.sum == sum(vals1) + float64_tracker = FloatTracker() + float64_array = np.array([1.0, 2.0, 3.0], dtype=np.float64) + _test_tracker_vs_array(float64_tracker, float64_array) def test_merge_floattrackers_should_addup(): - first = FloatTracker() + first_tracker = FloatTracker() vals1 = [1.0, 2.0, 3.0] - for val in vals1: - first.update(val) - assert first.count == len(vals1) - assert first.max == max(vals1) - assert first.min == min(vals1) - assert first.sum == sum(vals1) + second_tracker = FloatTracker() + vals2 = [4.0, 5.0, 6.0] + + all_vals = vals1 + vals2 + _test_merged_tracker_vs_arrays(all_vals, first_tracker, second_tracker) + - second = FloatTracker() +def _test_merged_tracker_vs_arrays(combined_arrays, first_tracker, second_tracker): + merge_first = first_tracker.merge(second_tracker) + assert merge_first.count == len(combined_arrays) + assert merge_first.max == max(combined_arrays) + assert merge_first.min == min(combined_arrays) + assert merge_first.sum == sum(combined_arrays) + merge_second = second_tracker.merge(first_tracker) + assert merge_second.__dict__ == merge_first.__dict__ + + +def test_merge_floattrackers_should_addup(): + float32_tracker = FloatTracker() + float32_array = np.array([1.0, 2.0, 3.0], dtype=np.float32) + for val in float32_array: + float32_tracker.update(val) + + float64_tracker = FloatTracker() + float64_array = np.array([1.0, 2.0, 3.0], dtype=np.float64) + for val in float64_array: + float64_tracker.update(val) + + simple_tracker = FloatTracker() vals2 = [4.0, 5.0, 6.0] for val in vals2: - second.update(val) + simple_tracker.update(val) - assert second.count == len(vals2) - assert second.max == max(vals2) - assert second.min == min(vals2) - assert second.sum == sum(vals2) + merge_32_and_simple = float32_array.tolist() + vals2 + _test_merged_tracker_vs_arrays(merge_32_and_simple, float32_tracker, simple_tracker) - all_vals = vals1 + vals2 - merge_first = first.merge(second) - assert merge_first.count == len(all_vals) - assert merge_first.max == max(all_vals) - assert merge_first.min == min(all_vals) - assert merge_first.sum == sum(all_vals) + merge_64_and_simple = float64_array.tolist() + vals2 + _test_merged_tracker_vs_arrays(merge_64_and_simple, float64_tracker, simple_tracker) - merge_second = second.merge(first) - assert merge_second.__dict__ == merge_first.__dict__ + merge_64_and_32 = np.concatenate((float64_array, float32_array)) + _test_merged_tracker_vs_arrays(merge_64_and_32, float32_tracker, float64_tracker)