Skip to content

Commit 69b3fb6

Browse files
authored
Remove unrolling with tma + pipelining (#994)
1 parent 3a71689 commit 69b3fb6

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

helion/_compiler/tile_strategy.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,23 +136,21 @@ def get_tl_range_kwargs(config: Config, block_idx: int) -> list[str]:
136136
range_unroll_factor = env.config_spec.range_unroll_factors.config_get(
137137
config.range_unroll_factors, block_idx, 0
138138
)
139-
if range_unroll_factor > 0:
140-
kwargs.append(f"loop_unroll_factor={range_unroll_factor}")
141-
142139
range_warp_specialize = env.config_spec.range_warp_specialize.config_get(
143140
config.range_warp_specializes, block_idx, None
144141
)
145-
if range_warp_specialize is not None:
146-
kwargs.append(f"warp_specialize={range_warp_specialize}")
147-
148142
range_num_stages = env.config_spec.range_num_stages.config_get(
149143
config.range_num_stages, block_idx, 0
150144
)
145+
num_stages = config.num_stages
151146

152-
if config.indexing == "tensor_descriptor" and range_num_stages > 0:
153-
# Tensor descriptor + multi-stage tl.range pipelines tend to cause
147+
if config.indexing == "tensor_descriptor":
148+
# Tensor descriptor + multi-stage pipelines in addition to unrolling tend to cause
154149
# CUDA "misaligned address" or "unspecified launch failure" errors.
155-
range_num_stages = 0
150+
if range_num_stages > 0:
151+
range_num_stages = 0
152+
if range_unroll_factor > 0 and num_stages > 1:
153+
range_unroll_factor = 0
156154
elif (
157155
range_num_stages > 1
158156
and range_unroll_factor > 1
@@ -170,6 +168,10 @@ def get_tl_range_kwargs(config: Config, block_idx: int) -> list[str]:
170168
max(1, int(math.ceil(remainder / step))), range_num_stages
171169
)
172170

171+
if range_unroll_factor > 0:
172+
kwargs.append(f"loop_unroll_factor={range_unroll_factor}")
173+
if range_warp_specialize is not None:
174+
kwargs.append(f"warp_specialize={range_warp_specialize}")
173175
if range_num_stages > 0:
174176
kwargs.append(f"num_stages={range_num_stages}")
175177

0 commit comments

Comments
 (0)