diff --git a/_doc/examples/plot_f8.py b/_doc/examples/plot_f8.py index 5718105..cf66066 100644 --- a/_doc/examples/plot_f8.py +++ b/_doc/examples/plot_f8.py @@ -4,13 +4,83 @@ float 8 ======= -Precision is not that important when it comes to train -a deep neural network. That's what the paper -`FP8 Formats for Deep Learning `_ -shows. It introduces two new types encoded on one byte: +Two papers have been published in 2022 to introduce floats +stored on a byte as opposed to float 32 stored on 4 bytes. +The float precision is much lower but the training precision +does not suffer too much. -* E4M3: 1 bit for the sign, 4 for the exponent, 3 for the mantissa -* E5M2: 1 bit for the sign, 5 for the exponent, 2 for the mantissa +`FP8 Formats for Deep Learning `_ +from NVIDIA introduces two types following +`IEEE specifciations `_. +First one is E4M3, 1 bit for the sign, 4 bits for the exponents and 3 +bits for the mantissa. Second one is E5M2, 1 bit for the sign, +3 bits for the exponents and 2 for the mantissa. The first types +is mostly used for the coefficients, the second one for the gradient. + +Second paper `8-bit Numerical Formats For Deep Neural Networks +`_ introduces +similar types. IEEE standard gives the same value +to `+0` (or integer 0) and `-0` (or integer 128). +They chose to give distinct float values to these two +numbers. The paper experiments different split between +exponent and mantissa and shows and E4M3 and E5M2 are +the best ones. + +:math:`S` stands for the sign. :math:`10_2` describe a number base 2. + +.. list-table:: Float8 types + :widths: 10 10 10 + :header-rows: 1 + + * - + - E4M3 + - E5M2 + * - Exponent bias + - 7 + - 15 + * - Infinities + - + - :math:`S.11111.00_2` + * - NaN + - :math:`S.1111.111_2` + - :math:`S.11111.\{01, 10, 11\}_2` + * - Zeros + - :math:`S.0000.000_2` + - :math:`S.00000.00_2` + * - Max + - :math:`S.1111.110_2` + - :math:`1.75 \times 2^{15}= 57344` + * - Min + - :math:`S.0000.001_2 = 2^{-9}` + - :math:`S.00000.01_2 = 2^{-16}` + + +Let's denote the bit representation as :math:`S.b_6 b_5 b_4 b_3 b_2 b_1 b_0`. +The float value is defined by the following expressions: + +.. list-table:: Float8 types values + :widths: 10 10 10 + :header-rows: 1 + + * - + - E4M3 + - E5M2 + * - exponent :math:`\neq` 0 + - :math:`(-1)^S 2^{\sum_{i=3}^6 b_i 2^{i-3} - 7} \sum_{i=0}^2 b_i 2^{i-2}` + - :math:`(-1)^S 2^{\sum_{i=2}^6 b_i 2^{i-2} - 15} \sum_{i=0}^1 b_i 2^{i-1}` + * - exponent :math:`=` 0 + - :math:`(-1)^S 2^{-6} \sum_{i=0}^2 b_i 2^{i-3}` + - :math:`(-1)^S 2^{-14} \sum_{i=0}^1 b_i 2^{i-2}` + +Cast from float 8 to +`float 16 `_ (or E5M10), +`bfloat16 `_ (or E8M7), +`float32 `_ (or E8M23) is easier. +The cast is exact. The tricky part is to distinguish between exponent = 0 and :math:`neq 0`. + +Cast to float 8 consists in finding the closest float 8 +to the original float 32 value. It is usually done by shifting +and truncating. The tricky part is to handle rounding. .. index:: discrepencies, float8, float, E4M3, E5M2 diff --git a/onnxcustom/experiment/f8.py b/onnxcustom/experiment/f8.py index 0d115e9..9ece04f 100644 --- a/onnxcustom/experiment/f8.py +++ b/onnxcustom/experiment/f8.py @@ -194,16 +194,11 @@ def fe4m3_to_float32(ival: int) -> float: mant &= 0x3 mant <<= 1 expo -= 1 - if mant & 0x4 == 0: - mant &= 0x3 - mant <<= 1 - expo -= 1 res |= (mant & 0x3) << 21 res |= expo << 23 else: res |= mant << 20 - expo -= 0x7 - expo += 0x7F + expo += 0x7F - 7 res |= expo << 23 f = numpy.uint32(res).view(numpy.float32) # pylint: disable=E1121 return f @@ -232,10 +227,6 @@ def fe5m2_to_float32(ival: int) -> float: if expo == 0: if mant > 0: expo = 0x7F - 15 - if mant & 0x2 == 0: - mant &= 0x1 - mant <<= 1 - expo -= 1 if mant & 0x2 == 0: mant &= 0x1 mant <<= 1 @@ -244,8 +235,7 @@ def fe5m2_to_float32(ival: int) -> float: res |= expo << 23 else: res |= mant << 21 - expo -= 15 - expo += 0x7F + expo += 0x7F - 15 res |= expo << 23 f = numpy.uint32(res).view(numpy.float32) # pylint: disable=E1121 return f @@ -341,7 +331,7 @@ def float32_to_fe4m3(x): if e != 0: if e < 117: pass - elif e < 118: + if e < 118: ret |= 1 if (m >> 23) & 1: # rounding