Skip to content

Commit

Permalink
add_observer: respect device affinity for ReLU
Browse files Browse the repository at this point in the history
Summary:

In #39031 we made fake quantize respect device affinity of the
original module. However, that PR only handled modules with parameters
or buffers, and did not work properly for `ReLU`.

Fixing the logic to also work for `ReLU` by passing the parent's
device when adding observers.

Test Plan:

```
python test/test_quantization.py TestDistributed.test_device_affinity
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jun 1, 2020
1 parent d4345c8 commit 1dee066
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
10 changes: 7 additions & 3 deletions test/quantization/test_workflow_module.py
Expand Up @@ -823,10 +823,14 @@ class Model(nn.Module):

def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(2, 2)
self.conv = nn.Conv2d(1, 1, 1)
self.bn = nn.BatchNorm2d(1)
self.relu = nn.ReLU()

def forward(self, x):
x = self.linear(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x

model = Model()
Expand All @@ -841,5 +845,5 @@ def forward(self, x):
self.assertEqual(model_device, device)

# ensure that running an input on CUDA works without any needed changes
input = torch.randn(2, device=device)
input = torch.randn(4, 1, 4, 4, device=device)
model(input)
19 changes: 10 additions & 9 deletions torch/quantization/quantize.py
Expand Up @@ -73,26 +73,27 @@ def _observer_forward_hook(self, input, output):
"""
return self.activation_post_process(output)

def add_observer_(module):
def add_observer_(module, device=None):
r"""Add observer for the leaf child of the module.
This function insert observer module to all leaf child module that
has a valid qconfig attribute.
Args:
module: input module with qconfig attributes for all the leaf modules that we want to quantize
device: parent device, if any
Return:
None, module is modified inplace with added observer modules and forward_hooks
"""
# respect device affinity when adding observers
# devices = {p.device for p in module.parameters()}
devices = get_unique_devices_(module)
assert len(devices) <= 1, (
"add_observer_ only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None
if device is None:
devices = get_unique_devices_(module)
assert len(devices) <= 1, (
"add_observer_ only works with cpu or single-device CUDA modules, "
"but got devices {}".format(devices)
)
device = next(iter(devices)) if len(devices) > 0 else None

for child in module.children():
if type(child) == nnq.FloatFunctional or type(child) == nnq.QFunctional:
Expand All @@ -102,7 +103,7 @@ def add_observer_(module):
activation.to(device)
child.activation_post_process = activation
else:
add_observer_(child)
add_observer_(child, device)

# Insert observers only for leaf nodes, note that this observer is for
# the output of the module, for input QuantStub will observe them
Expand Down

0 comments on commit 1dee066

Please sign in to comment.