Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,25 @@ def nvfp4_view_op(func, types, args, kwargs):
)


@implements([aten.select.int])
def nvfp4_select(func, types, args, kwargs):
old, dim, index = args
assert dim == 0, f"NVFP4Tensor aten.select.int with {dim=} is not yet supported"
assert len(old.qdata.shape) == len(old._scale_e4m3.shape), "unsupported"
new = old.__class__(
old.qdata[index],
old._scale_e4m3[index],
old._block_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should a dim get knocked off block size after you select?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently block_size is an integer for this tensor, 16 for NVFP4. If we change it to a multidimensional block, we'd have to update this code.

old._orig_dtype,
old._per_tensor_scale,
old._act_per_tensor_scale,
old._is_swizzled_scales,
old.use_triton_kernel,
old.act_quant_kwargs,
)
return return_and_correct_aliasing(func, args, kwargs, new)


def _addmm_nvfp4_dispatch(
a: NVFP4Tensor, b: NVFP4Tensor, aten_op, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
Expand Down
6 changes: 4 additions & 2 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,9 @@ def _test_narrow_similar_to_vllm(self, config: AOBaseConfig):
)

def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig):
# this happens when vLLM loads empty MoE weights and quantizes
# them
# this happens when vLLM loads empty MoE weights, quantizes
# them, and stitches 2d params from the checkpoint into a 3d param
# in memory

dtype = torch.bfloat16
with torch.device("meta"):
Expand All @@ -636,6 +637,7 @@ def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig):
torch.randn(60, 2816, 2048, device="cuda", dtype=dtype)
)
quantize_(l, config)
_w_slice = l.weight[0]


common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
Expand Down
Loading