33
44
55@triton .jit
6- def triton_nms_IoU_kernel (boxes , output_ptr , threshold , num_boxes , BLOCK_SIZE : tl .constexpr ):
6+ def _combine_bits (val0 , val1 ):
7+ tl .static_assert (val0 .dtype == tl .int32 , "input must be int32" )
8+ tl .static_assert (val1 .dtype == tl .int32 , "input must be int32" )
9+ return val0 | val1
10+
11+
12+ def triton_nms_IoU_kernel (boxes , output_ptr , threshold , num_boxes , stride_i , stride_j , BLOCK_SIZE : tl .constexpr ):
713 """
814 This nms_kernel computes the supressed mask of boxes [i, j].
915 mask[i, j]==1 means if we choose box 1, the box j will be supressed.
@@ -14,6 +20,8 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t
1420 output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored.
1521 threshold (float): The IoU threshold for suppressing boxes.
1622 num_boxes (int): The total number of boxes.
23+ stride_i (int): The stride of the output tensor along the first dimension.
24+ stride_j (int): The stride of the output tensor along the second dimension.
1725 BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel.
1826 """
1927
@@ -59,14 +67,23 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t
5967 area_b = (col_block_x2 - col_block_x1 ) * (col_block_y2 - col_block_y1 )
6068 union = area_a + area_b - intersection
6169
62- iou_keep_out_mask = ((intersection / union ) > threshold ).to (tl .int8 )
70+ iou_keep_out_bit_mask = ((intersection / union ) > threshold ).to (tl .int32 )
71+
72+ shift_offsets = tl .arange (0 , BLOCK_SIZE ) % 32
73+ shift_offsets = tl .flip (shift_offsets , 0 )[None , :]
74+ shift_offsets = tl .broadcast_to (shift_offsets .to (tl .int32 ), [BLOCK_SIZE , BLOCK_SIZE ])
75+ iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets
76+
77+ iou_keep_out_bit_mask = tl .reshape (iou_keep_out_bit_mask , (BLOCK_SIZE , (BLOCK_SIZE + 32 - 1 ) // 32 , 32 ))
78+ iou_keep_out_combined = tl .reduce (iou_keep_out_bit_mask , axis = 2 , combine_fn = _combine_bits )
6379
80+ iou_keep_out_combined = iou_keep_out_combined .to (tl .int64 )
6481 output_block_ptr = tl .make_block_ptr (
6582 output_ptr ,
66- shape = (num_boxes , num_boxes ),
67- strides = (num_boxes , 1 ),
68- offsets = (row_block_start , col_block_start ),
69- block_shape = (BLOCK_SIZE , BLOCK_SIZE ),
83+ shape = (num_boxes , ( num_boxes + 32 - 1 ) // 32 ),
84+ strides = (stride_i , stride_j ),
85+ offsets = (row_block_start , 0 ),
86+ block_shape = (BLOCK_SIZE , ( BLOCK_SIZE + 32 - 1 ) // 32 ),
7087 order = (0 , 1 ),
7188 )
72- tl .store (output_block_ptr , iou_keep_out_mask , boundary_check = (0 , 1 ))
89+ tl .store (output_block_ptr , iou_keep_out_combined , boundary_check = (0 , 1 ))
0 commit comments