Skip to content

Commit

Permalink
Enable UFMT on all of test/quantization/ao_migration &bc (pytorch#123994
Browse files Browse the repository at this point in the history
)

Partially addresses pytorch#123062
Ran lintrunner on:
- test/quantization/ao_migration
- test/quantization/bc

Detail:
```
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```

@ezyang

Pull Request resolved: pytorch#123994
Approved by: https://github.com/ezyang
  • Loading branch information
WeiChunyu-star authored and petrex committed May 3, 2024
1 parent f741b76 commit 1b12f6e
Show file tree
Hide file tree
Showing 6 changed files with 706 additions and 515 deletions.
7 changes: 0 additions & 7 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1334,13 +1334,6 @@ exclude_patterns = [
'test/profiler/test_profiler.py',
'test/profiler/test_profiler_tree.py',
'test/quantization/__init__.py',
'test/quantization/ao_migration/__init__.py',
'test/quantization/ao_migration/common.py',
'test/quantization/ao_migration/test_ao_migration.py',
'test/quantization/ao_migration/test_quantization.py',
'test/quantization/ao_migration/test_quantization_fx.py',
'test/quantization/bc/__init__.py',
'test/quantization/bc/test_backward_compatibility.py',
'test/quantization/core/__init__.py',
'test/quantization/core/experimental/apot_fx_graph_mode_ptq.py',
'test/quantization/core/experimental/apot_fx_graph_mode_qat.py',
Expand Down
48 changes: 29 additions & 19 deletions test/quantization/ao_migration/common.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,52 @@
from torch.testing._internal.common_utils import TestCase

import importlib
from typing import List, Optional

from torch.testing._internal.common_utils import TestCase


class AOMigrationTestCase(TestCase):
def _test_function_import(self, package_name: str, function_list: List[str],
base: Optional[str] = None, new_package_name: Optional[str] = None):
def _test_function_import(
self,
package_name: str,
function_list: List[str],
base: Optional[str] = None,
new_package_name: Optional[str] = None,
):
r"""Tests individual function list import by comparing the functions
and their hashes."""
if base is None:
base = 'quantization'
old_base = 'torch.' + base
new_base = 'torch.ao.' + base
base = "quantization"
old_base = "torch." + base
new_base = "torch.ao." + base
if new_package_name is None:
new_package_name = package_name
old_location = importlib.import_module(f'{old_base}.{package_name}')
new_location = importlib.import_module(f'{new_base}.{new_package_name}')
old_location = importlib.import_module(f"{old_base}.{package_name}")
new_location = importlib.import_module(f"{new_base}.{new_package_name}")
for fn_name in function_list:
old_function = getattr(old_location, fn_name)
new_function = getattr(new_location, fn_name)
assert old_function == new_function, f"Functions don't match: {fn_name}"
assert hash(old_function) == hash(new_function), \
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
assert hash(old_function) == hash(new_function), (
f"Hashes don't match: {old_function}({hash(old_function)}) vs. "
f"{new_function}({hash(new_function)})"
)

def _test_dict_import(self, package_name: str, dict_list: List[str],
base: Optional[str] = None):
def _test_dict_import(
self, package_name: str, dict_list: List[str], base: Optional[str] = None
):
r"""Tests individual function list import by comparing the functions
and their hashes."""
if base is None:
base = 'quantization'
old_base = 'torch.' + base
new_base = 'torch.ao.' + base
old_location = importlib.import_module(f'{old_base}.{package_name}')
new_location = importlib.import_module(f'{new_base}.{package_name}')
base = "quantization"
old_base = "torch." + base
new_base = "torch.ao." + base
old_location = importlib.import_module(f"{old_base}.{package_name}")
new_location = importlib.import_module(f"{new_base}.{package_name}")
for dict_name in dict_list:
old_dict = getattr(old_location, dict_name)
new_dict = getattr(new_location, dict_name)
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
for key in new_dict.keys():
assert old_dict[key] == new_dict[key], f"Dicts don't match: {dict_name} for key {key}"
assert (
old_dict[key] == new_dict[key]
), f"Dicts don't match: {dict_name} for key {key}"
Loading

0 comments on commit 1b12f6e

Please sign in to comment.