Skip to content

Commit 39dc4bc

Browse files
authored
Add impl for torchvision.nms and support multiple batch dim for convolution. (#8291)
1 parent 7fd4e1c commit 39dc4bc

File tree

7 files changed

+314
-15
lines changed

7 files changed

+314
-15
lines changed

experimental/torch_xla2/test/test_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def test_reference_eager(self, device, dtype, op):
252252
ignore_indices=ignore_index)
253253

254254

255-
instantiate_device_type_tests(TestOpInfo, globals())
255+
instantiate_device_type_tests(TestOpInfo, globals(), only_for='cpu')
256256

257257
if __name__ == '__main__':
258258
unittest.main()

experimental/torch_xla2/torch_xla2/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ def disable_globally():
8181
unsupported_dtype=unsupported_dtype)
8282

8383
import jax
84-
torch._register_device_module('jax', jax)
84+
import torch_xla2.device_module
85+
torch._register_device_module('jax', torch_xla2.device_module)
86+
87+
8588

8689

8790
def enable_accuracy_mode():
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
def _is_in_bad_fork():
2+
return False
3+
4+
def manual_seed_all(seed):
5+
pass
6+
7+
def device_count():
8+
return 1
9+
10+
def get_rng_state():
11+
return []
12+
13+
def set_rng_state(new_state, device):
14+
pass
15+
16+
def is_available():
17+
return True
18+
19+
def current_device():
20+
return 0

experimental/torch_xla2/torch_xla2/ops/jaten.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,22 @@ def _aten_bucketize(input, boundaries, *, out_int32=False, right=False, out=None
904904
return_type = jnp.int32 if out_int32 else jnp.int64
905905
return jnp.digitize(input, boundaries, right=not right).astype(return_type)
906906

907+
908+
@op(torch.ops.aten.conv2d)
909+
def _aten_conv2d(
910+
input,
911+
weight,
912+
bias,
913+
stride,
914+
padding,
915+
dilation,
916+
groups,
917+
):
918+
return _aten_convolution(
919+
input, weight, bias, stride, padding,
920+
dilation, transposed=False,
921+
output_padding=1, groups=groups)
922+
907923
@op(torch.ops.aten.convolution)
908924
def _aten_convolution(
909925
input,
@@ -919,6 +935,11 @@ def _aten_convolution(
919935
if transposed:
920936
raise NotImplementedError("Transposed convolution is not implemented.")
921937

938+
num_shape_dim = weight.ndim - 1
939+
batch_dims = input.shape[:-num_shape_dim]
940+
941+
input = input.reshape((-1, *input.shape[-num_shape_dim:]))
942+
922943
def make_padding(padding, num_spatial_dims):
923944
# Expand single padding to pairs expected by jax
924945
if len(padding) == 1 and len(padding) < num_spatial_dims:
@@ -960,6 +981,8 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
960981
shape[1] = bias.shape[0]
961982
bias = bias.reshape(tuple(shape))
962983
res = res + bias
984+
985+
res = res.reshape((*batch_dims, *res.shape[-num_shape_dim:]))
963986
return res
964987

965988

@@ -3214,9 +3237,9 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai
32143237
running_var = jnp.ones(input.shape[1], dtype=input.dtype) # Initialize running variance if None
32153238

32163239
if training:
3217-
return torch.ops.aten._native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps)
3240+
return _aten__native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps)
32183241
else:
3219-
return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps)
3242+
return _aten__native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps)
32203243

32213244

32223245
@op(torch.ops.aten.normal, needs_env=True)

experimental/torch_xla2/torch_xla2/ops/jtorch.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Tensor constructor overrides"""
2+
import math
23
import collections.abc
34
import functools
45
from typing import Optional, Sequence
@@ -98,13 +99,11 @@ def get_params(*a):
9899
return jnp.einsum(equation, *filtered_operands)
99100

100101

101-
def _sdpa_reference(
102-
query, key, value, attn_mask=None,
103-
dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
102+
def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0,
103+
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
104104
L, S = query.size(-2), key.size(-2)
105-
scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale
105+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
106106
attn_bias = torch.zeros(L, S, dtype=query.dtype)
107-
108107
if is_causal:
109108
assert attn_mask is None
110109
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
@@ -116,6 +115,10 @@ def _sdpa_reference(
116115
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
117116
else:
118117
attn_bias += attn_mask
118+
if enable_gqa:
119+
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
120+
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
121+
119122
attn_weight = query @ key.transpose(-2, -1) * scale_factor
120123
attn_weight += attn_bias
121124
attn_weight = torch.softmax(attn_weight, dim=-1)
@@ -209,14 +212,14 @@ def pad(tensor, pad, mode="constant", value=None):
209212
@register_function(torch.nn.functional.scaled_dot_product_attention, is_jax_function=False, needs_env=True)
210213
def scaled_dot_product_attention(
211214
query, key, value, attn_mask=None,
212-
dropout_p=0.0, is_causal=False, scale=None, env=None) -> torch.Tensor:
215+
dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False, env=None) -> torch.Tensor:
213216

214217
if env.config.use_tpu_flash_attention:
215218
jquery, jkey, jvalue = env.t2j_iso((query, key, value))
216219
res = _tpu_flash_attention(jquery, jkey, jvalue, env)
217220
return env.j2t_iso(res)
218221

219-
return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, scale)
222+
return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
220223

221224
@register_function(torch.Tensor.__getitem__)
222225
def getitem(self, indexes):
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""
2+
Forked at: https://raw.githubusercontent.com/mlperf/training_results_v0.7/refs/heads/master/Google/benchmarks/ssd/implementations/ssd-research-JAX-tpu-v3-4096/nms.py
3+
"""
4+
5+
import functools
6+
from typing import List, Union, Optional, Tuple
7+
8+
import torch
9+
from jax import lax
10+
import jax.numpy as jnp
11+
from . import ops_registry
12+
13+
_NMS_TILE_SIZE = 256
14+
15+
16+
def _bbox_overlap(boxes, gt_boxes):
17+
"""Find Bounding box overlap.
18+
19+
Args:
20+
boxes: first set of bounding boxes
21+
gt_boxes: second set of boxes to compute IOU
22+
23+
Returns:
24+
iou: Intersection over union matrix of all input bounding boxes
25+
"""
26+
bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split(
27+
ary=boxes, indices_or_sections=4, axis=2)
28+
gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split(
29+
ary=gt_boxes, indices_or_sections=4, axis=2)
30+
31+
# Calculates the intersection area.
32+
i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1]))
33+
i_xmax = jnp.minimum(bb_x_max, jnp.transpose(gt_x_max, [0, 2, 1]))
34+
i_ymin = jnp.maximum(bb_y_min, jnp.transpose(gt_y_min, [0, 2, 1]))
35+
i_ymax = jnp.minimum(bb_y_max, jnp.transpose(gt_y_max, [0, 2, 1]))
36+
i_area = jnp.maximum((i_xmax - i_xmin), 0) * jnp.maximum((i_ymax - i_ymin), 0)
37+
38+
# Calculates the union area.
39+
bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
40+
gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
41+
# Adds a small epsilon to avoid divide-by-zero.
42+
u_area = bb_area + jnp.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8
43+
44+
# Calculates IoU.
45+
iou = i_area / u_area
46+
47+
return iou
48+
49+
50+
def _self_suppression(in_args):
51+
iou, _, iou_sum = in_args
52+
batch_size = iou.shape[0]
53+
can_suppress_others = jnp.reshape(
54+
jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype)
55+
iou_suppressed = jnp.reshape(
56+
(jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype),
57+
[batch_size, -1, 1]) * iou
58+
iou_sum_new = jnp.sum(iou_suppressed, [1, 2])
59+
return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new
60+
61+
62+
def _cross_suppression(in_args):
63+
boxes, box_slice, iou_threshold, inner_idx = in_args
64+
batch_size = boxes.shape[0]
65+
new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0],
66+
[batch_size, _NMS_TILE_SIZE, 4])
67+
iou = _bbox_overlap(new_slice, box_slice)
68+
ret_slice = jnp.expand_dims(
69+
(jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype),
70+
2) * box_slice
71+
return boxes, ret_slice, iou_threshold, inner_idx + 1
72+
73+
74+
def _suppression_loop_body(in_args):
75+
"""Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE).
76+
77+
Args:
78+
in_args: A tuple of arguments: boxes, iou_threshold, output_size, idx
79+
80+
Returns:
81+
boxes: updated boxes.
82+
iou_threshold: pass down iou_threshold to the next iteration.
83+
output_size: the updated output_size.
84+
idx: the updated induction variable.
85+
"""
86+
boxes, iou_threshold, output_size, idx = in_args
87+
num_tiles = boxes.shape[1] // _NMS_TILE_SIZE
88+
batch_size = boxes.shape[0]
89+
90+
# Iterates over tiles that can possibly suppress the current tile.
91+
box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
92+
[batch_size, _NMS_TILE_SIZE, 4])
93+
def _loop_cond(in_args):
94+
_, _, _, inner_idx = in_args
95+
return inner_idx < idx
96+
97+
_, box_slice, _, _ = lax.while_loop(
98+
_loop_cond,
99+
_cross_suppression, (boxes, box_slice, iou_threshold,
100+
0))
101+
102+
# Iterates over the current tile to compute self-suppression.
103+
iou = _bbox_overlap(box_slice, box_slice)
104+
mask = jnp.expand_dims(
105+
jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) > jnp.reshape(
106+
jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0)
107+
iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype)
108+
109+
def _loop_cond2(in_args):
110+
_, loop_condition, _ = in_args
111+
return loop_condition
112+
113+
suppressed_iou, _, _ = lax.while_loop(
114+
_loop_cond2, _self_suppression,
115+
(iou, True,
116+
jnp.sum(iou, [1, 2])))
117+
suppressed_box = jnp.sum(suppressed_iou, 1) > 0
118+
box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2)
119+
120+
# Uses box_slice to update the input boxes.
121+
mask = jnp.reshape(
122+
(jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype),
123+
[1, -1, 1, 1])
124+
boxes = jnp.tile(jnp.expand_dims(
125+
box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape(
126+
boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
127+
boxes = jnp.reshape(boxes, [batch_size, -1, 4])
128+
129+
# Updates output_size.
130+
output_size += jnp.sum(
131+
jnp.any(box_slice > 0, [2]).astype(jnp.int32), [1])
132+
return boxes, iou_threshold, output_size, idx + 1
133+
134+
135+
def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold):
136+
"""A wrapper that handles non-maximum suppression.
137+
138+
Assumption:
139+
* The boxes are sorted by scores unless the box is a dot (all coordinates
140+
are zero).
141+
* Boxes with higher scores can be used to suppress boxes with lower scores.
142+
143+
The overal design of the algorithm is to handle boxes tile-by-tile:
144+
145+
boxes = boxes.pad_to_multiply_of(tile_size)
146+
num_tiles = len(boxes) // tile_size
147+
output_boxes = []
148+
for i in range(num_tiles):
149+
box_tile = boxes[i*tile_size : (i+1)*tile_size]
150+
for j in range(i - 1):
151+
suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
152+
iou = _bbox_overlap(box_tile, suppressing_tile)
153+
# if the box is suppressed in iou, clear it to a dot
154+
box_tile *= _update_boxes(iou)
155+
# Iteratively handle the diagnal tile.
156+
iou = _box_overlap(box_tile, box_tile)
157+
iou_changed = True
158+
while iou_changed:
159+
# boxes that are not suppressed by anything else
160+
suppressing_boxes = _get_suppressing_boxes(iou)
161+
# boxes that are suppressed by suppressing_boxes
162+
suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
163+
# clear iou to 0 for boxes that are suppressed, as they cannot be used
164+
# to suppress other boxes any more
165+
new_iou = _clear_iou(iou, suppressed_boxes)
166+
iou_changed = (new_iou != iou)
167+
iou = new_iou
168+
# remaining boxes that can still suppress others, are selected boxes.
169+
output_boxes.append(_get_suppressing_boxes(iou))
170+
if len(output_boxes) >= max_output_size:
171+
break
172+
173+
Args:
174+
scores: a tensor with a shape of [batch_size, anchors].
175+
boxes: a tensor with a shape of [batch_size, anchors, 4].
176+
max_output_size: a scalar integer `Tensor` representing the maximum number
177+
of boxes to be selected by non max suppression.
178+
iou_threshold: a float representing the threshold for deciding whether boxes
179+
overlap too much with respect to IOU.
180+
Returns:
181+
nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
182+
dtype as input scores.
183+
nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
184+
same dtype as input boxes.
185+
"""
186+
batch_size = boxes.shape[0]
187+
num_boxes = boxes.shape[1]
188+
pad = int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)
189+
) * _NMS_TILE_SIZE - num_boxes
190+
boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]])
191+
scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]])
192+
num_boxes += pad
193+
194+
def _loop_cond(in_args):
195+
unused_boxes, unused_threshold, output_size, idx = in_args
196+
return jnp.logical_and(
197+
jnp.min(output_size) < max_output_size,
198+
idx < num_boxes // _NMS_TILE_SIZE)
199+
200+
selected_boxes, _, output_size, _ = lax.while_loop(
201+
_loop_cond, _suppression_loop_body, (
202+
boxes, iou_threshold,
203+
jnp.zeros([batch_size], jnp.int32),
204+
0
205+
))
206+
idx = num_boxes - lax.top_k(
207+
jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) *
208+
jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0),
209+
max_output_size)[0].astype(jnp.int32)
210+
idx = jnp.minimum(idx, num_boxes - 1)
211+
idx = jnp.reshape(
212+
idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1])
213+
214+
return idx
215+
boxes = jnp.reshape(
216+
(jnp.reshape(boxes, [-1, 4]))[idx],
217+
[batch_size, max_output_size, 4])
218+
boxes = boxes * (
219+
jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) < jnp.reshape(
220+
output_size, [-1, 1, 1])).astype(boxes.dtype)
221+
scores = jnp.reshape(
222+
jnp.reshape(scores, [-1, 1])[idx],
223+
[batch_size, max_output_size])
224+
scores = scores * (
225+
jnp.reshape(jnp.arange(max_output_size), [1, -1]) < jnp.reshape(
226+
output_size, [-1, 1])).astype(scores.dtype)
227+
return scores, boxes
228+
229+
230+
# registry:
231+
232+
def nms(boxes, scores, iou_threshold):
233+
max_output_size = boxes.shape[0]
234+
boxes = boxes.reshape((1, *boxes.shape))
235+
scores = scores.reshape((1, *scores.shape))
236+
res = non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold)
237+
return res
238+
239+
240+
try:
241+
import torch
242+
import torchvision
243+
ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms)
244+
except ImportError:
245+
pass

0 commit comments

Comments
 (0)