-
Notifications
You must be signed in to change notification settings - Fork 342
enable 3d weights for NVFP4Tensor #3109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f9ca2f8
7da7826
fa40093
e10a16e
9d0590b
08e9d13
7b76009
55cc5e8
7ce6dcf
dffb91c
55c361f
eb82d5f
baf7568
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
import sys | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
@@ -112,7 +113,7 @@ def __new__( | |
|
||
new_size = tensor_size_fp4x2_to_hp( | ||
new_size, | ||
qdata.stride(0) > qdata.stride(1), | ||
qdata.stride(-2) > qdata.stride(-1), | ||
) | ||
|
||
self = torch.Tensor._make_wrapper_subclass( | ||
|
@@ -174,21 +175,21 @@ def to_nvfp4( | |
Returns: | ||
NVFP4Tensor: Quantized tensor in NVFP4 format | ||
""" | ||
assert len(data_hp.shape) == 2, "unsupported" | ||
M, K = data_hp.shape[0], data_hp.shape[1] | ||
assert len(data_hp.shape) in (2, 3), "unsupported" | ||
leading_dims, M, K = data_hp.shape[:-2], data_hp.shape[-2], data_hp.shape[-1] | ||
|
||
if use_triton_kernel: | ||
assert is_swizzled_scales, "Triton kernel only supports swizzled scales" | ||
assert data_hp.shape[1] % 16 == 0, ( | ||
f"Triton kernel requires K (dim 1) to be divisible by 16, got {data_hp.shape[1]}" | ||
assert K % 16 == 0, ( | ||
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}" | ||
) | ||
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale) | ||
else: | ||
blockwise_scales, data_lp = nvfp4_quantize( | ||
data_hp, block_size, per_tensor_scale | ||
) | ||
if is_swizzled_scales: | ||
scale_shape = (M, K // block_size) | ||
scale_shape = (math.prod(leading_dims) * M, K // block_size) | ||
blockwise_scales = to_blocked( | ||
blockwise_scales.view(scale_shape) | ||
).flatten() | ||
|
@@ -199,7 +200,7 @@ def to_nvfp4( | |
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 | ||
# scale element | ||
scale_M, scale_K = M, K // block_size | ||
blockwise_scales = blockwise_scales.view(scale_M, scale_K) | ||
blockwise_scales = blockwise_scales.view(*leading_dims, scale_M, scale_K) | ||
|
||
return NVFP4Tensor( | ||
data_lp, | ||
|
@@ -225,22 +226,26 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor: | |
Returns: | ||
torch.Tensor: Dequantized tensor in the target dtype | ||
""" | ||
is_transposed = self.qdata.stride(0) < self.qdata.stride(1) | ||
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1) | ||
if is_transposed: | ||
M, K = self.shape[1], self.shape[0] | ||
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2] | ||
else: | ||
M, K = self.shape[0], self.shape[1] | ||
data = self.qdata.t() if is_transposed else self.qdata | ||
leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1] | ||
data = self.qdata.transpose(-2, -1) if is_transposed else self.qdata | ||
data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8)) | ||
data_f32 = f4_unpacked_to_f32(data_unpacked) | ||
|
||
data_f32 = data_f32.view(M, K // self._block_size, self._block_size) | ||
scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1) | ||
data_f32 = data_f32.view( | ||
*leading_dims, M, K // self._block_size, self._block_size | ||
) | ||
scale_e4m3_reshaped = self.get_hp_scales().view( | ||
*leading_dims, M, K // self._block_size, 1 | ||
) | ||
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32) | ||
result = data_scaled.view(M, K).to(target_dtype) | ||
result = data_scaled.view(*leading_dims, M, K).to(target_dtype) | ||
|
||
if is_transposed: | ||
result = result.t() | ||
result = result.transpose(-2, -1) | ||
|
||
return result | ||
|
||
|
@@ -250,16 +255,18 @@ def get_hp_scales(self) -> torch.Tensor: | |
Returns: | ||
torch.Tensor: Scales of the NVFP4Tensor | ||
""" | ||
is_transposed = self.qdata.stride(0) < self.qdata.stride(1) | ||
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1) | ||
if is_transposed: | ||
M, K = self.shape[1], self.shape[0] | ||
scale_e4m3 = self._scale_e4m3.t() | ||
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2] | ||
scale_e4m3 = self._scale_e4m3.transpose(-2, -1) | ||
else: | ||
M, K = self.shape[0], self.shape[1] | ||
leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1] | ||
scale_e4m3 = self._scale_e4m3 | ||
|
||
if self._is_swizzled_scales: | ||
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size) | ||
scale_e4m3 = from_blocked( | ||
scale_e4m3, math.prod(leading_dims) * M, K // self._block_size | ||
) | ||
|
||
return ( | ||
scale_e4m3.to(self._orig_dtype) | ||
|
@@ -380,6 +387,9 @@ def nvfp4_slice(func, types, args, kwargs): | |
raise ValueError("Only support aten.slice with step=1") | ||
|
||
assert x.qdata.is_contiguous(), "Only support contiguous data for now" | ||
assert len(x.shape) == 2, ( | ||
f"only rank 2 is supported for slice, got rank {len(x.shape)}" | ||
) | ||
|
||
M, K = x.shape[0], x.shape[1] | ||
|
||
|
@@ -583,6 +593,28 @@ def nvfp4_t(func, types, args, kwargs): | |
return new | ||
|
||
|
||
@implements([aten.transpose.int]) | ||
def nvfp4_transpose(func, types, args, kwargs): | ||
old, dim0, dim1 = args | ||
assert len(old.shape) == 3, f"unsupported rank {len(old.shape)}" | ||
valid_3d_dims = ((1, 2), (2, 1), (-1, -2), (-2, -1)) | ||
assert (dim0, dim1) in valid_3d_dims, f"transpose unsupported for {dim0=} {dim1=}" | ||
new_qdata = func(old.qdata, dim0, dim1, **kwargs) | ||
new_scale = func(old._scale_e4m3, dim0, dim1, **kwargs) | ||
new = NVFP4Tensor( | ||
new_qdata, | ||
new_scale, | ||
old._block_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would block size change with transpose? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently |
||
old._orig_dtype, | ||
old._per_tensor_scale, | ||
old._act_per_tensor_scale, | ||
old._is_swizzled_scales, | ||
old.use_triton_kernel, | ||
old.act_quant_kwargs, | ||
) | ||
return new | ||
|
||
|
||
@implements([aten.view.default]) | ||
def nvfp4_view_op(func, types, args, kwargs): | ||
data = args[0].qdata | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just fyi, i think you can do: