Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,6 +2175,15 @@ def fn(a, b):
check_lowp=False,
)

def test_upsample_bicubic2d(self):
def fn(a):
return (
aten.upsample_bicubic2d(a, (128, 128), True),
aten.upsample_bicubic2d(a, (128, 256), False),
)

self.common(fn, (torch.randn([4, 3, 64, 32], dtype=torch.float32),))

def test_sort(self):
def fn(a):
return torch.sort(a)
Expand Down
111 changes: 110 additions & 1 deletion torchinductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def add_needs_realized_inputs(fn):
aten.mm,
aten.upsample_bilinear2d,
aten.upsample_nearest2d,
aten.upsample_bicubic2d,
]
)

Expand Down Expand Up @@ -953,7 +954,6 @@ def inner_fn(index):
make_fallback(aten.topk)
make_fallback(aten.unfold)
make_fallback(aten.unfold_backward)
make_fallback(aten.upsample_bicubic2d)
make_fallback(aten.upsample_bicubic2d_backward)
make_fallback(aten.upsample_bilinear2d_backward)

Expand Down Expand Up @@ -1718,6 +1718,115 @@ def fn(idx):
)


@register_lowering(aten.upsample_bicubic2d)
def upsample_bicubic2d(x, output_size, align_corners, scales_h=None, scales_w=None):
x.realize_hint()
x_loader = x.make_loader()

N, C, iH, iW = x.get_size()
oH, oW = output_size

iH = V.graph.sizevars.guard_static_shape(iH)
iW = V.graph.sizevars.guard_static_shape(iW)

def get_int_dtype(maxval):
if maxval > torch.iinfo(torch.int32).max:
return torch.int64
return torch.int32

def compute_scale(in_size, out_size, align_corners, scale=None):
if align_corners:
return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
else:
return 1 / scale if scale is not None and scale > 0 else in_size / out_size

def compute_source_index(scale, dst_index, align_corners):
dst_index_ie = ops.index_expr(dst_index, torch.float32)
if align_corners:
return ops.mul(scale, dst_index_ie)
else:
return ops.sub(
ops.mul(scale, ops.add(dst_index_ie, 0.5)), 0.5
) # scale * (dst_index + 0.5) - 0.5

def cubic_convolution1(x, A):
# ((A + 2) * x - (A+3)) * x * x + 1
return ops.add(ops.mul(ops.mul(ops.sub(ops.mul(A + 2, x), A + 3), x), x), 1.0)

def cubic_convolution2(x, A):
# ((A * x - 5 * A) * x + 8 * A) * x - 4*A
return ops.sub(
ops.mul(ops.add(ops.mul(ops.sub(ops.mul(A, x), 5 * A), x), 8 * A), x), 4 * A
)

def get_cubic_upsample_coefficients(t):
A = -0.75
c0 = cubic_convolution2(ops.add(t, 1.0), A)
c1 = cubic_convolution1(t, A)

x2 = ops.sub(1.0, t)
c2 = cubic_convolution1(x2, A)
c3 = cubic_convolution2(ops.add(x2, 1.0), A)
return (
c0,
c1,
c2,
c3,
)

def cubic_interp1d(xs, t):
cs = get_cubic_upsample_coefficients(t)
# dot product between xs and cs
return ops.add(
ops.mul(xs[0], cs[0]),
ops.add(
ops.mul(xs[1], cs[1]),
ops.add(ops.mul(xs[2], cs[2]), ops.mul(xs[3], cs[3])),
),
)

height_scale = compute_scale(iH, oH, align_corners, scales_h)
width_scale = compute_scale(iW, oW, align_corners, scales_h)

def clamp(v, min, max):
return ops.maximum(min, ops.minimum(max, v))

def fn(idx):
n, c, oy, ox = idx

real_x = compute_source_index(width_scale, ox, align_corners)
in_x = ops.floor(real_x)
t_x = ops.sub(real_x, in_x)

real_y = compute_source_index(height_scale, oy, align_corners)
in_y = ops.floor(real_y)
t_y = ops.sub(real_y, in_y)

def load_bounded(fy, fx):
iy = ops.indirect_indexing(clamp(fy, 0, iH - 1))
ix = ops.indirect_indexing(clamp(fx, 0, iW - 1))
return x_loader([n, c, iy, ix])

iy = ops.to_dtype(in_y, get_int_dtype(iH + 1))
ix = ops.to_dtype(in_x, get_int_dtype(iW + 1))
iys_ofs = tuple((ops.add(iy, ofs) for ofs in (-1, 0, 1, 2)))
ixs_ofs = tuple((ops.add(ix, ofs) for ofs in (-1, 0, 1, 2)))

def get_x_interp(y):
coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs))
return cubic_interp1d(coeffs_x, t_x)

coeffs_y = tuple(get_x_interp(y) for y in iys_ofs)
Copy link

@ngimel ngimel Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are coeffs_x?
Edit: ah ok it's rolled in get_x_interp function, but that's pretty confusing, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an asymmetry between coeffs_x and coeffs_y. For coeffs_x, they are direct memory reads from input tensor and there are actually 4 sets of them, one for each of 4 y offsets. For coeffs_y they are the result of applying cubic_interp1d to each of the 4 sets of coeffs_x.

return cubic_interp1d(coeffs_y, t_y)

return Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=fn,
ranges=[N, C, sympy.Integer(oH), sympy.Integer(oW)],
)


@register_lowering(aten.reflection_pad2d)
def reflection_pad2d(x, padding):
assert len(padding) == 4
Expand Down