Skip to content

Store NVFP4 block scales in swwizzled layout on tensor #2438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jun 24, 2025

Stacked PRs:


Store NVFP4 block scales in swizzled layout on tensor

For llama3 70b no TP sizes w/ 1024 tokens: 15% E2E speedup

In eager Before: https://fburl.com/7w3j6b1q

nvfp4
Screenshot 2025-06-24 at 4 02 42 PM
Runtime: 2436.98 μs per iteration

In eager After: https://fburl.com/s7ggvm94

nvfp4 Runtime: 2356.77 μs per iteration
Screenshot 2025-06-24 at 4 03 27 PM

In compile

Before: https://fburl.com/1gvfjjlu

nvfp4 Runtime: 576.14 μs per iteration
Screenshot 2025-06-24 at 4 11 36 PM

After: https://fburl.com/usp1xelj

nvfp4 Runtime: 486.69 μs per iteration
Screenshot 2025-06-24 at 4 11 55 PM

Throughput: 47.12 requests/s, 19998.00 total tokens/s, 9635.87 output tokens/s
Total num prompt tokens:  225190
Total num output tokens:  209407

@drisspg drisspg force-pushed the drisspg/stack/80 branch from 2621b5c to a27abbc Compare June 24, 2025 23:00
drisspg added a commit that referenced this pull request Jun 24, 2025
stack-info: PR: #2438, branch: drisspg/stack/80
Copy link

pytorch-bot bot commented Jun 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2438

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 9337cc6 with merge base faf788a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2025
@drisspg drisspg changed the title Store NVFP4 block scales in swwizzled layout on tensor Store NVFP4 block scales in swizzled layout on tensor Jun 24, 2025
@drisspg drisspg added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Jun 24, 2025
drisspg added a commit that referenced this pull request Jun 24, 2025
stack-info: PR: #2438, branch: drisspg/stack/80
@drisspg drisspg force-pushed the drisspg/stack/80 branch from a27abbc to 9d539f3 Compare June 24, 2025 23:17
@drisspg drisspg changed the title Store NVFP4 block scales in swizzled layout on tensor Store NVFP4 block scales in swwizzled layout on tensor Jun 24, 2025
drisspg added a commit that referenced this pull request Jun 24, 2025
stack-info: PR: #2438, branch: drisspg/stack/80
@drisspg drisspg force-pushed the drisspg/stack/80 branch from 9d539f3 to e53b456 Compare June 24, 2025 23:51
drisspg added a commit that referenced this pull request Jun 24, 2025
stack-info: PR: #2438, branch: drisspg/stack/80
@drisspg drisspg force-pushed the drisspg/stack/80 branch from e53b456 to 5c059bf Compare June 24, 2025 23:57
@drisspg drisspg mentioned this pull request Jun 25, 2025
@drisspg drisspg force-pushed the drisspg/stack/80 branch 4 times, most recently from ce3c2e9 to 36ff7cf Compare June 25, 2025 23:45
@drisspg drisspg requested review from gau-nernst and vkuzo June 25, 2025 23:48
@drisspg
Copy link
Contributor Author

drisspg commented Jun 25, 2025

@vkuzo I plan to refactor the test from test_mx_tensor into a different file in the same folder after these two land to avoid rebase conflicts

Copy link
Collaborator

@gau-nernst gau-nernst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed tensor cores have become so fast that everything else is a bottleneck 😄

@drisspg I'm curious about the profile trace you show here. Is it for inference or training? I'm guessing it's for inference, since from my understanding, training wouldn't benefit from this, at least without fused quantize+scale_swizzle

@drisspg
Copy link
Contributor Author

drisspg commented Jun 26, 2025

Yup this is for inference

import torch
from typing import Literal

from transformer_nuggets.misc.mlp import FeedForward
from transformer_nuggets.utils.benchmark import (
    profiler,
    benchmark_do_bench_in_microseconds,
    benchmark_cuda_function_in_microseconds,
)

from torchao.quantization import quantize_
from torchao.prototype.mx_formats.mx_subclass import (
    MXFPInferenceConfig,
    NVFP4InferenceConfig,
)
from torchao.prototype.mx_formats import MXGemmKernelChoice
import os
from rich import print

from jsonargparse import CLI

torch._logging.set_logs(fusion=True)
torch._logging.set_logs(graph=True)
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"


def main(
    num_tokens: int = 1024,
    flavor: str = "8B",
    mode: str = "mxfp8",
    bf16: bool = False,
    setting: Literal["max-autotune", "inductor"] = "max-autotune",
    backend: Literal["inductor", "eager"] = "inductor",
    use_triton: bool = False,
):
    device = "cuda"
    dtype = torch.bfloat16
    assert flavor in ("8B", "70B", "405B")
    assert mode in ("mxfp8", "mxfp4", "nvfp4")
    if bf16:
        model = FeedForward.llama3_mlp(flavor).to(device=device, dtype=dtype)
        inpt = model.get_input(num_tokens, device, dtype)
        model = torch.compile(
            model, fullgraph=True, dynamic=False, mode=setting, backend=backend
        )

        for _ in range(5):
            model(inpt)
        with profiler("data/bf16", with_stack=True) as p:
            for _ in range(5):
                model(inpt)
                p.step()
        runtime_us = benchmark_cuda_function_in_microseconds(lambda: model(inpt))
        print(
            f"[bold green]Bf16 Runtime:[/bold green] {runtime_us:.2f} μs per iteration"
        )

    model = FeedForward.llama3_mlp(flavor).to(device=device, dtype=dtype)
    inpt = model.get_input(num_tokens, device, dtype)
    match mode:
        case "mxfp8":
            elem_dtype = torch.float8_e4m3fn
            choice = MXGemmKernelChoice.CUBLAS
            config = MXFPInferenceConfig(
                activation_dtype=elem_dtype,
                weight_dtype=elem_dtype,
                block_size=32,
                gemm_kernel_choice=choice,
            )
        case "mxfp4":
            elem_dtype = torch.float4_e2m1fn_x2
            choice = MXGemmKernelChoice.CUTLASS
            config = MXFPInferenceConfig(
                activation_dtype=elem_dtype,
                weight_dtype=elem_dtype,
                block_size=32,
                gemm_kernel_choice=choice,
            )
        case "nvfp4":
            config = NVFP4InferenceConfig(use_triton_kernel=use_triton)
        case _:
            raise ValueError(f"Unknown mode {mode}")

    quantize_(model, config)
    model = torch.compile(
        model, fullgraph=True, dynamic=False, mode="max-autotune", backend=backend
    )

    for _ in range(5):
        model(inpt)

    # Use profiler for detailed profiling
    with profiler(f"data/mx_fp_8_{"subclass"}_{mode}", with_stack=True) as p:
        for _ in range(5):
            model(inpt)
            p.step()

    # Use benchmark function for pretty printed runtime
    runtime_us = benchmark_do_bench_in_microseconds(lambda: model(inpt))
    print(f"[bold green]{mode} Runtime:[/bold green] {runtime_us:.2f} μs per iteration")


if __name__ == "__main__":
    CLI(main)

training wouldn't benefit from this, at least without fused quantize+scale_swizzle
Yup this is a proxy for inference where weights are static, like you said for training since the scales change at every grad step, I dont think this would be much help

@drisspg drisspg force-pushed the drisspg/stack/80 branch 2 times, most recently from 32e3621 to 4860a8a Compare June 26, 2025 05:19
@drisspg drisspg force-pushed the drisspg/stack/80 branch 3 times, most recently from 8339636 to 1b8dd95 Compare June 26, 2025 16:26
stack-info: PR: #2438, branch: drisspg/stack/80
@drisspg drisspg force-pushed the drisspg/stack/80 branch from 1b8dd95 to 9337cc6 Compare June 26, 2025 17:16
@drisspg drisspg merged commit 994a4ba into main Jun 26, 2025
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants