Connected to py39 (Python 3.9.19)

In [7]:
import torch
import numpy as np

In [15]:
# Construct the full range of bfloat16
def bits_to_bfloat16(x):
    assert x.dtype == torch.int32
    x = x.to(dtype=torch.uint16)
    return x.view(dtype=torch.bfloat16)

y = torch.arange(0, 2**16, dtype=torch.int32)
y = bits_to_bfloat16(y)
print(y.shape, y)

torch.Size([65536]) tensor([0.0000e+00, 9.1835e-41, 1.8367e-40,  ...,        nan,        nan,
               nan], dtype=torch.bfloat16)


In [14]:
def bfloat16_to_bits(x):
    x = x.view(dtype=torch.uint16)
    return x.to(dtype=torch.int32)
z = bfloat16_to_bits(y)
print(z.shape, z)

torch.Size([65536]) tensor([    0,     1,     2,  ..., 65533, 65534, 65535], dtype=torch.int32)


In [16]:
# Construct the full range of uint8
x = torch.arange(0, 2**8, dtype=torch.int32)
print(x.shape, x)

torch.Size([256]) tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175

In [17]:
def decompose_16bit(x):
    assert x.dtype == torch.int32
    sign = (x & 0b1000_0000_0000_0000) >> 15
    exponent = (x & 0b0111_1111_1000_0000) >> 7
    mantissa = (x & 0b0000_0000_0111_1111)
    return sign, exponent, mantissa

def decompose_8bit_e4m3(x):
    assert x.dtype == torch.int32
    sign =     (x & 0b1000_0000) >> 7
    exponent = (x & 0b0111_1000) >> 3
    mantissa = (x & 0b0000_0111)
    return sign, exponent, mantissa

def decompose_8bit_e5m2(x):
    assert x.dtype == torch.int32
    sign =     (x & 0b1000_0000) >> 7
    exponent = (x & 0b0111_1100) >> 2
    mantissa = (x & 0b0000_0011)
    return sign, exponent, mantissa


In [18]:
print(*decompose_16bit(z))
print(*decompose_8bit_e4m3(x))
print(*decompose_8bit_e5m2(x))

tensor([0, 0, 0,  ..., 1, 1, 1], dtype=torch.int32) tensor([  0,   0,   0,  ..., 255, 255, 255], dtype=torch.int32) tensor([  0,   1,   2,  ..., 125, 126, 127], dtype=torch.int32)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 

In [19]:
def compose_16bit(sign, exponent, mantissa):
    return (sign << 15) + (exponent << 7) + mantissa

sign, exponent, mantissa = decompose_16bit(z)
bits = compose_16bit(sign, exponent, mantissa)
print(torch.equal(bits, z))

True


In [20]:
y_ = bits_to_bfloat16(
    bfloat16_to_bits(y)
)
comparison = torch.isclose(
        y,
        y_
)
comparison[torch.isnan(y) & torch.isnan(y_)] = True
comparison.all()

tensor(True)

In [119]:
y_bits = bfloat16_to_bits(y)
components = decompose_16bit(y_bits)
y_bits_ = compose_16bit(*components)
y_ = bits_to_bfloat16(y_bits_)

comparison = torch.isclose(
        y,
        y_
)
comparison[torch.isnan(y) & torch.isnan(y_)] = True
comparison.all()

tensor(True)

In [62]:
# Construct infinity
s = torch.tensor([0], dtype=torch.int32)
e = torch.tensor([0xFF], dtype=torch.int32)
m = torch.tensor([0], dtype=torch.int32)
n = bits_to_bfloat16(compose_16bit(s, e, m))
print(n)

tensor([inf], dtype=torch.bfloat16)


In [63]:
# Construct nan
s = torch.tensor([0], dtype=torch.int32)
e = torch.tensor([0xFF], dtype=torch.int32)
m = torch.tensor([1], dtype=torch.int32)
n = bits_to_bfloat16(compose_16bit(s, e, m))
print(n)

tensor([nan], dtype=torch.bfloat16)


In [75]:
t = torch.tensor([0.3125], dtype=torch.bfloat16)
print(t, bin(bfloat16_to_bits(t)))

tensor([0.3125], dtype=torch.bfloat16) 0b11111010100000


In [93]:
components = decompose_16bit(bfloat16_to_bits(t))
print(bin(components[0]), bin(components[1]), bin(components[2]))
print(bits_to_bfloat16(compose_16bit(*components)))

0b0 0b1111101 0b100000
tensor([0.3125], dtype=torch.bfloat16)


In [132]:
def encode_as_e5m2(s, e, m):
    # Quantize to e5m2
    # Subtract bfloat16 bias
    e = (e - 127) % 0b1_0000_0000
    # Add e5m2 bias
    e = (e + 15) % 0b1_00_000
    # chop mantissa
    m = m >> 5
    return (s << 7) + (e << 2) + m

encoded = encode_as_e5m2(*components)
bin(encoded)

'0b110101'

In [143]:
s, e, m = decompose_8bit_e5m2(encoded)
print(bin(s), bin(e), bin(m))

0b0 0b1101 0b1


0b1101 0b1111101


In [159]:
def decode_from_e5m2(encoded):
    s, e, m = decompose_8bit_e5m2(encoded)
    # Update mantissa
    e = e + (127 - 15) % 0b1_0000_0000
    # Expand mantissa
    m = m << 5
    return (s << 15) + (e << 7) + m

decoded = decode_from_e5m2(encoded)
bits_to_bfloat16(decoded)

tensor([0.3125], dtype=torch.bfloat16)

At this point we are able to:
- Construct the full range of bfloat16, uint8
- Convert bfloat16 to bits and back
- Convert bfloat16 to sign, exponent, mantissa and back
- Test that conversion is correct except NaN values

Next:
- Round by chop
- Handle bias for bfloat16, e4m3, e5m2