Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,15 @@ def fn(a, b):
check_lowp=False,
)

def test_upsample_bicubic2d(self):
def fn(a):
return (
aten.upsample_bicubic2d(a, (128, 128), True),
aten.upsample_bicubic2d(a, (128, 256), False),
)

self.common(fn, (torch.randn([4, 3, 64, 32], dtype=torch.float32),))

def test_sort(self):
def fn(a):
return torch.sort(a)
Expand Down
139 changes: 138 additions & 1 deletion torchinductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import operator
from collections.abc import Iterable
from typing import List
from typing import Optional
from typing import Tuple

import sympy
import torch
Expand Down Expand Up @@ -62,6 +64,7 @@ def add_needs_realized_inputs(fn):
aten.mm,
aten.upsample_bilinear2d,
aten.upsample_nearest2d,
aten.upsample_bicubic2d,
]
)

Expand Down Expand Up @@ -953,7 +956,6 @@ def inner_fn(index):
make_fallback(aten.topk)
make_fallback(aten.unfold)
make_fallback(aten.unfold_backward)
make_fallback(aten.upsample_bicubic2d)
make_fallback(aten.upsample_bicubic2d_backward)
make_fallback(aten.upsample_bilinear2d_backward)

Expand Down Expand Up @@ -1726,6 +1728,141 @@ def fn(idx):
)


@register_lowering(aten.upsample_bicubic2d.default)
def upsample_bicubic2d_default(
x,
output_size,
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
):
x.realize_hint()
x_loader = x.make_loader()

N, C, iH, iW = x.get_size()
oH, oW = output_size

iH = V.graph.sizevars.guard_static_shape(iH)
iW = V.graph.sizevars.guard_static_shape(iW)

def get_int_dtype(maxval):
if maxval > torch.iinfo(torch.int32).max:
return torch.int64
return torch.int32

def compute_scale(in_size, out_size, align_corners, scale=None):
if align_corners:
return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
else:
return 1 / scale if scale is not None and scale > 0 else in_size / out_size

def compute_source_index(scale, dst_index, align_corners):
dst_index_ie = ops.index_expr(dst_index, torch.float32)
if align_corners:
return ops.mul(scale, dst_index_ie)
else:
return ops.sub(
ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5
) # scale * (dst_index + 0.5) - 0.5

def cubic_convolution1(x, A):
# ((A + 2) * x - (A+3)) * x * x + 1
return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0)

def cubic_convolution2(x, A):
# ((A * x - 5 * A) * x + 8 * A) * x - 4*A
return ops.sub(
ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A
)

def get_cubic_upsample_coefficients(t):
A = -0.75
c0 = cubic_convolution2(ops.add(t, 1.0), A)
c1 = cubic_convolution1(t, A)

x2 = ops.sub(1.0, t)
c2 = cubic_convolution1(x2, A)
c3 = cubic_convolution2(ops.add(x2, 1.0), A)
return (
c0,
c1,
c2,
c3,
)

def cubic_interp1d(xs, t):
cs = get_cubic_upsample_coefficients(t)
# dot product between xs and cs
return ops.add(
ops.mul(xs[0], cs[0]),
ops.add(
ops.mul(xs[1], cs[1]),
ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])),
),
)

height_scale = compute_scale(iH, oH, align_corners, scales_h)
width_scale = compute_scale(iW, oW, align_corners, scales_h)

def clamp(v, min, max):
return ops.maximum(min, ops.minimum(max, v))

def fn(idx):
n, c, oy, ox = idx

real_x = compute_source_index(width_scale, ox, align_corners)
in_x = ops.floor(real_x)
t_x = ops.sub(real_x, in_x)

real_y = compute_source_index(height_scale, oy, align_corners)
in_y = ops.floor(real_y)
t_y = ops.sub(real_y, in_y)

def load_bounded(fy, fx):
iy = ops.indirect_indexing(clamp(fy, 0, iH - 1))
ix = ops.indirect_indexing(clamp(fx, 0, iW - 1))
return x_loader([n, c, iy, ix])

iy = ops.to_dtype(in_y, get_int_dtype(iH + 1))
ix = ops.to_dtype(in_x, get_int_dtype(iW + 1))
iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2)))
ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)))

def get_x_interp(y):
coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs))
return cubic_interp1d(coeffs_x, t_x)

coeffs_y = tuple(get_x_interp(y) for y in iys_ofs)
return cubic_interp1d(coeffs_y, t_y)

return Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=fn,
ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)],
)


@register_lowering(aten.upsample_bicubic2d.vec)
def upsample_bicubic2d_vec(
a,
output_size,
align_corners: bool,
scale_factors: Optional[Tuple[float, float]] = None,
):
_, _, iH, iW = a.get_size()
iH = V.graph.sizevars.guard_static_shape(iH)
iW = V.graph.sizevars.guard_static_shape(iW)

if bool(output_size) + bool(scale_factors) != 1:
raise RuntimeError("Must specify exactly one of output_size and scale_factor.")
if output_size is None:
assert scale_factors is not None
output_size = (int(iH * scale_factors[0]), int(iW * scale_factors[1]))
scale_h, scale_w = scale_factors if scale_factors else (None, None)
return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w)


@register_lowering(aten.reflection_pad2d)
def reflection_pad2d(x, padding):
assert len(padding) == 4
Expand Down