Skip to content

Commit

Permalink
[Inductor CUTLASS backend] Epilogue fusion codegen (Step 1) (#110890)
Browse files Browse the repository at this point in the history
Summary:

This PR adds epilogue fusion code generation support for the new experimental
[Inductor Cutlass backend]([#108015]).

Details:

A fusion happens on the GEMM template level by taking a Cutlass 3.x GEMM Universal Matmul Kernel template
and adding a custom template functor based on Cutlass new “Epilogue Visitor Trees” (EVT) on top, which represents and
performs the computation of the fused Pointwise / Elementwise computation nodes.

This is the approach dictated by [NVIDIA/cutlass example 49](https://github.com/NVIDIA/cutlass/blob/main/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu),
which is currently the only documentation and example of Cutlass Epilogue Visitor Trees.

This EVT functor in turn is a hierarchical template expression which represents an abstract syntax tree of the fused computation to perform.
A second codegen task is to create a hierarchical initializer expression, which provides potentially necessary arguments
to each of the functor subexpressions.

Step 1 functionality:

 * End to end code generation is possible using the above approach.
 * Supports simple elementwise expression fusion of chains of elementwise operations (with scalar constants )
   after a matmul.
 * Elementwise operation support includes addition, subtraction, multiplication, division, minimum, maximum etc.
 * Examples / Unit tests include ReLU and ReLU6 fusion.
 * Support for fp16 and fp16 with fp32 accumulation data types.
 * Generates SM90 ( Hopper ) based CUDA Kernels ( as Cutlass up to 3.2.0 only supported EVT for SM90 )

The following is not yet supported, and is left for future work:

 * Full operation support ( e.g. full set of all ops usually handled via V.ops handlers )
 * Cutlass EVT with SM80 support ( possible in Cutlass 3.2.1 according to release notes, but not yet documented )
 * Add support for additional (auxiliary) inputs, which changes the Template Kernels' call signature
 * Add support for additional (auxiliary) outputs ( requires support for full computation graphs )
 * Add support for reduction operations and operations which use different output layouts than the input
 * Add support for additional dtypes ( as far as Cutlass allows )

This PR updates third_party/cutlass to v3.2.2, which has some important improvements and features
for the inductor backend.

See also Cutlass release notes:
https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1 and https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2

Notable changes in Cutlass 3.2.1 include:
 * Cutlass codegen python code has moved into a package with the "cutlass_library" namespace, which allows to
   prevent namespace clashes without resolving to monkey-patching ( which was done earlier ).
 * Support for SM80 epilogue visitor trees ( according to the Release Notes, not tried yet )
 * Small API changes to the cutlass_library API ( requires adapting the inductor backend code )

Notable changes in Cutlass 3.2.2 include:
 * Bugfix that led to CUDA Illegal memory access in some Pytorch unit tests involving flash attention

 Test Plan:
  * CI
  * pytest test/inductor/test_max_autotune.py

Note: So far, the CUTLASS backend is still disabled by default. Benchmarks are planned once more advanced fusions are enabled.

Differential Revision: [D50988161](https://our.internmc.facebook.com/intern/diff/D50988161)
Pull Request resolved: #110890
Approved by: https://github.com/jansel
ghstack dependencies: #112762
  • Loading branch information
kadeng authored and pytorchmergebot committed Nov 6, 2023
1 parent 59e003d commit bdfde62
Show file tree
Hide file tree
Showing 16 changed files with 1,348 additions and 122 deletions.
155 changes: 152 additions & 3 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os
import unittest

from typing import List, Optional
from typing import Callable, List, Optional

import torch
from torch import multiprocessing as mp
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import reset_rng_state
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.autotune_process import (
BenchmarkRequest,
Expand Down Expand Up @@ -256,14 +257,162 @@ def mm(a, b):
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)

def _test_max_autotune_cutlass_backend_epilogue_fusion(
self,
dynamic: bool = False,
max_autotune_gemm_backends: str = "CUTLASS",
mixed_precision=False,
fp16=True,
expected_fuse_count=1,
mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
):
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
mixed_precision
)

# Note: The ops that are available
# also depend on the alignment of the shapes
# so if these shapes don't all align to at least 8 elements
# it can happen that no Cutlass 3.x op is available
# that allows fusions
a = torch.randn(256, 32).cuda()
b = torch.randn(32, 256).cuda()
if fp16:
a = a.half()
b = b.half()

with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": False,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 4,
"cuda.cutlass_only_evt_capable_ops": True,
"cuda.version": "12.2", # required to enable the Kernels we need
}
):
counters["inductor"]["cuda_epilogue_fusion_counter"] = 0
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"]
assert (
actual_count == expected_fuse_count
), f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.0

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
def mm(a, b):
return (a @ b) * 3.0

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self):
def mm(a, b):
return (a @ b) * 3.3 - 1.234

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self):
def mm(a, b):
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self):
def mm(a, b):
return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0)

# The pointwise ops seem to be pre-fused into a single Pointwise
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self):
def mm(a, b):
# this should not be fused, since the output dtype is different from the matmul dtype
return (a @ b).to(torch.float32) * 0.00001

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self):
def mm(a, b):
return (a @ b) / b.size(1)

self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm
)

# TODO: Enable dynamic test cases when dynamic support is added.
@unittest.skipIf(not SM75OrLater, "need sm_75")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@parametrize("dynamic", (False,))
@parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS"))
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_mm_bias(
self, dynamic: bool, max_autotune_gemm_backends: str
self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
):
"""
Make sure autotuning mm in sub processes work without crashes.
Expand Down Expand Up @@ -294,7 +443,7 @@ def mm(a, b, bias):
torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1)

@parametrize("dynamic", (False, True))
def test_max_autotune_addmm(self, dynamic):
def test_max_autotune_addmm(self, dynamic=False):
"""
Make sure autotuning addmm in sub processes work without crashes.
"""
Expand Down

0 comments on commit bdfde62

Please sign in to comment.