Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,59 @@ def fn(
y1_expected = fn(x1, w, b, mul1)
torch.testing.assert_close(y1, y1_expected)

def test_triton_template_with_prologues_and_dynamic_shape(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add another test where this is an additional usage of the prologue?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually updated this test so that it would be more aligned with our target.

def const(x: torch.Tensor, val) -> torch.Tensor:
return torch.full(x.size(), val).cuda()

def fn(
x: torch.Tensor, w: torch.Tensor
) -> torch.Tensor:
return torch.matmul(
torch.transpose(x, 0, 1) * torch.transpose(const(x, 0.05), 0, 1) + torch.transpose(const(x, 0.1), 0, 1),
torch.transpose(w, 0, 1)
)

torch.backends.cuda.matmul.allow_tf32 = False

with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": "Triton",
"prologue_fusion": True,
"max_prologue_opcount": 16,
}
):
compiled_fn = torch.compile(
fn, fullgraph=True, dynamic=True, mode="max-autotune-no-cudagraphs"
)

counters["inductor"]["cuda_prologue_fusion_counter"] = 0

M0 = 5
K = 5
N = 5
w = torch.rand(N, K).cuda()
x0 = torch.rand(K, M0).cuda()
y0 = compiled_fn(x0, w)
y0_expected = fn(x0, w)
torch.testing.assert_close(y0, y0_expected)

M1 = 8
K = 8
N = 8
w = torch.rand(N, K).cuda()
x1 = torch.rand(K, M1).cuda()
y1 = compiled_fn(x1, w)
y1_expected = fn(x1, w)
torch.testing.assert_close(y1, y1_expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an assert that the prologue was actually generated? Either by inspecting the output code (there are some helpers to get this) or by counting kernels.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an assert to check if the fusion count is expected.


actual_count = counters["inductor"]["cuda_prologue_fusion_counter"]
assert (
actual_count == 1
), f"Expected fuse count of 1 but got {actual_count}"


@config.patch(
benchmark_kernel=True,
fallback_random=True,
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,10 @@ def load(name: str, index: sympy.Expr) -> CSEVariable:
return store_cache[name]
return self.load(name, index)

@staticmethod
def set_cse_store_cache(name: str, value: str):
self.cse.store_cache[name] = value

@staticmethod
def store(
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def define_kernel(self, src_code: str, node_schedule) -> str:
return kernel_name

def codegen_template(
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode]
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode], isEpilogue=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isEpilogue => is_epilogue

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed a naming convention.

):
"""
Codegen a CUDA template, possibly with fused epilogues
Expand Down
5 changes: 3 additions & 2 deletions torch/_inductor/codegen/cuda_combined_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .triton import TritonScheduling

from typing import Any, Dict, Optional, Set

class CUDACombinedScheduling(BaseScheduling):
"""
Expand Down Expand Up @@ -48,15 +49,15 @@ def group_fn(self, sizes):
return self._triton_scheduling.group_fn(sizes)

def codegen_template(
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode], rename_dict: Optional[Dict[str, Set[Any]]] = None, is_epilogue = True
):
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
return self._cuda_cpp_scheduling.codegen_template(
template_node, epilogue_nodes
)
else:
return self._triton_scheduling.codegen_template(
template_node, epilogue_nodes
template_node, epilogue_nodes, rename_dict, is_epilogue
)

def codegen_nodes(self, nodes: List[SchedulerNode]):
Expand Down
48 changes: 42 additions & 6 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,8 @@ def __init__(
self.iter_vars_count = itertools.count()
self.inside_reduction = self.numels[-1] != 1
self.body = IndentedBuffer()
self.prologue_body = IndentedBuffer()
self.prologue_reuse = IndentedBuffer()
self.indexing_code = IndentedBuffer()
self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment]
self.outside_loop_vars: Set[Any] = set()
Expand All @@ -1264,6 +1266,8 @@ def __init__(
self.block_ptr_id = itertools.count()
# buffer accesses in the kernel
self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
self.is_epilogue = True
self.rename_dict = dict()

self.persistent_reduction: bool = (
not disable_persistent_reduction
Expand Down Expand Up @@ -2423,6 +2427,27 @@ def where_cond(value):
result_var.mask_vars = masks # type: ignore[attr-defined]
return result_var

def codegen_prologue_body(self, x, mask=False):
self.prologue_body.splice(self.loads)
if mask:
self.prologue_body.clear()
self.prologue_body.splice(self.compute)
else:
if self.prologue_reuse:
self.prologue_body.clear()
self.prologue_body.splice(self.prologue_reuse.getvalue().replace("prologue_val", x))
self.prologue_reuse.clear()
else:
compute = self.compute.getvalue()
if compute.find("prologue_val"):
self.prologue_reuse.splice(compute)
self.prologue_body.splice(compute.replace("prologue_val", x))
else:
# cannot find "prologue_val" from compute buffer.
None

self.compute.clear()

def codegen_body(self):
"""
Concat output code from index_code, loads, compute, stores,
Expand Down Expand Up @@ -3087,6 +3112,11 @@ def can_fuse(self, node1, node2):
if not is_triton_template:
why("node1 is not TritonTemplateBuffer")
return is_triton_template
elif node2.is_template():
is_triton_template = isinstance(node2.node, TritonTemplateBuffer)
if not is_triton_template:
why("node2 is not TritonTemplateBuffer")
return is_triton_template

# check for a bad combined tiling
tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1)
Expand Down Expand Up @@ -3557,29 +3587,35 @@ def define_kernel(self, src_code, node_schedule):

return kernel_name

def codegen_template(self, template_node, epilogue_nodes):
def codegen_template(self, template_node: BaseSchedulerNode, nodes: List[SchedulerNode], rename_dict: Optional[Dict[str, Set[Any]]] = None, is_epilogue = True):
"""
Codegen a triton template
"""
if not is_epilogue:
counters["inductor"]["cuda_prologue_fusion_counter"] += len(nodes)

_, (numel, rnumel) = template_node.group
assert rnumel == 1
kernel, render = template_node.node.make_kernel_render(template_node.node)

with kernel:
for node in [template_node, *epilogue_nodes]:
kernel.is_epilogue = is_epilogue
kernel.rename_dict = rename_dict
for node in [template_node, *nodes]:
node.mark_run()
partial_code = render()
for node in epilogue_nodes:
for node in nodes:
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))

# finalize must be called after adding epilogue above
with V.set_kernel_handler(kernel):
with kernel, V.set_kernel_handler(kernel):
# TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion.
src_code = (
partial_code
if isinstance(partial_code, str)
else partial_code.finalize()
)
node_schedule = [template_node, *epilogue_nodes]
node_schedule = [template_node, *nodes]

if config.benchmark_kernel:
num_gb = kernel.estimate_kernel_num_bytes() / 1e9
Expand All @@ -3594,7 +3630,7 @@ def codegen_template(self, template_node, epilogue_nodes):

kernel_name = self.define_kernel(src_code, node_schedule)
self.codegen_comment(node_schedule)
kernel.call_kernel(kernel_name, template_node.node)
kernel.call_kernel(kernel_name, template_node.node, rename_dict=rename_dict)
V.graph.removed_buffers |= kernel.removed_buffers
V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
self.scheduler.free_buffers()
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def is_fbcode():
# do epilogue fusions before other fusions
epilogue_fusion_first = False

# fuse pointwise into templates
prologue_fusion = True

# threshold to emit prologue code
max_prologue_opcount = 4

# enable pattern match+replace optimizations
pattern_matcher = True

Expand Down
15 changes: 14 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3413,7 +3413,20 @@ def dummy(index, rindex):
deps = dependencies.extract_read_writes(
dummy, self.get_size(), (), normalize=True
)
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}

if not config.prologue_fusion:
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
else:
for x in self.inputs:
def loader(index):
indexer = self.layout.make_indexer()
return ops.load(x.get_name(), indexer(index))

deps_reads = dependencies.extract_read_writes(
loader, self.get_size(), normalize=True
)
deps.reads |= deps_reads.reads

return deps

def get_reduction_size(self):
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/kernel/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
{% if ENABLED_PROLOGUE_FUSION %}
{{prologue("a", 0, "rk[None, :] < k")}}
{{prologue("b", 1, "rk[:, None] < k")}}
{% endif %}
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/kernel/mm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None):
ALLOW_TF32=allow_tf32,
ACC_TYPE=acc_type(layout.dtype),
B_PROLOGUE_CAST_TYPE=b_prologue_cast_type,
ENABLED_PROLOGUE_FUSION=inductor_config.prologue_fusion,
num_stages=config.num_stages,
num_warps=config.num_warps,
**config.kwargs,
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/ops_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def store(
"""
...

def set_cse_store_cache(self, name: str, value: str) -> None:
"""
Cache 'value' with 'name'.
"""
...

# TODO: Better explain how the "collective" semantics of these ops;
# remember that the input value is a scalar, you can't reduce on it in the
# traditional sense!
Expand Down
68 changes: 61 additions & 7 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,22 @@ def fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(
node2, (SchedulerNode, FusedSchedulerNode)
)
return cls(node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())) # type: ignore[arg-type]
fused = cls(node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())) # type: ignore[arg-type]

common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
node2.read_writes.reads | node2.read_writes.writes
)
common_memory_deps = {
dep for dep in common_memory_deps if not dep.has_unbacked_symbols()
}
keys = (node1.read_writes.writes) & (common_memory_deps)
reads = set()
for read in node1.read_writes.reads:
reads.add(read.name)
if len(keys) > 0:
key = keys.pop()
fused.rename_dict[key.name] = reads
return fused

def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]):
# NB: No need to call super().__init__() because we don't need to re-use any of its logic.
Expand All @@ -847,6 +862,7 @@ def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]):
self.users: List[NodeUser] = []
self.inverse_users = []
self.node_users = []
self.rename_dict: Dict[str, Set[...]] = dict()
self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
self.ancestors = set.union(
*[x.ancestors for x in snodes if x.ancestors is not None]
Expand Down Expand Up @@ -1988,9 +2004,6 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
):
return False

if node2.is_template():
why("templates can only fuse epilogues")
return False
if node1.is_template() and (
node2.has_aliasing_or_mutation()
or node2.is_reduction()
Expand All @@ -2013,6 +2026,40 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
why("no shared data")
return False # heuristic not needed for correctness

if node2.is_template():
if node2.node.get_origin_node().name != "mm":
why("prologue fusion can be enabled only with matmul")
return False

if not config.prologue_fusion:
why("prologue fusion is disabled in config")
return False
# Avoid fusing multiple nodes into template
if len(node1.get_nodes()) > 1:
why("fusing multiple nodes into template is not supported")
return False
# Avoid fusing expensive ops in prolgoue by estimating the cost based on the code size
node = node1.get_nodes()[0]
if isinstance(node.node, ComputedBuffer):
pointwise = node.node.data
if pointwise.inner_fn_opcount() >= config.max_prologue_opcount:
why("prolgoue op count is more than the threshold %s", config.max_prologue_opcount)
return False
if len(pointwise.get_reads()) > 1:
why("multiple reads in prologue is currently not supported")
return False

# Check if node1's reads and node2's reads have common indices
index1 = {
node1_read.index for node1_read in node1.read_writes.reads
}
index2 = {
node2_read.index for node2_read in node2.read_writes.reads
}
if len(index1 - index2) > 0:
why("Indices are different")
return False

if (
not node1.is_foreach()
and not node2.is_foreach()
Expand Down Expand Up @@ -2321,8 +2368,15 @@ def codegen(self):
self.buffer_names_to_free.update(node.last_usage)

if node.is_template():
node, *epilogue = node.get_nodes()
self.get_backend(device).codegen_template(node, epilogue) # type: ignore[possibly-undefined]
node1, *node2 = node.get_nodes()
if isinstance(node1.node, ir.TemplateBuffer):
# epilogue
self.get_backend(device).codegen_template(node1, node2)
else:
*node1, node2 = node.get_nodes()
if isinstance(node2.node, ir.TemplateBuffer):
# prologue
self.get_backend(device).codegen_template(node2, node1, rename_dict=node.rename_dict, is_epilogue=False)
elif node.is_extern():
self.codegen_extern_call(node)
elif node.is_foreach():
Expand Down Expand Up @@ -2394,7 +2448,7 @@ def group_fn(self, sizes):
raise NotImplementedError()

def codegen_template(
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode]
self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode], rename_dict: Optional[Dict[str, Set[Any]]] = None, is_epilogue=True
):
"""
Given a template node, generate a kernel.
Expand Down
Loading