Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Conversation

fdrocha
Copy link
Contributor

@fdrocha fdrocha commented Sep 19, 2022

Added a lowering for upsample_bicubic2d.

Currently performance is about 40% worse than eager.
Not sure why this is happening, would appreciate suggestions on how to address it @Chillee @ezyang @jansel @lezcano

Benchmarks
[------------------------------ upsample_bicubic2d ------------------------------]
                                                             |  Inductor  |  Eager
32 threads: ----------------------------------------------------------------------
      (torch.Size([16, 4, 128, 256]),), ((64, 128), True)    |     140    |   130
      (torch.Size([16, 4, 128, 256]),), ((64, 128), False)   |     141    |   130
      (torch.Size([16, 4, 128, 256]),), ((128, 256), True)   |     226    |   170
      (torch.Size([16, 4, 128, 256]),), ((128, 256), False)  |     224    |   169
      (torch.Size([16, 4, 128, 256]),), ((384, 768), True)   |    1100    |   804
      (torch.Size([16, 4, 128, 256]),), ((384, 768), False)  |    1100    |   803
The generated triton looks reasonable to me
kernel0 = TritonCodeCache.load('''
import triton
import triton.language as tl
from torchinductor.triton_ops.autotune import pointwise_heuristics

@pointwise_heuristics(size_hints=[524288], filename=__file__)
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x0 = xindex % 128
    x1 = (xindex // 128) % 64
    x2 = (xindex // 8192)
    x4 = xindex
    tmp0 = x0
    tmp1 = 2.0078740157480315 * tmp0
    tmp2 = tl.libdevice.floor(tmp1)
    tmp3 = tmp1 - tmp2
    tmp4 = x1
    tmp5 = 2.015873015873016 * tmp4
    tmp6 = tl.libdevice.floor(tmp5)
    tmp7 = tmp5 - tmp6
    tmp8 = tmp6.to(tl.int32)
    tmp9 = tmp2.to(tl.int32)
    tmp10 = tmp8 + -1
    tmp11 = tmp8 + 0
    tmp12 = tmp8 + 1
    tmp13 = tmp8 + 2
    tmp14 = tmp9 + -1
    tmp15 = tmp9 + 0
    tmp16 = tmp9 + 1
    tmp17 = tmp9 + 2
    tmp18 = tl.minimum(127, tmp10)
    tmp19 = tl.maximum(0, tmp18)
    tmp20 = tl.minimum(255, tmp14)
    tmp21 = tl.maximum(0, tmp20)
    tmp22 = tl.load(in_ptr0 + tmp21 + (256*tmp19) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp23 = tl.minimum(255, tmp15)
    tmp24 = tl.maximum(0, tmp23)
    tmp25 = tl.load(in_ptr0 + tmp24 + (256*tmp19) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp26 = tl.minimum(255, tmp16)
    tmp27 = tl.maximum(0, tmp26)
    tmp28 = tl.load(in_ptr0 + tmp27 + (256*tmp19) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp29 = tl.minimum(255, tmp17)
    tmp30 = tl.maximum(0, tmp29)
    tmp31 = tl.load(in_ptr0 + tmp30 + (256*tmp19) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp32 = tmp3 + 1.0
    tmp33 = -0.75 * tmp32
    tmp34 = tmp33 - -3.75
    tmp35 = tmp34 * tmp32
    tmp36 = tmp35 + -6.0
    tmp37 = tmp36 * tmp32
    tmp38 = tmp37 - -3.0
    tmp39 = 1.25 * tmp3
    tmp40 = tmp39 - 2.25
    tmp41 = tmp40 * tmp3
    tmp42 = tmp41 * tmp3
    tmp43 = tmp42 + 1.0
    tmp44 = 1.0 - tmp3
    tmp45 = 1.25 * tmp44
    tmp46 = tmp45 - 2.25
    tmp47 = tmp46 * tmp44
    tmp48 = tmp47 * tmp44
    tmp49 = tmp48 + 1.0
    tmp50 = tmp44 + 1.0
    tmp51 = -0.75 * tmp50
    tmp52 = tmp51 - -3.75
    tmp53 = tmp52 * tmp50
    tmp54 = tmp53 + -6.0
    tmp55 = tmp54 * tmp50
    tmp56 = tmp55 - -3.0
    tmp57 = tmp22 * tmp38
    tmp58 = tmp25 * tmp43
    tmp59 = tmp28 * tmp49
    tmp60 = tmp31 * tmp56
    tmp61 = tmp59 + tmp60
    tmp62 = tmp58 + tmp61
    tmp63 = tmp57 + tmp62
    tmp64 = tl.minimum(127, tmp11)
    tmp65 = tl.maximum(0, tmp64)
    tmp66 = tl.load(in_ptr0 + tmp21 + (256*tmp65) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp67 = tl.load(in_ptr0 + tmp24 + (256*tmp65) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp68 = tl.load(in_ptr0 + tmp27 + (256*tmp65) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp69 = tl.load(in_ptr0 + tmp30 + (256*tmp65) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp70 = tmp66 * tmp38
    tmp71 = tmp67 * tmp43
    tmp72 = tmp68 * tmp49
    tmp73 = tmp69 * tmp56
    tmp74 = tmp72 + tmp73
    tmp75 = tmp71 + tmp74
    tmp76 = tmp70 + tmp75
    tmp77 = tl.minimum(127, tmp12)
    tmp78 = tl.maximum(0, tmp77)
    tmp79 = tl.load(in_ptr0 + tmp21 + (256*tmp78) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp80 = tl.load(in_ptr0 + tmp24 + (256*tmp78) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp81 = tl.load(in_ptr0 + tmp27 + (256*tmp78) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp82 = tl.load(in_ptr0 + tmp30 + (256*tmp78) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp83 = tmp79 * tmp38
    tmp84 = tmp80 * tmp43
    tmp85 = tmp81 * tmp49
    tmp86 = tmp82 * tmp56
    tmp87 = tmp85 + tmp86
    tmp88 = tmp84 + tmp87
    tmp89 = tmp83 + tmp88
    tmp90 = tl.minimum(127, tmp13)
    tmp91 = tl.maximum(0, tmp90)
    tmp92 = tl.load(in_ptr0 + tmp21 + (256*tmp91) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp93 = tl.load(in_ptr0 + tmp24 + (256*tmp91) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp94 = tl.load(in_ptr0 + tmp27 + (256*tmp91) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp95 = tl.load(in_ptr0 + tmp30 + (256*tmp91) + (32768*x2) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp96 = tmp92 * tmp38
    tmp97 = tmp93 * tmp43
    tmp98 = tmp94 * tmp49
    tmp99 = tmp95 * tmp56
    tmp100 = tmp98 + tmp99
    tmp101 = tmp97 + tmp100
    tmp102 = tmp96 + tmp101
    tmp103 = tmp7 + 1.0
    tmp104 = -0.75 * tmp103
    tmp105 = tmp104 - -3.75
    tmp106 = tmp105 * tmp103
    tmp107 = tmp106 + -6.0
    tmp108 = tmp107 * tmp103
    tmp109 = tmp108 - -3.0
    tmp110 = 1.25 * tmp7
    tmp111 = tmp110 - 2.25
    tmp112 = tmp111 * tmp7
    tmp113 = tmp112 * tmp7
    tmp114 = tmp113 + 1.0
    tmp115 = 1.0 - tmp7
    tmp116 = 1.25 * tmp115
    tmp117 = tmp116 - 2.25
    tmp118 = tmp117 * tmp115
    tmp119 = tmp118 * tmp115
    tmp120 = tmp119 + 1.0
    tmp121 = tmp115 + 1.0
    tmp122 = -0.75 * tmp121
    tmp123 = tmp122 - -3.75
    tmp124 = tmp123 * tmp121
    tmp125 = tmp124 + -6.0
    tmp126 = tmp125 * tmp121
    tmp127 = tmp126 - -3.0
    tmp128 = tmp63 * tmp109
    tmp129 = tmp76 * tmp114
    tmp130 = tmp89 * tmp120
    tmp131 = tmp102 * tmp127
    tmp132 = tmp130 + tmp131
    tmp133 = tmp129 + tmp132
    tmp134 = tmp128 + tmp133
    tl.store(out_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), tmp134, xmask)
''').kernel

def call(arg0_1):
    arg0_1_size = arg0_1.size()
    s0 = arg0_1_size[0]
    s1 = arg0_1_size[1]
    s2 = arg0_1_size[2]
    s3 = arg0_1_size[3]
    buf0 = empty_strided((s0, s1, 64, 128), (8192*s1, 8192, 128, 1), device='cuda', dtype=torch.float32)
    kernel0_xnumel = 8192*s0*s1
    kernel0[grid(kernel0_xnumel)](arg0_1, buf0, kernel0_xnumel)
    return (buf0, )
Code used for benchmarking
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils.benchmark import Timer, Compare
import torchinductor
from torchinductor.compile_fx import compile_fx_inner, cudagraphify
from torchinductor.decomposition import decompositions
from itertools import product
from functools import partial
import sys


batch_size = (16, 4)
sizes =  ((128, 256),)
#((8,2), (2,8), (16, 32), (32,16), (73, 198), (198,74))
# (8,2), (8,3), (8,4), (8,5),) #(64, 128), (256, 256),)
scales = ( 0.5, 1, 3 )
options = (True, False)

test_func = torch.ops.aten.upsample_bicubic2d
benchmark_name = test_func.__name__

def gen_inputs():
    make_arg = partial(torch.randn, dtype=torch.float32, device="cuda")
    for (iH, iW), scale, option in product(sizes, scales, options):
        oH = int(scale*iH)
        oW = int(scale*iW)
        t_args = (make_arg( (*batch_size, iH, iW)), )
        nont_args = ( (oH, oW,), option)
        yield t_args, nont_args

def adapt_func(f, nont_args):
    def f_adapted(*t_args):
        val = f(*(t_args+nont_args))
        return (val,)
    return f_adapted

def make_nondecomposed(f, t_args, nont_args):
    f_adapted = adapt_func(f, nont_args)
    non_decomposed = make_fx(f_adapted, tracing_mode="fake")(*t_args)
    compiled_non_decomposed = compile_fx_inner(non_decomposed, t_args)
    return compiled_non_decomposed

def make_decomposed(f, t_args, nont_args):
    f_adapted = adapt_func(f, nont_args)
    decomposed = make_fx(f_adapted, decomposition_table=decompositions, tracing_mode="fake")(*t_args)
    compiled_decomposed = compile_fx_inner(decomposed, t_args)
    return compiled_decomposed

def make_eager(f, t_args, nont_args):
    f_adapted = adapt_func(f, nont_args)
    return cudagraphify(f_adapted, t_args)

def benchmark(label, sublabel, f, args):
    return Timer("f(*args)",
                 globals={"f": f, "args": args},
                 label=benchmark_name,
                 description=label,
                 sub_label=sublabel,
                 num_threads=torch.get_num_threads()).blocked_autorange()


def compare_func(test_func, t_args, nont_args):
    sublabel = f"{tuple(arg.shape for arg in t_args)}"
    if nont_args: sublabel += f", {nont_args}"
    print(sublabel)

    yield benchmark("Inductor", sublabel, make_decomposed(test_func, t_args, nont_args), t_args)

    yield benchmark("Eager", sublabel, make_eager(test_func, t_args, nont_args), t_args)


def run_benchmarks():
    results = []
    for t_args, nont_args in gen_inputs():
        for res in compare_func(test_func, t_args, nont_args):
            results.append(res)

    compare = Compare(results)
    compare.trim_significant_figures()
    compare.print()

def check_equality():
    for t_args, nont_args in gen_inputs():
        f_eager = make_eager(test_func, t_args, nont_args)
        f_inductor = make_decomposed(test_func, t_args, nont_args)
        result_eager = f_eager(*t_args)[0]
        result_inductor = f_inductor(*t_args)[0]

        if not torch.allclose(result_eager, result_inductor, atol=1e-4, rtol=10):
            print(f"ERROR Mismatch for t_args={t_args[0].shape}, non_targs={nont_args}!")
            diff = result_eager - result_inductor
            max_abs = diff.abs().max().item()
            print(f"Max abs difference: {max_abs}")
            max_rel = (diff.abs()/torch.maximum(result_eager.abs(), result_inductor.abs())).max().item()
            print(f"Max rel difference: {max_rel}")
            print(f"Input tensors: {t_args}")
            print(f"Eager returned\n{result_eager[0]}\n")
            print(f"Inductor returned\n{result_inductor[0]}\n\n")
            print(f"Difference\n{diff}\n\n")
        else:
            print(f"OK {t_args[0].shape} {nont_args}")


if __name__ == '__main__':
    torchinductor.config.debug = True
    t_args, nont_args = next(gen_inputs())
    fi = make_decomposed(test_func, t_args, nont_args)
    torchinductor.config.debug = False
    check_equality()
    run_benchmarks()

@fdrocha
Copy link
Contributor Author

fdrocha commented Sep 21, 2022

@jansel could you PTAL?
FWIW, I spent some time trying to improve lowering but found nothing that led to better performance, any pointers would be much appreciated.

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is slower, what does the eager mode version do? Are there any tricks hiding in that version?

Perhaps it is a benchmark setup issue? Have you tried disabling cudagraphs for both legs of the experiment?

Maybe @ngimel has ideas.

@ezyang
Copy link
Contributor

ezyang commented Sep 21, 2022

part of it is 64 bit vs 32 bit indexing, check the linked issue

@jansel
Copy link
Contributor

jansel commented Sep 21, 2022

We should be able to use 32bit indexing in the lowering version.

@ngimel
Copy link

ngimel commented Sep 21, 2022

Lowering uses 32 bit indexing, so I don't think that's the reason. But yeah, looking at the raw kernel times would be helpful, as cuda graphs incur additional copies.

ix = ops.indirect_indexing(clamp(fx, 0, iW - 1))
return x_loader([n, c, iy, ix])

iy = ops.to_dtype(in_y, torch.int32)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this break if x or y dimension is larger than int_max? (That's a very rare situation, but our users tend to poke holes like this, even if only to file an issue)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good point, and it will definitely break. Note that the cuda implementation has the same issue though, it is using int for those indices. Probably there are a few other kernels with this issue too.

In the lowering, it is easy enough to add a check and use int32 or int64 accordingly, but if we are worrying about that we should probably also worry about float32 (use for all the floating point index calculations) not being precise enough if the xy dimensions are large enough. Should we also be picking the floating point size according to those dimensions? CUDA implementation uses at::acc_type<scalar_t, true>, which doesn't make much sense to me.

I will add a check for the int dtype, seems like a bigger problem, but let me know what you think about the float type issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't worry fixing this. I think this should be fixed at a compiler level, as discussed in #1293

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, yeah, your solution of just adding a check for now SGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, ideally it should be at compiler level, but meanwhile we can get significantly better performance by having the explicit check

coeffs_x = tuple((load_bounded(y, x) for x in ixs_ofs))
return cubic_interp1d(coeffs_x, t_x)

coeffs_y = tuple(get_x_interp(y) for y in iys_ofs)
Copy link

@ngimel ngimel Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are coeffs_x?
Edit: ah ok it's rolled in get_x_interp function, but that's pretty confusing, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an asymmetry between coeffs_x and coeffs_y. For coeffs_x, they are direct memory reads from input tensor and there are actually 4 sets of them, one for each of 4 y offsets. For coeffs_y they are the result of applying cubic_interp1d to each of the 4 sets of coeffs_x.

@fdrocha
Copy link
Contributor Author

fdrocha commented Sep 22, 2022

(Replying to @jansel above)

Not sure why this is slower, what does the eager mode version do? Are there any tricks hiding in that version?

Not that I can see? I largely based this lowering on the cuda implementation, which seems pretty straightforward.

Perhaps it is a benchmark setup issue? Have you tried disabling cudagraphs for both legs of the experiment?

This is what you get without and with cudagraphs (+CG columns). I also included the results for the decomp in pytorch/pytorch#85403

[---------------------------------------------------------- upsample_bicubic2d ---------------------------------------------------------]
                                                              |  lowering  |  lowering+CG  |  decomp  |  decomp+CG  |  eager  |  eager+CG
32 threads: -----------------------------------------------------------------------------------------------------------------------------
      (torch.Size([32, 4, 128, 256]),), ((256, 512), True)    |    1000    |      1000     |   2001   |     2121    |    711  |     830
      (torch.Size([32, 4, 128, 256]),), ((256, 512), False)   |     940    |      1050     |   2030   |     2119    |    711  |     830
      (torch.Size([32, 4, 128, 256]),), ((512, 1024), True)   |    3563    |      3677     |   7700   |     7800    |   2633  |    2755
      (torch.Size([32, 4, 128, 256]),), ((512, 1024), False)  |    3660    |      3779     |   7700   |     7800    |   2637  |    2752

So doesn't seem like it's a cudagraph thing.

By the way, I disabled cudagraphs for torchinductor versions by passing cudagraphs=False when calling compile_fx_inner, not sure if there's something else I need to do.

@fdrocha
Copy link
Contributor Author

fdrocha commented Sep 24, 2022

@ngimel @jansel what do you think? Should I merge this?

@ngimel
Copy link

ngimel commented Sep 24, 2022

Yeah it's ok to merge this

@jansel
Copy link
Contributor

jansel commented Sep 24, 2022

Yes, we can merge this.

@ngimel ngimel merged commit 1d4794d into main Sep 25, 2022
ngimel pushed a commit to ngimel/torchdynamo that referenced this pull request Sep 25, 2022
@ngimel
Copy link

ngimel commented Sep 25, 2022

Sorry, had to revert, as master tests started failing

@jansel
Copy link
Contributor

jansel commented Sep 25, 2022

@fdrocha the failure is related to not handling None inputs:
https://github.com/pytorch/torchdynamo/actions/runs/3120977240/jobs/5061978080

    @register_lowering(aten.upsample_bicubic2d)
    def upsample_bicubic2d(x, output_size, align_corners, scales_h=None, scales_w=None):
        x.realize_hint()
        x_loader = x.make_loader()
    
        N, C, iH, iW = x.get_size()
>       oH, oW = output_size
E       TypeError: cannot unpack non-iterable NoneType object

torchinductor/lowering.py:1728: TypeError

In this case:

upsample_bicubic2d(x, output_size=None, scales_h=0.6, scales_w=0.6)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants