diff --git a/_doc/examples/plot_f8.py b/_doc/examples/plot_f8.py new file mode 100644 index 0000000..9601cba --- /dev/null +++ b/_doc/examples/plot_f8.py @@ -0,0 +1,25 @@ +""" +.. _l-example-float8: + +About float 8 +============= + +Float 8 types were recently introduced to speed up the +training of deep learning models. + +Possible values ++++++++++++++++ + +First E4M3FN. +""" + +import pprint +from onnx_array_api.validation.f8 import CastFloat8 + +pprint.pprint(CastFloat8.values_e4m3fn) + + +############################################ +# Then E5M2. + +pprint.pprint(CastFloat8.values_e5m2) diff --git a/_unittests/ut_validation/test_f8.py b/_unittests/ut_validation/test_f8.py index fdbf9cb..c50cfed 100644 --- a/_unittests/ut_validation/test_f8.py +++ b/_unittests/ut_validation/test_f8.py @@ -6,6 +6,7 @@ import pandas from onnx_array_api.validation.f8 import ( CastFloat8, + UndefinedCastError, display_fe4m3, display_fe5m2, display_float16, @@ -120,16 +121,28 @@ def test_search_float32_into_fe4m3fn_simple(self): (480, 448), (0.001953125, 0.001953125), (416, 416), - (-447.5, -448), - (23.5, 24), (192.5, 192), + (304, 320), + (368, 384), + (248, 256), + (432, 448), + (100, 96), + (400, 384), + (336, 320), + (272, 256), + (23.5, 24), + (-447.5, -448), (79.5, 80), ] for v, expected in values: with self.subTest(v=v, expected=expected): - b = search_float32_into_fe4m3(v) - got = fe4m3_to_float32_float(b) - self.assertEqual(expected, got) + try: + b = search_float32_into_fe4m3(v) + except UndefinedCastError: + b = None + if b is not None: + got = fe4m3_to_float32_float(b) + self.assertEqual(expected, got) b = float32_to_fe4m3(v) got = fe4m3_to_float32_float(b) self.assertEqual(expected, got) @@ -143,6 +156,10 @@ def test_search_float32_into_fe5m2_simple(self): (20480.5, 20480), (14.5, 14), (-3584.5, -3584), + (352, 384), + (416, 384), + (0.4, 0.375), + (0.4068359375, 0.4375), ] for v, expected in values: with self.subTest(v=v, expected=expected): @@ -154,13 +171,19 @@ def test_search_float32_into_fe5m2_simple(self): got = fe5m2_to_float32_float(b) self.assertLess(abs(expected - got), 1e-5) else: - b1 = search_float32_into_fe5m2(v) + try: + b1 = search_float32_into_fe5m2(v) + except UndefinedCastError: + b1 = None + if b1 is not None: + got1 = fe5m2_to_float32_float(b1) + self.assertEqual(got1, expected) + b2 = float32_to_fe5m2(v) - self.assertEqual(b1, b2) - got1 = fe5m2_to_float32_float(b1) got2 = fe5m2_to_float32(b2) - self.assertEqual(got1, expected) self.assertEqual(got2, expected) + if b1 is not None: + self.assertEqual(b1, b2) def test_search_float32_into_fe4m3fn_equal(self): values = [(fe4m3_to_float32_float(i), i) for i in range(0, 256)] @@ -215,13 +238,33 @@ def test_search_float32_into_fe4m3fn(self): values += [(1e-9, 0), (-1e-9, 0), (1e8, 448), (-1e-8, -448)] wrong = 0 for value, origin in values: - for add in [0, -0.4, -1e-4, 1e-4, 0.4, (3, "x"), (0.3, "x")]: + for add in [ + 0, + -0.4, + -1e-4, + 1e-4, + 0.4, + (3, "x"), + (0.3, "x"), + 16, + 32, + 64, + -16, + -32, + -64, + ]: if isinstance(add, tuple): v = value * add[0] add = v - value else: v = value + add - b = search_float32_into_fe4m3(v) + try: + b = search_float32_into_fe4m3(v) + except UndefinedCastError: + if add == 0: + b = search_float32_into_fe4m3(origin) + else: + continue nf = float32_to_fe4m3(v) if b != nf: # signed, not signed zero? @@ -258,10 +301,31 @@ def test_search_float32_into_fe5m2(self): values.sort() obs = [] - values += [(1e-8, 0), (-1e-8, 0), (1e8, 448), (-1e-8, -448)] + values += [ + (1e-8, 0), + (-1e-8, 0), + (1e8, 57344), + (-1e8, -57344), + (352, 384), + (416, 384), + ] wrong = 0 - for value, _ in values: - for add in [0, -0.4, -1e-4, 1e-4, 0.4, (3, "x"), (0.3, "x")]: + for value, origin in values: + for add in [ + 0, + -0.4, + -1e-4, + 1e-4, + 0.4, + (3, "x"), + (0.3, "x"), + 16, + 32, + 64, + -16, + -32, + -64, + ]: if isinstance(add, tuple): v = value * add[0] with warnings.catch_warnings(record=True) as w: @@ -276,7 +340,13 @@ def test_search_float32_into_fe5m2(self): ) else: v = value + add - b = search_float32_into_fe5m2(v) + try: + b = search_float32_into_fe5m2(v) + except UndefinedCastError: + if add == 0: + b = search_float32_into_fe5m2(origin) + else: + continue nf = float32_to_fe5m2(v) if b != nf: # signed, not signed zero? @@ -373,7 +443,10 @@ def test_search_e4m3_pow(self): self.assertTrue(hasattr(CastFloat8, "values_e4m3fn")) for p in range(1, 40): v = 2 ** (-p) - r1 = search_float32_into_fe4m3(v) + try: + r1 = search_float32_into_fe4m3(v) + except UndefinedCastError: + continue r2 = float32_to_fe4m3(v) if r1 != r2: raise AssertionError( @@ -383,7 +456,10 @@ def test_search_e4m3_pow(self): ) for p in range(1, 40): v = -(2 ** (-p)) - r1 = search_float32_into_fe4m3(v) + try: + r1 = search_float32_into_fe4m3(v) + except UndefinedCastError: + continue r2 = float32_to_fe4m3(v) if r1 != r2: raise AssertionError( @@ -396,7 +472,10 @@ def test_search_e5m2_pow(self): self.assertTrue(hasattr(CastFloat8, "values_e5m2")) for p in range(1, 40): v = 2 ** (-p) - r1 = search_float32_into_fe5m2(v) + try: + r1 = search_float32_into_fe5m2(v) + except UndefinedCastError: + continue r2 = float32_to_fe5m2(v) if r1 != r2: raise AssertionError( @@ -406,7 +485,10 @@ def test_search_e5m2_pow(self): ) for p in range(1, 40): v = -(2 ** (-p)) - r1 = search_float32_into_fe5m2(v) + try: + r1 = search_float32_into_fe5m2(v) + except UndefinedCastError: + continue r2 = float32_to_fe5m2(v) if r1 != r2: raise AssertionError( @@ -577,7 +659,10 @@ def test_search_float32_into_fe4m3fnuz(self): add = v - value else: v = value + add - b = search_float32_into_fe4m3(v, uz=True) + try: + b = search_float32_into_fe4m3(v, uz=True) + except UndefinedCastError: + continue nf = float32_to_fe4m3(v, uz=True) if b != nf: wrong += 1 @@ -622,7 +707,10 @@ def test_search_float32_into_fe5m2fnuz(self): add = v - value else: v = value + add - b = search_float32_into_fe5m2(v, fn=True, uz=True) + try: + b = search_float32_into_fe5m2(v, fn=True, uz=True) + except UndefinedCastError: + continue nf = float32_to_fe5m2(v, fn=True, uz=True) if b != nf: wrong += 1 @@ -1066,5 +1154,4 @@ def test_float8_e5m2fnuz_negative_nan(self): if __name__ == "__main__": TestF8().test_search_float32_into_fe4m3fn_simple() - TestF8().test_search_float32_into_fe4m3fn() unittest.main(verbosity=2) diff --git a/onnx_array_api/validation/f8.py b/onnx_array_api/validation/f8.py index 43cf7e1..d95b3b3 100644 --- a/onnx_array_api/validation/f8.py +++ b/onnx_array_api/validation/f8.py @@ -4,6 +4,14 @@ # display functions +class UndefinedCastError(FloatingPointError): + """ + Unable to case a number. + """ + + pass + + def display_float32(value, sign=1, exponent=8, mantissa=23): """ Displays a float32 into b. @@ -386,7 +394,10 @@ def find_closest_value(value, sorted_values): if d1 < d2: return sorted_values[a][1] if d1 == d2: - return sorted_values[a][1] if value < 0 else sorted_values[b][1] + raise UndefinedCastError( + f"Unable to cast {value}, d1={d1}, d2={d2}, " + f"options are {sorted_values[a][1]} and {sorted_values[b][1]}." + ) return sorted_values[b][1] return sorted_values[a][1] @@ -572,7 +583,9 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True ret |= m >> 20 if (ret & 0x7F) == 0x7F: ret &= 0xFE - if m & 0x80000: + if (m & 0x80000) and ( + (m & 0x100000) or (m & 0x7C000) + ): # round to nearest even if (ret & 0x7F) < 0x7E: # rounding ret += 1 @@ -675,7 +688,9 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru ex = e - 112 # 127 - 15 ret |= ex << 2 ret |= m >> 21 - if m & 0x100000: + if m & 0x100000 and ( + (m & 0xFFFFF) or (m & 0x200000) + ): # round to nearest even if (ret & 0x7F) < 0x7B: # rounding ret += 1