-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Open
Labels
module: custom-operatorscustom operators, custom ops, custom-operators, custom-opscustom operators, custom ops, custom-operators, custom-opsmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
The following code works fine if we commend out torch.compile
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************
import torch
import triton
import triton.language as tl
@triton.jit
def _load_x(x_ptr, h, H, BLOCK_SIZE_H, indices_b, mask_b, other=None):
indices_h = h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
mask_h = indices_h < H
indices = indices_b[:, None] * H + indices_h[None, :]
mask_bh = mask_b[:, None] & mask_h[None, :]
x = tl.load(x_ptr + indices, mask=mask_bh, other=other)
return x, indices, mask_bh
@triton.jit
def softmax_forward_triton_kernel(
x_ptr,
output_ptr,
logits_multiplier,
B,
H,
BLOCK_SIZE_B: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
):
pid = tl.program_id(axis=0)
indices_b = pid * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B)
mask_b = indices_b < B
Z = tl.zeros((BLOCK_SIZE_B, 1), dtype=tl.float32)
M = tl.full((BLOCK_SIZE_B, 1), -float("inf"), dtype=tl.float32)
num_blocks_h = tl.cdiv(H, BLOCK_SIZE_H)
for h in range(num_blocks_h):
x, indices, mask_bh = _load_x(
x_ptr=x_ptr, h=h, H=H, BLOCK_SIZE_H=BLOCK_SIZE_H, indices_b=indices_b, mask_b=mask_b, other=-float("inf")
)
x = x.to(tl.float32)
if logits_multiplier is not None:
x *= logits_multiplier
prev_m = M
m = tl.max(x, axis=1, keep_dims=True)
M = max(M, m)
x -= M
x = tl.exp(x)
Z = Z * tl.exp(prev_m - M) + tl.sum(x, axis=1, keep_dims=True)
for h in range(num_blocks_h):
x, indices, mask_bh = _load_x(
x_ptr=x_ptr, h=h, H=H, BLOCK_SIZE_H=BLOCK_SIZE_H, indices_b=indices_b, mask_b=mask_b
)
x = x.to(tl.float32)
if logits_multiplier is not None:
x *= logits_multiplier
x -= M
x = tl.exp(x)
x /= Z
tl.store(output_ptr + indices, x, mask=mask_bh)
@triton.jit
def _load_output_output_grad(output_ptr, output_grad_ptr, h, H, BLOCK_SIZE_H, indices_b, mask_b):
indices_h = h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
mask_h = indices_h < H
indices = indices_b[:, None] * H + indices_h[None, :]
mask_bh = mask_b[:, None] & mask_h[None, :]
output = tl.load(output_ptr + indices, mask=mask_bh)
output_grad = tl.load(output_grad_ptr + indices, mask=mask_bh)
return output, output_grad, indices, mask_bh
@triton.jit
def softmax_backward_triton_kernel(
output_ptr,
output_grad_ptr,
x_grad_ptr,
logits_multiplier,
B,
H,
BLOCK_SIZE_B: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
):
pid = tl.program_id(axis=0)
indices_b = pid * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B)
mask_b = indices_b < B
accumulator = tl.zeros((BLOCK_SIZE_B, 1), dtype=tl.float32)
num_blocks_h = tl.cdiv(H, BLOCK_SIZE_H)
for h in range(num_blocks_h):
output, output_grad, indices, mask_bh = _load_output_output_grad(
output_ptr=output_ptr,
output_grad_ptr=output_grad_ptr,
h=h,
H=H,
BLOCK_SIZE_H=BLOCK_SIZE_H,
indices_b=indices_b,
mask_b=mask_b,
)
acc = output_grad * output
acc = acc.to(tl.float32)
accumulator += tl.sum(acc, axis=1, keep_dims=True)
for h in range(num_blocks_h):
output, output_grad, indices, mask_bh = _load_output_output_grad(
output_ptr=output_ptr,
output_grad_ptr=output_grad_ptr,
h=h,
H=H,
BLOCK_SIZE_H=BLOCK_SIZE_H,
indices_b=indices_b,
mask_b=mask_b,
)
output_grad -= accumulator
output *= output_grad
if logits_multiplier is not None:
output *= logits_multiplier
tl.store(x_grad_ptr + indices, output, mask=mask_bh)
class _Softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, logits_multiplier: float | None) -> torch.Tensor:
output = torch.empty_like(x)
B, H = x.size()
BLOCK_SIZE_B = 1
BLOCK_SIZE_H = min(triton.next_power_of_2(H), 4096 if x.dtype == torch.float32 else 8192)
softmax_forward_triton_kernel[(B + BLOCK_SIZE_B - 1) // BLOCK_SIZE_B,](
x_ptr=x,
output_ptr=output,
logits_multiplier=logits_multiplier,
B=B,
H=H,
BLOCK_SIZE_B=BLOCK_SIZE_B,
BLOCK_SIZE_H=BLOCK_SIZE_H,
)
ctx.save_for_backward(output)
ctx.logits_multiplier = logits_multiplier
return output
@staticmethod
def backward(ctx, output_grad: torch.Tensor) -> tuple[torch.Tensor | None]:
output = ctx.saved_tensors[0]
x_grad = torch.empty_like(output)
B, H = x_grad.size()
BLOCK_SIZE_B = 1
BLOCK_SIZE_H = min(triton.next_power_of_2(H), 4096 if output.dtype == torch.float32 else 8192)
softmax_backward_triton_kernel[(B + BLOCK_SIZE_B - 1) // BLOCK_SIZE_B,](
output_ptr=output,
output_grad_ptr=output_grad,
x_grad_ptr=x_grad,
logits_multiplier=ctx.logits_multiplier,
B=B,
H=H,
BLOCK_SIZE_B=BLOCK_SIZE_B,
BLOCK_SIZE_H=BLOCK_SIZE_H,
)
return x_grad, None
@torch.compile(fullgraph=True)
def softmax(
x: torch.Tensor,
logits_multiplier: float | None = None
) -> torch.Tensor:
"""computes softmax activation
Args:
x (torch.Tensor): input activation tensor
logits_multiplier (float, optional): pre-multiplies `x` with `logits_multiplier` before computing softmax.
Defaults to None.
kernel_backend (KernelBackend | CutoTuneParameter, optional): kernel backend to prioritize.
Defaults to KernelBackend.triton.
Returns:
torch.Tensor: output tensor
"""
return _Softmax.apply(x, logits_multiplier)
x = torch.randn(5, 4, device=torch.cuda.current_device(), requires_grad=True)
print(softmax(x))
error:
Traceback (most recent call last):
File "/proj/checkpoints/shawntan/mayank/flash-model-architectures/a.py", line 214, in <module>
print(softmax(x))
^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 900, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2213, in _call_user_compiler
raise BackendCompilerFailed(
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2188, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/__init__.py", line 2388, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2681, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 117, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1107, in aot_module_simplified
compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 240, in aot_stage2_compile
return aot_stage2_autograd(aot_state, aot_graph_capture)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 1393, in aot_stage2_autograd
fw_module, bw_module = aot_config.partition_fn(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2118, in partition_fn
return min_cut_rematerialization_partition(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_functorch/partitioners.py", line 2792, in min_cut_rematerialization_partition
node_info = classify_nodes(joint_module, static_lifetime_input_indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_functorch/partitioners.py", line 2759, in classify_nodes
forward_only_graph = _extract_graph_with_inputs_outputs(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/u/shawntan/.conda/envs/ai-mayank/lib/python3.12/site-packages/torch/_functorch/partitioners.py", line 238, in _extract_graph_with_inputs_outputs
assert not isinstance(env[x], InvalidNodeBase), (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: Node getitem_2 was invalid, but is output
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Versions
Collecting environment information...
PyTorch version: 2.10.0.dev20250916+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Red Hat Enterprise Linux 9.4 (Plow) (x86_64)
GCC version: (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3)
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: glibc-2.34
Python version: 3.12.11 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:09:17) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.14.0-427.42.1.el9_4.x86_64-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 560.35.03
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8468
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
Stepping: 8
CPU(s) scaling MHz: 100%
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
L1d cache: 4.5 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 192 MiB (96 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==2.3.3
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pytorch-triton==3.5.0+git5ae38bdb
[pip3] torch==2.10.0.dev20250916+cu128
[pip3] triton==3.4.0
[conda] numpy 2.3.3 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] pytorch-triton 3.5.0+git5ae38bdb pypi_0 pypi
[conda] torch 2.10.0.dev20250916+cu128 pypi_0 pypi
[conda] triton 3.4.0 pypi_0 pypi
Metadata
Metadata
Assignees
Labels
module: custom-operatorscustom operators, custom ops, custom-operators, custom-opscustom operators, custom ops, custom-operators, custom-opsmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module