Skip to content

Commit ba19c52

Browse files
authored
Fix multi output layout error in indexing dtype calculation (#108085) (#108693)
Differential Revision: [D48757829](https://our.internmc.facebook.com/intern/diff/D48757829) Pull Request resolved: #108085 Approved by: https://github.com/yanboliang, https://github.com/davidberard98, https://github.com/jansel, https://github.com/peterbell10
1 parent c5c9536 commit ba19c52

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.fx.experimental.proxy_tensor import make_fx
1515
from torch.testing._internal.common_utils import (
1616
DeterministicGuard,
17+
freeze_rng_state,
1718
IS_FBCODE,
1819
TEST_WITH_ASAN,
1920
)
@@ -1008,6 +1009,19 @@ def test_linear_with_zero_infeature_size(self):
10081009
actual = opt_fn(x)
10091010
self.assertEqual(expect, actual)
10101011

1012+
@config.patch(fallback_random=True)
1013+
def test_multi_output_layout_fallback(self):
1014+
mod = nn.RReLU(lower=3.2350976, upper=8.4220314, inplace=True)
1015+
inp = torch.rand([4, 4]).cuda()
1016+
m = torch.compile(mod)
1017+
1018+
with freeze_rng_state():
1019+
o1 = m(inp.clone())
1020+
1021+
o2 = mod(inp.clone())
1022+
1023+
self.assertEqual(o1, o2)
1024+
10111025

10121026
if __name__ == "__main__":
10131027
from torch._dynamo.test_case import run_tests

torch/_inductor/codegen/triton.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2298,7 +2298,14 @@ def within_32bit(e):
22982298
if not within_32bit(numel):
22992299
return False
23002300

2301-
buf_sizes = [buf.get_layout().storage_size() for buf in buffers]
2301+
# Any use of a MultiOutputLayout will create a buffer with a
2302+
# Layout whose sizes are accounted for
2303+
buf_sizes = [
2304+
buf.get_layout().storage_size()
2305+
for buf in buffers
2306+
if not isinstance(buf.get_layout(), ir.MultiOutputLayout)
2307+
]
2308+
23022309
if not all(within_32bit(size) for size in buf_sizes):
23032310
return False
23042311

0 commit comments

Comments
 (0)