-
Notifications
You must be signed in to change notification settings - Fork 685
Description
🐛 Describe the bug
python -m examples.models.llama.export_llama \
--disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \
-c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \
-p ~/.llama/checkpoints/Llama3.2-1B/params.json \
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
INFO:root:model.to torch.float32
INFO:root:Loading custom ops library: /opt/oss/executorch/0.4.0/lib/libcxx/aarch64-linux/python3.10/site-packages/executorch/extension/llm/custom_ops/libcustom_ops_aot_lib.so
INFO:root:Model after source transforms: Transformer(
(tok_embeddings): Embedding(128256, 2048)
(layers): ModuleList(
(0-15): 16 x TransformerBlock(
(attention): Attention(
(wq): Linear(in_features=2048, out_features=2048, bias=False)
(wk): Linear(in_features=2048, out_features=512, bias=False)
(wv): Linear(in_features=2048, out_features=512, bias=False)
(wo): Linear(in_features=2048, out_features=2048, bias=False)
(kv_cache): KVCache()
(SDPA): SDPACustom(
(kv_cache): KVCache()
)
)
(feed_forward): FeedForward(
(w1): Linear(in_features=2048, out_features=8192, bias=False)
(w2): Linear(in_features=8192, out_features=2048, bias=False)
(w3): Linear(in_features=2048, out_features=8192, bias=False)
)
(attention_norm): RMSNorm()
(ffn_norm): RMSNorm()
)
)
(norm): RMSNorm()
(output): Linear(in_features=2048, out_features=128256, bias=False)
)
INFO:root:Using pt2e [] to quantizing the model...
INFO:root:No quantizer provided, passing...
INFO:root:Lowering model using following partitioner(s):
INFO:root:--> VulkanPartitioner
INFO:root:Skipping node in Vulkan partitioning: %aten_mm_default_112 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%aten_mul_tensor_258, %aten_permute_copy_default_112), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_permute_copy_default_112 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.permute_copy.default](args = (%p_output_weight, [1, 0]), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_rsqrt_default_32 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.rsqrt.default](args = (%aten_add_tensor_96,), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_mean_dim_32 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten_mul_tensor_256, [-1], True), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_select_copy_int_16 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.select_copy.int](args = (%aten_add_tensor_95, 1, -1), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_view_copy_default_383 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_mm_default_111, [1, 1, 2048]), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_view_copy_default_382 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_mul_tensor_255, [1, 8192]), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_view_copy_default_381 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_mm_default_110, [1, 1, 8192]), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_view_copy_default_380 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_copy.default](args = (%aten_mul_tensor_253, [1, 2048]), kwargs = {})
INFO:root:Skipping node in Vulkan partitioning: %aten_view_copy_default_379 : [num_users=2] = call_function[target=executorch.exir.dialects.edge._ops.aten.view_cop
...
INFO:root:Skipping node in Vulkan partitioning: %aten_embedding_default : [num_users=3] = call_function[target=executorch.exir.dialects.edge._ops.aten.embedding.default](args = (%p_tok_embeddings_weight, %tokens), kwargs = {})
INFO:root:Found 212 Vulkan subgraphs to be partitioned.
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.add.Tensor
INFO:root: et_vk.prepack.default
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.mul.Tensor
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.add.Tensor
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.linear.default
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.mul.Tensor
INFO:root: aten.sigmoid.default
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.linear.default
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.mul.Tensor
INFO:root: et_vk.prepack.default
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.add.Tensor
INFO:root: et_vk.prepack.default
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.add.Tensor
INFO:root: aten.mul.Tensor
INFO:root:Operators included in this Vulkan partition:
INFO:root: aten.linear.default
Traceback (most recent call last):
File "/opt/miniforge3/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/miniforge3/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/scratch/code/meta/executorch/examples/models/llama/export_llama.py", line 30, in
main() # pragma: no cover
File "/scratch/code/meta/executorch/examples/models/llama/export_llama.py", line 26, in main
export_llama(modelname, args)
File "/scratch/code/meta/executorch/examples/models/llama/export_llama_lib.py", line 478, in export_llama
builder = _export_llama(modelname, args)
File "/scratch/code/meta/executorch/examples/models/llama/export_llama_lib.py", line 687, in _export_llama
builder = builder_exported_to_edge.to_backend(partitioners)
File "/scratch/code/meta/executorch/extension/llm/export/builder.py", line 376, in to_backend
self.edge_manager = self.edge_manager.to_backend(partitioner)
File "/scratch/code/meta/executorch/exir/program/_program.py", line 1291, in to_backend
new_edge_programs[name] = to_backend(program, partitioner)
File "/opt/miniforge3/lib/python3.10/functools.py", line 889, in wrapper
return dispatch(args[0].class)(*args, **kw)
File "/scratch/code/meta/executorch/exir/backend/backend_api.py", line 396, in _
tagged_graph_module = _partition_and_lower(
File "/scratch/code/meta/executorch/exir/backend/backend_api.py", line 319, in _partition_and_lower
partitioned_module = _partition_and_lower_one_graph_module(
File "/scratch/code/meta/executorch/exir/backend/backend_api.py", line 241, in _partition_and_lower_one_graph_module
) = create_exported_program_from_submodule(
File "/scratch/code/meta/executorch/exir/lowered_backend_module.py", line 700, in create_exported_program_from_submodule
ExportedProgram(
File "/opt/miniforge3/lib/python3.10/site-packages/torch/export/exported_program.py", line 700, in init
self.validate()
File "/opt/miniforge3/lib/python3.10/site-packages/torch/export/exported_program.py", line 1117, in validate
self._validate()
File "/opt/miniforge3/lib/python3.10/site-packages/torch/export/exported_program.py", line 1126, in _validate
v().check(self)
File "/opt/miniforge3/lib/python3.10/site-packages/torch/_export/verifier.py", line 157, in check
_verify_exported_program_signature(ep)
File "/opt/miniforge3/lib/python3.10/site-packages/torch/_export/verifier.py", line 439, in _verify_exported_program_signature
raise SpecViolationError(
torch._export.verifier.SpecViolationError: Mutation node getitem_45 is neither a buffer nor a user input. Buffers to mutate: {'getitem_46': 'layers.15.attention.SDPA.kv_cache.k_cache', 'getitem_47': 'layers.15.attention.SDPA.kv_cache.v_cache'}, User inputs to mutate: {}
Tests
- pytest works
kernels/prim_ops/test/prim_ops_test.py . [ 66%]
kernels/quantized/test/test_out_variants.py .......... [ 68%]
backends/xnnpack/test/passes/test_activation_fusion.py .......... [ 70%]
backends/xnnpack/test/passes/test_batch_norm_fusion.py ... [ 71%]
backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py ..... [ 72%]
backends/xnnpack/test/passes/test_convert_to_linear.py . [ 73%]
backends/xnnpack/test/passes/test_remove_get_item_pass.py .... [ 73%]
backends/xnnpack/test/passes/test_tag_implicit_q_dq_pass.py . [ 74%]
backends/xnnpack/test/serialization/test_serialization.py . [ 74%]
backends/xnnpack/test/serialization/test_xnnheader.py ...... [ 75%]
backends/vulkan/test/test_vulkan_delegate.py ...........................................s...................s.............. [ 93%]
backends/vulkan/test/test_vulkan_delegate_header.py ...... [ 95%]
extension/pybindings/test/test_pybindings.py . [ 95%]
runtime/test/test_runtime.py .... [ 96%]
test/end2end/test_end2end.py .ss..s..sssssss [100%]=========================================================== 416 passed, 14 skipped in 71.11s (0:01:11) ============================================================
- vulkan_compute_api_test works,
[ OK ] VulkanComputeGraphOpsTest.test_to_copy (31 ms)
[----------] 7 tests from VulkanComputeGraphOpsTest (131 ms total)[----------] 2 tests from VulkanInt4LinearTest
[ RUN ] VulkanInt4LinearTest.test_reference_impl
[ OK ] VulkanInt4LinearTest.test_reference_impl (32 ms)
[ RUN ] VulkanInt4LinearTest.test_vulkan_impl
[ OK ] VulkanInt4LinearTest.test_vulkan_impl (24 ms)
[----------] 2 tests from VulkanInt4LinearTest (57 ms total)[----------] 4 tests from VulkanSDPATest
[ RUN ] VulkanSDPATest.test_sdpa_op_small_params
[ OK ] VulkanSDPATest.test_sdpa_op_small_params (7 ms)
[ RUN ] VulkanSDPATest.test_sdpa_op_small_params_dynamic
[ OK ] VulkanSDPATest.test_sdpa_op_small_params_dynamic (8 ms)
[ RUN ] VulkanSDPATest.test_sdpa_op_llama3_params_dynamic
[ OK ] VulkanSDPATest.test_sdpa_op_llama3_params_dynamic (154 ms)
[ RUN ] VulkanSDPATest.test_reference_impl
[ OK ] VulkanSDPATest.test_reference_impl (297 ms)
[----------] 4 tests from VulkanSDPATest (468 ms total)[----------] Global test environment tear-down
[==========] 63 tests from 6 test suites ran. (1030 ms total)
[ PASSED ] 63 tests.
- If I remove the use_sdpa_with_kv_cache, the model conversion successes and the llama3.2 demo works. @SS-JIA
Executorch build info
-- Generating kernel bindings:
-- LIB_NAME: quantized_ops_pybind_lib
-- FUNCTIONS_YAML:
-- CUSTOM_OPS_YAML: /scratch/code/meta/executorch/kernels/quantized/quantized.yaml
-- Generating operator lib:
-- LIB_NAME: quantized_ops_pybind_lib
-- KERNEL_LIBS: quantized_pybind_kernels_lib
-- DEPS: portable_lib
-- Generating operator lib:
-- LIB_NAME: quantized_ops_lib
-- KERNEL_LIBS: quantized_kernels
-- DEPS: executorch-- ******** Summary ********
-- CMAKE_BUILD_TYPE : RelWithDebInfo
-- CMAKE_CXX_STANDARD : 17
-- CMAKE_CXX_COMPILER_ID : Clang
-- CMAKE_TOOLCHAIN_FILE : ucc.cmake
-- BUCK2 : /usr/bin//buck2
-- PYTHON_EXECUTABLE : /opt/miniforge3/bin/python3
-- FLATC_EXECUTABLE : /opt/oss/flatbuffers/24.3.25/bin/libcxx/aarch64-linux/flatc
-- EXECUTORCH_ENABLE_LOGGING : ON
-- EXECUTORCH_ENABLE_PROGRAM_VERIFICATION : ON
-- EXECUTORCH_LOG_LEVEL : Info
-- EXECUTORCH_BUILD_ANDROID_JNI : OFF
-- EXECUTORCH_BUILD_ARM_BAREMETAL : OFF
-- EXECUTORCH_BUILD_COREML : OFF
-- EXECUTORCH_BUILD_KERNELS_CUSTOM : ON
-- EXECUTORCH_BUILD_EXECUTOR_RUNNER : ON
-- EXECUTORCH_BUILD_EXTENSION_DATA_LOADER : ON
-- EXECUTORCH_BUILD_EXTENSION_MODULE : ON
-- EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL : ON
-- EXECUTORCH_BUILD_EXTENSION_TENSOR : ON
-- EXECUTORCH_BUILD_EXTENSION_TRAINING : ON
-- EXECUTORCH_BUILD_FLATC : OFF
-- EXECUTORCH_BUILD_GFLAGS : ON
-- EXECUTORCH_BUILD_GTESTS : OFF
-- EXECUTORCH_BUILD_HOST_TARGETS : ON
-- EXECUTORCH_BUILD_MPS : OFF
-- EXECUTORCH_BUILD_PYBIND : ON
-- EXECUTORCH_BUILD_QNN : OFF
-- EXECUTORCH_BUILD_KERNELS_OPTIMIZED : ON
-- EXECUTORCH_BUILD_KERNELS_QUANTIZED : ON
-- EXECUTORCH_BUILD_DEVTOOLS : ON
-- EXECUTORCH_BUILD_SIZE_TEST : OFF
-- EXECUTORCH_BUILD_XNNPACK : ON
-- EXECUTORCH_BUILD_VULKAN : ON
-- EXECUTORCH_BUILD_PTHREADPOOL : ON
-- EXECUTORCH_BUILD_CPUINFO : ON
Versions
PyTorch version: 2.5.0+git8a0ce38
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
VULKAN used to build PyTorch: True
OS: Fedora Linux Asahi Remix 40 (KDE Plasma) (aarch64)
GCC version: (GCC) 14.2.1 20240912 (Red Hat 14.2.1-3)
Clang version: 18.1.7 (https://github.com/llvm/llvm-project 768118d1ad38bf13c545828f67bd6b474d61fc55)
CMake version: version 3.20.0
Libc version: glibc-2.39
Vulkan Driver Version: Mesa 24.3.0-devel (git-f05157f591),
Vulkan Instance Version: 1.3.290
Python version: 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 19:56:21) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-6.11.3-dimilar-4-1-edge-ARCH+-aarch64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU: Apple M1 Max
Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==23.9.16
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.25.0
[pip3] pytorch-sphinx-theme==0.0.19
[pip3] torchao==0.7.0+git6b529961
[pip3] torchvision==0.20.0a0+f851df1
[conda] numpy 1.25.0 pypi_0 pypi
[conda] pytorch-sphinx-theme 0.0.19 pypi_0 pypi
[conda] torchao 0.7.0+git6b529961 pypi_0 pypi
[conda] torchfix 0.5.0 pypi_0 pypi
[conda] torchvision 0.20.0a0+f851df1 pypi_0 pypi