Skip to content

Commit

Permalink
[inductor max autotune] Flexible GEMM layout autotuning
Browse files Browse the repository at this point in the history
This diff introduces memory layout autotuning and
flexibilizes memory layouts that are accepted and
written by the Cutlass GEMM Kernels.

During autotuning, if Cutlass GEMM Kernels have
inputs with Flexible Layouts, all possible combinations
of row-major or column major layouts are tried during
autotuning.

Note: Flexible input layouts are practically relevant in certain internal production models, this made these changes neccessary.

Test Plan:

 * Additional Unit test(s) (more tbd)
 * CI

ghstack-source-id: 5dcfc8eb1712ec40672e9cf2b1a878cae1ee2311
Pull Request resolved: #114319
  • Loading branch information
kadeng committed Jan 23, 2024
1 parent e1e6b67 commit 2e42a2d
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 122 deletions.
22 changes: 22 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,28 @@ def mm(a, b):
k=257,
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
def test_max_autotune_cutlass_backend_double_matmul(
self,
):
def mm(a, b):
return ((a @ b).T @ a) - 4.5

# For this, we have no Cutlass Kernel because of alignment constraints.
# We expect the ATen fallback to be used, but this will not register a fusion,
# therefore expected_fuse_count=0
self._test_max_autotune_cutlass_backend_epilogue_fusion(
mixed_precision=False,
fp16=True,
expected_fuse_count=1,
mm=mm,
m=128,
n=128,
k=128,
)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(torch.version.hip, "HIP not supported")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
Expand Down
9 changes: 5 additions & 4 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Callable,
ClassVar,
Dict,
Generator,
List,
NamedTuple,
Optional,
Expand Down Expand Up @@ -1419,17 +1420,17 @@ def maybe_append_choice(self, choices, **kwargs):
Maybe generates a new ChoiceCaller and appends it into existing choices.
choices: A list of ChoiceCallers.
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
kwargs: Additional kwargs to be passed to self.generate() to generate new ChoiceCallers.
"""

try:
choices.append(self.generate(**kwargs))
choices.extend(self.generate(**kwargs))
except NotImplementedError:
pass

def generate(self, **kwargs) -> ChoiceCaller:
def generate(self, **kwargs) -> Generator[ChoiceCaller, None, None]:
"""
Generates a ChoiceCaller instance from the given arguments.
Generates a sequence of ChoiceCaller instances from the given arguments.
"""

raise NotImplementedError()
36 changes: 21 additions & 15 deletions torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def _can_fuse_epilogue_impl(
)
if last_epilogue_name not in additional_node.get_read_names():
return False
if additional_node.layout != cuda_template_buffer.layout:
return False

template_buffer_names: Set[str] = cuda_template_buffer.get_read_names()
fused_reading_buffer_names: Set[str] = set(template_buffer_names)
Expand All @@ -120,21 +118,24 @@ def _can_fuse_epilogue_impl(
fused_reading_buffer_names.union(additional_node.get_read_names())
- fused_written_names
)
if len(after_fuse_reading_buffers) > 3:
return False
if len(after_fuse_reading_buffers) > len(fused_reading_buffer_names):
# Check that the layout of the additional input is compatible
added_names = after_fuse_reading_buffers - fused_reading_buffer_names
assert len(added_names) == 1, "Only one additional input is supported."
added_name = added_names.pop()
added_node = V.graph.get_buffer(added_name)
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate

template: CUDATemplate = cuda_template_buffer.template
if not template.are_inputs_layout_compatible(
[n.layout for n in template.input_nodes] + [added_node.layout]
):
return False
for added_name in added_names:
added_node = V.graph.get_buffer(added_name)
from torch._inductor.codegen.cuda.cuda_template import CUDATemplate

template: CUDATemplate = cuda_template_buffer.template
check_layouts = [n.layout for n in template.input_nodes[:2]] + [
added_node.layout
]
if not template.are_inputs_layout_compatible(
[n.layout for n in template.input_nodes[:2]] + [added_node.layout]
):
log.warning(
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, since the layouts (A,B,C)={check_layouts} are not compatible" # noqa: B950, G004
)
return False
try:
from torch._inductor.codegen.cuda.cutlass_epilogue_gen import (
CutlassEVTEpilogueArgumentFormatter,
Expand All @@ -158,9 +159,14 @@ def _can_fuse_epilogue_impl(
else:
# Likely due to unsupported dtype.
log.warning(
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950, G004
)
return False
if len(after_fuse_reading_buffers) > 3:
log.warning(
f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, since that would require auxiliary input support." # noqa: G004, B950, G004
)
return False
return True

@staticmethod
Expand Down
12 changes: 11 additions & 1 deletion torch/_inductor/codegen/cuda/cuda_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from ... import ir
from ...autotune_process import CUDABenchmarkRequest
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox
from ...ir import Buffer, CUDATemplateBuffer, FlexibleLayout, IRNode, Layout, TensorBox
from ...select_algorithm import ChoiceCaller
from ...utils import sympy_product
from ...virtualized import V
Expand Down Expand Up @@ -352,6 +352,16 @@ def info_dict(self) -> Dict[str, PrimitiveInfoType]:
return {"backend": "CUDA", "op_type": "unknown"}

def output_node(self) -> TensorBox:
for i, node in enumerate(self.input_nodes):
if (
hasattr(node, "freeze_layout_with_same_order")
and hasattr(node, "layout")
and isinstance(node.layout, FlexibleLayout)
):
node.freeze_layout_with_same_order(
self.bmreq.input_tensor_meta[i].strides
)

return TensorBox.create(
CUDATemplateBuffer(
layout=self.layout,
Expand Down
Loading

0 comments on commit 2e42a2d

Please sign in to comment.