From 98628c78b6a4d000c3d5ca61b26dfdafc8b9c9eb Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 29 Mar 2023 14:42:55 +0200 Subject: [PATCH] fix f8 conversion --- _unittests/ut_validation/test_f8.py | 354 +++++++++++++++++++++++++--- onnx_array_api/validation/f8.py | 167 ++++++++++--- 2 files changed, 445 insertions(+), 76 deletions(-) diff --git a/_unittests/ut_validation/test_f8.py b/_unittests/ut_validation/test_f8.py index b432fea..fdbf9cb 100644 --- a/_unittests/ut_validation/test_f8.py +++ b/_unittests/ut_validation/test_f8.py @@ -190,8 +190,8 @@ def test_search_float32_into_fe5m2_equal(self): with self.subTest( value=value, expected=expected, bin=display_float32(value) ): - b = search_float32_into_fe5m2(value) - nf = float32_to_fe5m2(value) + b = search_float32_into_fe5m2(value, saturate=False) + nf = float32_to_fe5m2(value, saturate=False) cf = new_cvt_float32_to_e5m2(value) if expected in {253, 254, 255, 125, 126, 127}: # nan self.assertIn(b, {253, 254, 255, 125, 126, 127}) @@ -330,7 +330,7 @@ def test_inf_nan(self): v_float32_to_fe5m2 = numpy.vectorize(float32_to_fe5m2) v_fe5m2_to_float32 = numpy.vectorize(fe5m2_to_float32) - got = v_fe4m3_to_float32(v_float32_to_fe4m3(np_fp32)) + got = v_fe4m3_to_float32(v_float32_to_fe4m3(np_fp32, saturate=False)) expected = numpy.array( [ 0.46875, @@ -349,7 +349,7 @@ def test_inf_nan(self): dtype=numpy.float32, ) self.assertEqualArray(expected, got) - got = v_fe5m2_to_float32(v_float32_to_fe5m2(np_fp32)) + got = v_fe5m2_to_float32(v_float32_to_fe5m2(np_fp32, saturate=False)) expected = numpy.array( [ 0.5, @@ -416,30 +416,28 @@ def test_search_e5m2_pow(self): ) def test_float32_to_fe4m3fn_inf(self): - mx = numpy.float32(numpy.nan) - v0 = numpy.float32(mx) + v0 = numpy.float32(numpy.nan) v1 = numpy.float32(numpy.inf) - a = search_float32_into_fe4m3(v0) - b = search_float32_into_fe4m3(v1) + a = search_float32_into_fe4m3(v0, saturate=False) + b = search_float32_into_fe4m3(v1, saturate=False) self.assertEqual(a, b) - v0 = numpy.float32(mx) + v0 = numpy.float32(numpy.nan) v1 = numpy.float32(numpy.inf) - a = float32_to_fe4m3(v0) - b = float32_to_fe4m3(v1) + a = float32_to_fe4m3(v0, saturate=False) + b = float32_to_fe4m3(v1, saturate=False) self.assertEqual(a, b) - mi = numpy.float32(-numpy.nan) - v0 = numpy.float32(mi) + v0 = numpy.float32(-numpy.nan) v1 = numpy.float32(-numpy.inf) a = search_float32_into_fe4m3(v0) b = search_float32_into_fe4m3(v1) self.assertEqual(a, b) - v0 = numpy.float32(mi) + v0 = numpy.float32(-numpy.nan) v1 = numpy.float32(-numpy.inf) - a = float32_to_fe4m3(v0) - b = float32_to_fe4m3(v1) + a = float32_to_fe4m3(v0, saturate=False) + b = float32_to_fe4m3(v1, saturate=False) self.assertEqual(a, b) v0 = numpy.float32(numpy.nan) @@ -653,17 +651,17 @@ def test_search_float32_into_fe5m2fnuz(self): f"{wrong} conversion are wrong\n{pprint.pformat(obs[:2])}" ) - def test_float32_to_fe4m3fnuz_inf(self): + def test_search_float32_to_fe4m3fnuz_inf(self): v0 = numpy.float32(numpy.nan) v1 = numpy.float32(numpy.inf) - a = search_float32_into_fe4m3(v0, uz=True) - b = search_float32_into_fe4m3(v1, uz=True) + a = search_float32_into_fe4m3(v0, uz=True, saturate=False) + b = search_float32_into_fe4m3(v1, uz=True, saturate=False) self.assertEqual(a, b) v0 = numpy.float32(-numpy.nan) v1 = numpy.float32(-numpy.inf) - a = search_float32_into_fe4m3(v0, uz=True) - b = search_float32_into_fe4m3(v1, uz=True) + a = search_float32_into_fe4m3(v0, uz=True, saturate=False) + b = search_float32_into_fe4m3(v1, uz=True, saturate=False) self.assertEqual(a, b) v0 = numpy.float32(numpy.nan) @@ -674,10 +672,11 @@ def test_float32_to_fe4m3fnuz_inf(self): v0 = numpy.float32(numpy.inf) v1 = numpy.float32(-numpy.inf) - a = search_float32_into_fe4m3(v0, uz=True) - b = search_float32_into_fe4m3(v1, uz=True) + a = search_float32_into_fe4m3(v0, uz=True, saturate=False) + b = search_float32_into_fe4m3(v1, uz=True, saturate=False) self.assertEqual(a, b) + def test_float32_to_fe4m3fnuz_inf(self): v0 = numpy.float32(numpy.nan) v1 = numpy.float32(-numpy.nan) a = float32_to_fe4m3(v0, uz=True) @@ -686,35 +685,35 @@ def test_float32_to_fe4m3fnuz_inf(self): v0 = numpy.float32(numpy.inf) v1 = numpy.float32(-numpy.inf) - a = float32_to_fe4m3(v0, uz=True) - b = float32_to_fe4m3(v1, uz=True) + a = float32_to_fe4m3(v0, uz=True, saturate=False) + b = float32_to_fe4m3(v1, uz=True, saturate=False) self.assertEqual(a, b) def test_float32_to_fe5m2fnuz_inf(self): mx = numpy.nan v0 = numpy.float32(mx) v1 = numpy.float32(numpy.inf) - a = search_float32_into_fe5m2(v0, fn=True, uz=True) - b = search_float32_into_fe5m2(v1, fn=True, uz=True) + a = search_float32_into_fe5m2(v0, fn=True, uz=True, saturate=False) + b = search_float32_into_fe5m2(v1, fn=True, uz=True, saturate=False) self.assertEqual(a, b) v0 = numpy.float32(mx) v1 = numpy.float32(numpy.inf) - a = float32_to_fe5m2(v0, fn=True, uz=True) - b = float32_to_fe5m2(v1, fn=True, uz=True) + a = float32_to_fe5m2(v0, fn=True, uz=True, saturate=False) + b = float32_to_fe5m2(v1, fn=True, uz=True, saturate=False) self.assertEqual(a, b) mi = numpy.nan v0 = numpy.float32(mi) v1 = numpy.float32(-numpy.inf) - a = search_float32_into_fe5m2(v0, fn=True, uz=True) - b = search_float32_into_fe5m2(v1, fn=True, uz=True) + a = search_float32_into_fe5m2(v0, fn=True, uz=True, saturate=False) + b = search_float32_into_fe5m2(v1, fn=True, uz=True, saturate=False) self.assertEqual(a, b) v0 = numpy.float32(mi) v1 = numpy.float32(-numpy.inf) - a = float32_to_fe5m2(v0, fn=True, uz=True) - b = float32_to_fe5m2(v1, fn=True, uz=True) + a = float32_to_fe5m2(v0, fn=True, uz=True, saturate=False) + b = float32_to_fe5m2(v1, fn=True, uz=True, saturate=False) self.assertEqual(a, b) v0 = numpy.float32(numpy.nan) @@ -725,8 +724,8 @@ def test_float32_to_fe5m2fnuz_inf(self): v0 = numpy.float32(numpy.inf) v1 = numpy.float32(-numpy.inf) - a = search_float32_into_fe5m2(v0, fn=True, uz=True) - b = search_float32_into_fe5m2(v1, fn=True, uz=True) + a = search_float32_into_fe5m2(v0, fn=True, uz=True, saturate=False) + b = search_float32_into_fe5m2(v1, fn=True, uz=True, saturate=False) self.assertEqual(a, b) v0 = numpy.float32(numpy.nan) @@ -737,8 +736,8 @@ def test_float32_to_fe5m2fnuz_inf(self): v0 = numpy.float32(numpy.inf) v1 = numpy.float32(-numpy.inf) - a = float32_to_fe5m2(v0, fn=True, uz=True) - b = float32_to_fe5m2(v1, fn=True, uz=True) + a = float32_to_fe5m2(v0, fn=True, uz=True, saturate=False) + b = float32_to_fe5m2(v1, fn=True, uz=True, saturate=False) self.assertEqual(a, b) def test_simple_fe4m3(self): @@ -764,8 +763,8 @@ def test_simple_fe4m3(self): def test_inf_nan_ml_dtypes(self): x = numpy.float32(numpy.inf) - g1 = float32_to_fe4m3(x) - g2 = float32_to_fe5m2(x) + g1 = float32_to_fe4m3(x, saturate=False) + g2 = float32_to_fe5m2(x, saturate=False) i1 = fe4m3_to_float32(g1) i2 = fe5m2_to_float32(g2) self.assertNotEqual(i1, 448) @@ -787,6 +786,283 @@ def test_inf_nan_ml_dtypes(self): self.assertTrue(numpy.isnan(m1)) self.assertTrue(numpy.isnan(m2)) + def test_float8_e4m3fn_inf(self): + x = numpy.float32(numpy.inf) + to = float32_to_fe4m3(x) + back = fe4m3_to_float32(to) + self.assertEqual(back, 448) + + x = numpy.float32(numpy.inf) + to = float32_to_fe4m3(x, saturate=False) + back = fe4m3_to_float32(to) + self.assertTrue(numpy.isnan(back)) + + x = numpy.float32(-numpy.inf) + to = float32_to_fe4m3(x) + self.assertEqual(to & 0x80, 0x80) + back = fe4m3_to_float32(to) + self.assertEqual(back, -448) + + x = numpy.float32(-numpy.inf) + to = float32_to_fe4m3(x, saturate=False) + self.assertEqual(to & 0x80, 0x80) + back = fe4m3_to_float32(to) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e4m3fnuz_inf(self): + x = numpy.float32(numpy.inf) + to = float32_to_fe4m3(x, uz=True) + back = fe4m3_to_float32(to, uz=True) + self.assertEqual(back, 224) + + x = numpy.float32(numpy.inf) + to = float32_to_fe4m3(x, uz=True, saturate=False) + back = fe4m3_to_float32(to, uz=True) + self.assertTrue(numpy.isnan(back)) + + x = numpy.float32(-numpy.inf) + to = float32_to_fe4m3(x, uz=True) + back = fe4m3_to_float32(to, uz=True) + self.assertEqual(back, -224) + + x = numpy.float32(-numpy.inf) + to = float32_to_fe4m3(x, uz=True, saturate=False) + back = fe4m3_to_float32(to, uz=True) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e5m2_inf(self): + x = numpy.float32(numpy.inf) + to = float32_to_fe5m2(x) + back = fe5m2_to_float32(to) + self.assertEqual(back, 57344) + + x = numpy.float32(numpy.inf) + to = float32_to_fe5m2(x, saturate=False) + back = fe5m2_to_float32(to) + self.assertTrue(numpy.isinf(back)) + + x = numpy.float32(-numpy.inf) + to = float32_to_fe5m2(x) + self.assertEqual(to & 0x80, 0x80) + back = fe5m2_to_float32(to) + self.assertEqual(back, -57344) + + x = numpy.float32(-numpy.inf) + to = float32_to_fe5m2(x, saturate=False) + self.assertEqual(to & 0x80, 0x80) + back = fe5m2_to_float32(to) + self.assertTrue(numpy.isinf(back)) + self.assertTrue(back < 0) + + def test_float8_e5m2fnuz_inf(self): + x = numpy.float32(numpy.inf) + to = float32_to_fe5m2(x, fn=True, uz=True) + back = fe5m2_to_float32(to, fn=True, uz=True) + self.assertEqual(back, 57344) + + x = numpy.float32(numpy.inf) + to = float32_to_fe5m2(x, fn=True, uz=True, saturate=False) + back = fe5m2_to_float32(to, fn=True, uz=True) + self.assertTrue(numpy.isnan(back)) + + x = numpy.float32(-numpy.inf) + to = float32_to_fe5m2(x, fn=True, uz=True) + back = fe5m2_to_float32(to, fn=True, uz=True) + self.assertEqual(back, -57344) + + x = numpy.float32(-numpy.inf) + to = float32_to_fe5m2(x, fn=True, uz=True, saturate=False) + back = fe5m2_to_float32(to, fn=True, uz=True) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e4m3fn_out_of_range(self): + x = numpy.float32(1000000) + to = float32_to_fe4m3(x) + back = fe4m3_to_float32(to) + self.assertEqual(back, 448) + + x = numpy.float32(1000000) + to = float32_to_fe4m3(x, saturate=False) + back = fe4m3_to_float32(to) + self.assertTrue(numpy.isnan(back)) + + x = numpy.float32(-1000000) + to = float32_to_fe4m3(x) + back = fe4m3_to_float32(to) + self.assertEqual(back, -448) + + x = numpy.float32(-1000000) + to = float32_to_fe4m3(x, saturate=False) + back = fe4m3_to_float32(to) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e4m3fnuz_out_of_range(self): + x = numpy.float32(1000000) + to = float32_to_fe4m3(x, uz=True) + back = fe4m3_to_float32(to, uz=True) + self.assertEqual(back, 240) + + x = numpy.float32(1000000) + to = float32_to_fe4m3(x, uz=True, saturate=False) + back = fe4m3_to_float32(to, uz=True) + self.assertTrue(numpy.isnan(back)) + + x = numpy.float32(-1000000) + to = float32_to_fe4m3(x, uz=True) + back = fe4m3_to_float32(to, uz=True) + self.assertEqual(back, -240) + + x = numpy.float32(-1000000) + to = float32_to_fe4m3(x, uz=True, saturate=False) + back = fe4m3_to_float32(to, uz=True) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e5m2_out_of_range(self): + x = numpy.float32(1000000) + to = float32_to_fe5m2(x) + back = fe5m2_to_float32(to) + self.assertEqual(back, 57344) + + x = numpy.float32(1000000) + to = float32_to_fe5m2(x, saturate=False) + back = fe5m2_to_float32(to) + self.assertTrue(numpy.isinf(back)) + + x = numpy.float32(-1000000) + to = float32_to_fe5m2(x) + back = fe5m2_to_float32(to) + self.assertEqual(back, -57344) + + x = numpy.float32(-1000000) + to = float32_to_fe5m2(x, saturate=False) + back = fe5m2_to_float32(to) + self.assertTrue(numpy.isinf(back)) + + def test_float8_e5m2fnuz_out_of_range(self): + x = numpy.float32(1000000) + to = float32_to_fe5m2(x, fn=True, uz=True) + back = fe5m2_to_float32(to, fn=True, uz=True) + self.assertEqual(back, 57344) + + x = numpy.float32(1000000) + to = float32_to_fe5m2(x, fn=True, uz=True, saturate=False) + back = fe5m2_to_float32(to, fn=True, uz=True) + self.assertTrue(numpy.isnan(back)) + + x = numpy.float32(-1000000) + to = float32_to_fe5m2(x, fn=True, uz=True) + back = fe5m2_to_float32(to, fn=True, uz=True) + self.assertEqual(back, -57344) + + x = numpy.float32(-1000000) + to = float32_to_fe5m2(x, fn=True, uz=True, saturate=False) + back = fe5m2_to_float32(to, fn=True, uz=True) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e4m3fn_negative_zero(self): + x = fe5m2_to_float32(0x80) # -0 + to = float32_to_fe4m3(x) + self.assertEqual(to, 0x80) + back = fe4m3_to_float32(to) + self.assertEqual(back, 0) + + x = fe5m2_to_float32(0x80) # -0 + to = float32_to_fe4m3(x, saturate=False) + self.assertEqual(to, 0x80) + back = fe4m3_to_float32(to) + self.assertEqual(back, 0) + + def test_float8_e4m3fnuz_negative_zero(self): + x = fe5m2_to_float32(0x80) # -0 + to = float32_to_fe4m3(x, uz=True) + self.assertEqual(to, 0) + back = fe4m3_to_float32(to, uz=True) + self.assertEqual(back, 0) + + x = fe5m2_to_float32(0x80) # -0 + to = float32_to_fe4m3(x, uz=True, saturate=False) + back = fe4m3_to_float32(to, uz=True) + self.assertEqual(back, 0) + self.assertEqual(to, 0) + + def test_float8_e5m2_negative_zero(self): + x = fe5m2_to_float32(0x80) # -0 + to = float32_to_fe5m2(x) + self.assertEqual(to, 0x80) + back = fe4m3_to_float32(to) + self.assertEqual(back, 0) + + x = fe5m2_to_float32(0x80) # -0 + to = float32_to_fe5m2(x, saturate=False) + self.assertEqual(to, 0x80) + back = fe4m3_to_float32(to) + self.assertEqual(back, 0) + + def test_float8_e5m2fnuz_negative_zero(self): + x = fe5m2_to_float32(0x80) # -0 + to = float32_to_fe5m2(x, fn=True, uz=True) + self.assertEqual(to, 0) + back = fe4m3_to_float32(to, fn=True, uz=True) + self.assertEqual(back, 0) + + x = fe5m2_to_float32(0x80) # -0 + to = float32_to_fe5m2(x, fn=True, uz=True, saturate=False) + self.assertEqual(to, 0) + back = fe4m3_to_float32(to, fn=True, uz=True) + self.assertEqual(back, 0) + + def test_float8_e4m3fn_negative_nan(self): + x = fe5m2_to_float32(255) # -nan + to = float32_to_fe4m3(x) + self.assertEqual(to, 255) + back = fe4m3_to_float32(to) + self.assertTrue(numpy.isnan(back)) + + x = fe5m2_to_float32(255) # -nan + to = float32_to_fe4m3(x, saturate=False) + self.assertEqual(to, 255) + back = fe4m3_to_float32(to) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e4m3fnuz_negative_nan(self): + x = fe5m2_to_float32(255) # -nan + to = float32_to_fe4m3(x, uz=True) + self.assertEqual(to, 0x80) + back = fe4m3_to_float32(to, uz=True) + self.assertTrue(numpy.isnan(back)) + + x = fe5m2_to_float32(255) # -nan + to = float32_to_fe4m3(x, uz=True, saturate=False) + self.assertEqual(to, 0x80) + back = fe4m3_to_float32(to, uz=True) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e5m2_negative_nan(self): + x = fe5m2_to_float32(255) # -nan + to = float32_to_fe5m2(x) + self.assertEqual(to, 255) + back = fe4m3_to_float32(to) + self.assertTrue(numpy.isnan(back)) + + x = fe5m2_to_float32(255) # -nan + to = float32_to_fe5m2(x, saturate=False) + self.assertEqual(to, 255) + back = fe4m3_to_float32(to) + self.assertTrue(numpy.isnan(back)) + + def test_float8_e5m2fnuz_negative_nan(self): + x = fe5m2_to_float32(255) # -nan + to = float32_to_fe5m2(x, fn=True, uz=True) + self.assertEqual(to, 0x80) + back = fe4m3_to_float32(to, fn=True, uz=True) + self.assertTrue(numpy.isnan(back)) + + x = fe5m2_to_float32(255) # -nan + to = float32_to_fe5m2(x, fn=True, uz=True, saturate=False) + self.assertEqual(to, 0x80) + back = fe4m3_to_float32(to, fn=True, uz=True) + self.assertTrue(numpy.isnan(back)) + if __name__ == "__main__": TestF8().test_search_float32_into_fe4m3fn_simple() diff --git a/onnx_array_api/validation/f8.py b/onnx_array_api/validation/f8.py index 2d7690f..43cf7e1 100644 --- a/onnx_array_api/validation/f8.py +++ b/onnx_array_api/validation/f8.py @@ -301,23 +301,17 @@ def fe5m2_to_float32(ival: int, fn: bool = False, uz: bool = False) -> float: # cast from float32 to float 8 -class CastFloat8: - """ - Helpers to cast float8 into float32 or the other way around. - """ - +class CastFloat8Sets: values_e4m3fn = list( sorted( (fe4m3_to_float32_float(i), i) for i in range(0, 256) if i not in (255, 127) ) ) - values_e4m3fnuz = list( sorted( (fe4m3_to_float32_float(i, uz=True), i) for i in range(0, 256) if i != 0x80 ) ) - values_e5m2 = list( sorted( (fe5m2_to_float32_float(i), i) @@ -325,7 +319,6 @@ class CastFloat8: if i not in {253, 254, 255, 125, 126, 127} ) ) - values_e5m2fnuz = list( sorted( (fe5m2_to_float32_float(i, fn=True, uz=True), i) @@ -334,6 +327,33 @@ class CastFloat8: ) ) + +class CastFloat8(CastFloat8Sets): + """ + Helpers to cast float8 into float32 or the other way around. + """ + + values_e4m3fn_max_value = max( + v + for v in CastFloat8Sets.values_e4m3fn + if not numpy.isinf(v[0]) and not numpy.isnan(v[0]) + ) + values_e4m3fnuz_max_value = max( + v + for v in CastFloat8Sets.values_e4m3fnuz + if not numpy.isinf(v[0]) and not numpy.isnan(v[0]) + ) + values_e5m2_max_value = max( + v + for v in CastFloat8Sets.values_e5m2 + if not numpy.isinf(v[0]) and not numpy.isnan(v[0]) + ) + values_e5m2fnuz_max_value = max( + v + for v in CastFloat8Sets.values_e5m2fnuz + if not numpy.isinf(v[0]) and not numpy.isnan(v[0]) + ) + @staticmethod def find_closest_value(value, sorted_values): """ @@ -371,13 +391,16 @@ def find_closest_value(value, sorted_values): return sorted_values[a][1] -def search_float32_into_fe4m3(value: float, fn: bool = True, uz: bool = False) -> int: +def search_float32_into_fe4m3( + value: float, fn: bool = True, uz: bool = False, saturate: bool = True +) -> int: """ Casts a float 32 into a float E4M3. :param value: float :param fn: no infinite values :param uz: no negative zero + :param saturate: to convert out of range and infinities to max value if True :return: byte """ if not fn: @@ -386,38 +409,65 @@ def search_float32_into_fe4m3(value: float, fn: bool = True, uz: bool = False) - b = int.from_bytes(struct.pack("> 24 # sign if uz: - if numpy.isnan(value) or numpy.isinf(value): + if numpy.isnan(value): + return 0x80 + if numpy.isinf(value) and not saturate: return 0x80 set_values = CastFloat8.values_e4m3fnuz + max_value = CastFloat8.values_e4m3fnuz_max_value + if value > max_value[0]: + return max_value[1] if saturate else 0x80 + if value < -max_value[0]: + return (max_value[1] | ret) if saturate else 0x80 else: if numpy.isnan(value) or numpy.isinf(value): return 0x7F | ret set_values = CastFloat8.values_e4m3fn + max_value = CastFloat8.values_e4m3fn_max_value + if value > max_value[0]: + return max_value[1] if saturate else 0x7F | ret + if value < -max_value[0]: + return (max_value[1] | ret) if saturate else 0x7F | ret f = numpy.float32(value) i = CastFloat8.find_closest_value(f, set_values) return (i & 0x7F) | ret -def search_float32_into_fe5m2(value: float, fn: bool = False, uz: bool = False) -> int: +def search_float32_into_fe5m2( + value: float, fn: bool = False, uz: bool = False, saturate: bool = True +) -> int: """ Casts a float 32 into a float E5M2. :param value: float :param fn: no infinite values :param uz: no negative zero + :param saturate: to convert out of range and infinities to max value if True :return: byte """ b = int.from_bytes(struct.pack("> 24 # sign if fn and uz: - if numpy.isnan(value) or numpy.isinf(value): + if numpy.isnan(value): + return 0x80 + if numpy.isinf(value) and not saturate: return 0x80 set_values = CastFloat8.values_e5m2fnuz + max_value = CastFloat8.values_e5m2fnuz_max_value + if value > max_value[0]: + return max_value[1] if saturate else 0x80 + if value < -max_value[0]: + return (max_value[1] | ret) if saturate else 0x80 elif not fn and not uz: if numpy.isnan(value): return 0x7F | ret set_values = CastFloat8.values_e5m2 + max_value = CastFloat8.values_e5m2_max_value + if value > max_value[0]: + return max_value[1] if saturate else (0x7C | ret) + if value < -max_value[0]: + return (max_value[1] | ret) if saturate else (0x7C | ret) else: raise NotImplementedError("fn and uz must both True or False.") @@ -426,13 +476,14 @@ def search_float32_into_fe5m2(value: float, fn: bool = False, uz: bool = False) return (i & 0x7F) | ret -def float32_to_fe4m3(x, fn: bool = True, uz: bool = False): +def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True): """ Converts a float32 into a float E4M3. :param x: numpy.float32 :param fn: no infinite values :param uz: no negative zero + :param saturate: to convert out of range and infinities to max value if True :return: byte """ if not fn: @@ -440,7 +491,11 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False): b = int.from_bytes(struct.pack("> 24 # sign if uz: - if (b & 0x7FC00000) == 0x7FC00000 or numpy.isinf(x): + if (b & 0x7FC00000) == 0x7FC00000: + return 0x80 + if numpy.isinf(x): + if saturate: + return ret | 126 return 0x80 e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa @@ -468,14 +523,26 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False): else: ret |= ex << 3 ret |= m >> 20 - if (m & 0x80000) and (ret & 0x7F) < 0x7F: - # rounding - ret += 1 - else: + if m & 0x80000: + if (ret & 0x7F) < 0x7F: + # rounding + ret += 1 + elif not saturate: + return 0x80 + elif saturate: ret |= 0x7F # 01111110 + else: + ret = 0x80 + elif m == 0: + # -0 + ret = 0 return int(ret) else: - if (b & 0x7FC00000) == 0x7FC00000 or numpy.isinf(x): + if (b & 0x7FC00000) == 0x7FC00000: + return 0x7F | ret + if numpy.isinf(x): + if saturate: + return ret | 126 return 0x7F | ret e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa @@ -505,30 +572,39 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False): ret |= m >> 20 if (ret & 0x7F) == 0x7F: ret &= 0xFE - if (m & 0x80000) and (ret & 0x7F) < 0x7E: - # rounding - ret += 1 - else: + if m & 0x80000: + if (ret & 0x7F) < 0x7E: + # rounding + ret += 1 + elif not saturate: + ret |= 0x7F + elif saturate: ret |= 126 # 01111110 + else: + ret |= 0x7F return int(ret) -def float32_to_fe5m2(x, fn: bool = False, uz: bool = False): +def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = True): """ Converts a float32 into a float E5M2. :param x: numpy.float32 :param fn: no infinite values :param uz: no negative zero + :param saturate: to convert out of range and infinities to max value if True :return: byte """ b = int.from_bytes(struct.pack("> 24 # sign if fn and uz: - if (b & 0x7FC00000) == 0x7FC00000: # NaN + if (b & 0x7FC00000) == 0x7FC00000: return 0x80 - if (b & 0x7FFFFFFF) == 0x7F800000: # Inf + if (b & 0x7FFFFFFF) == 0x7F800000: + # inf + if saturate: + return ret | 0x7F return 0x80 e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa @@ -553,17 +629,29 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False): ex = e - 111 # 127 - 16 ret |= ex << 2 ret |= m >> 21 - if (m & 0x100000) and (ret & 0x7F) < 0x7F: - # rounding - ret += 1 + if m & 0x100000: + if (ret & 0x7F) < 0x7F: + # rounding + ret += 1 + elif not saturate: + ret = 0x80 elif e == 255 and m == 0: # inf - return 0x80 - else: + ret = 0x80 + elif saturate: ret |= 0x7F # last possible number + else: + ret = 0x80 + elif m == 0: + # -0 + ret = 0 return int(ret) elif not fn and not uz: if (b & 0x7FC00000) == 0x7FC00000: return 0x7F | ret + if numpy.isinf(x): + if saturate: + return 0x7B | ret + return 0x7C | ret e = (b & 0x7F800000) >> 23 # exponent m = b & 0x007FFFFF # mantissa @@ -587,13 +675,18 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False): ex = e - 112 # 127 - 15 ret |= ex << 2 ret |= m >> 21 - if (m & 0x100000) and (ret & 0x7F) < 0x7B: - # rounding - ret += 1 - elif e == 255 and m == 0: # inf - ret |= 124 + if m & 0x100000: + if (ret & 0x7F) < 0x7B: + # rounding + ret += 1 + elif saturate: + ret |= 0x7B + else: + ret |= 0x7C + elif saturate: + ret |= 0x7B else: - ret |= 123 + ret |= 0x7C return int(ret) else: raise NotImplementedError("fn and uz must be both False or True.")