forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable UFMT on all of test/quantization/ao_migration &bc (pytorch#123994
) Partially addresses pytorch#123062 Ran lintrunner on: - test/quantization/ao_migration - test/quantization/bc Detail: ``` $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` @ezyang Pull Request resolved: pytorch#123994 Approved by: https://github.com/ezyang
- Loading branch information
1 parent
f741b76
commit 1b12f6e
Showing
6 changed files
with
706 additions
and
515 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,52 @@ | ||
from torch.testing._internal.common_utils import TestCase | ||
|
||
import importlib | ||
from typing import List, Optional | ||
|
||
from torch.testing._internal.common_utils import TestCase | ||
|
||
|
||
class AOMigrationTestCase(TestCase): | ||
def _test_function_import(self, package_name: str, function_list: List[str], | ||
base: Optional[str] = None, new_package_name: Optional[str] = None): | ||
def _test_function_import( | ||
self, | ||
package_name: str, | ||
function_list: List[str], | ||
base: Optional[str] = None, | ||
new_package_name: Optional[str] = None, | ||
): | ||
r"""Tests individual function list import by comparing the functions | ||
and their hashes.""" | ||
if base is None: | ||
base = 'quantization' | ||
old_base = 'torch.' + base | ||
new_base = 'torch.ao.' + base | ||
base = "quantization" | ||
old_base = "torch." + base | ||
new_base = "torch.ao." + base | ||
if new_package_name is None: | ||
new_package_name = package_name | ||
old_location = importlib.import_module(f'{old_base}.{package_name}') | ||
new_location = importlib.import_module(f'{new_base}.{new_package_name}') | ||
old_location = importlib.import_module(f"{old_base}.{package_name}") | ||
new_location = importlib.import_module(f"{new_base}.{new_package_name}") | ||
for fn_name in function_list: | ||
old_function = getattr(old_location, fn_name) | ||
new_function = getattr(new_location, fn_name) | ||
assert old_function == new_function, f"Functions don't match: {fn_name}" | ||
assert hash(old_function) == hash(new_function), \ | ||
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \ | ||
assert hash(old_function) == hash(new_function), ( | ||
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " | ||
f"{new_function}({hash(new_function)})" | ||
) | ||
|
||
def _test_dict_import(self, package_name: str, dict_list: List[str], | ||
base: Optional[str] = None): | ||
def _test_dict_import( | ||
self, package_name: str, dict_list: List[str], base: Optional[str] = None | ||
): | ||
r"""Tests individual function list import by comparing the functions | ||
and their hashes.""" | ||
if base is None: | ||
base = 'quantization' | ||
old_base = 'torch.' + base | ||
new_base = 'torch.ao.' + base | ||
old_location = importlib.import_module(f'{old_base}.{package_name}') | ||
new_location = importlib.import_module(f'{new_base}.{package_name}') | ||
base = "quantization" | ||
old_base = "torch." + base | ||
new_base = "torch.ao." + base | ||
old_location = importlib.import_module(f"{old_base}.{package_name}") | ||
new_location = importlib.import_module(f"{new_base}.{package_name}") | ||
for dict_name in dict_list: | ||
old_dict = getattr(old_location, dict_name) | ||
new_dict = getattr(new_location, dict_name) | ||
assert old_dict == new_dict, f"Dicts don't match: {dict_name}" | ||
for key in new_dict.keys(): | ||
assert old_dict[key] == new_dict[key], f"Dicts don't match: {dict_name} for key {key}" | ||
assert ( | ||
old_dict[key] == new_dict[key] | ||
), f"Dicts don't match: {dict_name} for key {key}" |
Oops, something went wrong.