Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions _doc/examples/plot_f8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,83 @@
float 8
=======

Precision is not that important when it comes to train
a deep neural network. That's what the paper
`FP8 Formats for Deep Learning <https://arxiv.org/abs/2209.05433>`_
shows. It introduces two new types encoded on one byte:
Two papers have been published in 2022 to introduce floats
stored on a byte as opposed to float 32 stored on 4 bytes.
The float precision is much lower but the training precision
does not suffer too much.

* E4M3: 1 bit for the sign, 4 for the exponent, 3 for the mantissa
* E5M2: 1 bit for the sign, 5 for the exponent, 2 for the mantissa
`FP8 Formats for Deep Learning <https://arxiv.org/abs/2209.05433>`_
from NVIDIA introduces two types following
`IEEE specifciations <https://en.wikipedia.org/wiki/IEEE_754>`_.
First one is E4M3, 1 bit for the sign, 4 bits for the exponents and 3
bits for the mantissa. Second one is E5M2, 1 bit for the sign,
3 bits for the exponents and 2 for the mantissa. The first types
is mostly used for the coefficients, the second one for the gradient.

Second paper `8-bit Numerical Formats For Deep Neural Networks
<https://arxiv.org/pdf/2206.02915.pdf>`_ introduces
similar types. IEEE standard gives the same value
to `+0` (or integer 0) and `-0` (or integer 128).
They chose to give distinct float values to these two
numbers. The paper experiments different split between
exponent and mantissa and shows and E4M3 and E5M2 are
the best ones.

:math:`S` stands for the sign. :math:`10_2` describe a number base 2.

.. list-table:: Float8 types
:widths: 10 10 10
:header-rows: 1

* -
- E4M3
- E5M2
* - Exponent bias
- 7
- 15
* - Infinities
-
- :math:`S.11111.00_2`
* - NaN
- :math:`S.1111.111_2`
- :math:`S.11111.\{01, 10, 11\}_2`
* - Zeros
- :math:`S.0000.000_2`
- :math:`S.00000.00_2`
* - Max
- :math:`S.1111.110_2`
- :math:`1.75 \times 2^{15}= 57344`
* - Min
- :math:`S.0000.001_2 = 2^{-9}`
- :math:`S.00000.01_2 = 2^{-16}`


Let's denote the bit representation as :math:`S.b_6 b_5 b_4 b_3 b_2 b_1 b_0`.
The float value is defined by the following expressions:

.. list-table:: Float8 types values
:widths: 10 10 10
:header-rows: 1

* -
- E4M3
- E5M2
* - exponent :math:`\neq` 0
- :math:`(-1)^S 2^{\sum_{i=3}^6 b_i 2^{i-3} - 7} \sum_{i=0}^2 b_i 2^{i-2}`
- :math:`(-1)^S 2^{\sum_{i=2}^6 b_i 2^{i-2} - 15} \sum_{i=0}^1 b_i 2^{i-1}`
* - exponent :math:`=` 0
- :math:`(-1)^S 2^{-6} \sum_{i=0}^2 b_i 2^{i-3}`
- :math:`(-1)^S 2^{-14} \sum_{i=0}^1 b_i 2^{i-2}`

Cast from float 8 to
`float 16 <https://en.wikipedia.org/wiki/Half-precision_floating-point_format>`_ (or E5M10),
`bfloat16 <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_ (or E8M7),
`float32 <https://en.wikipedia.org/wiki/Single-precision_floating-point_format>`_ (or E8M23) is easier.
The cast is exact. The tricky part is to distinguish between exponent = 0 and :math:`neq 0`.

Cast to float 8 consists in finding the closest float 8
to the original float 32 value. It is usually done by shifting
and truncating. The tricky part is to handle rounding.

.. index:: discrepencies, float8, float, E4M3, E5M2

Expand Down
16 changes: 3 additions & 13 deletions onnxcustom/experiment/f8.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,11 @@ def fe4m3_to_float32(ival: int) -> float:
mant &= 0x3
mant <<= 1
expo -= 1
if mant & 0x4 == 0:
mant &= 0x3
mant <<= 1
expo -= 1
res |= (mant & 0x3) << 21
res |= expo << 23
else:
res |= mant << 20
expo -= 0x7
expo += 0x7F
expo += 0x7F - 7
res |= expo << 23
f = numpy.uint32(res).view(numpy.float32) # pylint: disable=E1121
return f
Expand Down Expand Up @@ -232,10 +227,6 @@ def fe5m2_to_float32(ival: int) -> float:
if expo == 0:
if mant > 0:
expo = 0x7F - 15
if mant & 0x2 == 0:
mant &= 0x1
mant <<= 1
expo -= 1
if mant & 0x2 == 0:
mant &= 0x1
mant <<= 1
Expand All @@ -244,8 +235,7 @@ def fe5m2_to_float32(ival: int) -> float:
res |= expo << 23
else:
res |= mant << 21
expo -= 15
expo += 0x7F
expo += 0x7F - 15
res |= expo << 23
f = numpy.uint32(res).view(numpy.float32) # pylint: disable=E1121
return f
Expand Down Expand Up @@ -341,7 +331,7 @@ def float32_to_fe4m3(x):
if e != 0:
if e < 117:
pass
elif e < 118:
if e < 118:
ret |= 1
if (m >> 23) & 1:
# rounding
Expand Down