Skip to content

Commit

Permalink
feat: add PyTorch/XLA support (#2182)
Browse files Browse the repository at this point in the history
* enable torch_xla

* add venv to .gitignore

* add torch_xla 2.0 baackend; update map_location and device; remove mark_step

* clear map_location for torchxla

* format

* add torch_xla test
  • Loading branch information
morgandu authored Apr 11, 2023
1 parent c37da18 commit 4ea172d
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ plugins/*/bin
*.backup
docs/sphinx/src/
ts_scripts/spellcheck_conf/wordlist.dic
venv/

# Postman files
test/artifacts/
Expand Down
1 change: 1 addition & 0 deletions test/pytest/test_data/torch_xla/compile.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"pt2" : "torchxla_trace_once"}
9 changes: 9 additions & 0 deletions test/pytest/test_data/torch_xla/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
model_store=/home/model-server/model-store
load_models=half_plus_two.mar
min_workers=1
max_workers=1
default_workers_per_model=1
service_envelope=json
13 changes: 13 additions & 0 deletions test/pytest/test_data/torch_xla/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


class HalfPlusTwoModel(torch.nn.Module):
def forward(self, *input_args):
w = torch.tensor(0.5)
b = torch.tensor(2.0)
return torch.add(torch.multiply(w, input_args[0]), b)


if __name__ == "__main__":
model = HalfPlusTwoModel()
torch.save(model.state_dict(), "model.pt")
106 changes: 106 additions & 0 deletions test/pytest/test_torch_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import glob
import json
import os
import subprocess
import time
from pathlib import Path

import pytest
from pkg_resources import packaging

try:
import torch_xla

TORCHXLA_AVAILABLE = (
True
if packaging.version.parse(torch_xla.__version__)
>= packaging.version.parse("2.0")
else False
)
except:
TORCHXLA_AVAILABLE = False

CURR_FILE_PATH = Path(__file__).parent
TORCH_XLA_TEST_DATA_DIR = os.path.join(CURR_FILE_PATH, "test_data")

MODEL_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.py")
EXTRA_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "compile.json")
CONFIG_PROPERTIES = os.path.join(TORCH_XLA_TEST_DATA_DIR, "config.properties")

SERIALIZED_FILE = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model.pt")
MODEL_STORE_DIR = os.path.join(TORCH_XLA_TEST_DATA_DIR, "model_store")
MODEL_NAME = "half_plus_two"


@pytest.mark.skipif(TORCHXLA_AVAILABLE == False, reason="PyTorch/XLA is not installed")
class TestTorchXLA:
def teardown_class(self):
subprocess.run("torchserve --stop", shell=True, check=True)
time.sleep(10)

def test_archive_model_artifacts(self):
assert len(glob.glob(MODEL_FILE)) == 1
assert len(glob.glob(EXTRA_FILE)) == 1
assert len(glob.glob(CONFIG_PROPERTIES)) == 1
subprocess.run(
f"cd {TORCH_XLA_TEST_DATA_DIR} && python model.py", shell=True, check=True
)
subprocess.run(f"mkdir -p {MODEL_STORE_DIR}", shell=True, check=True)
subprocess.run(
f"torch-model-archiver --model-name {MODEL_NAME} --version 1.0 --model-file {MODEL_FILE} --serialized-file {SERIALIZED_FILE} --extra-files {EXTRA_FILE} --export-path {MODEL_STORE_DIR} --handler base_handler -f",
shell=True,
check=True,
)
assert len(glob.glob(SERIALIZED_FILE)) == 1
assert len(glob.glob(os.path.join(MODEL_STORE_DIR, f"{MODEL_NAME}.mar"))) == 1

def test_start_torchserve(self):
subprocess.run(
f"torchserve --start --ncs --models {MODEL_NAME}.mar --model-store {MODEL_STORE_DIR} --ts-config {CONFIG_PROPERTIES}",
shell=True,
check=True,
)
time.sleep(10)
assert len(glob.glob("logs/access_log.log")) == 1
assert len(glob.glob("logs/model_log.log")) == 1
assert len(glob.glob("logs/ts_log.log")) == 1

def test_server_status(self):
result = subprocess.run(
"curl http://localhost:8080/ping",
shell=True,
capture_output=True,
check=True,
)
expected_server_status_str = '{"status": "Healthy"}'
expected_server_status = json.loads(expected_server_status_str)
assert json.loads(result.stdout) == expected_server_status

def test_registered_model(self):
result = subprocess.run(
"curl http://localhost:8081/models",
shell=True,
capture_output=True,
check=True,
)
expected_registered_model_str = '{"models": [{"modelName": "half_plus_two", "modelUrl": "half_plus_two.mar"}]}'
expected_registered_model = json.loads(expected_registered_model_str)
assert json.loads(result.stdout) == expected_registered_model

def test_serve_inference(self):
request = "'{\"" 'instances"' ": [[1.0], [2.0], [3.0]]}'"
result = subprocess.run(
f'curl -s -X POST -H "Content-Type: application/json;" http://localhost:8080/predictions/half_plus_two -d {request}',
shell=True,
capture_output=True,
check=True,
)
expected_result_str = '{"predictions": [[2.5], [3.0], [3.5]]}'
expected_result = json.loads(expected_result_str)
assert json.loads(result.stdout) == expected_result

model_log_path = glob.glob("logs/model_log.log")[0]
with open(model_log_path, "rt") as model_log_file:
model_log = model_log_file.read()
assert "Compiled model with backend torchxla_trace_once" in model_log
assert "done compiler function torchxla_trace_once" in model_log
40 changes: 28 additions & 12 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@
else:
PROFILER_AVAILABLE = False

try:
import torch_xla.core.xla_model as xm

TORCHXLA_AVAILABLE = True
except ImportError as error:
TORCHXLA_AVAILABLE = False


logger = logging.getLogger(__name__)


# Possible values for backend in utils.py
def check_pt2_enabled():
try:
Expand Down Expand Up @@ -102,16 +110,17 @@ def initialize(self, context):
)

properties = context.system_properties
self.map_location = (
"cuda"
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else self.map_location
)
if torch.cuda.is_available() and properties.get("gpu_id") is not None:
self.map_location = "cuda"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
)
elif TORCHXLA_AVAILABLE:
self.device = xm.xla_device()
else:
self.map_location = "cpu"
self.device = torch.device(self.map_location)

self.manifest = context.manifest

model_dir = properties.get("model_dir")
Expand Down Expand Up @@ -181,7 +190,9 @@ def initialize(self, context):
# Compilation will delay your model initialization
try:
self.model = torch.compile(
self.model, backend=backend, mode="reduce-overhead"
self.model,
backend=backend,
mode="default" if TORCHXLA_AVAILABLE else "reduce-overhead",
)
logger.info(f"Compiled model with backend {backend}")
except:
Expand Down Expand Up @@ -245,7 +256,12 @@ def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model_class = model_class_definitions[0]
model = model_class()
if model_pt_path:
state_dict = torch.load(model_pt_path, map_location=self.device)
map_location = (
None
if (TORCHXLA_AVAILABLE and self.map_location is None)
else self.device
)
state_dict = torch.load(model_pt_path, map_location=map_location)
model.load_state_dict(state_dict)
return model

Expand Down
1 change: 1 addition & 0 deletions ts/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class PT2Backend(str, enum.Enum):
FX2TRT = "fx2trt"
ONNXRT = "onnxrt"
IPEX = "ipex"
TORCHXLA_TRACE_ONCE = "torchxla_trace_once"


logger = logging.getLogger(__name__)
Expand Down

0 comments on commit 4ea172d

Please sign in to comment.