-
Notifications
You must be signed in to change notification settings - Fork 56
/
mx_tensor.py
416 lines (358 loc) · 13.4 KB
/
mx_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Defines the tensor subclasses to represent the MX format spec from
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
Exponent E8M0 encoding details (OCP spec section 5.4.1):
* bias: 127
* supported exponent range: -127 to 127
* infinities: N/A
* NaN: 11111111
* Zeros: N/A
"""
from typing import Dict, Union
import torch
import torchao.prototype.mx_formats.config as config
from torchao.prototype.mx_formats.constants import (
BLOCK_SIZE_DEFAULT,
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
E8M0_EXPONENT_BIAS,
E8M0_EXPONENT_NAN_VAL,
F32_MIN_NORMAL,
F4_E2M1_MAX,
F4_E2M1_MAX_POW2,
F6_E2M3_MAX,
F6_E2M3_MAX_POW2,
F6_E3M2_MAX,
F6_E3M2_MAX_POW2,
F8E4M3_MAX,
F8E4M3_MAX_POW2,
F8E5M2_MAX,
F8E5M2_MAX_POW2,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.custom_cast import (
f32_to_f4_unpacked,
f32_to_f6_e2m3_unpacked,
f32_to_f6_e3m2_unpacked,
f4_unpacked_to_f32,
f6_e2m3_unpacked_to_f32,
f6_e3m2_unpacked_to_f32,
pack_uint4,
triton_f4_to_scaled_bf16,
unpack_uint4,
)
def to_mx(
data_hp: torch.Tensor,
elem_dtype: Union[torch.dtype, str],
block_size: int,
):
"""
Takes a high precision tensor and converts to MX scale and raw data, in
naive layout (scale and raw data are separate tensors).
"""
assert data_hp.dtype in (
torch.bfloat16,
torch.float,
), f"{data_hp.dtype} is not supported yet"
# TODO(future PR): consider supporting padding
assert data_hp.numel() % block_size == 0, "unsupported"
assert data_hp.is_contiguous(), "unsupported"
assert elem_dtype in SUPPORTED_ELEM_DTYPES, "unsupported"
# calculate the scale in e8m0 format
orig_shape = data_hp.shape
data_hp = data_hp.reshape(-1, block_size)
# find max value of the data
# Note: this only implements the `minimally supported` version of
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
# section 6.3.
max_abs = torch.amax(torch.abs(data_hp), 1)
# Add an epsilon to prevent the log2 function call for returning -inf
# where the values are zero.
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
# Find largest power of 2 less than or equal to max_abs.
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps))
# Set X to be the largest power-of-two less than or equal to
# max_abs(v), divided by the largest power of two representable
# in the element data type
if elem_dtype == torch.float8_e4m3fn:
target_max_pow2 = F8E4M3_MAX_POW2
elif elem_dtype == torch.float8_e5m2:
target_max_pow2 = F8E5M2_MAX_POW2
elif elem_dtype == DTYPE_FP6_E2M3:
target_max_pow2 = F6_E2M3_MAX_POW2
elif elem_dtype == DTYPE_FP6_E3M2:
target_max_pow2 = F6_E3M2_MAX_POW2
elif elem_dtype == DTYPE_FP4:
target_max_pow2 = F4_E2M1_MAX_POW2
else:
raise AssertionError("unsupported")
scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2
# Clamp to exponents that can be represented in e8m0
scale_e8m0_unbiased = torch.clamp(
scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS
)
# Create the biased e8m0 representation and cast it to 8 bits
scale_e8m0_biased = scale_e8m0_unbiased + E8M0_EXPONENT_BIAS
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
# Conversion to torch.uint8 sets NaN values to 0, fix this by
# explicitly setting known NaN values to 255
scale_e8m0_biased = torch.where(
torch.isnan(scale_e8m0_unbiased),
E8M0_EXPONENT_NAN_VAL,
scale_e8m0_biased,
)
# For now, calculate the scale in floating point.
# TODO(future) audit if there is a need to bit shift exponents instead.
scale_fp = torch.pow(
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
scale_e8m0_unbiased,
)
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
# float32 denormal range. For now, manually adjust the fp scale. This is
# relevant if all of the incoming block values are zeroes.
# See https://github.com/pytorch/pytorch/issues/125557 for details.
# Note: it would be more correct to set the minimum to 2**-127, but this
# does not work in triton either as it looks like subnormal value handling
# has some gaps. So, for now just set to the minimum normal value.
scale_fp = torch.clamp(scale_fp, min=F32_MIN_NORMAL)
# scale and saturated cast the data elements to max of target dtype
if elem_dtype == torch.float8_e4m3fn:
max_pos = F8E4M3_MAX
elif elem_dtype == torch.float8_e5m2:
max_pos = F8E5M2_MAX
elif elem_dtype == DTYPE_FP6_E2M3:
max_pos = F6_E2M3_MAX
elif elem_dtype == DTYPE_FP6_E3M2:
max_pos = F6_E3M2_MAX
elif elem_dtype == DTYPE_FP4:
max_pos = F4_E2M1_MAX
else:
raise AssertionError("unsupported")
data_lp = torch.clamp(
data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
)
data_lp = data_lp.reshape(orig_shape)
# cast to target dtype
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_lp = data_lp.to(elem_dtype)
elif elem_dtype == DTYPE_FP6_E2M3:
data_lp = f32_to_f6_e2m3_unpacked(data_lp)
elif elem_dtype == DTYPE_FP6_E3M2:
data_lp = f32_to_f6_e3m2_unpacked(data_lp)
elif elem_dtype == DTYPE_FP4:
data_lp = f32_to_f4_unpacked(data_lp)
data_lp = pack_uint4(data_lp)
else:
raise AssertionError("unsupported")
return scale_e8m0_biased, data_lp
def get_fp_scale(scale_e8m0):
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
# TODO(later): it would be nice if there was a way to do the 2^x operation
# in PyTorch without creating a tensor of twos
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
# pow(two, s_offset) can be out of range of floating point formats.
# TODO(later): handle this for float16 if we decide to support float16
# scales.
s_fp = torch.pow(two, s_offset)
# If a block exponent was 255, set values of that block to NaN
s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))
return s_fp
def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
orig_shape = data_lp.shape
is_transposed = not data_lp.is_contiguous()
# if the underlying data is transposed, convert to row major before
# unpacking and unscaling
if is_transposed:
data_lp = data_lp.t()
assert data_lp.is_contiguous()
orig_shape = (orig_shape[1], orig_shape[0])
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_hp = data_lp.to(target_dtype)
elif elem_dtype == DTYPE_FP6_E2M3:
data_hp = f6_e2m3_unpacked_to_f32(data_lp)
data_hp = data_hp.to(target_dtype)
elif elem_dtype == DTYPE_FP6_E3M2:
data_hp = f6_e3m2_unpacked_to_f32(data_lp)
data_hp = data_hp.to(target_dtype)
elif elem_dtype == DTYPE_FP4:
if config.use_fp4_custom_triton_dequant_kernel:
data_hp_rescaled = triton_f4_to_scaled_bf16(
data_lp,
scale_e8m0,
block_size,
)
if is_transposed:
data_hp_rescaled = data_hp_rescaled.t()
return data_hp_rescaled.to(target_dtype)
else:
# fp4
f4_unpacked = unpack_uint4(data_lp)
# for now we only have a cast to f32
# TODO(future PR): add cast directly to bf16
f32 = f4_unpacked_to_f32(f4_unpacked)
data_hp = f32.to(target_dtype)
# manually adjust shape to account for the unpacking
# TODO(future PR): clean up the shape code and remove the hack
# below
orig_shape = (*orig_shape[:-1], orig_shape[-1] * 2)
else:
raise AssertionError("unsupported")
data_hp = data_hp.reshape(-1, block_size)
s_fp = get_fp_scale(scale_e8m0).reshape(-1, 1).to(target_dtype)
data_hp = data_hp * s_fp
data_hp = data_hp.reshape(orig_shape)
# if we converted to row-major before unscaling convert back
if is_transposed:
data_hp = data_hp.t()
return data_hp
def tensor_size_hp_to_fp4x2(orig_size, is_contiguous):
new_size = orig_size
if is_contiguous:
new_size = [*list(new_size[:-1]), new_size[-1] // 2]
else:
new_size = [new_size[0] // 2, *list(new_size[1:])]
return new_size
def tensor_size_fp4x2_to_hp(orig_size, is_contiguous):
new_size = orig_size
if is_contiguous:
new_size = [*list(new_size[:-1]), new_size[-1] * 2]
else:
new_size = [new_size[0] * 2, *list(new_size[1:])]
return new_size
@torch._dynamo.allow_in_graph
class ToMXConstrFunc(torch.autograd.Function):
"""
Differentiable cast to MX, no-op in backward
"""
@staticmethod
def forward(ctx, data_hp, elem_dtype, block_size):
scale_e8m0_biased, data_lp = to_mx(data_hp, elem_dtype, block_size)
return MXTensor(
scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype
)
@staticmethod
def backward(ctx, g):
return g, None, None
@torch._dynamo.allow_in_graph
class FromMXConstrFunc(torch.autograd.Function):
"""
Differentiable cast from MX, no-op in backward
"""
@staticmethod
def forward(ctx, tensor_lp, target_dtype):
return to_dtype(
tensor_lp._data,
tensor_lp._scale_e8m0,
tensor_lp._elem_dtype,
tensor_lp._block_size,
target_dtype,
)
@staticmethod
def backward(ctx, g):
return g, None, None
class MXTensor(torch.Tensor):
def __new__(
cls,
scale_e8m0_bits,
data_bits,
elem_dtype,
block_size,
orig_dtype,
):
new_size = data_bits.size()
if elem_dtype == DTYPE_FP4:
# set the tensor size to what it would be without 2x4 packing
new_size = tensor_size_fp4x2_to_hp(
new_size,
data_bits.is_contiguous(),
)
self = torch.Tensor._make_wrapper_subclass(
cls,
new_size,
dtype=orig_dtype,
device=data_bits.device,
)
assert scale_e8m0_bits.dtype == torch.uint8, "unsupported"
assert len(scale_e8m0_bits.shape) == 1, "unsupported"
assert data_bits.dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.uint8,
), "unsupported"
if elem_dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
):
target_numel = scale_e8m0_bits.numel() * block_size
elif elem_dtype == DTYPE_FP4:
assert data_bits.dtype is torch.uint8 # fp4
target_numel = scale_e8m0_bits.numel() * block_size / 2
else:
raise AssertionError("unsupported")
if not issubclass(
torch._subclasses.fake_tensor.FakeTensor,
type(data_bits),
):
# this check is sometimes broken for FakeTensor
# TODO investigate
assert (
target_numel == data_bits.numel()
), f"{target_numel} != {data_bits.numel()}"
# `_scale_e8m0` has rank 1 and applies to a row-major memory layout of
# `_data`
self._scale_e8m0 = scale_e8m0_bits
self._data = data_bits
self._elem_dtype = elem_dtype
self._block_size = block_size
self._orig_dtype = orig_dtype
return self
def __repr__(self):
# TODO better elem dtype print for fp4
return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self._data}, d_hp: {self.to_dtype(self._orig_dtype)}" # noqa: E501
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# avoid circular dependency
from torchao.prototype.mx_formats.mx_ops import MX_OPS_TABLE
if func in MX_OPS_TABLE:
return MX_OPS_TABLE[func](func, args, kwargs)
raise NotImplementedError(f"{func} not implemented")
def to_dtype(self, target_dtype):
return FromMXConstrFunc.apply(self, target_dtype)
@staticmethod
@torch._dynamo.allow_in_graph
def to_mx(
data_hp: torch.Tensor,
elem_dtype: Union[torch.dtype, str],
block_size: int = BLOCK_SIZE_DEFAULT,
):
return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size)
def __tensor_flatten__(self):
ctx = {
"_elem_dtype": self._elem_dtype,
"_block_size": self._block_size,
"_orig_dtype": self._orig_dtype,
}
return ["_scale_e8m0", "_data"], ctx
@staticmethod
def __tensor_unflatten__(
inner_tensors: Dict,
metadata,
outer_size,
outer_stride,
):
return MXTensor(
inner_tensors["_scale_e8m0"],
inner_tensors["_data"],
metadata["_elem_dtype"],
metadata["_block_size"],
metadata["_orig_dtype"],
)
# Do not force the MXTensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl