Skip to content

Commit

Permalink
fix backward, hopefully
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Nov 26, 2023
1 parent 6382992 commit 300081d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def run(*ex, **kwargs):
("cpu", "cuda"), is_skip=True
),
"test_sdpa_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
"test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
#
# The following tests do not support dynamic shapes yet:
#
Expand Down
21 changes: 13 additions & 8 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,15 +2048,20 @@ def is_aligned(x):
return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0

assert isinstance(arg, TensorBox)
unaligned_input_shape = isinstance(arg.data, ir.ExpandView) and not is_aligned(
arg
)
aligned_input_view = unaligned_input_shape and is_aligned(arg.unwrap_view())

# input is padded, requiring_stride_order will unwrap the view and unpad.
# Would be nice to be able to require certain padding from inductor ir, nyi
if aligned_input_view:
return arg
# This correctly handles the forward case:
if isinstance(arg.data, (ir.SliceView, ir.ExpandView)):
if not is_aligned(arg):
# input is padded, requiring_stride_order will unwrap the view and unpad.
# Would be nice to be able to require certain padding from inductor ir, nyi
if is_aligned(arg.unwrap_view()):
return arg
# This is needed for the backward case:
# Hacky but not sure of a better way
if isinstance(arg.data, ir.StorageBox) and arg.data.is_input_buffer():
if "expand" in arg.data.get_name():
if (V.graph.sizevars.size_hint(arg.get_stride()[-2]) % ALIGNMENT) == 0:
return arg

return ir.ExternKernel.require_stride_order(arg, stride_order)

Expand Down

0 comments on commit 300081d

Please sign in to comment.