Skip to content

Conversation

fdrocha
Copy link
Collaborator

@fdrocha fdrocha commented Aug 31, 2022

Stack from ghstack (oldest at bottom):

I ran some benchmarks comparing performance of eager, torch.inductor using this decomposition and torch.inductor using a previously existing lowering of this function:

Benchmarks
[-------------------------------------- grid_sampler_2d --------------------------------------]
                                                           |  Decomposed  |  Lowering  |  Eager
32 threads: -----------------------------------------------------------------------------------
      (16, 3, 128, 128), (16, 128, 128, 2), (0, 0, 0)      |      100     |     120    |    118
      (16, 3, 128, 128), (16, 128, 128, 2), (0, 1, 0)      |      127     |     127    |    127
      (16, 3, 128, 128), (16, 128, 128, 2), (0, 0, 1)      |      119     |     121    |    119
      (16, 3, 128, 128), (16, 128, 128, 2), (0, 1, 1)      |      127     |     127    |    130
      (16, 3, 128, 128), (16, 512, 512, 2), (0, 0, 0)      |     1400     |    1406    |   1400
      (16, 3, 128, 128), (16, 512, 512, 2), (0, 1, 0)      |     1530     |    1450    |   1530
      (16, 3, 128, 128), (16, 512, 512, 2), (0, 0, 1)      |     1406     |    1421    |   1390
      (16, 3, 128, 128), (16, 512, 512, 2), (0, 1, 1)      |     1540     |    1460    |   1537
      (16, 3, 128, 128), (16, 1024, 1024, 2), (0, 0, 0)    |     5490     |    5542    |   5400
      (16, 3, 128, 128), (16, 1024, 1024, 2), (0, 1, 0)    |     6060     |    5720    |   6040
      (16, 3, 128, 128), (16, 1024, 1024, 2), (0, 0, 1)    |     5525     |    5570    |   5480
      (16, 3, 128, 128), (16, 1024, 1024, 2), (0, 1, 1)    |     6080     |    5730    |   6060
      (16, 3, 512, 512), (16, 128, 128, 2), (0, 0, 0)      |      556     |     655    |    578
      (16, 3, 512, 512), (16, 128, 128, 2), (0, 1, 0)      |      573     |     664    |    600
      (16, 3, 512, 512), (16, 128, 128, 2), (0, 0, 1)      |      556     |     661    |    578
      (16, 3, 512, 512), (16, 128, 128, 2), (0, 1, 1)      |      573     |     664    |    602
      (16, 3, 512, 512), (16, 512, 512, 2), (0, 0, 0)      |     1900     |    2198    |   2808
      (16, 3, 512, 512), (16, 512, 512, 2), (0, 1, 0)      |     2108     |    2267    |   2973
      (16, 3, 512, 512), (16, 512, 512, 2), (0, 0, 1)      |     1938     |    2187    |   2812
      (16, 3, 512, 512), (16, 512, 512, 2), (0, 1, 1)      |     2110     |    2266    |   2978
      (16, 3, 512, 512), (16, 1024, 1024, 2), (0, 0, 0)    |     6430     |    7300    |   9870
      (16, 3, 512, 512), (16, 1024, 1024, 2), (0, 1, 0)    |     7250     |    7658    |  10500
      (16, 3, 512, 512), (16, 1024, 1024, 2), (0, 0, 1)    |     6440     |    7500    |   9880
      (16, 3, 512, 512), (16, 1024, 1024, 2), (0, 1, 1)    |     7260     |    7674    |  10510
      (16, 3, 1024, 1024), (16, 128, 128, 2), (0, 0, 0)    |     1894     |    2113    |   1906
      (16, 3, 1024, 1024), (16, 128, 128, 2), (0, 1, 0)    |     1950     |    2134    |   1992
      (16, 3, 1024, 1024), (16, 128, 128, 2), (0, 0, 1)    |     1893     |    2138    |   1910
      (16, 3, 1024, 1024), (16, 128, 128, 2), (0, 1, 1)    |     1953     |    2146    |   1998
      (16, 3, 1024, 1024), (16, 512, 512, 2), (0, 0, 0)    |     5214     |    6160    |   9630
      (16, 3, 1024, 1024), (16, 512, 512, 2), (0, 1, 0)    |     5470     |    6560    |  10900
      (16, 3, 1024, 1024), (16, 512, 512, 2), (0, 0, 1)    |     5215     |    6240    |   9820
      (16, 3, 1024, 1024), (16, 512, 512, 2), (0, 1, 1)    |     5474     |    6616    |  10900
      (16, 3, 1024, 1024), (16, 1024, 1024, 2), (0, 0, 0)  |    15350     |   18800    |  34480
      (16, 3, 1024, 1024), (16, 1024, 1024, 2), (0, 1, 0)  |    16100     |   20110    |  38680
      (16, 3, 1024, 1024), (16, 1024, 1024, 2), (0, 0, 1)  |    15400     |   19000    |  34600
      (16, 3, 1024, 1024), (16, 1024, 1024, 2), (0, 1, 1)  |    16100     |   20150    |  38680

Times are in microseconds (us).

Seems decomposed version is fastest most of the time. There are two lines where the lowering is 10% faster but for larger sizes in particular decomp is 20% faster.

Here is the script used to run the benchmarks

Script
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils.benchmark import Timer, Compare
import torchinductor
from torchinductor.compile_fx import compile_fx_inner, cudagraphify
from torchinductor.decomposition import decompositions
from itertools import product
from functools import partial

torchinductor.config.debug = True

benchmark_name = "grid_sampler_2d"
Ns = [16]
Cs = [3]
# We will use same value for width and height
iHs = [128, 512, 1024]
oHs = [128, 512 , 1024]
# (interpolation_mode, padding_mode, align_corners)
options = [ (0, 0, False) , (0, 1, False), (0, 0, True), (0, 1, True) ]

def rand_uniform(*size, **kwargs):
    hi_val = kwargs['hi']
    lo_val = kwargs['lo']
    del kwargs['hi']
    del kwargs['lo']
    xs = torch.rand(*size, **kwargs)
    return (hi_val-lo_val)*xs + lo_val

def gen_inputs():
    make_arg = partial(rand_uniform, dtype=torch.float32, device="cuda", lo=-1.2, hi=1.2)
    for N, C, iH, oH,option in product(Ns, Cs, iHs, oHs, options):
        image_shape = (N, C, iH, iH)
        grid_shape = (N, oH, oH, 2)
        yield (make_arg(image_shape), make_arg(grid_shape), option)


def benchmark(label, sublabel, f, args):
    return Timer("f(*args)",
                 globals={"f": f, "args": args},
                 label=benchmark_name,
                 description=label,
                 sub_label=sublabel,
                 num_threads=torch.get_num_threads()).blocked_autorange()


def compare(image, grid, option):
    def f(image, grid):
        val = torch.ops.aten.grid_sampler_2d(image, grid, *option)
        return (val,)

    sublabel = f"{tuple(image.shape)}, {tuple(grid.shape)}, {tuple(map(int, option))}"
    print(sublabel)

    t_args = [image, grid]
    decomposed = make_fx(f, decomposition_table=decompositions, tracing_mode="fake")(*t_args)
    compiled_decomposed = compile_fx_inner(decomposed, t_args)
    yield benchmark("Decomposed", sublabel, compiled_decomposed, t_args)

    non_decomposed = make_fx(f, tracing_mode="fake")(*t_args)
    compiled_nondecomposed = compile_fx_inner(non_decomposed, t_args)
    yield benchmark("Lowering", sublabel, compiled_nondecomposed, t_args)

    # Just show the first two generated kernels
    if torchinductor.config.debug:
        torchinductor.config.debug = False

    cuda_f = cudagraphify(f, t_args)
    yield benchmark("Eager", sublabel, cuda_f, t_args)


if __name__ == '__main__':
    results = []
    for image, grid, option in gen_inputs():
        for res in compare(image, grid, option):
            results.append(res)

    compare = Compare(results)
    compare.trim_significant_figures()
    compare.print()

And here is the code generated by torch inductor for decomp and lowering versions:

Generated code
(16, 3, 128, 128), (16, 128, 128, 2), (0, 0, 0)
torchinductor.compile_fx: [INFO] Compiling FORWARDS graph
torchinductor.codegen.triton: [INFO] schedule: [SchedulerNode(name='buf0')]

from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torchinductor.codecache import CppCodeCache, TritonCodeCache

aten = torch.ops.aten

import triton
import triton.language as tl

from torchinductor.triton_ops.autotune import pointwise_heuristics
from torchinductor.triton_ops.autotune import reduction_heuristics
from torchinductor.triton_ops.autotune import grid


kernel0 = TritonCodeCache.load('''
import triton
import triton.language as tl
from torchinductor.triton_ops.autotune import pointwise_heuristics

@pointwise_heuristics(size_hints=[1048576], contiguous=False, filename=__file__)
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x2 = (xindex // 49152)
    x1 = (xindex // 16384) % 3
    x0 = xindex % 16384
    x4 = xindex
    tmp2 = tl.load(in_ptr0 + (2*x0) + (32768*x2), xmask)
    tmp12 = tl.load(in_ptr0 + 1 + (2*x0) + (32768*x2), xmask)
    tmp0 = x2
    tmp1 = x1
    tmp3 = 64.0
    tmp4 = tmp2 * tmp3
    tmp5 = 63.5
    tmp6 = tmp4 + tmp5
    tmp7 = tl.libdevice.floor(tmp6)
    tmp8 = 0
    tmp9 = tmp7 >= tmp8
    tmp10 = 128
    tmp11 = tmp7 < tmp10
    tmp13 = tmp12 * tmp3
    tmp14 = tmp13 + tmp5
    tmp15 = tl.libdevice.floor(tmp14)
    tmp16 = tmp15 >= tmp8
    tmp17 = tmp15 < tmp10
    tmp18 = tmp16 & tmp17
    tmp19 = tmp11 & tmp18
    tmp20 = tmp9 & tmp19
    tmp21 = tmp15.to(tl.int64)
    tmp22 = tmp20 | tl.zeros(tmp21.shape, tmp20.dtype) if tmp21.numel > 1 else tmp20
    tmp23 = tmp22 | tl.zeros(tmp8.shape, tmp22.dtype) if tmp8.numel > 1 else tmp22
    tmp24 = tl.where(tmp23, tmp21, tmp8)
    tmp25 = tmp7.to(tl.int64)
    tmp26 = tmp20 | tl.zeros(tmp25.shape, tmp20.dtype) if tmp25.numel > 1 else tmp20
    tmp27 = tmp26 | tl.zeros(tmp8.shape, tmp26.dtype) if tmp8.numel > 1 else tmp26
    tmp28 = tl.where(tmp27, tmp25, tmp8)
    tmp29 = tl.load(in_ptr1 + tmp28 + (128*tmp24) + (16384*tmp1) + (16384*ks0*tmp0), xmask)
    tmp30 = 1
    tmp31 = tmp7 + tmp30
    tmp32 = tmp31 - tmp6
    tmp33 = tmp15 + tmp30
    tmp34 = tmp33 - tmp14
    tmp35 = tmp32 * tmp34
    tmp36 = tmp20 | tl.zeros(tmp35.shape, tmp20.dtype) if tmp35.numel > 1 else tmp20
    tmp37 = tmp36 | tl.zeros(tmp8.shape, tmp36.dtype) if tmp8.numel > 1 else tmp36
    tmp38 = tl.where(tmp37, tmp35, tmp8)
    tmp39 = tmp29 * tmp38
    tmp40 = tmp31 >= tmp8
    tmp41 = tmp31 < tmp10
    tmp42 = tmp41 & tmp18
    tmp43 = tmp40 & tmp42
    tmp44 = tmp43 | tl.zeros(tmp21.shape, tmp43.dtype) if tmp21.numel > 1 else tmp43
    tmp45 = tmp44 | tl.zeros(tmp8.shape, tmp44.dtype) if tmp8.numel > 1 else tmp44
    tmp46 = tl.where(tmp45, tmp21, tmp8)
    tmp47 = tmp31.to(tl.int64)
    tmp48 = tmp43 | tl.zeros(tmp47.shape, tmp43.dtype) if tmp47.numel > 1 else tmp43
    tmp49 = tmp48 | tl.zeros(tmp8.shape, tmp48.dtype) if tmp8.numel > 1 else tmp48
    tmp50 = tl.where(tmp49, tmp47, tmp8)
    tmp51 = tl.load(in_ptr1 + tmp50 + (128*tmp46) + (16384*tmp1) + (16384*ks0*tmp0), xmask)
    tmp52 = tmp6 - tmp7
    tmp53 = tmp52 * tmp34
    tmp54 = tmp43 | tl.zeros(tmp53.shape, tmp43.dtype) if tmp53.numel > 1 else tmp43
    tmp55 = tmp54 | tl.zeros(tmp8.shape, tmp54.dtype) if tmp8.numel > 1 else tmp54
    tmp56 = tl.where(tmp55, tmp53, tmp8)
    tmp57 = tmp51 * tmp56
    tmp58 = tmp39 + tmp57
    tmp59 = tmp33 >= tmp8
    tmp60 = tmp33 < tmp10
    tmp61 = tmp59 & tmp60
    tmp62 = tmp11 & tmp61
    tmp63 = tmp9 & tmp62
    tmp64 = tmp33.to(tl.int64)
    tmp65 = tmp63 | tl.zeros(tmp64.shape, tmp63.dtype) if tmp64.numel > 1 else tmp63
    tmp66 = tmp65 | tl.zeros(tmp8.shape, tmp65.dtype) if tmp8.numel > 1 else tmp65
    tmp67 = tl.where(tmp66, tmp64, tmp8)
    tmp68 = tmp63 | tl.zeros(tmp25.shape, tmp63.dtype) if tmp25.numel > 1 else tmp63
    tmp69 = tmp68 | tl.zeros(tmp8.shape, tmp68.dtype) if tmp8.numel > 1 else tmp68
    tmp70 = tl.where(tmp69, tmp25, tmp8)
    tmp71 = tl.load(in_ptr1 + tmp70 + (128*tmp67) + (16384*tmp1) + (16384*ks0*tmp0), xmask)
    tmp72 = tmp14 - tmp15
    tmp73 = tmp32 * tmp72
    tmp74 = tmp63 | tl.zeros(tmp73.shape, tmp63.dtype) if tmp73.numel > 1 else tmp63
    tmp75 = tmp74 | tl.zeros(tmp8.shape, tmp74.dtype) if tmp8.numel > 1 else tmp74
    tmp76 = tl.where(tmp75, tmp73, tmp8)
    tmp77 = tmp71 * tmp76
    tmp78 = tmp58 + tmp77
    tmp79 = tmp41 & tmp61
    tmp80 = tmp40 & tmp79
    tmp81 = tmp80 | tl.zeros(tmp64.shape, tmp80.dtype) if tmp64.numel > 1 else tmp80
    tmp82 = tmp81 | tl.zeros(tmp8.shape, tmp81.dtype) if tmp8.numel > 1 else tmp81
    tmp83 = tl.where(tmp82, tmp64, tmp8)
    tmp84 = tmp80 | tl.zeros(tmp47.shape, tmp80.dtype) if tmp47.numel > 1 else tmp80
    tmp85 = tmp84 | tl.zeros(tmp8.shape, tmp84.dtype) if tmp8.numel > 1 else tmp84
    tmp86 = tl.where(tmp85, tmp47, tmp8)
    tmp87 = tl.load(in_ptr1 + tmp86 + (128*tmp83) + (16384*tmp1) + (16384*ks0*tmp0), xmask)
    tmp88 = tmp52 * tmp72
    tmp89 = tmp80 | tl.zeros(tmp88.shape, tmp80.dtype) if tmp88.numel > 1 else tmp80
    tmp90 = tmp89 | tl.zeros(tmp8.shape, tmp89.dtype) if tmp8.numel > 1 else tmp89
    tmp91 = tl.where(tmp90, tmp88, tmp8)
    tmp92 = tmp87 * tmp91
    tmp93 = tmp78 + tmp92
    tl.store(out_ptr0 + x4 + tl.zeros([XBLOCK], tl.int32), tmp93, xmask)
''').kernel


