Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,16 @@ def _test_dtensor_cast_to_mxfp8(mesh: DeviceMesh, size=4):
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert size % world_size == 0, "unsupported"
x_fp8_fp32 = x_fp8.to_dtype(torch.float32)
x_fp8_fp32 = x_fp8.dequantize(torch.float32)
rows_per_slice = size // world_size
slice_start = local_rank * rows_per_slice
slice_end = (local_rank + 1) * rows_per_slice
x_fp8_fp32_slice = x_fp8_fp32[slice_start:slice_end]
torch.testing.assert_close(
x_fp8_fp32_slice, dist_x_fp8.to_local().to_dtype(torch.float32), atol=0, rtol=0
x_fp8_fp32_slice,
dist_x_fp8.to_local().dequantize(torch.float32),
atol=0,
rtol=0,
)


Expand Down
6 changes: 3 additions & 3 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)

out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
-1, -2
)
out_hp = a_mx.dequantize(torch.bfloat16) @ b_mx.dequantize(
torch.bfloat16
).transpose(-1, -2)
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)

return compute_error(out_hp, out).item()
Expand Down
14 changes: 7 additions & 7 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _test_mx(
data_hp, elem_dtype, block_size, scale_calculation_mode=ScaleCalculationMode.FLOOR
):
data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size, scale_calculation_mode)
data_mx_dq = data_mx.to_dtype(data_hp.dtype)
data_mx_dq = data_mx.dequantize(data_hp.dtype)

def assert_sqnr_gt_threshold(orig, new, threshold):
sqnr = compute_error(orig, new)
Expand Down Expand Up @@ -389,7 +389,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
pack_fp6,
None,
)
tensor_hp = tensor_mx.to_dtype(torch.float)
tensor_hp = tensor_mx.dequantize(torch.float)
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))

Expand Down Expand Up @@ -436,10 +436,10 @@ def test_transpose(elem_dtype):
elem_dtype,
block_size,
)
tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t()
tensor_mx_dq_t = tensor_mx.dequantize(tensor_hp.dtype).t()

tensor_mx_t = tensor_mx.t()
tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype)
tensor_mx_t_dq = tensor_mx_t.dequantize(tensor_hp.dtype)

assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)
Expand All @@ -461,8 +461,8 @@ def test_clone():
data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size)
data_mx_c = data_mx.clone()
torch.testing.assert_close(
data_mx.to_dtype(torch.bfloat16),
data_mx_c.to_dtype(torch.bfloat16),
data_mx.dequantize(torch.bfloat16),
data_mx_c.dequantize(torch.bfloat16),
atol=0,
rtol=0,
)
Expand Down Expand Up @@ -571,7 +571,7 @@ def test_index_select():

x_mx_1 = x_mx[1]
torch.testing.assert_close(
x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0
x_mx.dequantize(x.dtype)[1], x_mx_1.dequantize(x.dtype), atol=0, rtol=0
)


Expand Down
18 changes: 9 additions & 9 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
scale = None

x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
x_reconstructed = x_nvfp4.to_dtype(dtype)
x_reconstructed = x_nvfp4.dequantize(dtype)

def assert_sqnr_gt_threshold(orig, new, threshold):
sqnr = compute_error(orig, new)
Expand Down Expand Up @@ -91,7 +91,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
x_nvfp4_t = x_nvfp4.transpose(-2, -1)
x_t = x.transpose(-2, -1)

x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
x_reconstructed_t = x_nvfp4_t.dequantize(dtype)
assert_sqnr_gt_threshold(x_t, x_reconstructed_t, 8.0)

assert x_t.shape == x_reconstructed_t.shape, (
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):

tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales)
assert tensor._is_swizzled_scales == is_swizzled_scales
reconstructed = tensor.to_dtype(torch.bfloat16)
reconstructed = tensor.dequantize(torch.bfloat16)
assert reconstructed.shape == data.shape


Expand Down Expand Up @@ -181,10 +181,10 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
assert sliced_tensor._is_swizzled_scales == True

# Verify sliced tensor can be dequantized
sliced_reconstructed = sliced_tensor.to_dtype(torch.bfloat16)
sliced_reconstructed = sliced_tensor.dequantize(torch.bfloat16)

# Compare with direct slicing of original data
original_reconstructed = tensor.to_dtype(torch.bfloat16)
original_reconstructed = tensor.dequantize(torch.bfloat16)
if slice_dim == 0:
expected = original_reconstructed[slice_spec, :]
else:
Expand Down Expand Up @@ -324,8 +324,8 @@ def test_nvfp4_swizzled_scales_serialization():
assert reconstructed_tensor._is_swizzled_scales == True

# Verify functionality is preserved
original_dq = original_tensor.to_dtype(torch.bfloat16)
reconstructed_dq = reconstructed_tensor.to_dtype(torch.bfloat16)
original_dq = original_tensor.dequantize(torch.bfloat16)
reconstructed_dq = reconstructed_tensor.dequantize(torch.bfloat16)

torch.testing.assert_close(original_dq, reconstructed_dq, atol=1e-6, rtol=1e-6)

Expand Down Expand Up @@ -404,8 +404,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
rtol=0,
)

x_pt_dequant = nvfp4_pt.to_dtype(dtype)
x_triton_dequant = nvfp4_triton.to_dtype(dtype)
x_pt_dequant = nvfp4_pt.dequantize(dtype)
x_triton_dequant = nvfp4_triton.dequantize(dtype)

sqnr = compute_error(x_pt_dequant, x_triton_dequant)
SQNR_THRESHOLD = 40.0
Expand Down
12 changes: 7 additions & 5 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,15 @@ def __repr__(self):
def _quantization_type(self):
return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}"

def to_dtype(self, target_dtype):
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype
return to_dtype(
self.qdata,
self.scale,
self._elem_dtype,
self._block_size,
target_dtype,
output_dtype,
self._pack_fp6,
)

Expand Down Expand Up @@ -718,8 +720,8 @@ def _addmm_mx_dispatch(

else:
# emulated MX gemm
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
a_hp = a.dequantize(a._orig_dtype)
b_hp = b.dequantize(b._orig_dtype)
# assert memory layout we expect to be required in hardware
assert a_hp.is_contiguous()
assert b_hp.t().is_contiguous()
Expand Down Expand Up @@ -780,7 +782,7 @@ def mx_cast_up_op(func, types, args, kwargs):

def unwrap(x):
if isinstance(x, MXTensor):
return x.to_dtype(x._orig_dtype)
return x.dequantize(x._orig_dtype)
return x

new_args = tree_map(unwrap, args)
Expand Down
18 changes: 10 additions & 8 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __new__(
return self

def __repr__(self):
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self.per_tensor_scale}, d: {self.qdata}, d_hp: {self.to_dtype(self._orig_dtype)}"
return f"NVFP4Tensor: scale: {self.scale}, per_tensor_scale: {self.per_tensor_scale}, d: {self.qdata}, d_hp: {self.dequantize(self._orig_dtype)}"

def _quantization_type(self):
return f"{self._is_swizzled_scales=}, {self.use_triton_kernel=}, {self.act_quant_kwargs=}"
Expand Down Expand Up @@ -217,7 +217,7 @@ def to_nvfp4(
# Do not force the NVFP4Tensor type on the returned tensor
__torch_function__ = torch._C._disabled_torch_function_impl

def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Convert NVFP4Tensor back to high precision dtype.

Args:
Expand All @@ -226,6 +226,8 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
Returns:
torch.Tensor: Dequantized tensor in the target dtype
"""
if output_dtype is None:
output_dtype = self.dtype
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1)
if is_transposed:
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2]
Expand All @@ -242,7 +244,7 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
*leading_dims, M, K // self._block_size, 1
)
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
result = data_scaled.view(*leading_dims, M, K).to(target_dtype)
result = data_scaled.view(*leading_dims, M, K).to(output_dtype)

if is_transposed:
result = result.transpose(-2, -1)
Expand Down Expand Up @@ -731,7 +733,7 @@ def nvfp4_linear(func, types, args, kwargs):

if weight_tensor.act_quant_kwargs is None:
# weight_only quant
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
weight_dequant = weight_tensor.dequantize(weight_tensor._orig_dtype)
return torch.nn.functional.linear(input_tensor, weight_dequant, bias)
else:
# dynamic quant
Expand Down Expand Up @@ -759,9 +761,9 @@ def nvfp4_mm(func, types, args, kwargs):
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")

if weight_tensor.act_quant_kwargs is None:
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
weight_dequant = weight_tensor.dequantize(weight_tensor._orig_dtype)
if isinstance(input_tensor, NVFP4Tensor):
input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype)
input_dequant = input_tensor.dequantize(input_tensor._orig_dtype)
return func(input_dequant, weight_dequant)
else:
return func(input_tensor, weight_dequant)
Expand Down Expand Up @@ -791,9 +793,9 @@ def nvfp4_addmm(func, types, args, kwargs):
raise NotImplementedError("NVFP4Tensor: weight must be NVFP4Tensor")

if weight_tensor.act_quant_kwargs is None:
weight_dequant = weight_tensor.to_dtype(weight_tensor._orig_dtype)
weight_dequant = weight_tensor.dequantize(weight_tensor._orig_dtype)
if isinstance(input_tensor, NVFP4Tensor):
input_dequant = input_tensor.to_dtype(input_tensor._orig_dtype)
input_dequant = input_tensor.dequantize(input_tensor._orig_dtype)
return torch.addmm(bias, input_dequant, weight_dequant)
else:
return torch.addmm(bias, input_tensor, weight_dequant)
Expand Down
Loading