Skip to content

Conversation

njriasan
Copy link
Contributor

@njriasan njriasan commented Sep 17, 2025

Summary: Enables support for epilogue subtiling in the blackwell ws template. This requires the ability to call store_output twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

Copy link

pytorch-bot bot commented Sep 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163145

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3bdfdae with merge base c261c71 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

@njriasan has exported this pull request. If you are a Meta employee, you can view the originating diff in D82610077.

@njriasan njriasan added Blackwell Specific failures or issues related to sm100 + Cuda arches ciflow/h100 release notes: inductor labels Sep 17, 2025
njriasan added a commit to njriasan/pytorch that referenced this pull request Sep 17, 2025
…plate (pytorch#163145)

Summary:

Enables supprot for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077
@facebook-github-bot
Copy link
Contributor

@njriasan has exported this pull request. If you are a Meta employee, you can view the originating diff in D82610077.

njriasan added a commit to njriasan/pytorch that referenced this pull request Sep 17, 2025
…plate (pytorch#163145)

Summary:

Enables supprot for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077
@facebook-github-bot
Copy link
Contributor

@njriasan has exported this pull request. If you are a Meta employee, you can view the originating diff in D82610077.

@njriasan
Copy link
Contributor Author

Here is an example of the output code with addmm (omitted the definition of _compute_pid)

@triton.jit
def triton_tem_fused_addmm_0(in_ptr0, arg_A, arg_B, out_ptr0, ks0, ws_ptr):
    EVEN_K : tl.constexpr = False
    ALLOW_TF32 : tl.constexpr = False
    USE_FAST_ACCUM : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    BLOCK_M : tl.constexpr = 128
    BLOCK_N : tl.constexpr = 128
    BLOCK_K : tl.constexpr = 128
    GROUP_M : tl.constexpr = 8
    A_ROW_MAJOR : tl.constexpr = True
    B_ROW_MAJOR : tl.constexpr = True
    NUM_SMS : tl.constexpr = 148
    TMA_SIZE : tl.constexpr = 128
    TMA_EXPERIMENTAL_API : tl.constexpr = False
    FLATTEN : tl.constexpr = True
    WARP_SPECIALIZE : tl.constexpr = True
    EPILOGUE_SUBTILE : tl.constexpr = True
    INDEX_DTYPE : tl.constexpr = tl.int32
    A = arg_A
    B = arg_B
    YBLOCK: tl.constexpr = BLOCK_N // 2
    XBLOCK: tl.constexpr = BLOCK_M
    tma_descriptor0 = tl.make_tensor_descriptor(out_ptr0, shape=[ks0, 248], strides=[248, 1], block_shape=[XBLOCK, YBLOCK])

    M = ks0
    N = 248
    K = 88
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    start_pid = tl.program_id(0)
    grid_m = tl.cdiv(M, BLOCK_M)
    grid_n = tl.cdiv(N, BLOCK_N)
    k_tiles = tl.cdiv(K, BLOCK_K)
    num_tiles = grid_m * grid_n

    # Note: We require TMA_EXPERIMENTAL_API == False, which
    # we will check before invoking this template.
    stride_am = 88
    stride_ak = 1
    stride_bk = 248
    stride_bn = 1
    a_desc = triton.language.make_tensor_descriptor(
        base=A,
        shape=[M, K] if A_ROW_MAJOR else [K, M],
        strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1],
        block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
    )
    b_desc = triton.language.make_tensor_descriptor(
        base=B,
        shape=[K, N] if B_ROW_MAJOR else [N, K],
        strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1],
        block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
    )

    # tile_id_c is used in the epilogue to break the dependency between
    # the prologue and the epilogue
    tile_id_c = start_pid - NUM_SMS
    num_pid_in_group = GROUP_M * grid_n

    for tile_id in tl.range(
        start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE
    ):
        pid_m, pid_n = _compute_pid(
            tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS
        )
        offs_am = pid_m * BLOCK_M
        offs_bn = pid_n * BLOCK_N

        accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
        for ki in range(k_tiles):
            offs_k = ki * BLOCK_K
            a = tl.load_tensor_descriptor(
                a_desc,
                [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am],
            )
            b = tl.load_tensor_descriptor(
                b_desc,
                [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k],
            )
            accumulator += tl.dot(
                a if A_ROW_MAJOR else a.T,
                b if B_ROW_MAJOR else b.T,
                allow_tf32=ALLOW_TF32,
            )

        tile_id_c += NUM_SMS
        pid_m, pid_n = _compute_pid(
            tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS
        )
        offs_cm = pid_m * BLOCK_M
        offs_cn = pid_n * BLOCK_N
        acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2))
        acc = tl.permute(acc, (0, 2, 1))
        acc0, acc1 = tl.split(acc)
        yoffset = offs_cn
        yindex = (yoffset + tl.arange(0, YBLOCK))[None, :, ]
        ymask = yindex < 248
        xoffset = offs_cm
        xindex = (xoffset + tl.arange(0, XBLOCK))[:, None]
        xmask = xindex < ks0
        tmp0 = tl.load(in_ptr0 + (yindex), ymask, eviction_policy='evict_last').to(tl.float32)
        tmp1 = acc0 + tmp0
        tma_descriptor0.store([xoffset, yoffset], tmp1.to(tl.float16))
        offs_cn2 = offs_cn + BLOCK_N // 2
        yoffset = offs_cn2
        yindex = (yoffset + tl.arange(0, YBLOCK))[None, :, ]
        ymask = yindex < 248
        xoffset = offs_cm
        xindex = (xoffset + tl.arange(0, XBLOCK))[:, None]
        xmask = xindex < ks0
        tmp2 = tl.load(in_ptr0 + (yindex), ymask, eviction_policy='evict_last').to(tl.float32)
        tmp3 = acc1 + tmp2
        tma_descriptor0.store([xoffset, yoffset], tmp3.to(tl.float16))

njriasan added a commit to njriasan/pytorch that referenced this pull request Sep 17, 2025
…plate (pytorch#163145)

Summary:

Enables supprot for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077
@facebook-github-bot
Copy link
Contributor

@njriasan has exported this pull request. If you are a Meta employee, you can view the originating diff in D82610077.

njriasan added a commit to njriasan/pytorch that referenced this pull request Sep 17, 2025
…plate (pytorch#163145)

Summary:

Enables supprot for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077
njriasan added a commit to njriasan/pytorch that referenced this pull request Sep 17, 2025
…plate (pytorch#163145)

Summary:

Enables supprot for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077
@facebook-github-bot
Copy link
Contributor

@njriasan has exported this pull request. If you are a Meta employee, you can view the originating diff in D82610077.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@njriasan has exported this pull request. If you are a Meta employee, you can view the originating diff in D82610077.

njriasan added a commit to njriasan/pytorch that referenced this pull request Sep 17, 2025
…plate (pytorch#163145)

Summary:

Enables supprot for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077
@njriasan
Copy link
Contributor Author

looks good ! would it be possible to test this with @drisspg's recent b200 enablement ?

I spoke to @drisspg and this is not working with test_max_autotune.py. I'll verify this file works on my B200 server or doesn't appear to be broken due to my changes.

@njriasan
Copy link
Contributor Author

@eellison My blackwell specific tests all pass. This includes a fix to the tests that may unblock other tests. Once this lands I can help with getting test_max_autotune.py fully enabled for Blackwell testing.

@njriasan njriasan requested a review from eellison September 23, 2025 22:11
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 24, 2025
@njriasan
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@facebook-github-bot
Copy link
Contributor

@njriasan has imported this pull request. If you are a Meta employee, you can view this in D83115051.

@njriasan
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

…plate (pytorch#163145)

Summary:

Enables supprot for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan: Tested with test_max_autotune.py on a Blackwell server.

Differential Revision: D82610077
@njriasan
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…plate (pytorch#163145)

Summary: Enables support for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077

Pull Request resolved: pytorch#163145
Approved by: https://github.com/eellison
jainapurva pushed a commit that referenced this pull request Sep 29, 2025
…plate (#163145)

Summary: Enables support for epilogue subtiling in the blackwell ws template. This requires the ability to call `store_output` twice in the same kernel and reuse the same tensor descriptor across allocations.

Test Plan:
Tested with test_max_autotune.py on a Blackwell server.

Rollback Plan:

Differential Revision: D82610077

Pull Request resolved: #163145
Approved by: https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Blackwell Specific failures or issues related to sm100 + Cuda arches ciflow/h100 ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged meta-exported module: inductor release notes: inductor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants