Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fake_quant: more memory efficient per-channel backward #51255

Closed
wants to merge 3 commits into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 28, 2021

Stack from ghstack:

Summary:

This is the same as #50561, but for per-channel fake_quant. We add an alternative definition
of fake quantize per channel's backward which computes a mask of what is clipped in the
forward, and reuses that mask in the backward (instead of recomputing it):

# before - forward (pseudocode)
def fq_forward(x, scale, zp, qmin, qmax):
    q_val = clamp(nearby_int(x / scale) + zp, qmin, qmax)
    fq_val = (q_val - zp) * scale
    return fq_val

# before - backward (pseudocode)
def fq_backward(dy, x, scale, zp, qmin, qmax):
    q_val_unclamped = nearby_int(x / scale) + zp
    mask = qmin <= q_val_unclamped and q_val_unclamped <= qmax
    return dy * mask

# after - forward (pseudocode)
def fq_forward(x, scale, zp, qmin, qmax):
    q_val_unclamped = nearby_int(x / scale) + zp
    mask = qmin <= q_val_unclamped and q_val_unclamped <= qmax
    q_val = clamp(q_val_unclamped, qmin, qmax)
    fq_val = (q_val - zp) * scale
    return fq_val, mask

# after - backward (pseudocode)
def fq_backward(dy, mask):
    return dy * mask

There is a slight memory efficiency win (75% of whatever per-channel fq contributes,
although it does not contribute much).

There is also a nice side effect that fake_quant_per_channel will now support
a module calling it twice in the same forward. Previously, this was broken because
(1) scale + zp were passed to the backward as arguments, and
(2) scale + zp were updated in-place during the forward
The combination of (1) and (2) made it illegal to use the same fake_quant twice, since
it would modify in-place the information needed for the backward. After this PR, (1)
will no longer apply, so this use case can be enabled.

There are two things left for future PRs:

  1. kernels for mask and fq value are duplicated, instead of reading once and writing twice. We will hopefully optimize that in a future PR. Impact is low in the real world because this is not a bottleneck.
  2. we use BoolTensor to pass the mask which takes 1 byte per element, in the future we can pack the bits to save more memory

Memory and performance impact (MobileNetV2):

# qat_fp32: model with fake_quants turned off (baseline)
# qat_1: step 2 of qat, with observers disabled and fake_quants enabled (all of the overhead is the fake_quants)

# before: fbgemm - qat_fp32 -> qat_1
max memory usage (mib): 3302 -> 3538 (overhead: 7.1%)
latency (ms):  147 -> 187 (overhead: 27%)

# after: fbgemm - qat_fp32 -> qat_1
max memory usage (mib): 3302 -> 3532 (overhead: 7.0%)
latency (ms):  147 -> 167 (overhead: 14%)

Performance impact (microbenchmarks): https://gist.github.com/vkuzo/fbe1968d2bbb79b3f6dd776309fbcffc

  • forward pass on cpu: 512ms -> 750ms (+46%)
  • forward pass on cuda: 99ms -> 128ms (+30%)
  • note: the overall performance impact to training jobs should be minimal, because this is used for weights, and relative importance of fq is dominated by fq'ing the activations. The data collected from real benchmarks (MobileNetV2 QAT) matches this hypothesis, and we actually see a speedup there.
  • note: we can optimize the perf in a future PR by changing the kernels to read once and write twice

Test Plan:

python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D26117721

Summary:

This is the same as #50561, but for per-channel fake_quant.

TODO before land write up better

Test Plan:

```
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 28, 2021

💊 CI failures summary and remediations

As of commit e42c44a (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Summary:

This is the same as #50561, but for per-channel fake_quant.

TODO before land write up better

Memory and performance impact (MobileNetV2): TODO

Performance impact (microbenchmarks): https://gist.github.com/vkuzo/fbe1968d2bbb79b3f6dd776309fbcffc
* forward pass on cpu: 512ms -> 750ms (+46%)
* forward pass on cuda: 99ms -> 128ms (+30%)
* note: the overall performance impact to training jobs should be minimal, because this is used for weights, and relative importance of fq is dominated by fq'ing the activations
* note: we can optimize the perf in a future PR by reading once and writing twice

Test Plan:

```
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jan 28, 2021
Summary:

This is the same as #50561, but for per-channel fake_quant.

TODO before land write up better

Test Plan:

```
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7498ee6ff77ae53fe30587cc0efe12f3a3b87428
Pull Request resolved: #51255
Summary:

This is the same as #50561, but for per-channel fake_quant.  We add an alternative definition
of fake quantize per channel's backward which computes a mask of what is clipped in the
forward, and reuses that mask in the backward (instead of recomputing it):

```
# before - forward (pseudocode)
def fq_forward(x, scale, zp, qmin, qmax):
    q_val = clamp(nearby_int(x / scale) + zp, qmin, qmax)
    fq_val = (q_val - zp) * scale
    return fq_val

# before - backward (pseudocode)
def fq_backward(dy, x, scale, zp, qmin, qmax):
    q_val_unclamped = nearby_int(x / scale) + zp
    mask = qmin <= q_val_unclamped and q_val_unclamped <= qmax
    return dy * mask

# after - forward (pseudocode)
def fq_forward(x, scale, zp, qmin, qmax):
    q_val_unclamped = nearby_int(x / scale) + zp
    mask = qmin <= q_val_unclamped and q_val_unclamped <= qmax
    q_val = clamp(q_val_unclamped, qmin, qmax)
    fq_val = (q_val - zp) * scale
    return fq_val, mask

# after - backward (pseudocode)
def fq_backward(dy, mask):
    return dy * mask
```

There is a slight memory efficiency win (75% of whatever per-channel fq contributes,
although it does not contribute much).

There is also a nice side effect that fake_quant_per_channel will now support
a module calling it twice in the same forward.  Previously, this was broken because
(1) scale + zp were passed to the backward as arguments, and
(2) scale + zp were updated in-place during the forward
The combination of (1) and (2) made it illegal to use the same fake_quant twice, since
it would modify in-place the information needed for the backward.  After this PR, (1)
will no longer apply, so this use case can be enabled.

There are two things left for future PRs:
1. kernels for mask and fq value are duplicated, instead of reading once and writing twice.  We will hopefully optimize that in a future PR.  Impact is low in the real world because this is not a bottleneck.
2. we use `BoolTensor` to pass the mask which takes 1 byte per element, in the future we can pack the bits to save more memory

Memory and performance impact (MobileNetV2): 

```
# qat_fp32: model with fake_quants turned off (baseline)
# qat_1: step 2 of qat, with observers disabled and fake_quants enabled (all of the overhead is the fake_quants)

# before: fbgemm - qat_fp32 -> qat_1
max memory usage (mib): 3302 -> 3538 (overhead: 7.1%)
latency (ms):  147 -> 187 (overhead: 27%)

# after: fbgemm - qat_fp32 -> qat_1
max memory usage (mib): 3302 -> 3532 (overhead: 7.0%)
latency (ms):  147 -> 167 (overhead: 14%)
```

Performance impact (microbenchmarks): https://gist.github.com/vkuzo/fbe1968d2bbb79b3f6dd776309fbcffc
* forward pass on cpu: 512ms -> 750ms (+46%)
* forward pass on cuda: 99ms -> 128ms (+30%)
* note: the overall performance impact to training jobs should be minimal, because this is used for weights, and relative importance of fq is dominated by fq'ing the activations.  The data collected from real benchmarks (MobileNetV2 QAT) matches this hypothesis, and we actually see a speedup there.
* note: we can optimize the perf in a future PR by changing the kernels to read once and write twice

Test Plan:

```
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_forward_per_channel_cachemask_cuda
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cpu
python test/test_quantization.py TestFakeQuantize.test_backward_per_channel_cachemask_cuda
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D26117721](https://our.internmc.facebook.com/intern/diff/D26117721)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 267e243.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants