New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Persistent reductions #92267
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92267
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3ecad52: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 6a45243721eb209a89e7d82264676168b7114255 Pull Request resolved: #92267
cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 73ea3576f834fedaa21a471b22d2f547ca1b523f Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.max(tmp0, 1)[:, None] tmp3 = tmp0 - tmp2 tmp4 = tl.exp(tmp3) tmp6 = tl.sum(tmp4, 1)[:, None] tmp7 = tmp4 / tmp6 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp7, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 4984d3d3db434d8d21573c6b222ba0920bb5b61c Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.max(tmp0, 1)[:, None] tmp3 = tmp0 - tmp2 tmp4 = tl.exp(tmp3) tmp6 = tl.sum(tmp4, 1)[:, None] tmp7 = tmp4 / tmp6 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp7, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 77b31cbebd23f4b5190d8a617f9902b8d5dc01a8 Pull Request resolved: #92267
Do you have any benchmarks for this? |
Yeah, I have some RTX 3090 numbers that show some decent speedups (lots of ~1-5% wins spread out across most models). Need to rerun on an A100 though. |
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 0d1d56d5827718134648e774100b0a95189eb051 Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: bb42b0638cc652bfeebc6fe48effb0ceb380d390 Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: d094bb90bc209b632b45f6eac421125ba60fdb97 Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 16d4b3dfe91781a6576f9fb042d61e58baa225e3 Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 7727a11150edc970d17394c57dd15cb7ecfb36fc Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 0a380cee2372300bd99fbb5d2da6a58877e37141 Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 6979a387ffa32bc7da440bd850b7d05b61e031f2 Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: b49f6e06caad33e4dc50b82bda477c42212a0df2 Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 77b66182fe2c0860cf22cb89a03a9a7468e91554 Pull Request resolved: #92267
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Tried to rebase and push PR #92267, but it was already up to date |
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: def2f6de146d17db47db20089b8dfb23cbdeea7b Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: fc2bc151cfea300a3e0ab82e08197f223f2df1eb Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 5ce9b1806c0f366e18ef2fee45f646aa6d6efe7d Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: 4b8fba05a4a0fe815de51c4ad66dfbffa802eb5b Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: e6990e6e62c1388692faeffbd35a9def57132c9a Pull Request resolved: #92267
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes. Before: ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_loop_cuda ... reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rbase = tl.arange(0, RBLOCK)[None, :] x0 = xindex _tmp1 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + float("-inf") for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') _tmp1 = tl.where(xmask & rmask & (_tmp1 < tmp0), tmp0, _tmp1) tmp1 = tl.max(_tmp1, 1)[:, None] _tmp5 = tl.zeros([XBLOCK, RBLOCK], tl.float32) + 0 for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp2 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp3 = tmp2 - tmp1 tmp4 = tl.exp(tmp3) _tmp5 = tl.where(xmask & rmask, _tmp5 + tmp4, _tmp5) tmp5 = tl.sum(_tmp5, 1)[:, None] for roffset in range(0, rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex < rnumel r1 = rindex tmp6 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask, eviction_policy='evict_last') tmp7 = tmp6 - tmp1 tmp8 = tl.exp(tmp7) tmp9 = tmp8 / tmp5 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` After ``` $ pytest test/inductor/test_torchinductor.py -vsk test_softmax_one_kernel_persist_cuda ... persistent_reduction( size_hints=[16, 32], reduction_hint=ReductionHint.INNER, filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]} ) triton.jit def triton_(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr): xnumel = 16 rnumel = 32 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel rindex = tl.arange(0, RBLOCK)[None, :] rmask = rindex < rnumel r1 = rindex x0 = xindex tmp0 = tl.load(in_ptr0 + (r1 + (32*x0)), rmask & xmask) tmp2 = tl.where(xmask & rmask, tmp0, float("-inf")) tmp3 = tl.max(tmp2, 1)[:, None] tmp4 = tmp0 - tmp3 tmp5 = tl.exp(tmp4) tmp7 = tl.where(xmask & rmask, tmp5, 0) tmp8 = tl.sum(tmp7, 1)[:, None] tmp9 = tmp5 / tmp8 tl.store(out_ptr2 + (r1 + (32*x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), tmp9, rmask & xmask) ``` cc mlazos soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen chunyuan-w XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire [ghstack-poisoned]
ghstack-source-id: c6810292d9f8435652f3de608d6c82c9457b5252 Pull Request resolved: #92267
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This one may need to wait for the new MLIR Triton to land as it triggers some Triton crashes.
Before:
After
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire