Skip to content

Commit

Permalink
Replace TTIR string parsing with structured MLIR walk in Triton kerne…
Browse files Browse the repository at this point in the history
…l mutation analysis (#120476)

Summary: Previously, we relied on the `lark`-based parsing of the string TTIR representation dumped by the Triton compiler. However, this has proven to be brittle in the face of changes both in the user-written Triton kernel code and in the Triton compiler code.

In this PR, we add an alternative way of mining the function information from the TTIR based on walking the tree of structured MLIR entities. To this end, we rely on the MLIR bindings exposed by `libtriton` (related PR in Triton: triton-lang/triton#3191).

For now, we introduce gating based on whether `ttir_module.hasattr("walk")`. This will allow switching to the newly introduced TTIR analysis approach only when the new MLIR bindings (including that of `ModuleOp::walk`) become available in the Triton pin. Before then, we'll keep using the old string TTIR parsing-based approach.

Test Plan: The new functionality was tested locally with the latest Triton version compiled with the added new MLIR bindings: all Triton kernel mutation tests in `test_triton_kernels.py` are passing. Here we rely on the CI for regression testing, but it won't cover the new functionality due to gating.

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: #120476
Approved by: https://github.com/oulgen
  • Loading branch information
aakhundov authored and pytorchmergebot committed Mar 1, 2024
1 parent 8861507 commit ea7149a
Showing 1 changed file with 173 additions and 7 deletions.
180 changes: 173 additions & 7 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Expand Up @@ -156,7 +156,156 @@ def generate_ttir(kernel, kwargs):
ttir_module = src.make_ir(options, context)
if not ttir_module.verify():
raise Exception("Verification for TTIR module has failed")
return str(ttir_module), ordered_tensor_names

return ttir_module, ordered_tensor_names


def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]:
"""
Walk the `ttir_module` bottom up to mine the `functions` from
the structured MLIR entities representing the Triton kernel
(mlir::Operation, mlir::Block, mlir::Region).
"""
functions: Dict[str, Dict[Intermediate, List[Op]]] = {}

# block id --> op result (Intermediate) --> one or more ops
op_stack: Dict[int, Dict[Intermediate, List[Op]]] = defaultdict(
lambda: defaultdict(list)
)
region_id_to_block_ids: Dict[int, List[int]] = defaultdict(list)
block_id_to_block_arg_ids: Dict[int, List[int]] = {}
replacements: Dict[int, Union[Intermediate, Param]] = {}
reindex_map: Dict[int, int] = {}
next_fake_intermediate = 0

def reindex(idx):
if idx not in reindex_map:
reindex_map[idx] = len(reindex_map)
return reindex_map[idx]

def mlir_to_functions(op) -> None:
name: str = op.get_name()
if name == "builtin.module":
# this wraps all tt.func ops
return

operand_ids: List[int] = [
reindex(op.get_operand(i).id()) for i in range(op.get_num_operands())
]
result_ids: List[int] = [
reindex(op.get_result(i).id()) for i in range(op.get_num_results())
]

child_block_ids: List[int] = []
for i in [op.get_region(i).id() for i in range(op.get_num_regions())]:
# as the walk is bottom-up, the region_id_to_block_ids[i]
# must be populated by the time we process the enclosing op
child_block_ids.extend(region_id_to_block_ids[i])

parent_block_id = -1
parent_block = op.get_block()
if parent_block is not None:
parent_block_id = parent_block.id()
if parent_block_id not in block_id_to_block_arg_ids:
block_id_to_block_arg_ids[parent_block_id] = []
for i in range(parent_block.get_num_arguments()):
block_id_to_block_arg_ids[parent_block_id].append(
reindex(parent_block.get_argument(i).id()),
)
# the region info is collected via ops' parent blocks to be
# used later when the region's encloding op is traversed
parent_region = parent_block.get_parent()
if parent_region is not None:
region_id_to_block_ids[parent_region.id()].append(parent_block_id)

nonlocal next_fake_intermediate

if name == "tt.func":
# for function ops: gather and inline
# the ops from all child blocks
fn_ops = defaultdict(list)
for child_block_id in child_block_ids:
for result, block_fn_ops in op_stack.pop(child_block_id).items():
for block_fn_op in block_fn_ops:
fn_ops[result].append(block_fn_op)

# replace the corresponding Intermediates in the
# child op args with the function args (Params)
for i, idx in enumerate(block_id_to_block_arg_ids[child_block_ids[0]]):
replacements[idx] = Param(i)

for fn_op_list in fn_ops.values():
for fn_op in fn_op_list:
for i in range(len(fn_op.args)):
arg = fn_op.args[i]
if isinstance(arg, Intermediate) and arg.idx in replacements:
fn_op.args[i] = replacements[arg.idx]

# next function capture starts
# with empty replacements
replacements.clear()

fn_name = op.get_str_attr("sym_name")
functions[fn_name] = fn_ops
elif child_block_ids:
if name in ("scf.if", "scf.for", "scf.while"):
# for blocked control flow ops: inline the enclosed
# ops into the parent block + rewire the last op in
# each child block (yield) to return the scf result
yield_ops = []
for block_id in child_block_ids:
# the block args used as operands of the ops in the block
# (and nested blocks inlined in the current block by now)
# are replaced by new fake Intermediates to avoid "this
# operand is not returned by anything other op in the fn"
# error in the downstream analysis
for idx in block_id_to_block_arg_ids[block_id]:
next_fake_intermediate -= 1
replacements[idx] = Intermediate(next_fake_intermediate)

if block_id in op_stack:
block_ops = op_stack.pop(block_id)
if not block_ops:
continue
last_ret, last_ops = block_ops.popitem()
if all(op.name == "scf.yield" for op in last_ops):
# if last_ops are scf.yield, treat them separately
yield_ops.extend(last_ops)
else:
# otherwise, return last_ops to the block
block_ops[last_ret] = last_ops
for op_result, child_ops in block_ops.items():
op_stack[parent_block_id][op_result].extend(child_ops)

scf_results = [Intermediate(idx) for idx in result_ids]
for scf_result in scf_results:
for yield_op in yield_ops:
op_stack[parent_block_id][scf_result].append(yield_op)
else:
# TODO(oulgen): add support for tt.reduce
raise Exception(
f"Unknown blocked function: {name}. Can't capture the TTIR."
)
else:
callee = None
if name == "tt.call":
callee = op.get_flat_symbol_ref_attr("callee")
args: List[Union[Param, Intermediate]] = [
Intermediate(operand) for operand in operand_ids
]
block_ops = op_stack[parent_block_id]
if result_ids:
for result_id in result_ids:
res = Intermediate(result_id)
block_ops[res].append(Op(name, callee, args, res))
else:
next_fake_intermediate -= 1
fake_res = Intermediate(next_fake_intermediate)
block_ops[fake_res].append(Op(name, callee, args, fake_res))

ttir_module.walk(mlir_to_functions)

return functions


def parse_ttir(ttir, kwargs):
Expand Down Expand Up @@ -370,7 +519,7 @@ def analyze_kernel_mutations(functions, fn_name, num_args):
Analyzes the graph to detect all sinks from a predefined list of sinks
by using triton's MemWrite trait list. NOTE: What if triton exposed this?
From each sink, it traverses the CFG backwards to identify all the input
pointers that are mutated
pointers that are mutated.
"""
# Name of mutation op to mutated parameter indices
# List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td
Expand Down Expand Up @@ -430,15 +579,25 @@ def identify_mutated_tensors(kernel, kwargs):
3) Analyzes the graph to detect all input tensor mutations
"""

ttir_module = None
functions = None
try:
from torch._dynamo import config

if not config.optimize_user_defined_triton_kernels:
raise Exception("optimize_user_defined_triton_kernels is False")

ttir, ordered_tensor_names = generate_ttir(kernel, kwargs)
functions = parse_ttir(ttir, kwargs)
ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)

# extract functions from TTIR
if hasattr(ttir_module, "walk"):
# use MLIR bindings exposed by Triton code
functions = ttir_to_functions(ttir_module)
else:
# parse string representation of Triton IR
functions = parse_ttir(str(ttir_module), kwargs)

assert functions is not None
kernel_name = next(iter(functions.keys()))
# Triton codegen modifies the name
assert kernel.fn.__name__ in kernel_name
Expand All @@ -457,13 +616,20 @@ def identify_mutated_tensors(kernel, kwargs):
import traceback

warnings.warn(
"Encountered an exception in identify_mutated_tensors, assuming every input is mutated"
)
log.debug(
"Encountered an exception in identify_mutated_tensors, "
"assuming every input is mutated:\n"
"".join(
traceback.TracebackException.from_exception(e).format() # noqa: G001
)
)
if ttir_module is not None:
log.debug("TTIR:\n%s", str(ttir_module))
if functions is not None:
log.debug("functions:")
for name, fn in functions.items():
log.debug("===\t%s\t===", name)
for ret, ops in fn.items():
log.debug("%s\t=>\t%s", ret, ops)
return [key for key, value in kwargs.items() if isinstance(value, Tensor)]


Expand Down

0 comments on commit ea7149a

Please sign in to comment.