Skip to content

Commit

Permalink
[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1)
Browse files Browse the repository at this point in the history
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

ghstack-source-id: 65b1d9466a7498cd97ee862add95daffcb9605f3
Pull Request resolved: #110890
  • Loading branch information
kadeng committed Oct 17, 2023
1 parent db76052 commit dfb1df5
Show file tree
Hide file tree
Showing 17 changed files with 1,486 additions and 152 deletions.
251 changes: 248 additions & 3 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import unittest

from typing import List, Optional
from typing import Callable, List, Optional

import torch
from torch import multiprocessing as mp
Expand Down Expand Up @@ -255,14 +255,259 @@ def mm(a, b):
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)

def test_max_autotune_cutlass_backend_dtype_conflict_fusion(
self,
dynamic: bool = False,
max_autotune_gemm_backends: str = "CUTLASS",
mixed_precision=False,
fp16=True,
):
"""
Test simple fusion that's not covered by specific lowering
"""

if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
mixed_precision
)

def mm(a, b):
return (a @ b).to(torch.float32) * 0.00001

# Note: The ops that are available
# also depend on the alignment of the shapes
# so if these shapes don't all align to at least 8 elements
# it can happen that no Cutlass 3.x op is available
# that allows fusions
a = torch.randn(256, 32).cuda()
b = torch.randn(32, 256).cuda()
if fp16:
a = a.half()
b = b.half()

with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.cutlass_only_evt_capable_ops": True,
"cuda.version": "12.2", # required to enable the Kernels we need
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)

def _test_max_autotune_cutlass_backend_epilogue_fusion(
self,
dynamic: bool = False,
max_autotune_gemm_backends: str = "CUTLASS",
mixed_precision=False,
fp16=True,
expected_fuse_count=1,
mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
):
from torch._inductor.codegen.cuda import cuda_scheduling

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
mixed_precision
)

# Note: The ops that are available
# also depend on the alignment of the shapes
# so if these shapes don't all align to at least 8 elements
# it can happen that no Cutlass 3.x op is available
# that allows fusions
a = torch.randn(256, 32).cuda()
b = torch.randn(32, 256).cuda()
if fp16:
a = a.half()
b = b.half()

with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.cutlass_only_evt_capable_ops": True,
"cuda.version": "12.2", # required to enable the Kernels we need
}
):
cuda_scheduling._cuda_epilogue_fusion_counter = 0
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
assert (
cuda_scheduling._cuda_epilogue_fusion_counter == expected_fuse_count
), f"Expected fuse count of {expected_fuse_count} but got {cuda_scheduling._cuda_epilogue_fusion_counter}"
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu6_fusion(
self,
dynamic: bool = False,
max_autotune_gemm_backends: str = "CUTLASS",
mixed_precision=True,
fp16=True,
):
"""
Test Relu6 fusion through Cutlass Epilogue Visitor Tree
"""

if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
mixed_precision
)

def mm(a, b):
return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0)

# Note: The ops that are available
# also depend on the alignment of the shapes
# so if these shapes don't all align to at least 8 elements
# it can happen that no Cutlass 3.x op is available
# that allows fusions
a = torch.randn(256, 32).cuda()
b = torch.randn(32, 256).cuda()
if fp16:
a = a.half()
b = b.half()

with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.cutlass_only_evt_capable_ops": True,
"cuda.version": "12.2", # required to enable the Kernels we need
}
):
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.0

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
def mm(a, b):
return (a @ b) * 3.0

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self):
def mm(a, b):
return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0)

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self):
def mm(a, b):
# this should not be fused, since the output dtype is different from the matmul dtype
return (a @ b).to(torch.float32) * 0.00001

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self):
def mm(a, b):
return (a @ b) / b.size(1)

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen, Triton, CUTLASS"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_mm_bias(
self, dynamic: bool, max_autotune_gemm_backends: str
self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
):
"""
Make sure autotuning mm in sub processes work without crashes.
Expand Down Expand Up @@ -293,7 +538,7 @@ def mm(a, b, bias):
torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1)

@parametrize("dynamic", (False, True))
def test_max_autotune_addmm(self, dynamic):
def test_max_autotune_addmm(self, dynamic=False):
"""
Make sure autotuning addmm in sub processes work without crashes.
"""
Expand Down
2 changes: 1 addition & 1 deletion third_party/cutlass
Submodule cutlass updated 429 files
1 change: 1 addition & 0 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def make_run_fn(
"Things need to be fixed to support non-zero workspace_size: "
"1) max autotune cache needs to store workspace size; "
"2) memory allocation needs to allocate / deallocate workspace correctly; "
"3) CUDATemplateBuffer.workspace_size needs to be set correctly"
)

# Generate partial function.
Expand Down
13 changes: 7 additions & 6 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ class InvalidVecISA(VecISA):
_bit_width = 0
_macro = ""
_arch_flags = ""
_dtype_nelements = {}
_dtype_nelements = {} # type: ignore[var-annotated]

def __str__(self) -> str:
return "INVALID_VEC_ISA"
Expand Down Expand Up @@ -1961,9 +1961,10 @@ def compile(cls, source_code, dst_file_ext) -> Tuple[str, str, str]:
with lock:
output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
if not os.path.exists(output_path):
cmd = cuda_compile_command(
cmdstr = cuda_compile_command(
[input_path], output_path, dst_file_ext
).split(" ")
)
cmd = cmdstr.split(" ")
try:
subprocess.check_output(
cmd, stderr=subprocess.STDOUT, env=os.environ
Expand Down Expand Up @@ -2095,7 +2096,7 @@ def process_pool() -> ProcessPoolExecutor:
# doesn't run, and we need to register our own handler.
# exitpriority has to be high, because another one of the finalizers will
# kill the worker thread that sends the shutdown message to the workers...
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) # type: ignore[attr-defined]
return pool

@classmethod
Expand All @@ -2120,12 +2121,12 @@ def warm_pool(cls) -> None:

# We force them to start here with some YOLOing of the internal methods.
if hasattr(pool, "_start_queue_management_thread"):
pool._start_queue_management_thread()
pool._start_queue_management_thread() # type: ignore[attr-defined]
else:
for _ in range(config.compile_threads):
pool._adjust_process_count()
if hasattr(pool, "_start_executor_manager_thread"):
pool._start_executor_manager_thread()
pool._start_executor_manager_thread() # type: ignore[attr-defined]
_compile_end()

@classmethod
Expand Down

0 comments on commit dfb1df5

Please sign in to comment.