-
Notifications
You must be signed in to change notification settings - Fork 53
/
constants.py
51 lines (40 loc) · 1.18 KB
/
constants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
# This is conceptually an enum of non-core dtypes
# TODO(future PR): change to a cleaner way to represent this without
# regressing torch.compile and while keeping things readable.
DTYPE_FP4 = "fp4_e2m1"
DTYPE_FP6_E3M2 = "fp6_e3m2"
DTYPE_FP6_E2M3 = "fp6_e2m3"
# Supported element dtypes
# TODO(future PR): add support for MX int8
SUPPORTED_ELEM_DTYPES = [
torch.float8_e4m3fn,
torch.float8_e5m2,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
DTYPE_FP4,
]
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0
F8E4M3_MAX_POW2 = 8 # 256
F8E5M2_MAX_POW2 = 15 # 32768
F6_E2M3_MAX_POW2 = 2 # 4
F6_E3M2_MAX_POW2 = 4 # 16
F4_E2M1_MAX_POW2 = 2 # 4
E8M0_EXPONENT_BIAS = 127
E8M0_EXPONENT_NAN_VAL = 255
F32_EXP_BIAS = 127
F6_E2M3_EXP_BIAS = 1
F6_E3M2_EXP_BIAS = 3
F4_E2M1_EXP_BIAS = 1
F32_MIN_NORMAL = 2 ** (-F32_EXP_BIAS + 1)
F6_E2M3_MAX = 7.5
F6_E2M3_MIN_NORMAL = 1.0
F6_E2M3_MAX_INT = 31 # integer corresponding to 0b00011111
F6_E3M2_MAX = 28.0
F6_E3M2_MIN_NORMAL = 0.25
F6_E3M2_MAX_INT = 31 # integer corresponding to 0b00011111
F4_E2M1_MAX = 6.0
F4_E2M1_MIN_NORMAL = 1.0
F4_E2M1_MAX_INT = 7
BLOCK_SIZE_DEFAULT = 32