-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[pallas backend] Implementing Strided/Scatter Access #167426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167426
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 5b7b216 with merge base 04a85b4 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could have performance implications but may not be of concern right now. (sorry posted at wrong place. Please ignore.)
yarongmu-google
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to have more focus on strided but less on scatter?
| for inp in input_params: | ||
| code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})") | ||
| code.writeline( | ||
| f"{inp}_jax = jax.dlpack.from_dlpack({inp}.contiguous())" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may have performance implications but may not be if concern right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep agreed, I’m prioritizing passing all the unit tests first (there’s a lot!!) and then we can look at perf
you can see me enabling unit tests one by one at the end of the test file
torch/_inductor/codegen/pallas.py
Outdated
| dtype_map = { | ||
| torch.float32: "jnp.float32", | ||
| torch.float64: "jnp.float64", | ||
| torch.float16: "jnp.float16", | ||
| torch.bfloat16: "jnp.bfloat16", | ||
| torch.int32: "jnp.int32", | ||
| torch.int64: "jnp.int64", | ||
| torch.int16: "jnp.int16", | ||
| torch.int8: "jnp.int8", | ||
| torch.uint8: "jnp.uint8", | ||
| torch.bool: "jnp.bool_", | ||
| } | ||
| jax_dtype = dtype_map.get(dtype, f"jnp.{dtype}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create a helper to map torch types to jnp.*.
I think you had a similar dict in the prior PR (in the output code), let's combine this into one place.
torch/_inductor/codegen/pallas.py
Outdated
| # Get iteration variables from range_tree_nodes (these are the actual symbols used in indices) | ||
| iter_vars = ( | ||
| OrderedSet(self.range_tree_nodes.keys()) | ||
| if hasattr(self, "range_tree_nodes") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will this be false?
| # Find which iteration variable(s) are used | ||
| used_vars = free_symbols & iter_vars | ||
|
|
||
| if len(used_vars) == 0: | ||
| # No iteration variables, this is a constant index | ||
| return str(index) | ||
| elif len(used_vars) == 1: | ||
| # Single iteration variable - try to extract stride and offset | ||
| var = next(iter(used_vars)) | ||
|
|
||
| # Expand and collect terms | ||
| expanded = sympy.expand(index) | ||
|
|
||
| # Try to extract coefficient (stride) and constant (offset) | ||
| # index = stride*var + offset | ||
| stride = expanded.coeff(var, 1) | ||
| offset = expanded.coeff(var, 0) | ||
|
|
||
| if stride is not None: | ||
| stride_val = stride | ||
| offset_val = offset if offset is not None else 0 | ||
|
|
||
| # Generate JAX slice notation | ||
| if stride_val == 1 and offset_val == 0: | ||
| # Contiguous access | ||
| return "..." | ||
| elif offset_val == 0: | ||
| # Pure stride: ::stride | ||
| return f"::{stride_val}" | ||
| else: | ||
| # Offset + stride: offset::stride | ||
| return f"{offset_val}::{stride_val}" | ||
| elif len(used_vars) > 1: | ||
| # Multi-dimensional indexing - need to generate proper index arrays | ||
| # For patterns like 2*x0 + 30*x1, we need to reshape and use advanced indexing | ||
| # For now, we'll use ellipsis which works for contiguous multi-dim access | ||
| # and fall back to error for truly strided multi-dim cases | ||
|
|
||
| # Check if all coefficients are 1 (contiguous multi-dim access) | ||
| all_unit_stride = True | ||
| for var in used_vars: | ||
| coeff = index.coeff(var, 1) | ||
| if coeff != 1: | ||
| all_unit_stride = False | ||
| break | ||
|
|
||
| if all_unit_stride: | ||
| # Contiguous multi-dimensional access | ||
| return "..." | ||
| else: | ||
| # Strided multi-dimensional access - requires advanced indexing | ||
| # For now, use ellipsis which may work for many cases | ||
| # TODO: Implement proper multi-dimensional strided indexing | ||
| return "..." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to rewrite this logic from scratch. The block_ptr handling in the Triton backend is doing something very similar. Can we reuse that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I shared bunch of things via BlockPatternMatcher but not sure what more can be shared since we are not using arange for broadcasting in pallas.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #167493 Approved by: https://github.com/jansel ghstack dependencies: #167426
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben