Skip to content

Commit

Permalink
[quant][qat] Ensure observer respects device affinity (#47514)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #47514

Previosuly the scale and zero_point were returned on the CPU even if
the input tensor was on the GPU.
This is because `copy_()` doesn't respect the device when copying over the tensor.

Also fixed a bug where we were always setting the device to 'cuda' (irrespective of the device id)
in the calculate_qparams function

Test Plan:
python test/test_quantization.py TestObserver.test_observer_qparams_respects_device_affinity

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D24800495

fbshipit-source-id: d7a76c59569842ed69029d0eb4fa9df63f87e28c
  • Loading branch information
supriyar authored and facebook-github-bot committed Nov 10, 2020
1 parent abae12b commit 6bb18b2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
23 changes: 23 additions & 0 deletions test/quantization/test_workflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,29 @@ def test_save_load_state_dict_script(self):
# Verify that state_dict matches exactly with original one.
self.assertEqual(scripted.state_dict(), scripted_2.state_dict())


@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_observer_qparams_respects_device_affinity(self):
"""
Ensure that the scale and zero_point returned by the observer
are on the same device as the input tensor.
"""
observerList = [MinMaxObserver(),
MovingAverageMinMaxObserver(),
PerChannelMinMaxObserver(),
MovingAveragePerChannelMinMaxObserver()]
for obs in observerList:
device = torch.device('cuda:1')
x = torch.randn(1, 2, device=device)
obs.to(device)
result = obs(x)
scale, zero_point = obs.calculate_qparams()

self.assertEqual(x.device, scale.device)
self.assertEqual(x.device, zero_point.device)


# HistogramObserver that works like it does on master
class _ReferenceHistogramObserver(HistogramObserver):
def __init__(self, *args, **kwargs):
Expand Down
7 changes: 3 additions & 4 deletions torch/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor) -> Tu
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

scale = torch.ones(min_val_neg.size(), dtype=torch.float32)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64)
device = 'cuda' if min_val_neg.is_cuda else 'cpu'
device = min_val_neg.device
scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
Expand Down Expand Up @@ -297,7 +297,6 @@ def _calculate_qparams(self, min_val: torch.Tensor, max_val: torch.Tensor) -> Tu
if self.qscheme == torch.per_channel_affine_float_qparams:
zero_point = torch.tensor([float(zero_point)], dtype=zero_point.dtype, device=device)


return scale, zero_point


Expand Down

0 comments on commit 6bb18b2

Please sign in to comment.