-
Notifications
You must be signed in to change notification settings - Fork 58
Rebenchmark configs to avoid noise #654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #654, branch: jansel/stack/146
a63cc7d to
427da19
Compare
300c778 to
8a7ccb7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you check how much this increases benchmarking time versus how much better results we gain?
stack-info: PR: #654, branch: jansel/stack/146
8a7ccb7 to
fd109de
Compare
3cd028f to
2626b41
Compare
stack-info: PR: #654, branch: jansel/stack/146
fd109de to
3703e20
Compare
This change is somehow causing misaligned memory addresses my local GPU why autotuning matmul. I am a bit puzzled by how the PR could be causing that since it doesn't touch codegen. So I am still debugging. Ideas welcome! |
stack-info: PR: #654, branch: jansel/stack/146
3703e20 to
c3cc755
Compare
In general, I think we need to prune the set of configs more for autotuning because we are having similar cuda errors on CI benchmarking |
|
yeah +1 I've seen it in #634 and #630. I wonder whether the "misaligned memory address" are due to autotuning now exploring a different part of the config space. #649 (merged) should allow us to see the full kernel config for reproing the misaligned memory issues. I can help with debugging some of the issues (I've seen two so far, both related to multi-stage pipelining: 5ae337b#diff-cb3b5c8f9dd5a38792e17c09e227adfe5346bb85e3ff62ddecadc28c085b1cecR264 and 89488e3, although I am not 100% confident on the root cause yet). |
|
Yeah it's a bit strange because the configs pass on the first run, then the same config (not even recompiled same fn) fails on the rerun. |
|
@jansel I found that this can deterministically repro the matmul "misaligned address" issue: # rm -rf /tmp/torchinductor_${USER}/ && HELION_AUTOTUNE_RANDOM_SEED=2011902841 CUDA_LAUNCH_BLOCKING=1 python examples/matmul.py
Testing helion correctness...
[0s] Set autotune random seed to 2011902841
[0s] Starting DifferentialEvolutionSearch with population=40, generations=20, crossover_rate=0.8
Traceback (most recent call last):
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 129, in benchmark_function
fn(*self.args) # make sure the kernel is compiled
^^^^^^^^^^^^^^
File "/tmp/torchinductor_willfeng/6o/c6o7gxb4etif4stiuvwlyzavwmvtigmpty7whaklj5wbf4vkymag.py", line 60, in matmul
_launcher(_helion_matmul, (_NUM_SM,), x, y, out, _NUM_SM, _BLOCK_SIZE_1, 1, _BLOCK_SIZE_2, num_warps=1, num_stages=4)
File "/data/users/willfeng/helion/helion/runtime/__init__.py", line 63, in default_launcher
return triton_kernel.run(
^^^^^^^^^^^^^^^^^^
File "/home/willfeng/local/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 757, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/home/willfeng/local/miniconda3/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 712, in __call__
self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl,
RuntimeError: Triton Error [CUDA]: misaligned address
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/willfeng/helion/examples/matmul.py", line 182, in <module>
main()
File "/data/users/willfeng/helion/examples/matmul.py", line 177, in main
check(1024, 1024, 1024)
File "/data/users/willfeng/helion/examples/matmul.py", line 95, in check
run_example(matmul, torch.matmul, (x, y))
File "/data/users/willfeng/helion/helion/_testing.py", line 461, in run_example
func(*args).to(torch.float32),
^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 285, in __call__
return self.bind(args)(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 617, in __call__
self.autotune(args)
File "/data/users/willfeng/helion/helion/runtime/kernel.py", line 506, in autotune
config = self.settings.autotuner_fn(self, args, **kwargs).autotune()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/autotuner/base_cache.py", line 168, in autotune
config = self.autotuner.autotune()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 253, in autotune
best = self._autotune()
^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/autotuner/differential_evolution.py", line 96, in _autotune
self.initial_two_generations()
File "/data/users/willfeng/helion/helion/autotuner/differential_evolution.py", line 59, in initial_two_generations
self.parallel_benchmark_flat(
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 376, in parallel_benchmark_flat
to_check, configs, self.parallel_benchmark(configs), strict=True
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 237, in parallel_benchmark
results.append((config, fn, self.benchmark_function(config, fn)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/willfeng/helion/helion/autotuner/base_search.py", line 146, in benchmark_function
raise exc.TritonError(
helion.exc.TritonError: Error running generated Triton program:
@helion.kernel(config=helion.Config(block_sizes=[1, 16, 16], indexing='tensor_descriptor', l2_groupings=[2], loop_orders=[[0, 1]], num_stages=4, num_warps=1, pid_type='persistent_blocked', range_flattens=[True, True], range_multi_buffers=[False, True], range_num_stages=[0, 1], range_unroll_factors=[0, 4]), static_shapes=True)
RuntimeError: Triton Error [CUDA]: misaligned addressThis config can trigger the issue |
|
@jansel Without this PR, I also found that |
|
@yf225 Does adding a tl.debug_barrier between the store and load fix it? |
@oulgen I believe for the matmul error, it's due to TMA tile size too small for the matmul instructions - just opened a fix PR at #662 For the kl_div kernel I believe it's the store-then-load pattern - |
|
Thanks! Let me do some more testing with your fix, I still want to confirm this actually makes results more stable. (Which I wasn't able to do before because of that error.) |
eb406f4 to
1320c51
Compare
992ebe3 to
7d77fcc
Compare
7d77fcc to
3a1c79e
Compare
4e5f0b8 to
d8a8fb7
Compare
stack-info: PR: #654, branch: jansel/stack/146
3a1c79e to
1d3f9e6
Compare
d8a8fb7 to
158780e
Compare
stack-info: PR: #654, branch: jansel/stack/146
1d3f9e6 to
7599175
Compare
3650824 to
707faa8
Compare
stack-info: PR: #654, branch: jansel/stack/146
707faa8 to
0ec45ee
Compare
Stacked PRs (oldest at bottom):
Rebenchmark configs to avoid noise