-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed
Labels
module: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_attention = torch.compile(flex_attention)
attn_mask = torch.ones((4,1,2048,2048), dtype=torch.bool, device='cuda').tril()
def causal(b, h, q_idx, kv_idx):
h_ = h.new_zeros(h.shape)
# print(b) # uncomment this line to make the code work
return attn_mask[b][h_][q_idx][kv_idx]
block_mask = create_block_mask(causal, B=4, H=None, Q_LEN=2048, KV_LEN=2048)
print(block_mask)
q, k, v = torch.randn(4, 1, 2048, 64, device='cuda'), torch.randn(4, 1, 2048, 64, device='cuda'), torch.randn(4, 1,2048, 64, device='cuda')
print(flex_attention(q, k, v, block_mask=block_mask))I want to create a block mask from a pre-defined mask tensor. However, if directly return attn_mask[b][h_][q_idx][kv_idx], it would raise:
Traceback (most recent call last):
File "/home/drisspg/meta/scripts/flex/key.py", line 17, in <module>
print(flex_attention(q, k, v, block_mask=block_mask))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/eval_frame.py", line 465, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 1294, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 1089, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 526, in __call__
return _compile(
^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 929, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 671, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_utils_internal.py", line 87, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 704, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object
transformations(instructions, code_options)
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 219, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/convert_frame.py", line 639, in transform
tracer.run()
File "/home/drisspg/meta/pytorch/torch/_dynamo/symbolic_convert.py", line 2766, in run
super().run()
File "/home/drisspg/meta/pytorch/torch/_dynamo/symbolic_convert.py", line 973, in run
while self.step():
^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/symbolic_convert.py", line 885, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/drisspg/meta/pytorch/torch/_dynamo/symbolic_convert.py", line 2957, in RETURN_VALUE
self._return(inst)
File "/home/drisspg/meta/pytorch/torch/_dynamo/symbolic_convert.py", line 2942, in _return
self.output.compile_subgraph(
File "/home/drisspg/meta/pytorch/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/drisspg/meta/pytorch/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/drisspg/meta/pytorch/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/__init__.py", line 2235, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1545, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/backends/common.py", line 72, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 1080, in aot_module_simplified
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 1065, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 524, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 761, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1371, in fw_compiler_base
return _fw_compiler_base(model, example_inputs, is_inference)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 1442, in _fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 460, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 646, in _compile_fx_inner
compiled_graph = FxGraphCache.load(
^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/codecache.py", line 1427, in load
compiled_graph = compile_fx_fn(
^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 555, in codegen_and_compile
compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/compile_fx.py", line 863, in fx_codegen_and_compile
compiled_fn = graph.compile_to_fn()
^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 1948, in compile_to_fn
return self.compile_to_module().call
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 1874, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 1880, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/graph.py", line 1815, in codegen
self.scheduler = Scheduler(self.operations)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/drisspg/meta/pytorch/torch/_inductor/scheduler.py", line 1741, in __init__
self._init(nodes)
File "/home/drisspg/meta/pytorch/torch/_inductor/scheduler.py", line 1795, in _init
self.compute_ancestors()
File "/home/drisspg/meta/pytorch/torch/_inductor/scheduler.py", line 2243, in compute_ancestors
dep_node_name = self.name_to_buf[dep.name].defining_op.get_name()
~~~~~~~~~~~~~~~~^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: 'b'
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
However, if I add print(b), the code works quite well! I would like to ask:
- Is this the right way to use flex attention with a pre-defined mask?
- Why its behaviour is so strange? Is this related to
torch.compile? - How to fix this?
Here is the complete script to reproduce this: https://gist.github.com/why-in-Shanghaitech/8b8205f98568c6741a2e38dfcdb9d362/e859567ddcc3b6dfc2aaa027640fdf8f2ee196ce
Versions
PyTorch version: 2.5.0.dev20240829
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-84-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6226R CPU @ 2.90GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 2
Stepping: 7
CPU max MHz: 3900.0000
CPU min MHz: 1200.0000
BogoMIPS: 5800.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 1 MiB (32 instances)
L1i cache: 1 MiB (32 instances)
L2 cache: 32 MiB (32 instances)
L3 cache: 44 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-15,32-47
NUMA node1 CPU(s): 16-31,48-63
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.5.0.dev20240829
[pip3] triton==3.0.0
[conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] numpy 1.26.4 pypi_0 pypi
[conda] pytorch 2.5.0.dev20240829 py3.12_cuda12.1_cudnn9.1.0_0 pytorch-nightly
[conda] pytorch-cuda 12.1 ha16c6d3_6 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchtriton 3.0.0+dedb7bdf33 py312 pytorch-nightly
cc @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @yf225 @Chillee @drisspg @yanboliang @BoyuanFeng
Metadata
Metadata
Assignees
Labels
module: flex attentionmodule: higher order operatorstorch.cond and similartorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module