-
Notifications
You must be signed in to change notification settings - Fork 850
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add PyTorch/XLA support (#2182)
* enable torch_xla * add venv to .gitignore * add torch_xla 2.0 baackend; update map_location and device; remove mark_step * clear map_location for torchxla * format * add torch_xla test
- Loading branch information
Showing
7 changed files
with
159 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"pt2" : "torchxla_trace_once"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
inference_address=http://0.0.0.0:8080 | ||
management_address=http://0.0.0.0:8081 | ||
metrics_address=http://0.0.0.0:8082 | ||
model_store=/home/model-server/model-store | ||
load_models=half_plus_two.mar | ||
min_workers=1 | ||
max_workers=1 | ||
default_workers_per_model=1 | ||
service_envelope=json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import torch | ||
|
||
|
||
class HalfPlusTwoModel(torch.nn.Module): | ||
def forward(self, *input_args): | ||
w = torch.tensor(0.5) | ||
b = torch.tensor(2.0) | ||
return torch.add(torch.multiply(w, input_args[0]), b) | ||
|
||
|
||
if __name__ == "__main__": | ||
model = HalfPlusTwoModel() | ||
torch.save(model.state_dict(), "model.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import glob | ||
import json | ||
import os | ||
import subprocess | ||
import time | ||
from pathlib import Path | ||
|
||
import pytest | ||
from pkg_resources import packaging | ||
|
||
try: | ||
import torch_xla | ||
|
||
TORCHXLA_AVAILABLE = ( | ||
True | ||
if packaging.version.parse(torch_xla.__version__) | ||
>= packaging.version.parse("2.0") | ||
else False | ||
) | ||
except: | ||
TORCHXLA_AVAILABLE = False | ||
|
||
CURR_FILE_PATH = Path(__file__).parent | ||
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data") | ||
|
||
MODEL_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.py") | ||
EXTRA_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "compile.json") | ||
CONFIG_PROPERTIES = os.path.join(TORCH_XLA_TEST_DATA_DIR, "config.properties") | ||
|
||
SERIALIZED_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.pt") | ||
MODEL_STORE_DIR = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model_store") | ||
MODEL_NAME = "half_plus_two" | ||
|
||
|
||
@pytest.mark.skipif(TORCHXLA_AVAILABLE == False, reason="PyTorch/XLA is not installed") | ||
class TestTorchXLA: | ||
def teardown_class(self): | ||
subprocess.run("torchserve --stop", shell=True, check=True) | ||
time.sleep(10) | ||
|
||
def test_archive_model_artifacts(self): | ||
assert len(glob.glob(MODEL_FILE)) == 1 | ||
assert len(glob.glob(EXTRA_FILE)) == 1 | ||
assert len(glob.glob(CONFIG_PROPERTIES)) == 1 | ||
subprocess.run( | ||
f"cd {TORCH_XLA_TEST_DATA_DIR} && python model.py", shell=True, check=True | ||
) | ||
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True) | ||
subprocess.run( | ||
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --extra-files {EXTRA_FILE} --export-path {MODEL_STORE_DIR} --handler base_handler -f", | ||
shell=True, | ||
check=True, | ||
) | ||
assert len(glob.glob(SERIALIZED_FILE)) == 1 | ||
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1 | ||
|
||
def test_start_torchserve(self): | ||
subprocess.run( | ||
f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR} --ts-config {CONFIG_PROPERTIES}", | ||
shell=True, | ||
check=True, | ||
) | ||
time.sleep(10) | ||
assert len(glob.glob("logs/access_log.log")) == 1 | ||
assert len(glob.glob("logs/model_log.log")) == 1 | ||
assert len(glob.glob("logs/ts_log.log")) == 1 | ||
|
||
def test_server_status(self): | ||
result = subprocess.run( | ||
"curl http://localhost:8080/ping", | ||
shell=True, | ||
capture_output=True, | ||
check=True, | ||
) | ||
expected_server_status_str = '{"status": "Healthy"}' | ||
expected_server_status = json.loads(expected_server_status_str) | ||
assert json.loads(result.stdout) == expected_server_status | ||
|
||
def test_registered_model(self): | ||
result = subprocess.run( | ||
"curl http://localhost:8081/models", | ||
shell=True, | ||
capture_output=True, | ||
check=True, | ||
) | ||
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}' | ||
expected_registered_model = json.loads(expected_registered_model_str) | ||
assert json.loads(result.stdout) == expected_registered_model | ||
|
||
def test_serve_inference(self): | ||
request = "'{\"" 'instances"' ": [[1.0], [2.0], [3.0]]}'" | ||
result = subprocess.run( | ||
f'curl -s -X POST -H "Content-Type: application/json;" http://localhost:8080/predictions/half_plus_two -d {request}', | ||
shell=True, | ||
capture_output=True, | ||
check=True, | ||
) | ||
expected_result_str = '{"predictions": [[2.5], [3.0], [3.5]]}' | ||
expected_result = json.loads(expected_result_str) | ||
assert json.loads(result.stdout) == expected_result | ||
|
||
model_log_path = glob.glob("logs/model_log.log")[0] | ||
with open(model_log_path, "rt") as model_log_file: | ||
model_log = model_log_file.read() | ||
assert "Compiled model with backend torchxla_trace_once" in model_log | ||
assert "done compiler function torchxla_trace_once" in model_log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters