Skip to content

Conversation

yiliu30
Copy link
Contributor

@yiliu30 yiliu30 commented May 13, 2024

Summary:
Added set_module_name_qconfig support to allow users to set configurations based on module name in X86InductorQuantizer.

For example, only quantize the sub:

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
        self.sub = Sub()

    def forward(self, x):
        x = self.linear(x)
        x = self.sub(x)
        return x

m = M().eval()
example_inputs = (torch.randn(3, 5),)
# Set config for a specific submodule.
quantizer = X86InductorQuantizer()
quantizer.set_module_name_qconfig("sub", xiq.get_default_x86_inductor_quantization_config())
  • Added set_module_name_qconfig to allow user set the configuration at the module_name level.
  • Unified the annotation process to follow this order: module_name_qconfig, operator_type_qconfig, and global_config.
  • Added config_checker to validate all user configurations and prevent mixing of static/dynamic or QAT/non-QAT configs.
  • Moved _get_module_name_filter from xnnpack_quantizer.py into utils.py as it common for all quantizer.

Test Plan

python -m pytest quantization/pt2e/test_x86inductor_quantizer.py -k test_set_module_name

@Xia-Weiwen @leslie-fang-intel @jgong5

yiliu30 added 8 commits May 10, 2024 16:50
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
@yiliu30 yiliu30 requested a review from jerryzh168 as a code owner May 13, 2024 03:47
Copy link

pytorch-bot bot commented May 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126044

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 5378b56 with merge base 5001f41 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: quantization release notes category label May 13, 2024
@leslie-fang-intel leslie-fang-intel self-requested a review May 13, 2024 03:48
@yiliu30 yiliu30 marked this pull request as draft May 13, 2024 03:51
@Xia-Weiwen Xia-Weiwen requested review from Xia-Weiwen and jgong5 and removed request for jerryzh168 May 13, 2024 03:53
Signed-off-by: yiliu30 <yi4.liu@intel.com>
@leslie-fang-intel
Copy link
Collaborator

I guess we also need to check here self.module_name_qconfig for dynamic quant here

if self.global_config and self.global_config.input_activation.is_dynamic: # type: ignore[union-attr]

Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Add some comments.

Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Add some comments.

@leslie-fang-intel leslie-fang-intel added the ciflow/trunk Trigger trunk jobs on your pull request label May 14, 2024
Copy link

pytorch-bot bot commented May 14, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label May 14, 2024
yiliu30 added 4 commits May 15, 2024 21:33
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
yiliu30 added 2 commits May 29, 2024 15:13
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
@yiliu30 yiliu30 requested a review from leslie-fang-intel May 30, 2024 01:02
@yiliu30 yiliu30 requested a review from leslie-fang-intel May 31, 2024 05:39
Signed-off-by: yiliu30 <yi4.liu@intel.com>
@leslie-fang-intel
Copy link
Collaborator

Thanks for the PR again. Looks good to me. @yiliu30 please kindly help to rebase this PR.

@leslie-fang-intel
Copy link
Collaborator

Hi @jerryzh168, looks like this PR needs your kindly approve to check the preCI.

@yiliu30 yiliu30 marked this pull request as ready for review June 11, 2024 01:24
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 12, 2024
yiliu30 added 3 commits June 13, 2024 13:35
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
is_qat = qconfig.is_qat
input_activation_spec = qconfig.input_activation
if input_activation_spec is not None:
is_dynamic = input_activation_spec.is_dynamic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we write as

if is_dynamic is None:
  is_dynamic = input_activation_spec.is_dynamic
else:
  assert is_dynamic == input_activation_spec.is_dynamic

and add some explanation for the code logic here? Same as above is_qat

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated in 621ff11

yiliu30 added 2 commits June 13, 2024 15:21
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Signed-off-by: yiliu30 <yi4.liu@intel.com>
@leslie-fang-intel
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this pull request Jun 14, 2024
…ctorQuantizer (pytorch#126044)

Summary:
Added `set_module_name_qconfig` support to allow users to set configurations based on module name in `X86InductorQuantizer`.

For example, only quantize the `sub`:

```python
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
        self.sub = Sub()

    def forward(self, x):
        x = self.linear(x)
        x = self.sub(x)
        return x

m = M().eval()
example_inputs = (torch.randn(3, 5),)
# Set config for a specific submodule.
quantizer = X86InductorQuantizer()
quantizer.set_module_name_qconfig("sub", xiq.get_default_x86_inductor_quantization_config())
```

- Added `set_module_name_qconfig` to allow user set the configuration at the `module_name` level.
- Unified the annotation process to follow this order:  `module_name_qconfig`, `operator_type_qconfig`, and `global_config`.
- Added `config_checker` to validate all user configurations and prevent mixing of static/dynamic or QAT/non-QAT configs.
- Moved `_get_module_name_filter` from `xnnpack_quantizer.py` into `utils.py` as it common for all quantizer.

Test Plan

```bash
python -m pytest quantization/pt2e/test_x86inductor_quantizer.py -k test_set_module_name
```

@Xia-Weiwen @leslie-fang-intel  @jgong5
Pull Request resolved: pytorch#126044
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jerryzh168
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: quantization release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants