Skip to content

Commit 269c990

Browse files
coconutrubenpytorchmergebot
authored andcommitted
[inductor][choices] rename get_mm_configs to get_template_configs (#162293)
# why - eventually we want all templates to go through this - we're exposing this through diode as a sort of interface/API - avoid later renaming # what - rename get_mm_configs to get_template_configs - rename _finalize_mm_configs to _finalize_template_configs # testing - lintrunner - ci Differential Revision: [D81820641](https://our.internmc.facebook.com/intern/diff/D81820641) Pull Request resolved: #162293 Approved by: https://github.com/eellison ghstack dependencies: #161351, #161350
1 parent a326ef3 commit 269c990

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

torch/_inductor/choices.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def get_flex_decode_configs(
106106
flex_heuristics = self.get_config_heuristics(device_type)
107107
return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
108108

109-
def _finalize_mm_configs(
109+
def _finalize_template_configs(
110110
self,
111111
template_choices: dict[str, Generator[KernelTemplateChoice, None, None]],
112112
kernel_inputs: KernelInputs,
@@ -148,12 +148,12 @@ def get_ktc(
148148
"""
149149
Utility to get the KernelTemplateChoice generator for a specific input.
150150
151-
This is a per template/op call, whereas get_mm_configs is an op wide call (all templates).
151+
This is a per template/op call, whereas get_template_configs is an op wide call (all templates).
152152
Consider when overriding/using at which level you need to make decisions
153153
"""
154154
# Extract device_type from kernel_inputs
155155
device_type = kernel_inputs.device_type
156-
assert device_type is not None, "get_mm_configs requires a valid device type"
156+
assert device_type is not None, "get_ktc requires a valid device type"
157157
# Extract template_name from the template object
158158
template_name = template.uid
159159

@@ -233,7 +233,7 @@ def _need_to_fix_layout(
233233
not isinstance(ktc.template, ExternKernelChoice) for ktc in adjusted_choices
234234
)
235235

236-
def get_mm_configs(
236+
def get_template_configs(
237237
self,
238238
kernel_inputs: KernelInputs,
239239
templates: list[Union[KernelTemplate, ExternKernelChoice]],
@@ -270,7 +270,7 @@ def get_mm_configs(
270270
)
271271

272272
# Second pass: Adjust the template choices
273-
adjusted_choices = self._finalize_mm_configs(
273+
adjusted_choices = self._finalize_template_configs(
274274
template_choices,
275275
kernel_inputs,
276276
templates,

torch/_inductor/kernel/bmm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def may_require_contiguous(t, meta_t):
214214

215215
# Single unified call for all templates
216216
choices.extend(
217-
V.choices.get_mm_configs(
217+
V.choices.get_template_configs(
218218
kernel_inputs,
219219
templates_to_use,
220220
name,
@@ -290,6 +290,8 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
290290
templates_to_use.append(bmm_template)
291291

292292
# Single unified call for all templates
293-
choices.extend(V.choices.get_mm_configs(kernel_inputs, templates_to_use, name))
293+
choices.extend(
294+
V.choices.get_template_configs(kernel_inputs, templates_to_use, name)
295+
)
294296

295297
return autotune_select_algorithm(name, choices, kernel_inputs.nodes(), layout)

torch/_inductor/kernel/mm.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,9 @@ def tuned_mm(mat1, mat2, *, layout=None):
770770
templates_to_use.append(mm_contiguous_subgraph_template)
771771

772772
# Single unified call for all non-autoheuristic templates
773-
choices.extend(V.choices.get_mm_configs(kernel_inputs, templates_to_use, "mm"))
773+
choices.extend(
774+
V.choices.get_template_configs(kernel_inputs, templates_to_use, "mm")
775+
)
774776

775777
if (
776778
is_nonzero
@@ -805,7 +807,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
805807
always_included.append("extern_mm")
806808
num_choices_before_extra_configs = len(choices)
807809
choices.extend(
808-
V.choices.get_mm_configs(
810+
V.choices.get_template_configs(
809811
# TODO(coconutruben): remove once we deprecate ah
810812
# mm-extra is a hack to keep the ah functionality alive
811813
# while we transition to the unified kwargs retrieval
@@ -898,7 +900,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
898900
templates_to_use.append(mm_template)
899901

900902
# Single unified call for all templates
901-
choices.extend(V.choices.get_mm_configs(kernel_inputs, templates_to_use, name))
903+
choices.extend(
904+
V.choices.get_template_configs(kernel_inputs, templates_to_use, name)
905+
)
902906

903907
if use_cutlass and _use_cutlass_for_op(name):
904908
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(
@@ -944,7 +948,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
944948
[inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta)
945949
)
946950
choices.extend(
947-
V.choices.get_mm_configs(
951+
V.choices.get_template_configs(
948952
kernel_inputs,
949953
[aten_addmm],
950954
name,
@@ -966,7 +970,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
966970
templates_to_use.append(addmm_contiguous_subgraph_template)
967971

968972
# Single unified call for all templates
969-
choices.extend(V.choices.get_mm_configs(kernel_inputs, templates_to_use, name))
973+
choices.extend(
974+
V.choices.get_template_configs(kernel_inputs, templates_to_use, name)
975+
)
970976

971977
if (
972978
is_nonzero
@@ -1153,7 +1159,7 @@ def tuned_scaled_mm(
11531159

11541160
# Single unified call for all templates
11551161
choices.extend(
1156-
V.choices.get_mm_configs(
1162+
V.choices.get_template_configs(
11571163
kernel_inputs,
11581164
templates_to_use,
11591165
name,

torch/_inductor/kernel/mm_plus_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None):
167167

168168
# Single unified call for all templates
169169
choices.extend(
170-
V.choices.get_mm_configs(kernel_inputs, templates_to_use, "mm_plus_mm")
170+
V.choices.get_template_configs(kernel_inputs, templates_to_use, "mm_plus_mm")
171171
)
172172

173173
return autotune_select_algorithm(

0 commit comments

Comments
 (0)