Skip to content

Commit

Permalink
fake_quantize: respect device affinity
Browse files Browse the repository at this point in the history
Summary:

Makes the eager mode QAT prepare logic respect device affinity.
This fixes the issue where a module is on `cuda:0`, and running
the QAT prepare script would add observers on `cpu`.  Now it
will add them on the original device.

Test Plan:

```
python test/test_quantization.py TestDistributed.test_device_affinity
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f14b10e950fb2c477126f9a137161a7f97239b5d
Pull Request resolved: #39031
  • Loading branch information
vkuzo committed May 29, 2020
1 parent c213317 commit 331af68
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
33 changes: 33 additions & 0 deletions test/quantization/test_workflow_module.py
Expand Up @@ -41,6 +41,7 @@
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
override_qengines,
)

# Reference method for fake quantize
Expand Down Expand Up @@ -810,3 +811,35 @@ def forward(self, x):
self.assertTrue(
isinstance(fused_model.conv.bn, nn.SyncBatchNorm),
"Expected BN to be converted to SyncBN")

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@override_qengines
def test_device_affinity(self):
"""
Tests that converting a model to QAT respects device affinity
"""
class Model(nn.Module):

def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(2, 2)

def forward(self, x):
x = self.linear(x)
return x

model = Model()
model.qconfig = torch.quantization.get_default_qat_qconfig(torch.backends.quantized.engine)
device = torch.device('cuda:0')
model.to(device)
torch.quantization.prepare_qat(model, inplace=True)
model_devices = {p.device for p in model.parameters()} | \
{p.device for p in model.buffers()}
self.assertEqual(len(model_devices), 1)
model_device = next(iter(model_devices))
self.assertEqual(model_device, device)

# ensure that running an input on CUDA works without any needed changes
input = torch.randn(2, device=device)
model(input)
32 changes: 30 additions & 2 deletions torch/quantization/quantize.py
Expand Up @@ -85,10 +85,22 @@ def add_observer_(module):
Return:
None, module is modified inplace with added observer modules and forward_hooks
"""
# respect device affinity when adding observers
# devices = {p.device for p in module.parameters()}
devices = get_unique_devices_(module)
assert len(devices) <= 1, (
"add_observer_ only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None

for child in module.children():
if type(child) == nnq.FloatFunctional or type(child) == nnq.QFunctional:
if hasattr(child, 'qconfig') and child.qconfig is not None:
child.activation_post_process = child.qconfig.activation()
activation = child.qconfig.activation()
if device is not None:
activation.to(device)
child.activation_post_process = activation
else:
add_observer_(child)

Expand All @@ -97,9 +109,16 @@ def add_observer_(module):
if hasattr(module, 'qconfig') and module.qconfig is not None and \
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential):
# observer and hook will be gone after we swap the module
module.add_module('activation_post_process', module.qconfig.activation())
activation = module.qconfig.activation()
if device is not None:
activation.to(device)
module.add_module('activation_post_process', activation)
module.register_forward_hook(_observer_forward_hook)

def get_unique_devices_(module):
return {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}

def add_quant_dequant(module):
r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
Note that this function will modify the children of module inplace and it
Expand Down Expand Up @@ -352,7 +371,16 @@ def swap_module(mod, mapping):
# Always replace dequantstub with dequantize
if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub:
if type(mod) in mapping:
# respect device affinity when swapping modules
devices = get_unique_devices_(mod)
assert len(devices) <= 1, (
"swap_module only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None
new_mod = mapping[type(mod)].from_float(mod)
if device:
new_mod.to(device)
return new_mod

def get_observer_dict(mod, target_dict, prefix=""):
Expand Down

0 comments on commit 331af68

Please sign in to comment.