Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions test/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2183,15 +2183,6 @@ def fn(a, b):
check_lowp=False,
)

def test_upsample_bicubic2d(self):
def fn(a):
return (
aten.upsample_bicubic2d(a, (128, 128), True),
aten.upsample_bicubic2d(a, (128, 256), False),
)

self.common(fn, (torch.randn([4, 3, 64, 32], dtype=torch.float32),))

def test_sort(self):
def fn(a):
return torch.sort(a)
Expand Down
111 changes: 1 addition & 110 deletions torchinductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def add_needs_realized_inputs(fn):
aten.mm,
aten.upsample_bilinear2d,
aten.upsample_nearest2d,
aten.upsample_bicubic2d,
]
)

Expand Down Expand Up @@ -954,6 +953,7 @@ def inner_fn(index):
make_fallback(aten.topk)
make_fallback(aten.unfold)
make_fallback(aten.unfold_backward)
make_fallback(aten.upsample_bicubic2d)
make_fallback(aten.upsample_bicubic2d_backward)
make_fallback(aten.upsample_bilinear2d_backward)

Expand Down Expand Up @@ -1719,115 +1719,6 @@ def fn(idx):
)


@register_lowering(aten.upsample_bicubic2d)
def upsample_bicubic2d(x, output_size, align_corners, scales_h=None, scales_w=None):
x.realize_hint()
x_loader = x.make_loader()

N, C, iH, iW = x.get_size()
oH, oW = output_size

iH = V.graph.sizevars.guard_static_shape(iH)
iW = V.graph.sizevars.guard_static_shape(iW)

def get_int_dtype(maxval):
if maxval > torch.iinfo(torch.int32).max:
return torch.int64
return torch.int32

def compute_scale(in_size, out_size, align_corners, scale=None):
if align_corners:
return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
else:
return 1 / scale if scale is not None and scale > 0 else in_size / out_size

def compute_source_index(scale, dst_index, align_corners):
dst_index_ie = ops.index_expr(dst_index, torch.float32)
if align_corners:
return ops.mul(scale, dst_index_ie)
else:
return ops.sub(
ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5
) # scale * (dst_index + 0.5) - 0.5

def cubic_convolution1(x, A):
# ((A + 2) * x - (A+3)) * x * x + 1
return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0)

def cubic_convolution2(x, A):
# ((A * x - 5 * A) * x + 8 * A) * x - 4*A
return ops.sub(
ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A
)

def get_cubic_upsample_coefficients(t):
A = -0.75
c0 = cubic_convolution2(ops.add(t, 1.0), A)
c1 = cubic_convolution1(t, A)

x2 = ops.sub(1.0, t)
c2 = cubic_convolution1(x2, A)
c3 = cubic_convolution2(ops.add(x2, 1.0), A)
return (
c0,
c1,
c2,
c3,
)

def cubic_interp1d(xs, t):
cs = get_cubic_upsample_coefficients(t)
# dot product between xs and cs
return ops.add(
ops.mul(xs[0], cs[0]),
ops.add(
ops.mul(xs[1], cs[1]),
ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])),
),
)

height_scale = compute_scale(iH, oH, align_corners, scales_h)
width_scale = compute_scale(iW, oW, align_corners, scales_h)

def clamp(v, min, max):
return ops.maximum(min, ops.minimum(max, v))

def fn(idx):
n, c, oy, ox = idx

real_x = compute_source_index(width_scale, ox, align_corners)
in_x = ops.floor(real_x)
t_x = ops.sub(real_x, in_x)

real_y = compute_source_index(height_scale, oy, align_corners)
in_y = ops.floor(real_y)
t_y = ops.sub(real_y, in_y)

def load_bounded(fy, fx):
iy = ops.indirect_indexing(clamp(fy, 0, iH - 1))
ix = ops.indirect_indexing(clamp(fx, 0, iW - 1))
return x_loader([n, c, iy, ix])

iy = ops.to_dtype(in_y, get_int_dtype(iH + 1))
ix = ops.to_dtype(in_x, get_int_dtype(iW + 1))
iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2)))
ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)))

def get_x_interp(y):
coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs))
return cubic_interp1d(coeffs_x, t_x)

coeffs_y = tuple(get_x_interp(y) for y in iys_ofs)
return cubic_interp1d(coeffs_y, t_y)

return Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=fn,
ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)],
)


@register_lowering(aten.reflection_pad2d)
def reflection_pad2d(x, padding):
assert len(padding) == 4
Expand Down