forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_tests_from_mixin.py
61 lines (44 loc) · 1.64 KB
/
extract_tests_from_mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import inspect
import sys
from pathlib import Path
from typing import List, Type
root_dir = Path(__file__).parent.parent.absolute()
sys.path.insert(0, str(root_dir))
parser = argparse.ArgumentParser()
parser.add_argument("--type", type=str, default=None)
args = parser.parse_args()
def get_test_methods_from_class(cls: Type) -> List[str]:
"""
Get all test method names from a given class.
Only returns methods that start with 'test_'.
"""
test_methods = []
for name, obj in inspect.getmembers(cls):
if name.startswith("test_") and inspect.isfunction(obj):
test_methods.append(name)
return sorted(test_methods)
def generate_pytest_pattern(test_methods: List[str]) -> str:
"""Generate pytest pattern string for the -k flag."""
return " or ".join(test_methods)
def generate_pattern_for_mixin(mixin_class: Type) -> str:
"""
Generate pytest pattern for a specific mixin class.
"""
if mixin_cls is None:
return ""
test_methods = get_test_methods_from_class(mixin_class)
return generate_pytest_pattern(test_methods)
if __name__ == "__main__":
mixin_cls = None
if args.type == "pipeline":
from tests.pipelines.test_pipelines_common import PipelineTesterMixin
mixin_cls = PipelineTesterMixin
elif args.type == "models":
from tests.models.test_modeling_common import ModelTesterMixin
mixin_cls = ModelTesterMixin
elif args.type == "lora":
from tests.lora.utils import PeftLoraLoaderMixinTests
mixin_cls = PeftLoraLoaderMixinTests
pattern = generate_pattern_for_mixin(mixin_cls)
print(pattern)