From 632ebe43072b7ec5191781281f790354d3b78e43 Mon Sep 17 00:00:00 2001 From: Fabio Rocha Date: Wed, 14 Sep 2022 15:03:43 +0000 Subject: [PATCH 1/2] Lowering for upsample_bicubic2d --- test/test_torchinductor.py | 9 ++++ torchinductor/lowering.py | 106 ++++++++++++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index 203a6614a4..036bc35b55 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -2175,6 +2175,15 @@ def fn(a, b): check_lowp=False, ) + def test_upsample_bicubic2d(self): + def fn(a): + return ( + aten.upsample_bicubic2d(a, (128, 128), True), + aten.upsample_bicubic2d(a, (128, 256), False), + ) + + self.common(fn, (torch.randn([4, 3, 64, 32], dtype=torch.float32),)) + def test_sort(self): def fn(a): return torch.sort(a) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index 95caffe949..6315e18e98 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -62,6 +62,7 @@ def add_needs_realized_inputs(fn): aten.mm, aten.upsample_bilinear2d, aten.upsample_nearest2d, + aten.upsample_bicubic2d, ] ) @@ -953,7 +954,6 @@ def inner_fn(index): make_fallback(aten.topk) make_fallback(aten.unfold) make_fallback(aten.unfold_backward) -make_fallback(aten.upsample_bicubic2d) make_fallback(aten.upsample_bicubic2d_backward) make_fallback(aten.upsample_bilinear2d_backward) @@ -1718,6 +1718,110 @@ def fn(idx): ) +@register_lowering(aten.upsample_bicubic2d) +def upsample_bicubic2d(x, output_size, align_corners, scales_h=None, scales_w=None): + x.realize_hint() + x_loader = x.make_loader() + + N, C, iH, iW = x.get_size() + oH, oW = output_size + + iH = V.graph.sizevars.guard_static_shape(iH) + iW = V.graph.sizevars.guard_static_shape(iW) + + def compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1) / (out_size - 1) if out_size > 1 else 0 + else: + return 1 / scale if scale is not None and scale > 0 else in_size / out_size + + def compute_source_index(scale, dst_index, align_corners): + dst_index_ie = ops.index_expr(dst_index, torch.float32) + if align_corners: + return ops.mul(scale, dst_index_ie) + else: + return ops.sub( + ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5 + ) # scale * (dst_index + 0.5) - 0.5 + + def cubic_convolution1(x, A): + # ((A + 2) * x - (A+3)) * x * x + 1 + return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0) + + def cubic_convolution2(x, A): + # ((A * x - 5 * A) * x + 8 * A) * x - 4*A + return ops.sub( + ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A + ) + + def get_cubic_upsample_coefficients(t): + A = -0.75 + c0 = cubic_convolution2(ops.add(t, 1.0), A) + c1 = cubic_convolution1(t, A) + + x2 = ops.sub(1.0, t) + c2 = cubic_convolution1(x2, A) + c3 = cubic_convolution2(ops.add(x2, 1.0), A) + return ( + c0, + c1, + c2, + c3, + ) + + def cubic_interp1d(xs, t): + cs = get_cubic_upsample_coefficients(t) + # dot product between xs and cs + return ops.add( + ops.mul(xs[0], cs[0]), + ops.add( + ops.mul(xs[1], cs[1]), + ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])), + ), + ) + + height_scale = compute_scale(iH, oH, align_corners, scales_h) + width_scale = compute_scale(iW, oW, align_corners, scales_h) + + def clamp(v, min, max): + return ops.maximum(min, ops.minimum(max, v)) + + def fn(idx): + n, c, oy, ox = idx + + real_x = compute_source_index(width_scale, ox, align_corners) + in_x = ops.floor(real_x) + t_x = ops.sub(real_x, in_x) + + real_y = compute_source_index(height_scale, oy, align_corners) + in_y = ops.floor(real_y) + t_y = ops.sub(real_y, in_y) + + def load_bounded(fy, fx): + iy = ops.indirect_indexing(clamp(fy, 0, iH - 1)) + ix = ops.indirect_indexing(clamp(fx, 0, iW - 1)) + return x_loader([n, c, iy, ix]) + + iy = ops.to_dtype(in_y, torch.int32) + ix = ops.to_dtype(in_x, torch.int32) + iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2))) + ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2))) + + def get_x_interp(y): + coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs)) + return cubic_interp1d(coeffs_x, t_x) + + coeffs_y = tuple(get_x_interp(y) for y in iys_ofs) + return cubic_interp1d(coeffs_y, t_y) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)], + ) + + @register_lowering(aten.reflection_pad2d) def reflection_pad2d(x, padding): assert len(padding) == 4 From 121d149b45b7d3226ec8646f4766a5b40feacee4 Mon Sep 17 00:00:00 2001 From: Fabio Rocha Date: Thu, 22 Sep 2022 12:30:21 +0000 Subject: [PATCH 2/2] Pick large enough int dtype for index calculations --- torchinductor/lowering.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index 6315e18e98..1b7a17ee52 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -1729,6 +1729,11 @@ def upsample_bicubic2d(x, output_size, align_corners, scales_h=None, scales_w=No iH = V.graph.sizevars.guard_static_shape(iH) iW = V.graph.sizevars.guard_static_shape(iW) + def get_int_dtype(maxval): + if maxval > torch.iinfo(torch.int32).max: + return torch.int64 + return torch.int32 + def compute_scale(in_size, out_size, align_corners, scale=None): if align_corners: return (in_size - 1) / (out_size - 1) if out_size > 1 else 0 @@ -1802,8 +1807,8 @@ def load_bounded(fy, fx): ix = ops.indirect_indexing(clamp(fx, 0, iW - 1)) return x_loader([n, c, iy, ix]) - iy = ops.to_dtype(in_y, torch.int32) - ix = ops.to_dtype(in_x, torch.int32) + iy = ops.to_dtype(in_y, get_int_dtype(iH + 1)) + ix = ops.to_dtype(in_x, get_int_dtype(iW + 1)) iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2))) ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)))