Add FuseQATConvBN to fuse_ops (#19442)#19442
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19442
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 84d7498 with merge base a49171d ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@ethansfng has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104497938. |
This PR needs a
|
Summary:
Adds a FuseQATConvBN which folds the QAT Conv-BN simulation chain (`conv → q → dq → div(scale) → add(orig_bias) → batch_norm`) inserted by `prepare_qat_pt2e` into the conv's quantized bias and removes the chain.
The pass runs in two steps inside a single `call()`:
1. Bias prep — for each conv, create a zero-filled quantized bias if missing, or quantize a float bias as per-tensor int32. Required so step 2 has a quantized bias slot to write the BN correction into.
2. Fold — for each matched chain, compute the BN correction
C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias
and absorb it into the conv's quantized bias in place. Erase the chain + batch_norm.
Differential Revision: D104497938
d5c07d4 to
84d7498
Compare
Summary:
Adds a FuseQATConvBN which folds the QAT Conv-BN simulation chain (
conv → q → dq → div(scale) → add(orig_bias) → batch_norm) inserted byprepare_qat_pt2einto the conv's quantized bias and removes the chain.The pass runs in two steps inside a single
call():C = (orig_bias - running_mean) * bn_weight / sqrt(running_var + eps) + bn_bias
and absorb it into the conv's quantized bias in place. Erase the chain + batch_norm.
Differential Revision: D104497938