[ET-VK] Fix pack_fp_linear_weight for devices without VK_KHR_16bit_storage#18642
[ET-VK] Fix pack_fp_linear_weight for devices without VK_KHR_16bit_storage#18642meta-codesync[bot] merged 1 commit intogh/SS-JIA/515/basefrom
Conversation
…orage The `pack_fp_linear_weight` prepack shader crashes on devices that lack `VK_KHR_16bit_storage` support because the half-precision variant reads from a `float16_t[]` staging buffer, which requires that extension. This applies the same two-dtype pattern used by `nchw_to_image` and `conv2d_dw_prepack_weights`: a new `BUF_DTYPE` shader parameter allows the staging buffer to use float32 (`[half, float]` combo) while the packed output remains half-precision. The runtime selects the correct variant via `get_staging_dtype_for()`, which returns `kFloat` when the device lacks fp16 buffer support. All three call sites that construct the `pack_fp_linear_weight` shader name (Linear.cpp, Conv1dPW.cpp, Conv2dPW.cpp) are updated to append the staging dtype suffix. Authored with Claude. Differential Revision: [D99133993](https://our.internmc.facebook.com/intern/diff/D99133993/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18642
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Pending, 2 Unrelated FailuresAs of commit 44b193b with merge base ad235f8 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…orage The `pack_fp_linear_weight` prepack shader crashes on devices that lack `VK_KHR_16bit_storage` support because the half-precision variant reads from a `float16_t[]` staging buffer, which requires that extension. This applies the same two-dtype pattern used by `nchw_to_image` and `conv2d_dw_prepack_weights`: a new `BUF_DTYPE` shader parameter allows the staging buffer to use float32 (`[half, float]` combo) while the packed output remains half-precision. The runtime selects the correct variant via `get_staging_dtype_for()`, which returns `kFloat` when the device lacks fp16 buffer support. All three call sites that construct the `pack_fp_linear_weight` shader name (Linear.cpp, Conv1dPW.cpp, Conv2dPW.cpp) are updated to append the staging dtype suffix. Authored with Claude. Differential Revision: [D99133993](https://our.internmc.facebook.com/intern/diff/D99133993/) ghstack-source-id: 361148853 Pull Request resolved: #18642
This PR needs a
|
c979227
into
gh/SS-JIA/515/base
…orage (#18653) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #18642 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/orig Differential Revision: [D99133993](https://our.internmc.facebook.com/intern/diff/D99133993/) @diff-train-skip-merge Co-authored-by: ssjia <ssjia@devvm26340.ftw0.facebook.com>
…orage (pytorch#18653) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#18642 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/orig Differential Revision: [D99133993](https://our.internmc.facebook.com/intern/diff/D99133993/) @diff-train-skip-merge Co-authored-by: ssjia <ssjia@devvm26340.ftw0.facebook.com>
…orage (#18653) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #18642 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/orig Differential Revision: [D99133993](https://our.internmc.facebook.com/intern/diff/D99133993/) @diff-train-skip-merge Co-authored-by: ssjia <ssjia@devvm26340.ftw0.facebook.com>
…orage (#18653) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #18642 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/515/orig Differential Revision: [D99133993](https://our.internmc.facebook.com/intern/diff/D99133993/) @diff-train-skip-merge Co-authored-by: ssjia <ssjia@devvm26340.ftw0.facebook.com>
Stack from ghstack (oldest at bottom):
The
pack_fp_linear_weightprepack shader crashes on devices that lackVK_KHR_16bit_storagesupport because the half-precision variant reads from afloat16_t[]staging buffer, which requires that extension.This applies the same two-dtype pattern used by
nchw_to_imageandconv2d_dw_prepack_weights: a newBUF_DTYPEshader parameter allows thestaging buffer to use float32 (
[half, float]combo) while the packed outputremains half-precision. The runtime selects the correct variant via
get_staging_dtype_for(), which returnskFloatwhen the device lacks fp16buffer support.
All three call sites that construct the
pack_fp_linear_weightshader name(Linear.cpp, Conv1dPW.cpp, Conv2dPW.cpp) are updated to append the staging
dtype suffix.
Authored with Claude.
Differential Revision: D99133993