def call(image_1, grid_1):
    image_1_size = image_1.size()
    s0 = image_1_size[0]
    s1 = image_1_size[1]
    s2 = image_1_size[2]
    grid_1_size = grid_1.size()
    s3 = grid_1_size[3]
    buf0 = empty_strided((16, 3, 128, 128), (49152, 16384, 128, 1), device='cuda', dtype=torch.float32)
    kernel0[grid(786432)](grid_1, image_1, buf0, s1, 786432)
    return (buf0, )


if __name__ == "__main__":
    from torchdynamo.testing import rand_strided
    from torchinductor.utils import print_performance
    image_1 = rand_strided((16, 3, 128, 128), (49152, 16384, 128, 1), device='cuda', dtype=torch.float32)
    grid_1 = rand_strided((16, 128, 128, 2), (32768, 256, 2, 1), device='cuda', dtype=torch.float32)
    print_performance(lambda: call(image_1, grid_1))

torchinductor.graph: [INFO] Output code: /tmp/torchinductor_fdrocha/jz/cjzxwktxfityc7qc2x73mzdwc2licss2pmpti5obszqxndot3zz3.py
torchinductor.compile_fx: [INFO] Compiling FORWARDS graph
torchinductor.codegen.triton: [INFO] schedule: [SchedulerNode(name='buf0')]

from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torchinductor.codecache import CppCodeCache, TritonCodeCache

aten = torch.ops.aten

import triton
import triton.language as tl

from torchinductor.triton_ops.autotune import pointwise_heuristics
from torchinductor.triton_ops.autotune import reduction_heuristics
from torchinductor.triton_ops.autotune import grid


kernel0 = TritonCodeCache.load('''
import triton
import triton.language as tl
from torchinductor.triton_ops.autotune import pointwise_heuristics

@pointwise_heuristics(size_hints=[1048576], contiguous=False, filename=__file__)
@triton.jit
def kernel(in_ptr0, in_ptr1, out_ptr0, ks0, ks1, ks2, xnumel, XBLOCK : tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
    xmask = xindex < xnumel
    x0 = xindex % (ks0*ks0)
    x2 = (xindex // (ks1*(ks0*ks0)))
    x4 = (xindex // (ks0*ks0))
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (ks2*x0) + (ks2*x2*(ks0*ks0)), xmask)
    tmp1 = tl.load(in_ptr0 + 1 + (ks2*x0) + (ks2*x2*(ks0*ks0)), xmask)
    tmp2 = 0.0
    tmp3 = 1.0
    tmp4 = 2.0
    tmp5 = ks0
    tmp6 = tmp0 + tmp3
    tmp7 = tmp6 * tmp5
    tmp8 = tmp7 - tmp3
    tmp9 = tmp8 / tmp4
    tmp10 = tmp1 + tmp3
    tmp11 = tmp10 * tmp5
    tmp12 = tmp11 - tmp3
    tmp13 = tmp12 / tmp4
    tmp14 = tl.libdevice.floor(tmp9)
    tmp15 = tl.libdevice.floor(tmp13)
    tmp16 = tmp14 + tmp3
    tmp17 = tmp15 + tmp3
    tmp18 = tmp16 - tmp9
    tmp19 = tmp17 - tmp13
    tmp20 = tmp18 * tmp19
    tmp21 = tmp9 - tmp14
    tmp22 = tmp21 * tmp19
    tmp23 = tmp13 - tmp15
    tmp24 = tmp18 * tmp23
    tmp25 = tmp21 * tmp23
    tmp26 = 0
    tmp27 = (-1) + ks0
    tmp28 = tl.minimum(tmp27, tmp14)
    tmp29 = tl.maximum(tmp26, tmp28)
    tmp30 = tl.minimum(tmp27, tmp15)
    tmp31 = tl.maximum(tmp26, tmp30)
    tmp32 = tl.minimum(tmp27, tmp16)
    tmp33 = tl.maximum(tmp26, tmp32)
    tmp34 = tl.minimum(tmp27, tmp17)
    tmp35 = tl.maximum(tmp26, tmp34)
    tmp36 = tmp31.to(tl.int64)
    tmp37 = tmp29.to(tl.int64)
    tmp38 = tl.load(in_ptr1 + tmp37 + (ks0*tmp36) + (x4*(ks0*ks0)) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp39 = tmp33.to(tl.int64)
    tmp40 = tl.load(in_ptr1 + tmp39 + (ks0*tmp36) + (x4*(ks0*ks0)) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp41 = tmp35.to(tl.int64)
    tmp42 = tl.load(in_ptr1 + tmp37 + (ks0*tmp41) + (x4*(ks0*ks0)) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp43 = tl.load(in_ptr1 + tmp39 + (ks0*tmp41) + (x4*(ks0*ks0)) + tl.zeros([XBLOCK], tl.int32), xmask)
    tmp44 = tmp38 * tmp20
    tmp45 = tmp40 * tmp22
    tmp46 = tmp42 * tmp24
    tmp47 = tmp43 * tmp25
    tmp48 = tmp44 + tmp45
    tmp49 = tmp48 + tmp46
    tmp50 = tmp49 + tmp47
    tl.store(out_ptr0 + x3 + tl.zeros([XBLOCK], tl.int32), tmp50, xmask)
''').kernel


def call(image_1, grid_1):
    image_1_size = image_1.size()
    s0 = image_1_size[0]
    s1 = image_1_size[1]
    s2 = image_1_size[2]
    grid_1_size = grid_1.size()
    s3 = grid_1_size[3]
    buf0 = empty_strided((s0, s1, s2, s2), (s1*(s2*s2), s2*s2, s2, 1), device='cuda', dtype=torch.float32)
    kernel0_xnumel = s0*s1*(s2*s2)
    kernel0[grid(kernel0_xnumel)](grid_1, image_1, buf0, s2, s1, s3, kernel0_xnumel)
    return (buf0, )


if __name__ == "__main__":
    from torchdynamo.testing import rand_strided
    from torchinductor.utils import print_performance
    image_1 = rand_strided((16, 3, 128, 128), (49152, 16384, 128, 1), device='cuda', dtype=torch.float32)
    grid_1 = rand_strided((16, 128, 128, 2), (32768, 256, 2, 1), device='cuda', dtype=torch.float32)
    print_performance(lambda: call(image_1, grid_1))

torchinductor.graph: [INFO] Output code: /tmp/torchinductor_fdrocha/6v/c6vxi6qdelkisdfsdjwhelhmscvts42ph2sleup2tm4m3llggpda.py

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 31, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 842de6c (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

fdrocha pushed a commit that referenced this pull request Aug 31, 2022
ghstack-source-id: 2c866fb
Pull Request resolved: #84350
@lezcano lezcano requested review from jansel and Chillee August 31, 2022 15:11
Comment on lines +1738 to +1739
x = grid[..., 0]
y = grid[..., 1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to benchmark / think about: Try doing grid[..., [0,1]] to merge both calls. We could then see whether we can do everything in a vectorised fashion via broadcasting at least for interpolation_mode == 0.

@fdrocha
Copy link
Collaborator Author

fdrocha commented Sep 5, 2022

Note: ended up finding a small bug in an edge case of the cuda implementation, related to rounding. The PR also fixes this.

fdrocha pushed a commit that referenced this pull request Sep 5, 2022
ghstack-source-id: 52fb594
Pull Request resolved: #84350
@fdrocha
Copy link
Collaborator Author

fdrocha commented Sep 5, 2022

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered with the green (-g) flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2022

Hey @fdrocha.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 6, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 6, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: d8f7d54
Pull Request resolved: #1134
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 6, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 6, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: 7ff6b8a
Pull Request resolved: #1134
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 6, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 6, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: 3e6559d
Pull Request resolved: #1134
facebook-github-bot pushed a commit that referenced this pull request Sep 7, 2022
Summary:
Pull Request resolved: #84350
Approved by: https://github.com/jansel, https://github.com/Lezcano

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/91a5f52f51de9d6aa305d184fe07fe15d20b82c9

Reviewed By: mehtanirav

Differential Revision: D39277804

fbshipit-source-id: dab01a97cea62949684a12ae7a785a295dcb1ff9
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 8, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 8, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 8, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: c91a2c5
Pull Request resolved: #1134
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 8, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 8, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 8, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: 986b35a
Pull Request resolved: #1134
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 9, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 9, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 9, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: adef10f
Pull Request resolved: #1134
@facebook-github-bot facebook-github-bot deleted the gh/fdrocha/14/head branch September 9, 2022 14:19
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 12, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 12, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 12, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: 4f2951f
Pull Request resolved: #1134
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 12, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 12, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 12, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: 872e1f1
Pull Request resolved: #1134
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 14, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 14, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

[ghstack-poisoned]
fdrocha pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 14, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: 93522ab
Pull Request resolved: #1134
ezyang pushed a commit to pytorch/torchdynamo that referenced this pull request Sep 15, 2022
There is now a decomposition in pytorch that seems to have
better performance, see benchmarks at
pytorch/pytorch#84350

ghstack-source-id: 93522ab
Pull Request resolved: #1134
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants