Skip to content

Conversation

@oulgen
Copy link
Contributor

@oulgen oulgen commented Nov 2, 2025

Stack from ghstack (oldest at bottom):

Very simple Pallas TorchInductor backend
Given

import torch

def f(x, y):
    return x.sin() + y

torch._inductor.config.cuda_backend="pallas"

x = torch.randn(4).cuda()
y = torch.randn(4).cuda()

compiled = torch.compile(f, backend="inductor", fullgraph=True)
torch.testing.assert_close(compiled(x, y), f(x, y))

it outputs

import torch
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from torch.utils import dlpack as torch_dlpack
def pallas_fused_add_sin_56b646d2_kernel(in_ptr0, in_ptr1, out_ptr0):
    tmp0 = in_ptr0[...]
    tmp1 = jnp.sin(tmp0)
    tmp2 = in_ptr1[...]
    tmp3 = tmp1 + tmp2
    out_ptr0[...] = tmp3
def pallas_fused_add_sin_56b646d2_main(in_ptr0, in_ptr1, out_ptr0, stream=None):
    # Convert Torch -> JAX for inputs
    in_ptr0_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr0))
    in_ptr1_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr1))
    # Prepare output spec from PyTorch tensor
    # Map PyTorch dtype to JAX dtype string
    _torch_dtype_to_jax = {
        torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,
        torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8,
        torch.uint8: jnp.uint8, torch.bool: jnp.bool_,
    }
    out_spec = jax.ShapeDtypeStruct(out_ptr0.shape, _torch_dtype_to_jax[out_ptr0.dtype])
    compiled = pl.pallas_call(
        lambda *refs: pallas_fused_add_sin_56b646d2_kernel(*refs),
        out_shape=out_spec,
        grid=(1,),
    )
    res = compiled(in_ptr0_jax, in_ptr1_jax)
    # Copy result back into the provided torch output tensor
    res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))
    out_ptr0.copy_(res_t)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166822

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fac62ad with merge base d980d8d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 2, 2025
@EikanWang
Copy link
Collaborator

@oulgen , can I expect the Pallas to deliver a more competitive performance advantage compared to Gluon?

@oulgen
Copy link
Contributor Author

oulgen commented Nov 3, 2025

@oulgen , can I expect the Pallas to deliver a more competitive performance advantage compared to Gluon?

No clue, probably not though considering gluon is a much lower level language able to express hardware semantics better

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 3, 2025
[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 3, 2025
[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 3, 2025
[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 3, 2025
@oulgen oulgen requested a review from jansel November 3, 2025 20:22
[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 3, 2025
@oulgen oulgen marked this pull request as ready for review November 3, 2025 20:33
@miladm miladm requested a review from zou3519 November 3, 2025 21:42
[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 3, 2025
- Compute expression with Python operators (compatible with jax.numpy broadcasting)
- Store as full-array ref assignment: "out_ptrY[...] = <expr>"
- Generate Python code that defines a Pallas kernel and a host entrypoint.
- Use async_compile.cutedsl path to compile and load Python code (generic wrapper).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cutedsl?

# Pallas refs must be unpacked with [...] to load the array
return self.cse.generate(
self.compute,
f"{buf}[...]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert based on index so this errors if the load order is not contiguous.

out = self.args.output(name)
self.store_buffer_names.add(name)
# Pallas refs must use [...] assignment to store back to the ref
self.stores.writeline(f"{out}[...] = {value}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert based on index so this errors if the load order is not contiguous. Use a shared indexing helper to compute the "..."

@classmethod
def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
# Start minimal: no special features advertised
return OrderedSet()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you do reductions, consider reduce to single element here if that is something pallas can do fast. Basically, should we break single element output reductions into multiple kernels.

if not has_pallas_package():
return False

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can import torch in global scope

cuda_backends = {
"triton": CUDACombinedScheduling,
"halide": HalideScheduling,
"pallas": PallasScheduling,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is Pallas registered as a cuda backend? Asking from a technical perspective; for example, is this a placeholder, or perhaps the concrete backend/HW diff doesn't ammeter at this layer?

Copy link
Contributor Author

@oulgen oulgen Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add pallas to other backends too (See halide in both cpu and gpu), i only added to cuda here because i was testing on cuda only for now. Once we have a tpu backend we can test on, we would register pallas to tpu device as well

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 4, 2025
[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 4, 2025
[ghstack-poisoned]
oulgen added a commit that referenced this pull request Nov 4, 2025
@oulgen
Copy link
Contributor Author

oulgen commented Nov 4, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 4, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants