-
Notifications
You must be signed in to change notification settings - Fork 293
Store NVFP4 block scales in swwizzled layout on tensor #2438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2621b5c
to
a27abbc
Compare
stack-info: PR: #2438, branch: drisspg/stack/80
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2438
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 9337cc6 with merge base faf788a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2438, branch: drisspg/stack/80
a27abbc
to
9d539f3
Compare
stack-info: PR: #2438, branch: drisspg/stack/80
9d539f3
to
e53b456
Compare
stack-info: PR: #2438, branch: drisspg/stack/80
e53b456
to
5c059bf
Compare
ce3c2e9
to
36ff7cf
Compare
@vkuzo I plan to refactor the test from test_mx_tensor into a different file in the same folder after these two land to avoid rebase conflicts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed tensor cores have become so fast that everything else is a bottleneck 😄
@drisspg I'm curious about the profile trace you show here. Is it for inference or training? I'm guessing it's for inference, since from my understanding, training wouldn't benefit from this, at least without fused quantize+scale_swizzle
Yup this is for inference import torch
from typing import Literal
from transformer_nuggets.misc.mlp import FeedForward
from transformer_nuggets.utils.benchmark import (
profiler,
benchmark_do_bench_in_microseconds,
benchmark_cuda_function_in_microseconds,
)
from torchao.quantization import quantize_
from torchao.prototype.mx_formats.mx_subclass import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
)
from torchao.prototype.mx_formats import MXGemmKernelChoice
import os
from rich import print
from jsonargparse import CLI
torch._logging.set_logs(fusion=True)
torch._logging.set_logs(graph=True)
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
def main(
num_tokens: int = 1024,
flavor: str = "8B",
mode: str = "mxfp8",
bf16: bool = False,
setting: Literal["max-autotune", "inductor"] = "max-autotune",
backend: Literal["inductor", "eager"] = "inductor",
use_triton: bool = False,
):
device = "cuda"
dtype = torch.bfloat16
assert flavor in ("8B", "70B", "405B")
assert mode in ("mxfp8", "mxfp4", "nvfp4")
if bf16:
model = FeedForward.llama3_mlp(flavor).to(device=device, dtype=dtype)
inpt = model.get_input(num_tokens, device, dtype)
model = torch.compile(
model, fullgraph=True, dynamic=False, mode=setting, backend=backend
)
for _ in range(5):
model(inpt)
with profiler("data/bf16", with_stack=True) as p:
for _ in range(5):
model(inpt)
p.step()
runtime_us = benchmark_cuda_function_in_microseconds(lambda: model(inpt))
print(
f"[bold green]Bf16 Runtime:[/bold green] {runtime_us:.2f} μs per iteration"
)
model = FeedForward.llama3_mlp(flavor).to(device=device, dtype=dtype)
inpt = model.get_input(num_tokens, device, dtype)
match mode:
case "mxfp8":
elem_dtype = torch.float8_e4m3fn
choice = MXGemmKernelChoice.CUBLAS
config = MXFPInferenceConfig(
activation_dtype=elem_dtype,
weight_dtype=elem_dtype,
block_size=32,
gemm_kernel_choice=choice,
)
case "mxfp4":
elem_dtype = torch.float4_e2m1fn_x2
choice = MXGemmKernelChoice.CUTLASS
config = MXFPInferenceConfig(
activation_dtype=elem_dtype,
weight_dtype=elem_dtype,
block_size=32,
gemm_kernel_choice=choice,
)
case "nvfp4":
config = NVFP4InferenceConfig(use_triton_kernel=use_triton)
case _:
raise ValueError(f"Unknown mode {mode}")
quantize_(model, config)
model = torch.compile(
model, fullgraph=True, dynamic=False, mode="max-autotune", backend=backend
)
for _ in range(5):
model(inpt)
# Use profiler for detailed profiling
with profiler(f"data/mx_fp_8_{"subclass"}_{mode}", with_stack=True) as p:
for _ in range(5):
model(inpt)
p.step()
# Use benchmark function for pretty printed runtime
runtime_us = benchmark_do_bench_in_microseconds(lambda: model(inpt))
print(f"[bold green]{mode} Runtime:[/bold green] {runtime_us:.2f} μs per iteration")
if __name__ == "__main__":
CLI(main)
|
32e3621
to
4860a8a
Compare
8339636
to
1b8dd95
Compare
stack-info: PR: #2438, branch: drisspg/stack/80
1b8dd95
to
9337cc6
Compare
Stacked PRs:
Store NVFP4 block scales in swizzled layout on tensor
For llama3 70b no TP sizes w/ 1024 tokens: 15% E2E speedup
In eager Before: https://fburl.com/7w3j6b1q
nvfp4

Runtime: 2436.98 μs per iteration
In eager After: https://fburl.com/s7ggvm94
nvfp4 Runtime: 2356.77 μs per iteration

In compile
Before: https://fburl.com/1gvfjjlu
nvfp4 Runtime: 576.14 μs per iteration

After: https://fburl.com/usp1xelj
nvfp4 Runtime: 486.69 μs per iteration
