Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions torchinductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

log = logging.getLogger(__name__)

TensorArg = namedtuple("TensorArg", ["name", "dtype"])
TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype"])
SizeArg = namedtuple("SizeArg", ["name", "expr"])


Expand Down Expand Up @@ -368,7 +368,9 @@ def python_argdefs(self):
call_args.append(inplaced.other_names[-1])
precompile_args.append(
TensorArg(
inplaced.inner_name, V.graph.get_dtype(inplaced.other_names[-1])
inplaced.inner_name,
inplaced.other_names[-1],
V.graph.get_dtype(inplaced.other_names[-1]),
)
)
for outer, inner in chain(
Expand All @@ -378,7 +380,7 @@ def python_argdefs(self):
continue
arg_defs.append(inner)
call_args.append(outer)
precompile_args.append(TensorArg(inner, V.graph.get_dtype(outer)))
precompile_args.append(TensorArg(inner, outer, V.graph.get_dtype(outer)))
for outer, inner in self.sizevars.items():
arg_defs.append(inner)
call_args.append(outer)
Expand Down
31 changes: 24 additions & 7 deletions torchinductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ def signature_of(arg):


def config_of(args):
from triton.runtime.jit import JITFunction
from ..compile_fx import ALIGNMENT

def is_aligned(x):
if isinstance(x, TensorArg):
return x.buffer not in V.graph.unaligned_buffers
assert isinstance(x, SizeArg)
return V.graph.sizevars.maybe_guard_multiple_of(x.expr, ALIGNMENT)

divisible_by_16 = [
i
for i, arg in enumerate(args)
if isinstance(arg, TensorArg)
or V.graph.sizevars.maybe_guard_multiple_of(arg.expr, JITFunction.divisibility)
]
divisible_by_16 = [i for i, arg in enumerate(args) if is_aligned(arg)]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol the place I worried about turned out to be problematic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it did :)

return instance_descriptor(tuple(divisible_by_16), ())


Expand Down Expand Up @@ -990,6 +991,7 @@ def codegen_kernel(self, name=None):
code.writeline(f"def {name or 'KERNEL_NAME'}({', '.join(argdefs)}):")
self.codegen_body()
with code.indent():
self.codegen_static_numels(code)
for old, new in self.args.aliases():
code.writeline(f"{old} = {new}")
code.splice(self.body)
Expand All @@ -1003,6 +1005,21 @@ def codegen_kernel(self, name=None):
wrapper.writeline("''')")
return wrapper.getvalue()

def codegen_static_numels(self, code):
"""
We get a small speedup from hard coding numels if they are static.
"""
for tree in self.range_trees:
if tree.prefix != "r" or self.inside_reduction:
if isinstance(V.graph.sizevars.simplify(tree.numel), sympy.Integer):
code.writeline(
f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)}"
)
elif not config.dynamic_shapes:
code.writeline(
f"{tree.prefix}numel = {V.graph.sizevars.size_hint(tree.numel)} # dynamic_shapes=False"
)

def reshape_size_str(self, i=None, x=None):
sizes = ["1"] * (len(self.range_trees) - int(self.numels[-1] == 1))
if i is not None:
Expand Down
2 changes: 2 additions & 0 deletions torchinductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def codegen_allocation(self, buffer):
return
if isinstance(layout, ir.AliasedLayout):
assert isinstance(layout.view, ir.ReinterpretView)
if not layout.maybe_guard_aligned():
V.graph.unaligned_buffers.add(name)
self.codegen_allocation(layout.view.data)
allocation = DeferredLine(
name, f"{name} = {layout.view.codegen_reference()} # alias"
Expand Down
1 change: 1 addition & 0 deletions torchinductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, gm: torch.fx.GraphModule, num_dynamic_inputs=None):
self.num_dynamic_inputs = num_dynamic_inputs
self.num_static_inputs = None
self.mutated_inputs = set()
self.unaligned_buffers = set()
self.randomness_offset = sympy.Integer(0)
self.randomness_seeds = []
self.name_to_buffer = {}
Expand Down
8 changes: 8 additions & 0 deletions torchinductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,14 @@ def __init__(self, view: "ReinterpretView"):
def make_indexer(self):
return self.as_fixed().make_indexer()

def maybe_guard_aligned(self):
offset = self.view.get_layout().offset
if offset == 0:
return True
from .compile_fx import ALIGNMENT

return V.graph.sizevars.maybe_guard_multiple_of(offset, ALIGNMENT)


class MutationLayout(Layout):
def __init__(self, target: IRNode):
Expand Down