[ExecuTorch][Llama][ez] Remove unwrap_tensor_subclass from 4w quantization path#18957
Conversation
…ation path D100066455 changed the Llama export pipeline to quantize weights in the checkpoint dtype (typically bfloat16) before casting to the computation dtype (fp32). This introduced a regression for Vulkan 4w export: `dequantize_affine` ops produced bfloat16 outputs, which Vulkan doesn't support, causing the graph to be split into multiple partitions. When `sym_constrain_range_for_size` constraint nodes were partitioned into a different delegate than the `_local_scalar_dense` + `slice_copy` ops they constrain, ExportPass re-tracing (in ConvertToLinearPass, SpecPropPass, etc.) would crash with `GuardOnDataDependentSymNode: Could not guard on data-dependent expression u539 < 0`. The root cause is `unwrap_tensor_subclass()`. This function decomposes `IntxUnpackedToInt8Tensor` into plain tensors via `torch.nn.utils.parametrize`, capturing the subclass's metadata — including its `dtype` attribute (which controls `dequantize_affine`'s output dtype) — as a frozen snapshot in `UnwrapTensorSubclass.rebuild_stack`. A subsequent `model.to(dtype=fp32)` casts the plain tensors but cannot update the frozen metadata, so `dequantize_affine` continues to output bfloat16. `unwrap_tensor_subclass()` was originally needed as a workaround because `torch.export` did not support tensor subclasses natively. This is no longer the case — the 8da4w path already works without it (using the same `IntxUnpackedToInt8Tensor` subclass), and `torch.export` traces through the subclass correctly. Removing it makes 4w consistent with 8da4w and avoids the metadata-freezing issue entirely. This change was authored with Claude. Differential Revision: [D101252037](https://our.internmc.facebook.com/intern/diff/D101252037/) [ghstack-poisoned]
…ation path D100066455 changed the Llama export pipeline to quantize weights in the checkpoint dtype (typically bfloat16) before casting to the computation dtype (fp32). This introduced a regression for Vulkan 4w export: `dequantize_affine` ops produced bfloat16 outputs, which Vulkan doesn't support, causing the graph to be split into multiple partitions. When `sym_constrain_range_for_size` constraint nodes were partitioned into a different delegate than the `_local_scalar_dense` + `slice_copy` ops they constrain, ExportPass re-tracing (in ConvertToLinearPass, SpecPropPass, etc.) would crash with `GuardOnDataDependentSymNode: Could not guard on data-dependent expression u539 < 0`. The root cause is `unwrap_tensor_subclass()`. This function decomposes `IntxUnpackedToInt8Tensor` into plain tensors via `torch.nn.utils.parametrize`, capturing the subclass's metadata — including its `dtype` attribute (which controls `dequantize_affine`'s output dtype) — as a frozen snapshot in `UnwrapTensorSubclass.rebuild_stack`. A subsequent `model.to(dtype=fp32)` casts the plain tensors but cannot update the frozen metadata, so `dequantize_affine` continues to output bfloat16. `unwrap_tensor_subclass()` was originally needed as a workaround because `torch.export` did not support tensor subclasses natively. This is no longer the case — the 8da4w path already works without it (using the same `IntxUnpackedToInt8Tensor` subclass), and `torch.export` traces through the subclass correctly. Removing it makes 4w consistent with 8da4w and avoids the metadata-freezing issue entirely. This change was authored with Claude. Differential Revision: [D101252037](https://our.internmc.facebook.com/intern/diff/D101252037/) ghstack-source-id: 368594388 Pull Request resolved: #18957
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18957
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New Failures, 2 Unrelated FailuresAs of commit 30e9d98 with merge base a43675c ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
4ca6270
into
gh/SS-JIA/521/base
…ation path D100066455 changed the Llama export pipeline to quantize weights in the checkpoint dtype (typically bfloat16) before casting to the computation dtype (fp32). This introduced a regression for Vulkan 4w export: `dequantize_affine` ops produced bfloat16 outputs, which Vulkan doesn't support, causing the graph to be split into multiple partitions. When `sym_constrain_range_for_size` constraint nodes were partitioned into a different delegate than the `_local_scalar_dense` + `slice_copy` ops they constrain, ExportPass re-tracing (in ConvertToLinearPass, SpecPropPass, etc.) would crash with `GuardOnDataDependentSymNode: Could not guard on data-dependent expression u539 < 0`. The root cause is `unwrap_tensor_subclass()`. This function decomposes `IntxUnpackedToInt8Tensor` into plain tensors via `torch.nn.utils.parametrize`, capturing the subclass's metadata — including its `dtype` attribute (which controls `dequantize_affine`'s output dtype) — as a frozen snapshot in `UnwrapTensorSubclass.rebuild_stack`. A subsequent `model.to(dtype=fp32)` casts the plain tensors but cannot update the frozen metadata, so `dequantize_affine` continues to output bfloat16. `unwrap_tensor_subclass()` was originally needed as a workaround because `torch.export` did not support tensor subclasses natively. This is no longer the case — the 8da4w path already works without it (using the same `IntxUnpackedToInt8Tensor` subclass), and `torch.export` traces through the subclass correctly. Removing it makes 4w consistent with 8da4w and avoids the metadata-freezing issue entirely. This change was authored with Claude. Differential Revision: [D101252037](https://our.internmc.facebook.com/intern/diff/D101252037/) ghstack-source-id: 368594388 Pull Request resolved: #18957
Stack from ghstack (oldest at bottom):
D100066455 changed the Llama export pipeline to quantize weights in the checkpoint dtype (typically bfloat16) before casting to the computation dtype (fp32). This introduced a regression for Vulkan 4w export:
dequantize_affineops produced bfloat16 outputs, which Vulkan doesn't support, causing the graph to be split into multiple partitions. Whensym_constrain_range_for_sizeconstraint nodes were partitioned into a different delegate than the_local_scalar_dense+slice_copyops they constrain, ExportPass re-tracing (in ConvertToLinearPass, SpecPropPass, etc.) would crash withGuardOnDataDependentSymNode: Could not guard on data-dependent expression u539 < 0.The root cause is
unwrap_tensor_subclass(). This function decomposesIntxUnpackedToInt8Tensorinto plain tensors viatorch.nn.utils.parametrize, capturing the subclass's metadata — including itsdtypeattribute (which controlsdequantize_affine's output dtype) — as a frozen snapshot inUnwrapTensorSubclass.rebuild_stack. A subsequentmodel.to(dtype=fp32)casts the plain tensors but cannot update the frozen metadata, sodequantize_affinecontinues to output bfloat16.unwrap_tensor_subclass()was originally needed as a workaround becausetorch.exportdid not support tensor subclasses natively. This is no longer the case — the 8da4w path already works without it (using the sameIntxUnpackedToInt8Tensorsubclass), andtorch.exporttraces through the subclass correctly. Removing it makes 4w consistent with 8da4w and avoids the metadata-freezing issue entirely.This change was authored with Claude.
Differential Revision: D101252037