Skip to content

Conversation

Jiseong-oh
Copy link
Collaborator

@Jiseong-oh Jiseong-oh commented Sep 22, 2025

Summary

  • Implemented quantized strategies for enn-backend.
  • Added support for ENN's quantization strategies.
  • Successfully verified multiple quantized models.

Test plan

python -m executorch.examples.samsung.scripts.${MODEL_NAME} -c e9955 -p A8W8

cc @SS-JIA @digantdesai @kimishpatel

Copy link

pytorch-bot bot commented Sep 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/14464

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures, 27 Pending, 3 Unrelated Failures

As of commit 3d58d14 with merge base d39992f (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 22, 2025
@Jiseong-oh Jiseong-oh added partner: samsung For backend delegation, kernels, demo, etc. from the 3rd-party partner, Samsung release notes: exynos labels Sep 22, 2025
)


def get_enn_pass_list(edge_program: ExportedProgram) -> List[PassType]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not apply these passes in enn_preprocess.py instead? With the current pattern, users will have to remember to call get_enn_pass_list() and pass the result to to_edge_transform_and_lower() which creates a point of failure if they forget to call it.

It seems to me that at the very least AnnotateQparamsPass should be moved to enn_preprocess.py, since it is required to preserve the quantization parameters. If it is not applied, then the FoldQDQPass() applied at the start of enn_preprocess.py will erase all quantization parameter information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want some passes to be done before to_backend, and we define to_edge_transform_and_lower_enn to make it easy to use. AnnotateQparamsPass could be moved to enn_preprocess.py, so we have moved it now. Thanks for your recommendation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I'll address this item in the next update.

if need_quantize and data is not None:
if isinstance(data, np.ndarray):
data = torch.tensor(data)
data = quantize_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious in what situations this is needed. Typically, after the pt2e quantization process, weight tensors should already be quantized and have int8 data type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems constant tensors are not quantized after convert_pt2e, but just has quant/dequant node after the constant tensor nodes. So we quantize all the constant tensors here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you happen to have an example of this behaviour? My understanding is that the activation tensor will have quant/dequant, but the weight tensor will have only a quant node. If you inspect the tensor data backing the weight tensor, it should have int8 type.

For instance, here are some logs I collected while debugging MobileNet V2 quantized with XNNPACK Quantizer that I got a while back:

# dq input node:
torch.float32: torch.Size([1, 384, 14, 14]) = quantized_decomposed::dequantize_per_tensor(torch.int8: torch.Size([1, 384, 14, 14]), 0.06534181535243988, 16, -128, 127, torch.int8,  ...)
# dq weight node:
torch.float32: torch.Size([96, 384, 1, 1]) = quantized_decomposed::dequantize_per_channel(torch.int8: torch.Size([96, 384, 1, 1]), torch.float32: torch.Size([96]), torch.int64: torch.Size([96]), 0, -127, 127, torch.int8,  ...)
# weight tensor:
tensor: torch.Size([96, 384, 1, 1]), torch.int8, -127, 127
# scales tensor:
tensor: torch.Size([96]), torch.float32, 0.004114292096346617, 0.009169017896056175
# zeros tensor:
tensor: torch.Size([96]), torch.int64, 0, 0
# conv node:
torch.float32: torch.Size([1, 96, 14, 14]) = aten::conv2d(torch.float32: torch.Size([1, 384, 14, 14]), torch.float32: torch.Size([96, 384, 1, 1]), torch.float32: torch.Size([96]),  ...)

Copy link

@Sangsooko Sangsooko Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you can see in the following code which is a part of mobilebert, the div has quantized input and a constant input.
The above code changes the constant number of 5.656854249492381 into a quantized one.
Because this changing is only for the internal operation with real HWs, the code can be moved to our backend codes.

    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%dequantize_per_tensor_default_424, 5.656854249492381), kwargs = {})
    %quantize_per_tensor_default_425 :   [num_users=1] = call_function[target=torch.ops.quantized_decomposed.quantize_per_tensor.default](args = (%div, 0.15715381503105164, -78, -128, 127, torch.int8), kwargs = {})
    %dequantize_per_tensor_default_425 : [num_users=1] = call_function[target=torch.ops.quantized_decomposed.dequantize_per_tensor.default](args = (%quantize_per_tensor_default_425, 0.15715381503105164, -78, -128, 127, torch.int8), kwargs = {out_dtype: torch.float16})

