Skip to content

Commit

Permalink
Fix tensor parallelism (#2741)
Browse files Browse the repository at this point in the history
* Fix tensor parallelism + introduce test

* Fix linting error

* Fix java formatting issues

* Skip test_parallelism on non-linux system
  • Loading branch information
mreso committed Oct 28, 2023
1 parent 20a6e8b commit 7f4419f
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private void encodeRequest(RequestInput req, ByteBuf out) {
out.writeInt(buf.length);
out.writeBytes(buf);

if (req.isCached()) {
if (req.isCachedInBackend()) {
out.writeInt(-1); // End of List
out.writeInt(-1); // End of List
return;
Expand All @@ -92,7 +92,6 @@ private void encodeRequest(RequestInput req, ByteBuf out) {
encodeParameter(input, out);
}
out.writeInt(-1); // End of List
req.setCached(true);
}

private void encodeParameter(InputParameter parameter, ByteBuf out) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ public void setRequestBatch(List<RequestInput> requestBatch) {
public void addRequest(RequestInput req) {
batch.add(req);
}

public void setCachedInBackend(boolean cached) {
for (RequestInput input : batch) {
input.setCachedInBackend(cached);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ public void setClientExpireTS(long clientTimeoutInMills) {
}
}

public boolean isCached() {
public boolean isCachedInBackend() {
return cached;
}

public void setCached(boolean cached) {
public void setCachedInBackend(boolean cached) {
this.cached = cached;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.pytorch.serve.util.codec.ModelResponseDecoder;
import org.pytorch.serve.util.messages.BaseModelRequest;
import org.pytorch.serve.util.messages.InputParameter;
import org.pytorch.serve.util.messages.ModelInferenceRequest;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
Expand Down Expand Up @@ -208,6 +209,9 @@ public void run() {
for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) {
backendChannel.get(i).writeAndFlush(req).sync();
}
if (req instanceof ModelInferenceRequest) {
((ModelInferenceRequest) req).setCachedInBackend(true);
}

ModelWorkerResponse reply = null;

Expand Down
148 changes: 148 additions & 0 deletions test/pytest/test_parallelism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import json
import platform
import shutil
from argparse import Namespace
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest
import requests
import test_utils

CURR_FILE_PATH = Path(__file__).parent
REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent

MODEL_PY = """
import torch
import torch.nn as nn
class Foo(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
torch.distributed.all_reduce(x)
return x
"""

HANDLER_PY = """
import os
import torch
from ts.torch_handler.base_handler import BaseHandler
class FooHandler(BaseHandler):
def initialize(self, ctx):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("gloo")
torch.set_default_device("cpu")
super().initialize(ctx)
def preprocess(self, data):
return torch.as_tensor(int(data[0].get('body').decode('utf-8')), device=self.device)
def postprocess(self, x):
return [x.item()]
"""

MODEL_CONFIG_YAML = f"""
#frontend settings
parallelType: "tp"
deviceType: "cpu"
torchrun:
nproc-per-node: 4
"""


@pytest.fixture(scope="module")
def model_name():
yield "tp_model"


@pytest.fixture(scope="module")
def work_dir(tmp_path_factory, model_name):
return Path(tmp_path_factory.mktemp(model_name))


@pytest.fixture(scope="module", name="mar_file_path")
def create_mar_file(work_dir, model_archiver, model_name):
mar_file_path = work_dir.joinpath(model_name + ".mar")

model_config_yaml_file = work_dir / "model_config.yaml"
model_config_yaml_file.write_text(MODEL_CONFIG_YAML)

model_py_file = work_dir / "model.py"
model_py_file.write_text(MODEL_PY)

handler_py_file = work_dir / "handler.py"
handler_py_file.write_text(HANDLER_PY)

args = Namespace(
model_name=model_name,
version="1.0",
serialized_file=None,
model_file=model_py_file.as_posix(),
handler=handler_py_file.as_posix(),
extra_files=None,
export_path=work_dir,
requirements_file=None,
runtime="python",
force=False,
archive_format="default",
config_file=model_config_yaml_file.as_posix(),
)

mock = MagicMock()
mock.parse_args = MagicMock(return_value=args)
with patch("archiver.ArgParser.export_model_args_parser", return_value=mock):
model_archiver.generate_model_archive()

assert mar_file_path.exists()

yield mar_file_path.as_posix()

# Clean up files
mar_file_path.unlink(missing_ok=True)


@pytest.fixture(scope="module", name="model_name")
def register_model(mar_file_path, model_store, torchserve):
"""
Register the model in torchserve
"""
shutil.copy(mar_file_path, model_store)

file_name = Path(mar_file_path).name

model_name = Path(file_name).stem

params = (
("model_name", model_name),
("url", file_name),
("initial_workers", "1"),
("synchronous", "true"),
("batch_size", "1"),
)

test_utils.reg_resp = test_utils.register_model_with_params(params)

yield model_name

test_utils.unregister_model(model_name)


@pytest.mark.skipif(
platform.system() != "Linux", reason="Skipping test on non-Linux system"
)
def test_tp_inference(model_name):
"""
Full circle test with torchserve
"""

response = requests.post(
url=f"http://localhost:8080/predictions/{model_name}", data=json.dumps(42)
)

assert int(response.text) == 4 * 42

assert response.status_code == 200

0 comments on commit 7f4419f

Please sign in to comment.