New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fake_quant: more memory efficient per-channel backward #51255
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Summary: This is the same as #50561, but for per-channel fake_quant. TODO before land write up better Test Plan: ``` python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
This was referenced Jan 28, 2021
💊 CI failures summary and remediationsAs of commit e42c44a (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Summary: This is the same as #50561, but for per-channel fake_quant. TODO before land write up better Memory and performance impact (MobileNetV2): TODO Performance impact (microbenchmarks): https://gist.github.com/vkuzo/fbe1968d2bbb79b3f6dd776309fbcffc * forward pass on cpu: 512ms -> 750ms (+46%) * forward pass on cuda: 99ms -> 128ms (+30%) * note: the overall performance impact to training jobs should be minimal, because this is used for weights, and relative importance of fq is dominated by fq'ing the activations * note: we can optimize the perf in a future PR by reading once and writing twice Test Plan: ``` python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
vkuzo
added a commit
that referenced
this pull request
Jan 28, 2021
Summary: This is the same as #50561, but for per-channel fake_quant. TODO before land write up better Test Plan: ``` python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 7498ee6ff77ae53fe30587cc0efe12f3a3b87428 Pull Request resolved: #51255
Summary: This is the same as #50561, but for per-channel fake_quant. We add an alternative definition of fake quantize per channel's backward which computes a mask of what is clipped in the forward, and reuses that mask in the backward (instead of recomputing it): ``` # before - forward (pseudocode) def fq_forward(x, scale, zp, qmin, qmax): q_val = clamp(nearby_int(x / scale) + zp, qmin, qmax) fq_val = (q_val - zp) * scale return fq_val # before - backward (pseudocode) def fq_backward(dy, x, scale, zp, qmin, qmax): q_val_unclamped = nearby_int(x / scale) + zp mask = qmin <= q_val_unclamped and q_val_unclamped <= qmax return dy * mask # after - forward (pseudocode) def fq_forward(x, scale, zp, qmin, qmax): q_val_unclamped = nearby_int(x / scale) + zp mask = qmin <= q_val_unclamped and q_val_unclamped <= qmax q_val = clamp(q_val_unclamped, qmin, qmax) fq_val = (q_val - zp) * scale return fq_val, mask # after - backward (pseudocode) def fq_backward(dy, mask): return dy * mask ``` There is a slight memory efficiency win (75% of whatever per-channel fq contributes, although it does not contribute much). There is also a nice side effect that fake_quant_per_channel will now support a module calling it twice in the same forward. Previously, this was broken because (1) scale + zp were passed to the backward as arguments, and (2) scale + zp were updated in-place during the forward The combination of (1) and (2) made it illegal to use the same fake_quant twice, since it would modify in-place the information needed for the backward. After this PR, (1) will no longer apply, so this use case can be enabled. There are two things left for future PRs: 1. kernels for mask and fq value are duplicated, instead of reading once and writing twice. We will hopefully optimize that in a future PR. Impact is low in the real world because this is not a bottleneck. 2. we use `BoolTensor` to pass the mask which takes 1 byte per element, in the future we can pack the bits to save more memory Memory and performance impact (MobileNetV2): ``` # qat_fp32: model with fake_quants turned off (baseline) # qat_1: step 2 of qat, with observers disabled and fake_quants enabled (all of the overhead is the fake_quants) # before: fbgemm - qat_fp32 -> qat_1 max memory usage (mib): 3302 -> 3538 (overhead: 7.1%) latency (ms): 147 -> 187 (overhead: 27%) # after: fbgemm - qat_fp32 -> qat_1 max memory usage (mib): 3302 -> 3532 (overhead: 7.0%) latency (ms): 147 -> 167 (overhead: 14%) ``` Performance impact (microbenchmarks): https://gist.github.com/vkuzo/fbe1968d2bbb79b3f6dd776309fbcffc * forward pass on cpu: 512ms -> 750ms (+46%) * forward pass on cuda: 99ms -> 128ms (+30%) * note: the overall performance impact to training jobs should be minimal, because this is used for weights, and relative importance of fq is dominated by fq'ing the activations. The data collected from real benchmarks (MobileNetV2 QAT) matches this hypothesis, and we actually see a speedup there. * note: we can optimize the perf in a future PR by changing the kernels to read once and write twice Test Plan: ``` python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D26117721](https://our.internmc.facebook.com/intern/diff/D26117721) [ghstack-poisoned]
jerryzh168
approved these changes
Jan 28, 2021
This pull request has been merged in 267e243. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack:
Summary:
This is the same as #50561, but for per-channel fake_quant. We add an alternative definition
of fake quantize per channel's backward which computes a mask of what is clipped in the
forward, and reuses that mask in the backward (instead of recomputing it):
There is a slight memory efficiency win (75% of whatever per-channel fq contributes,
although it does not contribute much).
There is also a nice side effect that fake_quant_per_channel will now support
a module calling it twice in the same forward. Previously, this was broken because
(1) scale + zp were passed to the backward as arguments, and
(2) scale + zp were updated in-place during the forward
The combination of (1) and (2) made it illegal to use the same fake_quant twice, since
it would modify in-place the information needed for the backward. After this PR, (1)
will no longer apply, so this use case can be enabled.
There are two things left for future PRs:
BoolTensor
to pass the mask which takes 1 byte per element, in the future we can pack the bits to save more memoryMemory and performance impact (MobileNetV2):
Performance impact (microbenchmarks): https://gist.github.com/vkuzo/fbe1968d2bbb79b3f6dd776309fbcffc
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D26117721