If you think the additional quantization process after convert_pt2e is not appropriate according to the ExecuTorch code policy, we will move this code to our backend in the next update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. I figured it might have been for binary ops that operate with constant args. In that case LGTM!

@SS-JIA
Copy link
Contributor

SS-JIA commented Sep 23, 2025

Also, please run lintrunner to fix the lint errors.

# install lintrunner dependencies
pip install lintrunner
pip install lintrunner_adapters

cd executorch
lintrunner -a --verbose

@SS-JIA
Copy link
Contributor

SS-JIA commented Sep 30, 2025

Overall LGTM, but before I stamp:

  1. Could you guys rebase to latest main and re-submit the PR? I want to see if the failures get resolved.
  2. There's a remaining comment regarding manually quantizing constant tensors when serializing the model. To me, I don't think this should be required since the quantization workflow should be quantizing the constant tensor for you already.

@Jiseong-oh Jiseong-oh force-pushed the exynos-quantize-support branch 2 times, most recently from 1eee0b5 to 1d1f37c Compare October 2, 2025 09:16
Copy link
Contributor

@SS-JIA SS-JIA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for addressing the comments.

Would you mind rebasing one more time before merging? Just to ensure that the merge base is fairly up to date. Thanks!

Copy link
Contributor

@SS-JIA SS-JIA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, one more thing. In addition to rebasing, can you include the fix for getting the AI_LITECORE_API_KEY that's implemented in #14866?

The fix was to move up when secrets-env is declared.

@SS-JIA
Copy link
Contributor

SS-JIA commented Oct 7, 2025

I decided to just merge #14866, so please just rebase past it and this PR LGTM!

@Jiseong-oh Jiseong-oh force-pushed the exynos-quantize-support branch 2 times, most recently from 44353d3 to b8215f4 Compare October 9, 2025 06:50
Jiseong-oh and others added 10 commits October 10, 2025 08:43
1. Add quant strategies of enn-backend
2. Add support for the enn's quant strategies
3. Provide example code of MV2

Co-authored-by:  chen.zhao <chen03.zhao@samsung.com>
Co-authored-by:  sangsoo.Ko <sangsoo.ko@samsung.com>
Models contain: dlv3/edsr/iv3/iv4/mv3/resnet50/vit/w2l

Co-authored-by:  chen.zhao <chen03.zhao@samsung.com>
Co-authored-by:  sangsoo.Ko <sangsoo.ko@samsung.com>
Current models are supported in each script, execute these scripts to
verify validness of quantization.

Co-authored-by: chong-chen <chong.chen@samsung.com>
- This model need to be updated version of LiteCore.
- This sdk can support mv3 quant model.
Fix comments

Co-authored-by: chen03.zhao <chen03.zhao@samsung.com>
As the title shows

Co-authored-by: chong-chen <chong.chen@samsung.com>
- This model need to be updated version of LiteCore.

Signed-off-by: jiseong.oh <jiseong.oh@samsung.com>
- This sdk can support mv3 quant model.

Signed-off-by: jiseong.oh <jiseong.oh@samsung.com>
For ic4, image shape should be (299, 299), it don't need CenterCrop.

Co-authored-by: xz-linghu <xz.linghu@samsung.com>
@Jiseong-oh Jiseong-oh force-pushed the exynos-quantize-support branch from d21a2ae to 44a5e9e Compare October 9, 2025 23:43
@SS-JIA
Copy link
Contributor

SS-JIA commented Oct 10, 2025

Validated that samsung test is passing via #14977. I believe that this PR cannot access the repository API secret because the merge branch is from a fork.

@SS-JIA SS-JIA merged commit 8b67236 into pytorch:main Oct 10, 2025
260 of 269 checks passed
@Jiseong-oh Jiseong-oh deleted the exynos-quantize-support branch October 10, 2025 04:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. partner: samsung For backend delegation, kernels, demo, etc. from the 3rd-party partner, Samsung release notes: exynos

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants