diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index c61f973d03..7cc29c7ee1 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -70,9 +70,9 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b elif elem_dtype == torch.float4_e2m1fn_x2: if not is_sm_at_least_100() and not emulate: pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm") - elif emulate and compile: + elif compile: # TODO(future PR): investigate and fix this - pytest.skip("mxfp4 + emulate + compile currently does not work, low SQNR") + pytest.skip("mxfp4 + compile currently does not work, low SQNR") m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") m_mx = copy.deepcopy(m) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 024586419a..1bd55e28f5 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -320,6 +320,26 @@ def test_fp4_pack_unpack(): orig_vals = torch.Tensor([[0.0, 0.5, 4.0, -0.0], [-0.0, 1.0, -6.0, 3.0]]) orig_vals_f4_unpacked = f32_to_f4_unpacked(orig_vals) orig_vals_f4_packed = pack_uint4(orig_vals_f4_unpacked) + + # ensure packing is + # + # 7654:3210 + # val1:val0 + expected_f4_packed = torch.tensor( + [ + [ + 0b00010000, + 0b10000110, + ], + [ + 0b00101000, + 0b01011111, + ], + ], + dtype=torch.uint8, + ) + + assert torch.all(orig_vals_f4_packed == expected_f4_packed) assert orig_vals_f4_packed.numel() == (orig_vals.numel() / 2) orig_vals_f4_packed_unpacked = unpack_uint4(orig_vals_f4_packed) orig_vals_dq = f4_unpacked_to_f32(orig_vals_f4_packed_unpacked) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 4a8c899d1c..260b61b9c9 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -142,10 +142,10 @@ def _fp4_packed_to_bf16( Output: a tensor of bfloat16 values """ - # low-bits: original location 0:3 - # high-bits: original location 4:7 - x_low_bits = x_packed >> 4 - x_high_bits = x_packed & 0xF + # high-bits: original location 0:3 + # low-bits: original location 4:7 + x_high_bits = x_packed >> 4 + x_low_bits = x_packed & 0xF x = tl.interleave(x_low_bits, x_high_bits) # cast logic below @@ -735,8 +735,8 @@ def unpack_uint4(uint8_data) -> torch.Tensor: # verified that we get a single triton kernel, but that is even slower # than the two kernels before this PR # * TODO add a microbenchmark of just the cast and profile this - first_elements = (uint8_data >> 4).to(torch.uint8) - second_elements = (uint8_data & 0b1111).to(torch.uint8) + first_elements = (uint8_data & 0b1111).to(torch.uint8) + second_elements = (uint8_data >> 4).to(torch.uint8) unpacked = torch.stack([first_elements, second_elements], dim=-1).view( up_size(shape) ) @@ -758,7 +758,7 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: shape = uint8_data.shape assert shape[-1] % 2 == 0 uint8_data = uint8_data.contiguous().view(-1) - return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) + return (uint8_data[::2] | uint8_data[1::2] << 4).view(down_size(shape)) # PyTorch implementation of fp6 packing for reference purposes @@ -1250,8 +1250,8 @@ def convert_fp32_to_fp4_packed(x_pairs): Returns: Packed tensor with shape [...] (last dimension removed) where each element is an int8 containing 2 FP4 values: - - First value of pair → high nibble (bits 4-7) - - Second value of pair → low nibble (bits 0-3) + - First value of pair → low nibble (bits 0-3) + - Second value of pair → high nibble (bits 4-7) Example: Input: [128, 32, 2] containing FP32 pairs @@ -1263,10 +1263,10 @@ def convert_fp32_to_fp4_packed(x_pairs): asm=""" { .reg .b8 byte0, byte1, byte2, byte3; - cvt.rn.satfinite.e2m1x2.f32 byte0, $1, $5; - cvt.rn.satfinite.e2m1x2.f32 byte1, $2, $6; - cvt.rn.satfinite.e2m1x2.f32 byte2, $3, $7; - cvt.rn.satfinite.e2m1x2.f32 byte3, $4, $8; + cvt.rn.satfinite.e2m1x2.f32 byte0, $5, $1; + cvt.rn.satfinite.e2m1x2.f32 byte1, $6, $2; + cvt.rn.satfinite.e2m1x2.f32 byte2, $7, $3; + cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $4; mov.b32 $0, {byte0, byte1, byte2, byte3}; } """,