Skip to content

Roofline quantized conv3d/2d layer #3419

Merged
jainapurva merged 17 commits intomainfrom
conv_roofline
Dec 24, 2025
Merged

Roofline quantized conv3d/2d layer #3419
jainapurva merged 17 commits intomainfrom
conv_roofline

Conversation

@jainapurva
Copy link
Copy Markdown
Contributor

@jainapurva jainapurva commented Dec 3, 2025

This pull request extends the float8 inference roofline benchmarking code to support convolution operations (conv2d and conv3d) in addition to linear layers. It introduces new utilities and refactors the workflow to enable roofline modeling and kernel benchmarking for convolutional operations, including calculation of equivalent GEMM dimensions and measurement of kernel times. A conv kernel (FBGEMM) is a combination of im2col + implicit GEMM

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Dec 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3419

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 5f3abb1 with merge base 095a7e6 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 3, 2025
@jainapurva jainapurva marked this pull request as ready for review December 10, 2025 05:57
# Filter out aten::fill_ and other non-conv operations
filtered_data = {k: v for k, v in data.items() if k in expected_conv_kernels}

assert len(filtered_data) >= 1, f"unexpected data: {data}"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the error message should indicate something about potential incompleteness of the above expected conv kernel list?

Copy link
Copy Markdown

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks pretty good from what I can tell. Thanks for the hard work here!

"aten::slow_conv_dilated3d",
"torchao::_conv2d_fp8_inner",
"torchao::_conv3d_fp8_inner",
"fbgemm::f8f8bf16_conv",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we just updated this to mslk btw

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rest looks good

)

# For fp8 conv timing, we need to use fbgemm operator
if recipe_name in ("mxfp4_cutlass", "nvfp4"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are not supported, so we don't need to check for these I think, we can add a check to say what are the supported recipe_names

# Try to use fbgemm fp8 conv operator
try:
# Check if fbgemm fp8 conv is available
if not hasattr(torch.ops.fbgemm, "f8f8bf16_conv"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use

to check I think

print(
"Warning: fbgemm.f8f8bf16_conv not available, skipping fp8 conv timing"
)
f8_time_s = 0.0
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can just error out, and talk about how user can turn off the conv benchmarking by setting do_benchmarks to False?

this will also reduce the indentation here

)
f8_time_s = 0.0

except Exception as e:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just raise the exception

- Replace hasattr check with _is_mslk_available() utility
- Add recipe validation for conv operations (only tensorwise supported)
- Add early error checks with helpful messages for conv2d and mslk availability
- Remove redundant exception handling in get_conv_times()
- Improve defense in depth with validation at multiple levels
- Re-raise validation errors (NotImplementedError, RuntimeError, ValueError) to fail fast
- Remove unused kernel names from utils.py (torchao::_conv2d_fp8_inner, torchao::_conv3d_fp8_inner)
This commit fixes multiple issues in the conv3d fp8 benchmarking code
to align with the updated mslk operator (from commit 095a7e6):

1. Remove outdated permute operations
   - The mslk operator was updated to support standard PyTorch tensor
     shapes with channels_last_3d memory format
   - Permute operations are no longer needed and were causing errors
   - Tensors now stay in shape (N, C_in, D, H, W) with channels_last_3d
     memory format, matching Float8Tensor implementation

2. Add mslk.conv import
   - Import mslk.conv module to properly register the fp8 conv operator
   - This ensures the operator is available for benchmarking

3. Add kernel_size=1 validation
   - kernel_size=1 creates ambiguous memory layouts (both contiguous
     and channels_last_3d simultaneously)
   - The mslk operator cannot correctly identify channel dimensions
     in this edge case
   - Added validation to reject kernel_size=1 with clear error message
   - Consistent with test suite constraints

4. Remove redundant try-except wrapper
   - Simplified error handling by removing unnecessary exception
     catching around get_conv_times() call
   - Validation errors now propagate directly with clear messages

Test Plan:
- Verified kernel_size=3 configurations run successfully
- Verified kernel_size=1 configurations fail with clear error message
- Aligned with Float8Tensor test suite behavior
@jainapurva jainapurva added the topic: performance Use this tag if this PR improves the performance of a feature label Dec 24, 2025
Replace the epsilon-based division (1e-20) with explicit conditional
logic for calculating b_fp8_e2e_speedup:

- Returns -1 when benchmarks weren't run (clearer sentinel value)
- Only calculates speedup when both bf16 and fp8 times are valid (> 0)
- More explicit and consistent with other ratio calculations like
  rb_bf16_gemm_ratio and rb_fp8_gemm_ratio

This makes the code clearer and avoids calculating meaningless huge
numbers when b_fp8_e2e_time_s is 0.
@jainapurva jainapurva merged commit 0fd0872 into main Dec 24, 2025
20 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: performance Use this tag if this PR improves the performance of a feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants