Skip to content

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Feb 16, 2022

Stack from ghstack:

Summary:

Before this PR, DBR quantization had a limitation on handling user
code which iterates over all module children. For example, imagine
a forward function such as

def forward(self, x):
    for module in self:
        x = module(x)
    return x

Before this PR, this code would break with DBR quantization, because
we attach AutoQuantizationState objects to each child, and those
objects live in the child's module hierarchy and will appear in
these kinds of iterations, changing the meaning of the user program.

This PR reduces the scope of this problem to just the top level module.
Instead of attaching AutoQuantizationState objects to each child,
we register them in a map on the parent. Here is a before and after:

// toy model
model
 |--> child1

// toy model with AutoQuantizationState objects, before this PR
model
 |--> child1
 |  |--> _auto_quant_state
 |--> _auto_quant_state

// toy model with AutoQuantizationState objects, after this PR
model
 |--> child1
 |--> _fqn_to_auto_quant_state_map
    |--> ( ) --> _auto_quant_state // of `model`
    |--> (child1) --> _auto_quant_state // of `model.child1`

Note: child1._auto_quant_state works as before for convenience,
but the child1 object now stores a soft link to its _auto_quant_state
instead of properly registering it in its module hierarchy. This is
somewhat hacky. If we need to improve this in the future, we could
remove this soft link and refactor the code to call the FQN map
instead.

Note: if the top level module iterates over its children, things will
still be broken. This is less likely, and we will recommend that the
user work around this by wrapping their model, or checking for the
AutoQuantizationStateModuleDict type in their iteration loop.

The impact of this change should be an improvement of coverage
of user models. In fact, we expect this to drive our coverage of
torchbenchmark models from 89% to 100%.

Test plan:

// previously disabled test cases with user code iterating
// over module children are now enabled, with wrappers
python test/test_quantization.py -k test_module_calls_items
python test/test_quantization.py -k test_vovnet_sequential

Differential Revision: D34281074

Summary:

Before this PR, DBR quantization had a limitation on handling user
code which iterates over all module children. For example, imagine
a forward function such as

```
def forward(self, x):
    for module in self:
        x = module(x)
    return x
```

Before this PR, this code would break with DBR quantization, because
we attach `AutoQuantizationState` objects to each child, and those
objects live in the child's module hierarchy and will appear in
these kinds of iterations, changing the meaning of the user program.

This PR reduces the scope of this problem to just the top level module.
Instead of attaching `AutoQuantizationState` objects to each child,
we register them in a map on the parent. Here is a before and after:

```
// toy model
model
 |--> child1

// toy model with AutoQuantizationState objects, before this PR
model
 |--> child1
 |  |--> _auto_quant_state
 |--> _auto_quant_state

// toy model with AutoQuantizationState objects, after this PR
model
 |--> child1
 |--> _fqn_to_auto_quant_state_map
    |--> ( ) --> _auto_quant_state // of `model`
    |--> (child1) --> _auto_quant_state // of `model.child1`
```

Note: `child1._auto_quant_state` works as before for convenience,
but the `child1` object now stores a soft link to its `_auto_quant_state`
instead of properly registering it in its module hierarchy. This is
somewhat hacky. If we need to improve this in the future, we could
remove this soft link and refactor the code to call the FQN map
instead.

Note: if the top level module iterates over its children, things will
still be broken. This is less likely, and we will recommend that the
user work around this by wrapping their model, or checking for the
`AutoQuantizationStateModuleDict` type in their iteration loop.

The impact of this change should be an improvement of coverage
of user models. In fact, we expect this to drive our coverage of
torchbenchmark models from 89% to 100%.

Test plan:

```
// previously disabled test cases with user code iterating
// over module children are now enabled, with wrappers
python test/test_quantization.py -k test_module_calls_items
python test/test_quantization.py -k test_vovnet_sequential
```

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 16, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/b4099f365f2851b56398347318b8717afdaa169e/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk, ciflow/xla ✅ triggered
linux-bionic-rocm4.5-py3.7 ciflow/all, ciflow/default, ciflow/linux, ciflow/rocm, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
macos-arm64-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-arm64-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
macos-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
macos-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
macos-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 16, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 97ded85 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Summary:

Before this PR, DBR quantization had a limitation on handling user
code which iterates over all module children. For example, imagine
a forward function such as

```
def forward(self, x):
    for module in self:
        x = module(x)
    return x
```

Before this PR, this code would break with DBR quantization, because
we attach `AutoQuantizationState` objects to each child, and those
objects live in the child's module hierarchy and will appear in
these kinds of iterations, changing the meaning of the user program.

This PR reduces the scope of this problem to just the top level module.
Instead of attaching `AutoQuantizationState` objects to each child,
we register them in a map on the parent. Here is a before and after:

```
// toy model
model
 |--> child1

// toy model with AutoQuantizationState objects, before this PR
model
 |--> child1
 |  |--> _auto_quant_state
 |--> _auto_quant_state

// toy model with AutoQuantizationState objects, after this PR
model
 |--> child1
 |--> _fqn_to_auto_quant_state_map
    |--> ( ) --> _auto_quant_state // of `model`
    |--> (child1) --> _auto_quant_state // of `model.child1`
```

Note: `child1._auto_quant_state` works as before for convenience,
but the `child1` object now stores a soft link to its `_auto_quant_state`
instead of properly registering it in its module hierarchy. This is
somewhat hacky. If we need to improve this in the future, we could
remove this soft link and refactor the code to call the FQN map
instead.

Note: if the top level module iterates over its children, things will
still be broken. This is less likely, and we will recommend that the
user work around this by wrapping their model, or checking for the
`AutoQuantizationStateModuleDict` type in their iteration loop.

The impact of this change should be an improvement of coverage
of user models. In fact, we expect this to drive our coverage of
torchbenchmark models from 89% to 100%.

Test plan:

```
// previously disabled test cases with user code iterating
// over module children are now enabled, with wrappers
python test/test_quantization.py -k test_module_calls_items
python test/test_quantization.py -k test_vovnet_sequential
```

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Feb 16, 2022
Summary:

Before this PR, DBR quantization had a limitation on handling user
code which iterates over all module children. For example, imagine
a forward function such as

```
def forward(self, x):
    for module in self:
        x = module(x)
    return x
```

Before this PR, this code would break with DBR quantization, because
we attach `AutoQuantizationState` objects to each child, and those
objects live in the child's module hierarchy and will appear in
these kinds of iterations, changing the meaning of the user program.

This PR reduces the scope of this problem to just the top level module.
Instead of attaching `AutoQuantizationState` objects to each child,
we register them in a map on the parent. Here is a before and after:

```
// toy model
model
 |--> child1

// toy model with AutoQuantizationState objects, before this PR
model
 |--> child1
 |  |--> _auto_quant_state
 |--> _auto_quant_state

// toy model with AutoQuantizationState objects, after this PR
model
 |--> child1
 |--> _fqn_to_auto_quant_state_map
    |--> ( ) --> _auto_quant_state // of `model`
    |--> (child1) --> _auto_quant_state // of `model.child1`
```

Note: `child1._auto_quant_state` works as before for convenience,
but the `child1` object now stores a soft link to its `_auto_quant_state`
instead of properly registering it in its module hierarchy. This is
somewhat hacky. If we need to improve this in the future, we could
remove this soft link and refactor the code to call the FQN map
instead.

Note: if the top level module iterates over its children, things will
still be broken. This is less likely, and we will recommend that the
user work around this by wrapping their model, or checking for the
`AutoQuantizationStateModuleDict` type in their iteration loop.

The impact of this change should be an improvement of coverage
of user models. In fact, we expect this to drive our coverage of
torchbenchmark models from 89% to 100%.

Test plan:

```
// previously disabled test cases with user code iterating
// over module children are now enabled, with wrappers
python test/test_quantization.py -k test_module_calls_items
python test/test_quantization.py -k test_vovnet_sequential
```

ghstack-source-id: 097d5aa
Pull Request resolved: #72934
@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 16, 2022

@vkuzo has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also comment on how does this affect autograd?

@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 18, 2022

can you also comment on how does this affect autograd?

I don't think there is any changes to autograd as long as the user runs forward on the top level model. Running forward on a submodule (without running it on the top level model) is no longer supported after this PR, which also matches FX graph mode quantization.

return x

m = SequentialAppendList(torch.nn.Conv2d(1, 1, 1)).eval()
class Wrapper(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to detect this use case (iterating through children modules) and throw a nicer exception that tells the user to use a wrapper? What does the current exception look like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AutoQuantizationState.forward throws an exception, that exception will be hit if user code tries to call the forward. In a future PR, we can make the exception be more verbose and point to documentation.


def get_fqn_valid_for_module_dict_key(fqn: str) -> str:
"""
Modifies `fqn` to make it a valid key to a ModuleDict.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Is this just the syntax expected by torch.nn.ModuleDict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One cannot have empty strings or dots in ModuleDict keys. I don't know exactly why.

# On the child, manually set the attribute without
# going through the `torch.nn.Module.__setattr__`
# function, to prevent this object from appearing in
# the child's module hierarchy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this affect the serialized output? I guess AutoQuantizationState is not currently serialized, is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is breaking BC with models using DBR created before this PR, because now the quantization state is always stored on the top level module. We are not calling out BC because we don't have people using this yet. Other than that, serialization will work as before.

facebook-github-bot pushed a commit that referenced this pull request Feb 22, 2022
Summary:
Pull Request resolved: #72934

Before this PR, DBR quantization had a limitation on handling user
code which iterates over all module children. For example, imagine
a forward function such as

```
def forward(self, x):
    for module in self:
        x = module(x)
    return x
```

Before this PR, this code would break with DBR quantization, because
we attach `AutoQuantizationState` objects to each child, and those
objects live in the child's module hierarchy and will appear in
these kinds of iterations, changing the meaning of the user program.

This PR reduces the scope of this problem to just the top level module.
Instead of attaching `AutoQuantizationState` objects to each child,
we register them in a map on the parent. Here is a before and after:

```
// toy model
model
 |--> child1

// toy model with AutoQuantizationState objects, before this PR
model
 |--> child1
 |  |--> _auto_quant_state
 |--> _auto_quant_state

// toy model with AutoQuantizationState objects, after this PR
model
 |--> child1
 |--> _fqn_to_auto_quant_state_map
    |--> ( ) --> _auto_quant_state // of `model`
    |--> (child1) --> _auto_quant_state // of `model.child1`
```

Note: `child1._auto_quant_state` works as before for convenience,
but the `child1` object now stores a soft link to its `_auto_quant_state`
instead of properly registering it in its module hierarchy. This is
somewhat hacky. If we need to improve this in the future, we could
remove this soft link and refactor the code to call the FQN map
instead.

Note: if the top level module iterates over its children, things will
still be broken. This is less likely, and we will recommend that the
user work around this by wrapping their model, or checking for the
`AutoQuantizationStateModuleDict` type in their iteration loop.

The impact of this change should be an improvement of coverage
of user models. In fact, we expect this to drive our coverage of
torchbenchmark models from 89% to 100%.

Test Plan:
```
// previously disabled test cases with user code iterating
// over module children are now enabled, with wrappers
python test/test_quantization.py -k test_module_calls_items
python test/test_quantization.py -k test_vovnet_sequential
```

Reviewed By: dzdang

Differential Revision: D34281074

Pulled By: vkuzo

fbshipit-source-id: 0e25fc1ec529c47f72478a1875fe43219feac6b1
@facebook-github-bot facebook-github-bot deleted the gh/vkuzo/473/head branch February 26, 2022 15:17
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#72934

Before this PR, DBR quantization had a limitation on handling user
code which iterates over all module children. For example, imagine
a forward function such as

```
def forward(self, x):
    for module in self:
        x = module(x)
    return x
```

Before this PR, this code would break with DBR quantization, because
we attach `AutoQuantizationState` objects to each child, and those
objects live in the child's module hierarchy and will appear in
these kinds of iterations, changing the meaning of the user program.

This PR reduces the scope of this problem to just the top level module.
Instead of attaching `AutoQuantizationState` objects to each child,
we register them in a map on the parent. Here is a before and after:

```
// toy model
model
 |--> child1

// toy model with AutoQuantizationState objects, before this PR
model
 |--> child1
 |  |--> _auto_quant_state
 |--> _auto_quant_state

// toy model with AutoQuantizationState objects, after this PR
model
 |--> child1
 |--> _fqn_to_auto_quant_state_map
    |--> ( ) --> _auto_quant_state // of `model`
    |--> (child1) --> _auto_quant_state // of `model.child1`
```

Note: `child1._auto_quant_state` works as before for convenience,
but the `child1` object now stores a soft link to its `_auto_quant_state`
instead of properly registering it in its module hierarchy. This is
somewhat hacky. If we need to improve this in the future, we could
remove this soft link and refactor the code to call the FQN map
instead.

Note: if the top level module iterates over its children, things will
still be broken. This is less likely, and we will recommend that the
user work around this by wrapping their model, or checking for the
`AutoQuantizationStateModuleDict` type in their iteration loop.

The impact of this change should be an improvement of coverage
of user models. In fact, we expect this to drive our coverage of
torchbenchmark models from 89% to 100%.

Test Plan:
```
// previously disabled test cases with user code iterating
// over module children are now enabled, with wrappers
python test/test_quantization.py -k test_module_calls_items
python test/test_quantization.py -k test_vovnet_sequential
```

Reviewed By: dzdang

Differential Revision: D34281074

Pulled By: vkuzo

fbshipit-source-id: 0e25fc1ec529c47f72478a1875fe43219feac6b1
(cherry picked from commit 4008f899671f643f5a54c311254af3f68ae22e2e)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#72934

Before this PR, DBR quantization had a limitation on handling user
code which iterates over all module children. For example, imagine
a forward function such as

```
def forward(self, x):
    for module in self:
        x = module(x)
    return x
```

Before this PR, this code would break with DBR quantization, because
we attach `AutoQuantizationState` objects to each child, and those
objects live in the child's module hierarchy and will appear in
these kinds of iterations, changing the meaning of the user program.

This PR reduces the scope of this problem to just the top level module.
Instead of attaching `AutoQuantizationState` objects to each child,
we register them in a map on the parent. Here is a before and after:

```
// toy model
model
 |--> child1

// toy model with AutoQuantizationState objects, before this PR
model
 |--> child1
 |  |--> _auto_quant_state
 |--> _auto_quant_state

// toy model with AutoQuantizationState objects, after this PR
model
 |--> child1
 |--> _fqn_to_auto_quant_state_map
    |--> ( ) --> _auto_quant_state // of `model`
    |--> (child1) --> _auto_quant_state // of `model.child1`
```

Note: `child1._auto_quant_state` works as before for convenience,
but the `child1` object now stores a soft link to its `_auto_quant_state`
instead of properly registering it in its module hierarchy. This is
somewhat hacky. If we need to improve this in the future, we could
remove this soft link and refactor the code to call the FQN map
instead.

Note: if the top level module iterates over its children, things will
still be broken. This is less likely, and we will recommend that the
user work around this by wrapping their model, or checking for the
`AutoQuantizationStateModuleDict` type in their iteration loop.

The impact of this change should be an improvement of coverage
of user models. In fact, we expect this to drive our coverage of
torchbenchmark models from 89% to 100%.

Test Plan:
```
// previously disabled test cases with user code iterating
// over module children are now enabled, with wrappers
python test/test_quantization.py -k test_module_calls_items
python test/test_quantization.py -k test_vovnet_sequential
```

Reviewed By: dzdang

Differential Revision: D34281074

Pulled By: vkuzo

fbshipit-source-id: 0e25fc1ec529c47f72478a1875fe43219feac6b1
(cherry picked from commit 4008f899671f643f5a54c311254af3f68ae22e2e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants