Skip to content

Commit

Permalink
Update on "[reland][inductor] make thread order consistent with loop …
Browse files Browse the repository at this point in the history
…order"


This PR relands #106827 which get reverted because of causing compilation error for some ads model.

Yanbo provide a repro in one of the 14k model ( `pytest ./generated/test_KaiyangZhou_deep_person_reid.py -k test_044`). This is also the model I used to confirm the fix and come up with a unit test. In this model, we call `tritoin_heuristics.triton_config` with size_hints [2048, 2]. Previously this would result in a trition config with XBLOCK=2048 and YBLOCK=2 . But since we change the mapping between size_hints and XYZ dimension, we now generate a triton config with XBLOCK=2 and YBLOCK=2048.  This fails compilation since we set max YBLOCK to be 1024.

My fix is to make sure we never generate a triton config that exceeds the maximum block size.


[ghstack-poisoned]
  • Loading branch information
shunting314 committed Aug 24, 2023
1 parent c92edce commit f6d509f
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions test/inductor/test_triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
import unittest

try:
import triton
import triton # noqa: F401
except ImportError:
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires triton")

from torch._dynamo.test_case import run_tests, TestCase
from torch._inductor import config
from torch._inductor.triton_heuristics import triton_config
from torch._inductor import config


class TestTritonHeuristics(TestCase):
def test_triton_config(self):
Expand All @@ -26,5 +27,6 @@ def test_triton_config(self):
continue
self.assertTrue(cfg.kwargs[key] <= config.triton.max_block[label])


if __name__ == "__main__":
run_tests()

0 comments on commit f6d509f

Please sign in to comment.