Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1) #110890

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
35b6084
[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)
kadeng Oct 9, 2023
45c42c8
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 11, 2023
0a23a0a
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 12, 2023
bf137ba
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 12, 2023
ea94597
using Cutlass 3.2.0 to work around Pytorch CUDA IMA issues in aten fl…
kadeng Oct 17, 2023
76d9e9f
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 17, 2023
91d802a
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 17, 2023
97cc508
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 17, 2023
d3b7451
Refactoring with goal to let CUDA C++ backend use more of the preexis…
kadeng Oct 20, 2023
3a81daa
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 20, 2023
69bc176
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 20, 2023
a8c5e68
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 20, 2023
6c3be5b
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 20, 2023
1fa2450
Rebased and updated Cutlass to v3.2.2 on "[Inductor CUTLASS backend] …
kadeng Oct 31, 2023
b4acb71
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 31, 2023
bb0af00
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Oct 31, 2023
2451f8e
Refactoring, making APIs more similar to original version / Triton ba…
kadeng Nov 1, 2023
d06f92f
Simplifying PR on "[Inductor CUTLASS backend] Epilogue fusion codegen…
kadeng Nov 1, 2023
3216e8a
Fixing test failure in test_compiled_optimizers.py on "[Inductor CUTL…
kadeng Nov 1, 2023
bcb4f37
Fix new CI failures on "[Inductor CUTLASS backend] Epilogue fusion co…
kadeng Nov 1, 2023
f88de72
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Nov 2, 2023
4a4dcb3
Minor changes as requested in review on "[Inductor CUTLASS backend] E…
kadeng Nov 2, 2023
f96270f
Attempted hotfix of Cutlass 3.2.2 on "[Inductor CUTLASS backend] Epil…
kadeng Nov 3, 2023
42e2dc1
Attempted hotfix (attempt 2) of Cutlass 3.2.2 on "[Inductor CUTLASS b…
kadeng Nov 3, 2023
e1cbf27
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Nov 3, 2023
5ff9cd0
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Nov 3, 2023
d76c31a
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Nov 3, 2023
eb23b1b
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Nov 3, 2023
07df836
Triggering new checks after github outage on "[Inductor CUTLASS backe…
kadeng Nov 3, 2023
601d762
Update on "[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)"
kadeng Nov 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 152 additions & 3 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import unittest

from typing import List, Optional
from typing import Callable, List, Optional

import torch
from torch import multiprocessing as mp
Expand Down Expand Up @@ -256,14 +256,163 @@ def mm(a, b):
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)

def _test_max_autotune_cutlass_backend_epilogue_fusion(
self,
dynamic: bool = False,
max_autotune_gemm_backends: str = "CUTLASS",
mixed_precision=False,
fp16=True,
kadeng marked this conversation as resolved.
Show resolved Hide resolved
expected_fuse_count=1,
mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
):
from torch._inductor.codegen.cuda import cuda_cpp_scheduling

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
mixed_precision
)

# Note: The ops that are available
# also depend on the alignment of the shapes
# so if these shapes don't all align to at least 8 elements
# it can happen that no Cutlass 3.x op is available
# that allows fusions
a = torch.randn(256, 32).cuda()
b = torch.randn(32, 256).cuda()
if fp16:
a = a.half()
b = b.half()

with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.cutlass_only_evt_capable_ops": True,
"cuda.version": "12.2", # required to enable the Kernels we need
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the underlying CUDA in the CI is 12.1? Does this work fine?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's necessary here, otherwise the Kernels we need for testing won't be listed at all. It does not mean that these Kernels are performant when CUDA < 12.2, but at least they seem to exist, and tests are working.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I understand that the value is passed to the CUTLASS generator which won't generate the EVT kernels under 12.2 (although, I'd expect under 12.1). I was just trying to make sure that the kernels are functioning. Perf doesn't matter at this point. Thanks for confirming.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kadeng I think 12.1 should be enough. Just checked https://github.com/NVIDIA/cutlass/blob/main/python/cutlass_library/generator.py, only found rules for 12.1 instead of 12.2.

}
):
cuda_cpp_scheduling._cuda_epilogue_fusion_counter = 0
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
assert (
cuda_cpp_scheduling._cuda_epilogue_fusion_counter == expected_fuse_count
), f"Expected fuse count of {expected_fuse_count} but got {cuda_cpp_scheduling._cuda_epilogue_fusion_counter}"
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.0

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
def mm(a, b):
return (a @ b) * 3.0

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self):
def mm(a, b):
return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0)

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self):
def mm(a, b):
# this should not be fused, since the output dtype is different from the matmul dtype
return (a @ b).to(torch.float32) * 0.00001

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self):
def mm(a, b):
return (a @ b) / b.size(1)

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_mm_bias(
self, dynamic: bool, max_autotune_gemm_backends: str
self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
kadeng marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Make sure autotuning mm in sub processes work without crashes.
Expand Down Expand Up @@ -294,7 +443,7 @@ def mm(a, b, bias):
torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1)

@parametrize("dynamic", (False, True))
def test_max_autotune_addmm(self, dynamic):
def test_max_autotune_addmm(self, dynamic=False):
"""
Make sure autotuning addmm in sub processes work without crashes.
"""
Expand Down
2 changes: 1 addition & 1 deletion third_party/cutlass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cutlass updates should go in a seperate PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, and also in the meantime merged as well..

Submodule cutlass updated 711 files
Loading
Loading