Skip to content

Update mx_formats README.md #2777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 15, 2025
Merged

Update mx_formats README.md #2777

merged 2 commits into from
Aug 15, 2025

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Aug 15, 2025

add e2e torchtitan benchmarks on LLaMa 3 8B for mxfp8 training

Copy link

pytorch-bot bot commented Aug 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2777

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 6 Pending

As of commit 3f6623f with merge base 49cb18a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 15, 2025
@vkuzo vkuzo added the topic: documentation Use this tag if this PR adds or improves documentation label Aug 15, 2025

ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features.</em>

## Training e2e benchmarks on NVIDIA B200

- Single-node training on 8xB100 GPUs, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you specify that this is a power throttled variant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also say B200 instead of B100, will fix

@vkuzo vkuzo merged commit d8bb51f into main Aug 15, 2025
7 of 9 checks passed
| Llama3-8b | none (bfloat16) | 33.71 | 8307.5 | -
| Llama3-8b | float8 tensorwise (f8 all-gather) | 33.38 | 10417.0 | 25.4%
| Llama3-8b | mxfp8_cublas | 33.88 | 9969.0 | 20.0%
| Llama3-8b | mxfp8_cublas_rceil | 33.88 | 9642.0 | 16.1%
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is odd, rceil uses the hardware acclerated fp32 -> e8m0 casting instruction, it should be faster than floor, and when i did benchmarking it was faster than floor. any idea what could be going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it only uses the accelerated instruction in the dim1 kernel. Once we make it also use that instruction in dim0, it should beat floor.

| Model | Scaling | Peak Memory (GB) | Median tokens/second | Speedup over baseline
| ------------- | ---------------------------------- | ------------------| -------------------- | ---------------------
| Llama3-8b | none (bfloat16) | 33.71 | 8307.5 | -
| Llama3-8b | float8 tensorwise (f8 all-gather) | 33.38 | 10417.0 | 25.4%
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, when i benchmarked fp8 tensorwise vs mxfp8 and found they were roughly the same throughput, i didn't use fp8 all-gather, i wonder if that explains this difference


ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features.</em>

## Training e2e benchmarks on NVIDIA B200

- Single-node training on 8xB100 GPUs, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC
Copy link
Contributor

@danielvegamyhre danielvegamyhre Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not using torch.compile here? When I ran benchmarks during mxfp8 dim1 cast cuda kernel development i was getting ~13.5k TPS on FSDP2 only 8xb200.. can double check the exact config

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: documentation Use this tag if this PR adds or improves documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants