1+ """
2+ Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
3+ """
4+
5+ import functools
6+ from typing import List , Union , Optional , Tuple
7+
8+ import torch
9+ from jax import lax
10+ import jax .numpy as jnp
11+ from . import ops_registry
12+
13+ _NMS_TILE_SIZE = 256
14+
15+
16+ def _bbox_overlap (boxes , gt_boxes ):
17+ """Find Bounding box overlap.
18+
19+ Args:
20+ boxes: first set of bounding boxes
21+ gt_boxes: second set of boxes to compute IOU
22+
23+ Returns:
24+ iou: Intersection over union matrix of all input bounding boxes
25+ """
26+ bb_y_min , bb_x_min , bb_y_max , bb_x_max = jnp .split (
27+ ary = boxes , indices_or_sections = 4 , axis = 2 )
28+ gt_y_min , gt_x_min , gt_y_max , gt_x_max = jnp .split (
29+ ary = gt_boxes , indices_or_sections = 4 , axis = 2 )
30+
31+ # Calculates the intersection area.
32+ i_xmin = jnp .maximum (bb_x_min , jnp .transpose (gt_x_min , [0 , 2 , 1 ]))
33+ i_xmax = jnp .minimum (bb_x_max , jnp .transpose (gt_x_max , [0 , 2 , 1 ]))
34+ i_ymin = jnp .maximum (bb_y_min , jnp .transpose (gt_y_min , [0 , 2 , 1 ]))
35+ i_ymax = jnp .minimum (bb_y_max , jnp .transpose (gt_y_max , [0 , 2 , 1 ]))
36+ i_area = jnp .maximum ((i_xmax - i_xmin ), 0 ) * jnp .maximum ((i_ymax - i_ymin ), 0 )
37+
38+ # Calculates the union area.
39+ bb_area = (bb_y_max - bb_y_min ) * (bb_x_max - bb_x_min )
40+ gt_area = (gt_y_max - gt_y_min ) * (gt_x_max - gt_x_min )
41+ # Adds a small epsilon to avoid divide-by-zero.
42+ u_area = bb_area + jnp .transpose (gt_area , [0 , 2 , 1 ]) - i_area + 1e-8
43+
44+ # Calculates IoU.
45+ iou = i_area / u_area
46+
47+ return iou
48+
49+
50+ def _self_suppression (in_args ):
51+ iou , _ , iou_sum = in_args
52+ batch_size = iou .shape [0 ]
53+ can_suppress_others = jnp .reshape (
54+ jnp .max (iou , 1 ) <= 0.5 , [batch_size , - 1 , 1 ]).astype (iou .dtype )
55+ iou_suppressed = jnp .reshape (
56+ (jnp .max (can_suppress_others * iou , 1 ) <= 0.5 ).astype (iou .dtype ),
57+ [batch_size , - 1 , 1 ]) * iou
58+ iou_sum_new = jnp .sum (iou_suppressed , [1 , 2 ])
59+ return iou_suppressed , jnp .any (iou_sum - iou_sum_new > 0.5 ), iou_sum_new
60+
61+
62+ def _cross_suppression (in_args ):
63+ boxes , box_slice , iou_threshold , inner_idx = in_args
64+ batch_size = boxes .shape [0 ]
65+ new_slice = lax .dynamic_slice (boxes , [0 , inner_idx * _NMS_TILE_SIZE , 0 ],
66+ [batch_size , _NMS_TILE_SIZE , 4 ])
67+ iou = _bbox_overlap (new_slice , box_slice )
68+ ret_slice = jnp .expand_dims (
69+ (jnp .all (iou < iou_threshold , [1 ])).astype (box_slice .dtype ),
70+ 2 ) * box_slice
71+ return boxes , ret_slice , iou_threshold , inner_idx + 1
72+
73+
74+ def _suppression_loop_body (in_args ):
75+ """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE).
76+
77+ Args:
78+ in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx
79+
80+ Returns:
81+ boxes: updated boxes.
82+ iou_threshold: pass down iou_threshold to the next iteration.
83+ output_size: the updated output_size.
84+ idx: the updated induction variable.
85+ """
86+ boxes , iou_threshold , output_size , idx = in_args
87+ num_tiles = boxes .shape [1 ] // _NMS_TILE_SIZE
88+ batch_size = boxes .shape [0 ]
89+
90+ # Iterates over tiles that can possibly suppress the current tile.
91+ box_slice = lax .dynamic_slice (boxes , [0 , idx * _NMS_TILE_SIZE , 0 ],
92+ [batch_size , _NMS_TILE_SIZE , 4 ])
93+ def _loop_cond (in_args ):
94+ _ , _ , _ , inner_idx = in_args
95+ return inner_idx < idx
96+
97+ _ , box_slice , _ , _ = lax .while_loop (
98+ _loop_cond ,
99+ _cross_suppression , (boxes , box_slice , iou_threshold ,
100+ 0 ))
101+
102+ # Iterates over the current tile to compute self-suppression.
103+ iou = _bbox_overlap (box_slice , box_slice )
104+ mask = jnp .expand_dims (
105+ jnp .reshape (jnp .arange (_NMS_TILE_SIZE ), [1 , - 1 ]) > jnp .reshape (
106+ jnp .arange (_NMS_TILE_SIZE ), [- 1 , 1 ]), 0 )
107+ iou *= (jnp .logical_and (mask , iou >= iou_threshold )).astype (iou .dtype )
108+
109+ def _loop_cond2 (in_args ):
110+ _ , loop_condition , _ = in_args
111+ return loop_condition
112+
113+ suppressed_iou , _ , _ = lax .while_loop (
114+ _loop_cond2 , _self_suppression ,
115+ (iou , True ,
116+ jnp .sum (iou , [1 , 2 ])))
117+ suppressed_box = jnp .sum (suppressed_iou , 1 ) > 0
118+ box_slice *= jnp .expand_dims (1.0 - suppressed_box .astype (box_slice .dtype ), 2 )
119+
120+ # Uses box_slice to update the input boxes.
121+ mask = jnp .reshape (
122+ (jnp .equal (jnp .arange (num_tiles ), idx )).astype (boxes .dtype ),
123+ [1 , - 1 , 1 , 1 ])
124+ boxes = jnp .tile (jnp .expand_dims (
125+ box_slice , 1 ), [1 , num_tiles , 1 , 1 ]) * mask + jnp .reshape (
126+ boxes , [batch_size , num_tiles , _NMS_TILE_SIZE , 4 ]) * (1 - mask )
127+ boxes = jnp .reshape (boxes , [batch_size , - 1 , 4 ])
128+
129+ # Updates output_size.
130+ output_size += jnp .sum (
131+ jnp .any (box_slice > 0 , [2 ]).astype (jnp .int32 ), [1 ])
132+ return boxes , iou_threshold , output_size , idx + 1
133+
134+
135+ def non_max_suppression_padded (scores , boxes , max_output_size , iou_threshold ):
136+ """A wrapper that handles non-maximum suppression.
137+
138+ Assumption:
139+ * The boxes are sorted by scores unless the box is a dot (all coordinates
140+ are zero).
141+ * Boxes with higher scores can be used to suppress boxes with lower scores.
142+
143+ The overal design of the algorithm is to handle boxes tile-by-tile:
144+
145+ boxes = boxes.pad_to_multiply_of(tile_size)
146+ num_tiles = len(boxes) // tile_size
147+ output_boxes = []
148+ for i in range(num_tiles):
149+ box_tile = boxes[i*tile_size : (i+1)*tile_size]
150+ for j in range(i - 1):
151+ suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
152+ iou = _bbox_overlap(box_tile, suppressing_tile)
153+ # if the box is suppressed in iou, clear it to a dot
154+ box_tile *= _update_boxes(iou)
155+ # Iteratively handle the diagnal tile.
156+ iou = _box_overlap(box_tile, box_tile)
157+ iou_changed = True
158+ while iou_changed:
159+ # boxes that are not suppressed by anything else
160+ suppressing_boxes = _get_suppressing_boxes(iou)
161+ # boxes that are suppressed by suppressing_boxes
162+ suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
163+ # clear iou to 0 for boxes that are suppressed, as they cannot be used
164+ # to suppress other boxes any more
165+ new_iou = _clear_iou(iou, suppressed_boxes)
166+ iou_changed = (new_iou != iou)
167+ iou = new_iou
168+ # remaining boxes that can still suppress others, are selected boxes.
169+ output_boxes.append(_get_suppressing_boxes(iou))
170+ if len(output_boxes) >= max_output_size:
171+ break
172+
173+ Args:
174+ scores: a tensor with a shape of [batch_size, anchors].
175+ boxes: a tensor with a shape of [batch_size, anchors, 4].
176+ max_output_size: a scalar integer `Tensor` representing the maximum number
177+ of boxes to be selected by non max suppression.
178+ iou_threshold: a float representing the threshold for deciding whether boxes
179+ overlap too much with respect to IOU.
180+ Returns:
181+ nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
182+ dtype as input scores.
183+ nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
184+ same dtype as input boxes.
185+ """
186+ batch_size = boxes .shape [0 ]
187+ num_boxes = boxes .shape [1 ]
188+ pad = int (jnp .ceil (float (num_boxes ) / _NMS_TILE_SIZE )
189+ ) * _NMS_TILE_SIZE - num_boxes
190+ boxes = jnp .pad (boxes .astype (jnp .float32 ), [[0 , 0 ], [0 , pad ], [0 , 0 ]])
191+ scores = jnp .pad (scores .astype (jnp .float32 ), [[0 , 0 ], [0 , pad ]])
192+ num_boxes += pad
193+
194+ def _loop_cond (in_args ):
195+ unused_boxes , unused_threshold , output_size , idx = in_args
196+ return jnp .logical_and (
197+ jnp .min (output_size ) < max_output_size ,
198+ idx < num_boxes // _NMS_TILE_SIZE )
199+
200+ selected_boxes , _ , output_size , _ = lax .while_loop (
201+ _loop_cond , _suppression_loop_body , (
202+ boxes , iou_threshold ,
203+ jnp .zeros ([batch_size ], jnp .int32 ),
204+ 0
205+ ))
206+ idx = num_boxes - lax .top_k (
207+ jnp .any (selected_boxes > 0 , [2 ]).astype (jnp .int32 ) *
208+ jnp .expand_dims (jnp .arange (num_boxes , 0 , - 1 ), 0 ),
209+ max_output_size )[0 ].astype (jnp .int32 )
210+ idx = jnp .minimum (idx , num_boxes - 1 )
211+ idx = jnp .reshape (
212+ idx + jnp .reshape (jnp .arange (batch_size ) * num_boxes , [- 1 , 1 ]), [- 1 ])
213+
214+ return idx
215+ boxes = jnp .reshape (
216+ (jnp .reshape (boxes , [- 1 , 4 ]))[idx ],
217+ [batch_size , max_output_size , 4 ])
218+ boxes = boxes * (
219+ jnp .reshape (jnp .arange (max_output_size ), [1 , - 1 , 1 ]) < jnp .reshape (
220+ output_size , [- 1 , 1 , 1 ])).astype (boxes .dtype )
221+ scores = jnp .reshape (
222+ jnp .reshape (scores , [- 1 , 1 ])[idx ],
223+ [batch_size , max_output_size ])
224+ scores = scores * (
225+ jnp .reshape (jnp .arange (max_output_size ), [1 , - 1 ]) < jnp .reshape (
226+ output_size , [- 1 , 1 ])).astype (scores .dtype )
227+ return scores , boxes
228+
229+
230+ # registry:
231+
232+ def nms (boxes , scores , iou_threshold ):
233+ max_output_size = boxes .shape [0 ]
234+ boxes = boxes .reshape ((1 , * boxes .shape ))
235+ scores = scores .reshape ((1 , * scores .shape ))
236+ res = non_max_suppression_padded (scores , boxes , max_output_size , iou_threshold )
237+ return res
238+
239+
240+ try :
241+ import torch
242+ import torchvision
243+ ops_registry .register_torch_dispatch_op (torch .ops .torchvision .nms , nms )
244+ except ImportError :
245+ pass
0 commit comments