diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 043e1160e0..b42b722997 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -634,6 +634,25 @@ def nvfp4_view_op(func, types, args, kwargs): ) +@implements([aten.select.int]) +def nvfp4_select(func, types, args, kwargs): + old, dim, index = args + assert dim == 0, f"NVFP4Tensor aten.select.int with {dim=} is not yet supported" + assert len(old.qdata.shape) == len(old._scale_e4m3.shape), "unsupported" + new = old.__class__( + old.qdata[index], + old._scale_e4m3[index], + old._block_size, + old._orig_dtype, + old._per_tensor_scale, + old._act_per_tensor_scale, + old._is_swizzled_scales, + old.use_triton_kernel, + old.act_quant_kwargs, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + def _addmm_nvfp4_dispatch( a: NVFP4Tensor, b: NVFP4Tensor, aten_op, bias: Optional[torch.Tensor] = None ) -> torch.Tensor: diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 7f694b56d3..a1dc40fdd3 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -626,8 +626,9 @@ def _test_narrow_similar_to_vllm(self, config: AOBaseConfig): ) def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig): - # this happens when vLLM loads empty MoE weights and quantizes - # them + # this happens when vLLM loads empty MoE weights, quantizes + # them, and stitches 2d params from the checkpoint into a 3d param + # in memory dtype = torch.bfloat16 with torch.device("meta"): @@ -636,6 +637,7 @@ def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig): torch.randn(60, 2816, 2048, device="cuda", dtype=dtype) ) quantize_(l, config) + _w_slice = l.weight[0] common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)