Skip to content

Commit

Permalink
[pytorch][ao] force weight observer/fake_quant to be on the same devi…
Browse files Browse the repository at this point in the history
…ce as the weight tensor (#106755)

Summary:
Pull Request resolved: #106755

As title.
There's a corner case where both cpu and gpu are avaiable, although the model is moved to cpu, the newly created PTQ weight observer is still on gpu. Therefore, during the convert, this line will fail https://fburl.com/4rhipfvb

Test Plan: CI

Differential Revision: D48141494

fbshipit-source-id: 8736e84a6242e18edde862408f11c9d3f8c5b4d3
  • Loading branch information
jiaxuzhu92 authored and facebook-github-bot committed Aug 8, 2023
1 parent 2764ead commit 2e00d76
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_get_module,
_is_custom_module_lstm,
_is_custom_module_mha,
assert_and_get_unique_device,
get_custom_module_class_keys,
create_getattr_from_value,
collect_producer_nodes,
Expand Down Expand Up @@ -733,6 +734,9 @@ def convert_weighted_module(
is_ptq = weight_post_process is None
if is_ptq:
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
device = assert_and_get_unique_device(float_module)
if device:
weight_post_process.to(device)

# Call weight observer/fake_quant at least once to ensure the scales and zero points
# have the right shapes. Note: there are two cases where we don't have to do this:
Expand Down

0 comments on commit 2e00d76

Please sign in to comment.