Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add PyTorch/XLA support #2182

Merged
merged 6 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ plugins/*/bin
*.backup
docs/sphinx/src/
ts_scripts/spellcheck_conf/wordlist.dic
venv/

# Postman files
test/artifacts/
Expand Down
1 change: 1 addition & 0 deletions test/pytest/test_data/torch_xla/compile.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"pt2" : "torchxla_trace_once"}
9 changes: 9 additions & 0 deletions test/pytest/test_data/torch_xla/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
model_store=/home/model-server/model-store
load_models=half_plus_two.mar
min_workers=1
max_workers=1
default_workers_per_model=1
service_envelope=json
13 changes: 13 additions & 0 deletions test/pytest/test_data/torch_xla/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


class HalfPlusTwoModel(torch.nn.Module):
def forward(self, *input_args):
w = torch.tensor(0.5)
b = torch.tensor(2.0)
return torch.add(torch.multiply(w, input_args[0]), b)


if __name__ == "__main__":
model = HalfPlusTwoModel()
torch.save(model.state_dict(), "model.pt")
106 changes: 106 additions & 0 deletions test/pytest/test_torch_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import glob
import json
import os
import subprocess
import time
from pathlib import Path

import pytest
from pkg_resources import packaging

try:
import torch_xla

TORCHXLA_AVAILABLE = (
True
if packaging.version.parse(torch_xla.__version__)
>= packaging.version.parse("2.0")
else False
)
except:
TORCHXLA_AVAILABLE = False

CURR_FILE_PATH = Path(__file__).parent
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data")

MODEL_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.py")
EXTRA_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "compile.json")
CONFIG_PROPERTIES = os.path.join(TORCH_XLA_TEST_DATA_DIR, "config.properties")

SERIALIZED_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.pt")
MODEL_STORE_DIR = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model_store")
MODEL_NAME = "half_plus_two"


@pytest.mark.skipif(TORCHXLA_AVAILABLE == False, reason="PyTorch/XLA is not installed")
class TestTorchXLA:
def teardown_class(self):
subprocess.run("torchserve --stop", shell=True, check=True)
time.sleep(10)

def test_archive_model_artifacts(self):
assert len(glob.glob(MODEL_FILE)) == 1
assert len(glob.glob(EXTRA_FILE)) == 1
assert len(glob.glob(CONFIG_PROPERTIES)) == 1
subprocess.run(
f"cd {TORCH_XLA_TEST_DATA_DIR} && python model.py", shell=True, check=True
)
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --extra-files {EXTRA_FILE} --export-path {MODEL_STORE_DIR} --handler base_handler -f",
shell=True,
check=True,
)
assert len(glob.glob(SERIALIZED_FILE)) == 1
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1

def test_start_torchserve(self):
subprocess.run(
f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR} --ts-config {CONFIG_PROPERTIES}",
shell=True,
check=True,
)
time.sleep(10)
assert len(glob.glob("logs/access_log.log")) == 1
assert len(glob.glob("logs/model_log.log")) == 1
assert len(glob.glob("logs/ts_log.log")) == 1

def test_server_status(self):
result = subprocess.run(
"curl http://localhost:8080/ping",
shell=True,
capture_output=True,
check=True,
)
expected_server_status_str = '{"status": "Healthy"}'
expected_server_status = json.loads(expected_server_status_str)
assert json.loads(result.stdout) == expected_server_status

def test_registered_model(self):
result = subprocess.run(
"curl http://localhost:8081/models",
shell=True,
capture_output=True,
check=True,
)
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}'
expected_registered_model = json.loads(expected_registered_model_str)
assert json.loads(result.stdout) == expected_registered_model

def test_serve_inference(self):
request = "'{\"" 'instances"' ": [[1.0], [2.0], [3.0]]}'"
result = subprocess.run(
f'curl -s -X POST -H "Content-Type: application/json;" http://localhost:8080/predictions/half_plus_two -d {request}',
shell=True,
capture_output=True,
check=True,
)
expected_result_str = '{"predictions": [[2.5], [3.0], [3.5]]}'
expected_result = json.loads(expected_result_str)
assert json.loads(result.stdout) == expected_result

model_log_path = glob.glob("logs/model_log.log")[0]
with open(model_log_path, "rt") as model_log_file:
model_log = model_log_file.read()
assert "Compiled model with backend torchxla_trace_once" in model_log
assert "done compiler function torchxla_trace_once" in model_log
40 changes: 28 additions & 12 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@
else:
PROFILER_AVAILABLE = False

try:
import torch_xla.core.xla_model as xm

TORCHXLA_AVAILABLE = True
except ImportError as error:
TORCHXLA_AVAILABLE = False


logger = logging.getLogger(__name__)


# Possible values for backend in utils.py
def check_pt2_enabled():
try:
Expand Down Expand Up @@ -102,16 +110,17 @@ def initialize(self, context):
)

properties = context.system_properties
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
if torch.cuda.is_available() and properties.get("gpu_id") is not None:
self.map_location = "cuda"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
)
elif TORCHXLA_AVAILABLE:
self.device = xm.xla_device()
else:
self.map_location = "cpu"
self.device = torch.device(self.map_location)

self.manifest = context.manifest

model_dir = properties.get("model_dir")
Expand Down Expand Up @@ -181,7 +190,9 @@ def initialize(self, context):
# Compilation will delay your model initialization
try:
self.model = torch.compile(
self.model, backend=backend, mode="reduce-overhead"
self.model,
backend=backend,
mode="default" if TORCHXLA_AVAILABLE else "reduce-overhead",
)
logger.info(f"Compiled model with backend {backend}")
except:
Expand Down Expand Up @@ -245,7 +256,12 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model_class = model_class_definitions[0]
model = model_class()
if model_pt_path:
state_dict = torch.load(model_pt_path, map_location=self.device)
map_location = (
None
if (TORCHXLA_AVAILABLE and self.map_location is None)
else self.device
)
state_dict = torch.load(model_pt_path, map_location=map_location)
model.load_state_dict(state_dict)
return model

Expand Down
1 change: 1 addition & 0 deletions ts/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class PT2Backend(str, enum.Enum):
FX2TRT = "fx2trt"
ONNXRT = "onnxrt"
IPEX = "ipex"
TORCHXLA_TRACE_ONCE = "torchxla_trace_once"


logger = logging.getLogger(__name__)
Expand Down