diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index 8511965637..5d063bb003 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -2184,6 +2184,15 @@ def fn(a, b): check_lowp=False, ) + def test_upsample_bicubic2d(self): + def fn(a): + return ( + aten.upsample_bicubic2d(a, (128, 128), True), + aten.upsample_bicubic2d(a, (128, 256), False), + ) + + self.common(fn, (torch.randn([4, 3, 64, 32], dtype=torch.float32),)) + def test_sort(self): def fn(a): return torch.sort(a) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index acacf5ac49..3939c06452 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -4,6 +4,8 @@ import operator from collections.abc import Iterable from typing import List +from typing import Optional +from typing import Tuple import sympy import torch @@ -62,6 +64,7 @@ def add_needs_realized_inputs(fn): aten.mm, aten.upsample_bilinear2d, aten.upsample_nearest2d, + aten.upsample_bicubic2d, ] ) @@ -953,7 +956,6 @@ def inner_fn(index): make_fallback(aten.topk) make_fallback(aten.unfold) make_fallback(aten.unfold_backward) -make_fallback(aten.upsample_bicubic2d) make_fallback(aten.upsample_bicubic2d_backward) make_fallback(aten.upsample_bilinear2d_backward) @@ -1726,6 +1728,141 @@ def fn(idx): ) +@register_lowering(aten.upsample_bicubic2d.default) +def upsample_bicubic2d_default( + x, + output_size, + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + x.realize_hint() + x_loader = x.make_loader() + + N, C, iH, iW = x.get_size() + oH, oW = output_size + + iH = V.graph.sizevars.guard_static_shape(iH) + iW = V.graph.sizevars.guard_static_shape(iW) + + def get_int_dtype(maxval): + if maxval > torch.iinfo(torch.int32).max: + return torch.int64 + return torch.int32 + + def compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1) / (out_size - 1) if out_size > 1 else 0 + else: + return 1 / scale if scale is not None and scale > 0 else in_size / out_size + + def compute_source_index(scale, dst_index, align_corners): + dst_index_ie = ops.index_expr(dst_index, torch.float32) + if align_corners: + return ops.mul(scale, dst_index_ie) + else: + return ops.sub( + ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5 + ) # scale * (dst_index + 0.5) - 0.5 + + def cubic_convolution1(x, A): + # ((A + 2) * x - (A+3)) * x * x + 1 + return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0) + + def cubic_convolution2(x, A): + # ((A * x - 5 * A) * x + 8 * A) * x - 4*A + return ops.sub( + ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A + ) + + def get_cubic_upsample_coefficients(t): + A = -0.75 + c0 = cubic_convolution2(ops.add(t, 1.0), A) + c1 = cubic_convolution1(t, A) + + x2 = ops.sub(1.0, t) + c2 = cubic_convolution1(x2, A) + c3 = cubic_convolution2(ops.add(x2, 1.0), A) + return ( + c0, + c1, + c2, + c3, + ) + + def cubic_interp1d(xs, t): + cs = get_cubic_upsample_coefficients(t) + # dot product between xs and cs + return ops.add( + ops.mul(xs[0], cs[0]), + ops.add( + ops.mul(xs[1], cs[1]), + ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])), + ), + ) + + height_scale = compute_scale(iH, oH, align_corners, scales_h) + width_scale = compute_scale(iW, oW, align_corners, scales_h) + + def clamp(v, min, max): + return ops.maximum(min, ops.minimum(max, v)) + + def fn(idx): + n, c, oy, ox = idx + + real_x = compute_source_index(width_scale, ox, align_corners) + in_x = ops.floor(real_x) + t_x = ops.sub(real_x, in_x) + + real_y = compute_source_index(height_scale, oy, align_corners) + in_y = ops.floor(real_y) + t_y = ops.sub(real_y, in_y) + + def load_bounded(fy, fx): + iy = ops.indirect_indexing(clamp(fy, 0, iH - 1)) + ix = ops.indirect_indexing(clamp(fx, 0, iW - 1)) + return x_loader([n, c, iy, ix]) + + iy = ops.to_dtype(in_y, get_int_dtype(iH + 1)) + ix = ops.to_dtype(in_x, get_int_dtype(iW + 1)) + iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2))) + ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2))) + + def get_x_interp(y): + coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs)) + return cubic_interp1d(coeffs_x, t_x) + + coeffs_y = tuple(get_x_interp(y) for y in iys_ofs) + return cubic_interp1d(coeffs_y, t_y) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)], + ) + + +@register_lowering(aten.upsample_bicubic2d.vec) +def upsample_bicubic2d_vec( + a, + output_size, + align_corners: bool, + scale_factors: Optional[Tuple[float, float]] = None, +): + _, _, iH, iW = a.get_size() + iH = V.graph.sizevars.guard_static_shape(iH) + iW = V.graph.sizevars.guard_static_shape(iW) + + if bool(output_size) + bool(scale_factors) != 1: + raise RuntimeError("Must specify exactly one of output_size and scale_factor.") + if output_size is None: + assert scale_factors is not None + output_size = (int(iH * scale_factors[0]), int(iW * scale_factors[1])) + scale_h, scale_w = scale_factors if scale_factors else (None, None) + return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w) + + @register_lowering(aten.reflection_pad2d) def reflection_pad2d(x, padding): assert len(padding) == 4