Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fake_quantize: respect device affinity #39031

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions test/quantization/test_workflow_module.py
Expand Up @@ -41,6 +41,7 @@
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
override_qengines,
)

# Reference method for fake quantize
Expand Down Expand Up @@ -782,3 +783,37 @@ def forward(self, x):
self.assertTrue(
isinstance(fused_model.conv.bn, nn.SyncBatchNorm),
"Expected BN to be converted to SyncBN")

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@override_qengines
def test_device_affinity(self):
"""
Tests that converting a model to QAT respects device affinity
"""
class Model(nn.Module):

def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(2, 2)

def forward(self, x):
x = self.linear(x)
return x

model = Model()
model.qconfig = torch.quantization.get_default_qat_qconfig(torch.backends.quantized.engine)
device = torch.device('cuda:0')
model.to(device)
torch.quantization.prepare_qat(model, inplace=True)
model_devices = {p.device for p in model.parameters()} | \
{p.device for p in model.buffers()}
self.assertEqual(len(model_devices), 1,
message="Model must only have one device after QAT")
model_device = next(iter(model_devices))
self.assertEqual(model_device, device,
message="QAT must respect device affinity")

# ensure that running an input on CUDA works without any needed changes
input = torch.randn(2, device=device)
model(input)
27 changes: 24 additions & 3 deletions torch/quantization/quantize.py
Expand Up @@ -85,10 +85,20 @@ def add_observer_(module):
Return:
None, module is modified inplace with added observer modules and forward_hooks
"""
# respect device affinity when adding observers
# devices = {p.device for p in module.parameters()}
devices = get_unique_devices_(module)
assert len(devices) == 1, (
"add_observer_ only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices))

for child in module.children():
if type(child) == nnq.FloatFunctional or type(child) == nnq.QFunctional:
if hasattr(child, 'qconfig') and child.qconfig is not None:
child.activation_post_process = child.qconfig.activation()
child.activation_post_process = \
child.qconfig.activation().to(device)
else:
add_observer_(child)

Expand All @@ -97,9 +107,13 @@ def add_observer_(module):
if hasattr(module, 'qconfig') and module.qconfig is not None and \
len(module._modules) == 0 and not isinstance(module, torch.nn.Sequential):
# observer and hook will be gone after we swap the module
module.add_module('activation_post_process', module.qconfig.activation())
module.add_module('activation_post_process', module.qconfig.activation().to(device))
module.register_forward_hook(_observer_forward_hook)

def get_unique_devices_(module):
return {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}

def add_quant_dequant(module):
r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
Note that this function will modify the children of module inplace and it
Expand Down Expand Up @@ -352,7 +366,14 @@ def swap_module(mod, mapping):
# Always replace dequantstub with dequantize
if hasattr(mod, 'qconfig') and mod.qconfig is not None or type(mod) == DeQuantStub:
if type(mod) in mapping:
new_mod = mapping[type(mod)].from_float(mod)
# respect device affinity when swapping modules
devices = get_unique_devices_(mod)
assert len(devices) == 1, (
"swap_module only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices))
new_mod = mapping[type(mod)].from_float(mod).to(device)
return new_mod

def get_observer_dict(mod, target_dict, prefix=""):
Expand Down