diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index 0dc50753ef..fd9de4451a 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -2183,15 +2183,6 @@ def fn(a, b): check_lowp=False, ) - def test_upsample_bicubic2d(self): - def fn(a): - return ( - aten.upsample_bicubic2d(a, (128, 128), True), - aten.upsample_bicubic2d(a, (128, 256), False), - ) - - self.common(fn, (torch.randn([4, 3, 64, 32], dtype=torch.float32),)) - def test_sort(self): def fn(a): return torch.sort(a) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index 9e4e485e65..18314e6e1a 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -62,7 +62,6 @@ def add_needs_realized_inputs(fn): aten.mm, aten.upsample_bilinear2d, aten.upsample_nearest2d, - aten.upsample_bicubic2d, ] ) @@ -954,6 +953,7 @@ def inner_fn(index): make_fallback(aten.topk) make_fallback(aten.unfold) make_fallback(aten.unfold_backward) +make_fallback(aten.upsample_bicubic2d) make_fallback(aten.upsample_bicubic2d_backward) make_fallback(aten.upsample_bilinear2d_backward) @@ -1719,115 +1719,6 @@ def fn(idx): ) -@register_lowering(aten.upsample_bicubic2d) -def upsample_bicubic2d(x, output_size, align_corners, scales_h=None, scales_w=None): - x.realize_hint() - x_loader = x.make_loader() - - N, C, iH, iW = x.get_size() - oH, oW = output_size - - iH = V.graph.sizevars.guard_static_shape(iH) - iW = V.graph.sizevars.guard_static_shape(iW) - - def get_int_dtype(maxval): - if maxval > torch.iinfo(torch.int32).max: - return torch.int64 - return torch.int32 - - def compute_scale(in_size, out_size, align_corners, scale=None): - if align_corners: - return (in_size - 1) / (out_size - 1) if out_size > 1 else 0 - else: - return 1 / scale if scale is not None and scale > 0 else in_size / out_size - - def compute_source_index(scale, dst_index, align_corners): - dst_index_ie = ops.index_expr(dst_index, torch.float32) - if align_corners: - return ops.mul(scale, dst_index_ie) - else: - return ops.sub( - ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5 - ) # scale * (dst_index + 0.5) - 0.5 - - def cubic_convolution1(x, A): - # ((A + 2) * x - (A+3)) * x * x + 1 - return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0) - - def cubic_convolution2(x, A): - # ((A * x - 5 * A) * x + 8 * A) * x - 4*A - return ops.sub( - ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A - ) - - def get_cubic_upsample_coefficients(t): - A = -0.75 - c0 = cubic_convolution2(ops.add(t, 1.0), A) - c1 = cubic_convolution1(t, A) - - x2 = ops.sub(1.0, t) - c2 = cubic_convolution1(x2, A) - c3 = cubic_convolution2(ops.add(x2, 1.0), A) - return ( - c0, - c1, - c2, - c3, - ) - - def cubic_interp1d(xs, t): - cs = get_cubic_upsample_coefficients(t) - # dot product between xs and cs - return ops.add( - ops.mul(xs[0], cs[0]), - ops.add( - ops.mul(xs[1], cs[1]), - ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])), - ), - ) - - height_scale = compute_scale(iH, oH, align_corners, scales_h) - width_scale = compute_scale(iW, oW, align_corners, scales_h) - - def clamp(v, min, max): - return ops.maximum(min, ops.minimum(max, v)) - - def fn(idx): - n, c, oy, ox = idx - - real_x = compute_source_index(width_scale, ox, align_corners) - in_x = ops.floor(real_x) - t_x = ops.sub(real_x, in_x) - - real_y = compute_source_index(height_scale, oy, align_corners) - in_y = ops.floor(real_y) - t_y = ops.sub(real_y, in_y) - - def load_bounded(fy, fx): - iy = ops.indirect_indexing(clamp(fy, 0, iH - 1)) - ix = ops.indirect_indexing(clamp(fx, 0, iW - 1)) - return x_loader([n, c, iy, ix]) - - iy = ops.to_dtype(in_y, get_int_dtype(iH + 1)) - ix = ops.to_dtype(in_x, get_int_dtype(iW + 1)) - iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2))) - ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2))) - - def get_x_interp(y): - coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs)) - return cubic_interp1d(coeffs_x, t_x) - - coeffs_y = tuple(get_x_interp(y) for y in iys_ofs) - return cubic_interp1d(coeffs_y, t_y) - - return Pointwise.create( - device=x.get_device(), - dtype=x.get_dtype(), - inner_fn=fn, - ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)], - ) - - @register_lowering(aten.reflection_pad2d) def reflection_pad2d(x, padding): assert len(padding) == 4