Skip to content

Conversation

pianpwk
Copy link
Contributor

@pianpwk pianpwk commented Aug 25, 2025

Summary:
Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/56218d85e2da09d9ede3809718ec989c2151632c

Rollback Plan:

Differential Revision: D80948073

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Copy link

pytorch-bot bot commented Aug 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161414

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 08cf5ae with merge base 84b57c9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D80948073

facebook-github-bot pushed a commit that referenced this pull request Sep 3, 2025
Summary:

Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.


Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/56218d85e2da09d9ede3809718ec989c2151632c

Rollback Plan:

Differential Revision: D80948073
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D80948073

@pianpwk pianpwk changed the title [reland] Dynamic shapes: unbacked-safe slicing (#157944) (#157944) [dynamic shapes] unbacked-safe slicing Sep 8, 2025
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 8, 2025

# realize to get strides/storage offset
if x.maybe_get_layout() is None:
x.realize()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison I'm not very familiar with inductor lowering, so I wasn't sure if this was the right approach; the goal was to call as_strided with codegened sizes/strides/offset, but later figured I had to call realize(), on inputs like ir.Pointwise.

maybe this case should be lowering in a different way?

elif guard_or_false(start_index >= end_index):
new_size = 0

# create unbacked if case unknown
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unknown -> undecided due to unbacked shapes.

def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
"""
Lowers a slice call, creating ExternKernels for the output size & storage offset symbols,
if the indices are unbacked and appropriate semantics aren't known.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: semantics are the same for known and unknown.

maybe:
creating ExternKernels to compute the output size & storage offset symbols dynamically if they cant be determined at compile time because indices are unbacked.

):
return x
except TypeError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would a type error be thrown? statically_known_leq should be resilient to unbacked.

sym_storage = sym

if V.graph.current_node is None or not clamp or (sym_size is None and sym_storage is None):
return TensorBox(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// This path is taken if no unbacked symbols has been allocated meaning that we know how to determine things statically and no need for dynamic computations.

except TypeError:
pass

# try to avoid dynamic (unbacked) slice
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not seem the appropriate line for this comment?

pass

# try to avoid dynamic (unbacked) slice
def compute_slice_index(index, size, default=None):
Copy link
Contributor

@laithsakka laithsakka Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only called once with default being 0?
sounds like we never call it with index None in this version?

Copy link
Contributor

@laithsakka laithsakka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved but please address my comments before landing

pianpwk added a commit that referenced this pull request Sep 16, 2025
Summary:

Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.


Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/56218d85e2da09d9ede3809718ec989c2151632c

Rollback Plan:

Reviewed By: laithsakka

Differential Revision: D80948073
@facebook-github-bot
Copy link
Contributor

@pianpwk has exported this pull request. If you are a Meta employee, you can view the originating diff in D80948073.

facebook-github-bot pushed a commit that referenced this pull request Sep 16, 2025
Summary:

Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.


Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/56218d85e2da09d9ede3809718ec989c2151632c

Rollback Plan:

Reviewed By: laithsakka

Differential Revision: D80948073
@facebook-github-bot
Copy link
Contributor

@pianpwk has exported this pull request. If you are a Meta employee, you can view the originating diff in D80948073.

facebook-github-bot pushed a commit that referenced this pull request Sep 23, 2025
Summary:

Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.


Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/56218d85e2da09d9ede3809718ec989c2151632c

Reviewed By: laithsakka

Differential Revision: D80948073
@facebook-github-bot
Copy link
Contributor

@pianpwk has exported this pull request. If you are a Meta employee, you can view the originating diff in D80948073.

Summary:

Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.


Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/56218d85e2da09d9ede3809718ec989c2151632c

Reviewed By: laithsakka

Differential Revision: D80948073
@facebook-github-bot
Copy link
Contributor

@pianpwk has exported this pull request. If you are a Meta employee, you can view the originating diff in D80948073.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants