Add BEVPool Kernel#37
Merged
Merged
Conversation
lwang-stackav
approved these changes
Jul 9, 2025
918e9d6 to
ce3c1d6
Compare
The previous implementation used a grid of `(num_intervals,)`, where each program would process all of the points in an interval, blockwise, in parallel. This is optimal if there are many points in an interval. However, in some cases, we don't have many points in an interval, so its actually better to process the intervals blockwise with each program processing a block of the intervals.
43da8c4 to
8a68608
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds BEVPool/QuickCumSum from BEVFusion to Conch.
We provide a pure PyTorch implementation, Triton kernel, and an optionally-compiled CUDA kernel for both the forward and backward pass of this operation. There is also a unit test and microbenchmark for both the fwd and bwd pass.
Testing
Please select all that apply.
Test instructions
# Add CONCH_ENABLE_CUDA_EXT=1 to compare against CUDA instead of PyTorch pytest tests/bev_pool_test.pyPlatforms
Please select all hardware platforms that this PR was tested on.
Note: we are getting correctness issues with the HIP-ified CUDA kernel on AMD MI300X. The test will fail with the CUDA extension, but passes against the PyTorch ref impl.
Benchmarks
Forward
CONCH_ENABLE_CUDA_EXT=1 python benchmarks/bev_pool_benchmark.py # -Or- python benchmarks/bev_pool_benchmark.py --compile-ref --cuda-refBackward
A10
Forward
Backward
H100
Forward
Number of intervals: 4789543 Min interval length: 1.0 Mean interval length: 1.2527291774749756 Max interval length: 8.0 Reference vs Conch: Results matched with atol=0.001 :) Parameters: {'num_points': 6000000, 'num_channels': 64, 'batch_size': 1, 'grid_cells_z': 20, 'grid_cells_x': 800, 'grid_cells_y': 800} Conch: num_iterations=3790, min=2.548 ms, max=2.566 ms, mean=2.551 ms, median=2.551 ms Baseline: num_iterations=1, min=691031.875 ms, max=691031.875 ms, mean=691031.875 ms, median=691031.875 ms Reference (Compiled): num_iterations=1, min=691136.375 ms, max=691136.375 ms, mean=691136.375 ms, median=691136.375 ms CUDA: num_iterations=2876, min=3.384 ms, max=3.399 ms, mean=3.387 ms, median=3.387 msBackward
MI300X
Forward
Number of intervals: 4789796 Min interval length: 1.0 Mean interval length: 1.2526628971099854 Max interval length: 7.0 Reference vs Conch: Results matched with atol=0.001 :) Parameters: {'num_points': 6000000, 'num_channels': 64, 'batch_size': 1, 'grid_cells_z': 20, 'grid_cells_x': 800, 'grid_cells_y': 800} Conch: num_iterations=6740, min=1.410 ms, max=1.459 ms, mean=1.431 ms, median=1.431 ms Baseline: num_iterations=1, min=860818.938 ms, max=860818.938 ms, mean=860818.938 ms, median=860818.938 ms Reference (Compiled): num_iterations=1, min=875389.500 ms, max=875389.500 ms, mean=875389.500 ms, median=875389.500 ms CUDA: num_iterations=14138, min=0.635 ms, max=0.686 ms, mean=0.648 ms, median=0.647 msBackward