diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 1813588c21e53..463867cf2f85b 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -265,6 +265,7 @@ def run(*ex, **kwargs): ("cpu", "cuda"), is_skip=True ), "test_sdpa_dynamic_shapes": TestFailure(("cpu",), is_skip=True), + "test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True), # # The following tests do not support dynamic shapes yet: # diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 179abb6185057..223176c99052d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2048,15 +2048,20 @@ def is_aligned(x): return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 assert isinstance(arg, TensorBox) - unaligned_input_shape = isinstance(arg.data, ir.ExpandView) and not is_aligned( - arg - ) - aligned_input_view = unaligned_input_shape and is_aligned(arg.unwrap_view()) - # input is padded, requiring_stride_order will unwrap the view and unpad. - # Would be nice to be able to require certain padding from inductor ir, nyi - if aligned_input_view: - return arg + # This correctly handles the forward case: + if isinstance(arg.data, (ir.SliceView, ir.ExpandView)): + if not is_aligned(arg): + # input is padded, requiring_stride_order will unwrap the view and unpad. + # Would be nice to be able to require certain padding from inductor ir, nyi + if is_aligned(arg.unwrap_view()): + return arg + # This is needed for the backward case: + # Hacky but not sure of a better way + if isinstance(arg.data, ir.StorageBox) and arg.data.is_input_buffer(): + if "expand" in arg.data.get_name(): + if (V.graph.sizevars.size_hint(arg.get_stride()[-2]) % ALIGNMENT) == 0: + return arg return ir.ExternKernel.require_stride_order(arg, stride_order)