Skip to content

Conversation

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 9, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167426

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 5b7b216 with merge base 04a85b4 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 9, 2025
@oulgen oulgen requested a review from jansel November 9, 2025 18:30
@oulgen oulgen added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category labels Nov 9, 2025
@oulgen oulgen marked this pull request as ready for review November 9, 2025 18:34
[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 9, 2025
Copy link
Collaborator

@yarongmu-google yarongmu-google left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could have performance implications but may not be of concern right now. (sorry posted at wrong place. Please ignore.)

Copy link
Collaborator

@yarongmu-google yarongmu-google left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to have more focus on strided but less on scatter?

for inp in input_params:
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
code.writeline(
f"{inp}_jax = jax.dlpack.from_dlpack({inp}.contiguous())"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may have performance implications but may not be if concern right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep agreed, I’m prioritizing passing all the unit tests first (there’s a lot!!) and then we can look at perf

you can see me enabling unit tests one by one at the end of the test file

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 9, 2025
Comment on lines 198 to 210
dtype_map = {
torch.float32: "jnp.float32",
torch.float64: "jnp.float64",
torch.float16: "jnp.float16",
torch.bfloat16: "jnp.bfloat16",
torch.int32: "jnp.int32",
torch.int64: "jnp.int64",
torch.int16: "jnp.int16",
torch.int8: "jnp.int8",
torch.uint8: "jnp.uint8",
torch.bool: "jnp.bool_",
}
jax_dtype = dtype_map.get(dtype, f"jnp.{dtype}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a helper to map torch types to jnp.*.

I think you had a similar dict in the prior PR (in the output code), let's combine this into one place.

# Get iteration variables from range_tree_nodes (these are the actual symbols used in indices)
iter_vars = (
OrderedSet(self.range_tree_nodes.keys())
if hasattr(self, "range_tree_nodes")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this be false?

Comment on lines 289 to 342
# Find which iteration variable(s) are used
used_vars = free_symbols & iter_vars

if len(used_vars) == 0:
# No iteration variables, this is a constant index
return str(index)
elif len(used_vars) == 1:
# Single iteration variable - try to extract stride and offset
var = next(iter(used_vars))

# Expand and collect terms
expanded = sympy.expand(index)

# Try to extract coefficient (stride) and constant (offset)
# index = stride*var + offset
stride = expanded.coeff(var, 1)
offset = expanded.coeff(var, 0)

if stride is not None:
stride_val = stride
offset_val = offset if offset is not None else 0

# Generate JAX slice notation
if stride_val == 1 and offset_val == 0:
# Contiguous access
return "..."
elif offset_val == 0:
# Pure stride: ::stride
return f"::{stride_val}"
else:
# Offset + stride: offset::stride
return f"{offset_val}::{stride_val}"
elif len(used_vars) > 1:
# Multi-dimensional indexing - need to generate proper index arrays
# For patterns like 2*x0 + 30*x1, we need to reshape and use advanced indexing
# For now, we'll use ellipsis which works for contiguous multi-dim access
# and fall back to error for truly strided multi-dim cases

# Check if all coefficients are 1 (contiguous multi-dim access)
all_unit_stride = True
for var in used_vars:
coeff = index.coeff(var, 1)
if coeff != 1:
all_unit_stride = False
break

if all_unit_stride:
# Contiguous multi-dimensional access
return "..."
else:
# Strided multi-dimensional access - requires advanced indexing
# For now, use ellipsis which may work for many cases
# TODO: Implement proper multi-dimensional strided indexing
return "..."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to rewrite this logic from scratch. The block_ptr handling in the Triton backend is doing something very similar. Can we reuse that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I shared bunch of things via BlockPatternMatcher but not sure what more can be shared since we are not using arange for broadcasting in pallas.

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 10, 2025
@oulgen oulgen requested a review from jansel November 10, 2025 20:46
oulgen added a commit that referenced this pull request Nov 11, 2025
@oulgen
Copy link
Contributor Author

oulgen commented Nov 11, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants