-
Notifications
You must be signed in to change notification settings - Fork 879
/
Copy pathtest_torch_export.py
138 lines (105 loc) · 3.92 KB
/
test_torch_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from pathlib import Path
import packaging.version
import torch
from ts.torch_handler.image_classifier import ImageClassifier
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
from ts.utils.util import load_label_mapping
from ts_scripts.utils import try_and_handle
CURR_FILE_PATH = Path(__file__).parent.absolute()
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1]
EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "pt2", "torch_export_aot_compile")
TEST_DATA = REPO_ROOT_DIR.joinpath("examples", "image_classifier", "kitten.jpg")
MAPPING_DATA = REPO_ROOT_DIR.joinpath(
"examples", "image_classifier", "index_to_name.json"
)
MODEL_SO_FILE = "resnet18_pt2.so"
MODEL_YAML_CFG_FILE = EXAMPLE_ROOT_DIR.joinpath("model-config.yaml")
PT_230_AVAILABLE = (
True
if packaging.version.parse(torch.__version__) > packaging.version.parse("2.2.2")
else False
)
EXPECTED_RESULTS = ["tabby", "tiger_cat", "Egyptian_cat", "lynx", "bucket"]
TEST_CASES = [
("kitten.jpg", EXPECTED_RESULTS[0]),
]
BATCH_SIZE = 32
import os
import pytest
@pytest.fixture
def custom_working_directory(tmp_path):
# Set the custom working directory
custom_dir = tmp_path / "model_dir"
custom_dir.mkdir()
os.chdir(custom_dir)
yield custom_dir
# Clean up and return to the original working directory
os.chdir(tmp_path)
@pytest.mark.skipif(
os.environ.get("TS_RUN_IN_DOCKER", False),
reason="Test to be run outside docker",
)
@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0")
def test_torch_export_aot_compile(custom_working_directory):
# Get the path to the custom working directory
model_dir = custom_working_directory
# Construct the path to the Python script to execute
script_path = os.path.join(EXAMPLE_ROOT_DIR, "resnet18_torch_export.py")
# Get the .pt2 file from torch.export
cmd = "python " + script_path
try_and_handle(cmd)
# Handler for Image classification
handler = ImageClassifier()
# Context definition
ctx = MockContext(
model_pt_file=MODEL_SO_FILE,
model_dir=model_dir.as_posix(),
model_file=None,
model_yaml_config_file=MODEL_YAML_CFG_FILE,
)
torch.manual_seed(42 * 42)
handler.initialize(ctx)
handler.context = ctx
handler.mapping = load_label_mapping(MAPPING_DATA)
data = {}
with open(TEST_DATA, "rb") as image:
image_file = image.read()
byte_array_type = bytearray(image_file)
data["body"] = byte_array_type
result = handler.handle([data], ctx)
labels = list(result[0].keys())
assert labels == EXPECTED_RESULTS
@pytest.mark.skipif(
os.environ.get("TS_RUN_IN_DOCKER", False),
reason="Test to be run outside docker",
)
@pytest.mark.skipif(PT_230_AVAILABLE == False, reason="torch version is < 2.3.0")
def test_torch_export_aot_compile_dynamic_batching(custom_working_directory):
# Get the path to the custom working directory
model_dir = custom_working_directory
# Construct the path to the Python script to execute
script_path = os.path.join(EXAMPLE_ROOT_DIR, "resnet18_torch_export.py")
# Get the .pt2 file from torch.export
cmd = "python " + script_path
try_and_handle(cmd)
# Handler for Image classification
handler = ImageClassifier()
# Context definition
ctx = MockContext(
model_pt_file=MODEL_SO_FILE,
model_dir=model_dir.as_posix(),
model_file=None,
model_yaml_config_file=MODEL_YAML_CFG_FILE,
)
torch.manual_seed(42 * 42)
handler.initialize(ctx)
handler.context = ctx
handler.mapping = load_label_mapping(MAPPING_DATA)
data = {}
with open(TEST_DATA, "rb") as image:
image_file = image.read()
byte_array_type = bytearray(image_file)
data["body"] = byte_array_type
# Send a batch of BATCH_SIZE elements
result = handler.handle([data for i in range(BATCH_SIZE)], ctx)
assert len(result) == BATCH_SIZE