Skip to content

Add BEVPool Kernel#37

Merged
jmanning-stackav merged 7 commits into
mainfrom
feature/jmanning/bevpool
Jul 15, 2025
Merged

Add BEVPool Kernel#37
jmanning-stackav merged 7 commits into
mainfrom
feature/jmanning/bevpool

Conversation

@jmanning-stackav
Copy link
Copy Markdown
Collaborator

@jmanning-stackav jmanning-stackav commented Jul 8, 2025

Description

This PR adds BEVPool/QuickCumSum from BEVFusion to Conch.

We provide a pure PyTorch implementation, Triton kernel, and an optionally-compiled CUDA kernel for both the forward and backward pass of this operation. There is also a unit test and microbenchmark for both the fwd and bwd pass.

Testing

Please select all that apply.

  • Existing unit tests
  • Unit tests added by this PR
  • Other (please explain)
  • This PR is not tested

Test instructions

# Add CONCH_ENABLE_CUDA_EXT=1 to compare against CUDA instead of PyTorch
pytest tests/bev_pool_test.py

Platforms

Please select all hardware platforms that this PR was tested on.

  • Nvidia GPU
  • AMD GPU
  • Other (please explain)

Note: we are getting correctness issues with the HIP-ified CUDA kernel on AMD MI300X. The test will fail with the CUDA extension, but passes against the PyTorch ref impl.

Benchmarks

Forward

CONCH_ENABLE_CUDA_EXT=1 python benchmarks/bev_pool_benchmark.py
# -Or-
python benchmarks/bev_pool_benchmark.py --compile-ref --cuda-ref

Backward

CONCH_ENABLE_CUDA_EXT=1 python benchmarks/bev_pool_benchmark.py
# Could run this but PyTorch/compiled version are _super_ slow
# python benchmarks/bev_pool_backward_benchmark.py --compile-ref --cuda-ref

A10

Forward

# Omitted PyTorch reference + compiled because its slow (see H100/MI300X numbers below)
Number of intervals: 4790425
Min interval length: 1.0
Mean interval length: 1.2524983882904053
Max interval length: 7.0
Reference vs Conch: Results matched with atol=0.001 :)
Parameters: {'num_points': 6000000, 'num_channels': 64, 'batch_size': 1, 'grid_cells_z': 20, 'grid_cells_x': 800, 'grid_cells_y': 800}
Conch: num_iterations=755, min=12.668 ms, max=12.791 ms, mean=12.696 ms, median=12.697 ms
Baseline: num_iterations=636, min=15.142 ms, max=15.248 ms, mean=15.150 ms, median=15.149 ms

Backward

# Omitted PyTorch reference + compiled because its slow (see H100/MI300X numbers below)
Reference vs Conch: Results matched with atol=0.001 :)
Parameters: {'num_points': 6000000, 'num_channels': 64, 'batch_size': 1, 'grid_cells_z': 20, 'grid_cells_x': 800, 'grid_cells_y': 800}
Conch: num_iterations=1026, min=9.180 ms, max=9.213 ms, mean=9.195 ms, median=9.194 ms
Baseline: num_iterations=816, min=11.688 ms, max=11.703 ms, mean=11.694 ms, median=11.694 ms

H100

Forward

Number of intervals: 4789543
Min interval length: 1.0
Mean interval length: 1.2527291774749756
Max interval length: 8.0
Reference vs Conch: Results matched with atol=0.001 :)
Parameters: {'num_points': 6000000, 'num_channels': 64, 'batch_size': 1, 'grid_cells_z': 20, 'grid_cells_x': 800, 'grid_cells_y': 800}
Conch: num_iterations=3790, min=2.548 ms, max=2.566 ms, mean=2.551 ms, median=2.551 ms
Baseline: num_iterations=1, min=691031.875 ms, max=691031.875 ms, mean=691031.875 ms, median=691031.875 ms
Reference (Compiled): num_iterations=1, min=691136.375 ms, max=691136.375 ms, mean=691136.375 ms, median=691136.375 ms
CUDA: num_iterations=2876, min=3.384 ms, max=3.399 ms, mean=3.387 ms, median=3.387 ms

Backward

# Omitted PyTorch reference + compiled because its slow (see above)
Reference vs Conch: Results matched with atol=0.001 :)
Parameters: {'num_points': 6000000, 'num_channels': 64, 'batch_size': 1, 'grid_cells_z': 20, 'grid_cells_x': 800, 'grid_cells_y': 800}
Conch: num_iterations=6518, min=1.445 ms, max=1.458 ms, mean=1.448 ms, median=1.447 ms
Baseline: num_iterations=3240, min=2.991 ms, max=3.003 ms, mean=2.994 ms, median=2.994 ms

MI300X

Forward

Number of intervals: 4789796
Min interval length: 1.0
Mean interval length: 1.2526628971099854
Max interval length: 7.0
Reference vs Conch: Results matched with atol=0.001 :)
Parameters: {'num_points': 6000000, 'num_channels': 64, 'batch_size': 1, 'grid_cells_z': 20, 'grid_cells_x': 800, 'grid_cells_y': 800}
Conch: num_iterations=6740, min=1.410 ms, max=1.459 ms, mean=1.431 ms, median=1.431 ms
Baseline: num_iterations=1, min=860818.938 ms, max=860818.938 ms, mean=860818.938 ms, median=860818.938 ms
Reference (Compiled): num_iterations=1, min=875389.500 ms, max=875389.500 ms, mean=875389.500 ms, median=875389.500 ms
CUDA: num_iterations=14138, min=0.635 ms, max=0.686 ms, mean=0.648 ms, median=0.647 ms

Backward

# Omitted PyTorch reference + compiled because its slow (see above)
# Note: HIP-ified CUDA impl appears to be giving incorrect answers...
WARNING: Reference and Conch results differ! (atol=0.001)
Output max diff: 6.037900924682617
Ref shape: torch.Size([6000000, 64]), Conch shape: torch.Size([6000000, 64])
Parameters: {'num_points': 6000000, 'num_channels': 64, 'batch_size': 1, 'grid_cells_z': 20, 'grid_cells_x': 800, 'grid_cells_y': 800}
Conch: num_iterations=8597, min=1.101 ms, max=1.139 ms, mean=1.116 ms, median=1.115 ms
Baseline: num_iterations=19586, min=0.298 ms, max=0.326 ms, mean=0.305 ms, median=0.304 ms

@jmanning-stackav jmanning-stackav self-assigned this Jul 8, 2025
Comment thread conch/kernels/vision/bev_pool.py
Comment thread benchmarks/bev_pool_benchmark.py
@jmanning-stackav jmanning-stackav changed the base branch from feature/jmanning/nms-v5 to main July 9, 2025 15:43
@jmanning-stackav jmanning-stackav force-pushed the feature/jmanning/bevpool branch from 918e9d6 to ce3c1d6 Compare July 9, 2025 15:46
The previous implementation used a grid of `(num_intervals,)`, where
each program would process all of the points in an interval, blockwise,
in parallel. This is optimal if there are many points in an interval.
However, in some cases, we don't have many points in an interval, so its
actually better to process the intervals blockwise with each program
processing a block of the intervals.
@jmanning-stackav jmanning-stackav force-pushed the feature/jmanning/bevpool branch from 43da8c4 to 8a68608 Compare July 9, 2025 20:19
@jmanning-stackav jmanning-stackav merged commit e4ac60c into main Jul 15, 2025
@jmanning-stackav jmanning-stackav deleted the feature/jmanning/bevpool branch July 15, 2025 16:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants