Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the priority of parameters defined in register curl cmd vs model-config.yaml #2858

Merged
merged 4 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

/** Register Model Request for Model server */
public class RegisterModelRequest {
public static final Integer DEFAULT_BATCH_SIZE = 1;
public static final Integer DEFAULT_MAX_BATCH_DELAY = 100;

@SerializedName("model_name")
private String modelName;

Expand Down Expand Up @@ -42,15 +45,18 @@ public RegisterModelRequest(QueryStringDecoder decoder) {
modelName = NettyUtils.getParameter(decoder, "model_name", null);
runtime = NettyUtils.getParameter(decoder, "runtime", null);
handler = NettyUtils.getParameter(decoder, "handler", null);
batchSize = NettyUtils.getIntParameter(decoder, "batch_size", 1);
maxBatchDelay = NettyUtils.getIntParameter(decoder, "max_batch_delay", 100);
batchSize = NettyUtils.getIntParameter(decoder, "batch_size", -1 * DEFAULT_BATCH_SIZE);
maxBatchDelay =
NettyUtils.getIntParameter(
decoder, "max_batch_delay", -1 * DEFAULT_MAX_BATCH_DELAY);
initialWorkers =
NettyUtils.getIntParameter(
decoder,
"initial_workers",
ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel());
synchronous = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", "true"));
responseTimeout = NettyUtils.getIntParameter(decoder, "response_timeout", -1);
responseTimeout =
NettyUtils.getIntParameter(decoder, "response_timeout", -1 * DEFAULT_BATCH_SIZE);
modelUrl = NettyUtils.getParameter(decoder, "url", null);
s3SseKms = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "s3_sse_kms", "false"));
}
Expand All @@ -59,8 +65,10 @@ public RegisterModelRequest(org.pytorch.serve.grpc.management.RegisterModelReque
modelName = GRPCUtils.getRegisterParam(request.getModelName(), null);
runtime = GRPCUtils.getRegisterParam(request.getRuntime(), null);
handler = GRPCUtils.getRegisterParam(request.getHandler(), null);
batchSize = GRPCUtils.getRegisterParam(request.getBatchSize(), 1);
maxBatchDelay = GRPCUtils.getRegisterParam(request.getMaxBatchDelay(), 100);
batchSize = GRPCUtils.getRegisterParam(request.getBatchSize(), -1 * DEFAULT_BATCH_SIZE);
maxBatchDelay =
GRPCUtils.getRegisterParam(
request.getMaxBatchDelay(), -1 * DEFAULT_MAX_BATCH_DELAY);
initialWorkers =
GRPCUtils.getRegisterParam(
request.getInitialWorkers(),
Expand All @@ -72,8 +80,8 @@ public RegisterModelRequest(org.pytorch.serve.grpc.management.RegisterModelReque
}

public RegisterModelRequest() {
batchSize = 1;
maxBatchDelay = 100;
batchSize = -1 * DEFAULT_BATCH_SIZE;
maxBatchDelay = -100 * DEFAULT_MAX_BATCH_DELAY;
synchronous = true;
initialWorkers = ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel();
responseTimeout = -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.http.ConflictStatusException;
import org.pytorch.serve.http.InvalidModelVersionException;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.messages.EnvironmentUtils;
Expand Down Expand Up @@ -300,43 +301,47 @@ private Model createModel(
boolean isWorkflowModel) {
Model model = new Model(archive, configManager.getJobQueueSize());

if (archive.getModelConfig() != null) {
int marBatchSize = archive.getModelConfig().getBatchSize();
batchSize =
marBatchSize > 0
? marBatchSize
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
batchSize);
} else {
batchSize =
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
batchSize);
if (batchSize == -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE) {
if (archive.getModelConfig() != null) {
int marBatchSize = archive.getModelConfig().getBatchSize();
batchSize =
marBatchSize > 0
? marBatchSize
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
RegisterModelRequest.DEFAULT_BATCH_SIZE);
} else {
batchSize =
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
RegisterModelRequest.DEFAULT_BATCH_SIZE);
}
}
model.setBatchSize(batchSize);

if (archive.getModelConfig() != null) {
int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay();
maxBatchDelay =
marMaxBatchDelay > 0
? marMaxBatchDelay
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
maxBatchDelay);
} else {
maxBatchDelay =
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
maxBatchDelay);
if (maxBatchDelay == -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY) {
if (archive.getModelConfig() != null) {
int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay();
maxBatchDelay =
marMaxBatchDelay > 0
? marMaxBatchDelay
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY);
} else {
maxBatchDelay =
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY);
}
}
model.setMaxBatchDelay(maxBatchDelay);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ private void addThreads(
List<WorkerThread> threads, Model model, int count, CompletableFuture<Integer> future) {
WorkerStateListener listener = new WorkerStateListener(future, count);
int maxGpu = model.getNumCores();
int stride = model.getParallelLevel() > 0 ? model.getParallelLevel() : 1;
for (int i = 0; i < count; ++i) {
int gpuId = -1;

Expand All @@ -215,12 +216,7 @@ private void addThreads(
gpuId =
model.getGpuCounter()
.getAndAccumulate(
maxGpu,
(prev, maxGpuId) ->
(prev + model.getParallelLevel() > 0
? model.getParallelLevel()
: 1)
% maxGpuId);
stride, (prev, myStride) -> (prev + myStride) % maxGpu);
if (model.getParallelLevel() == 0) {
gpuId = model.getDeviceIds().get(gpuId);
}
Expand Down
1 change: 1 addition & 0 deletions requirements/developer.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ intel_extension_for_pytorch==2.1.0; sys_platform != 'win32' and sys_platform !=
onnxruntime==1.15.0
googleapis-common-protos
onnx==1.14.1
orjson
142 changes: 142 additions & 0 deletions test/pytest/test_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import shutil
from pathlib import Path
from unittest.mock import patch

import pytest
import test_utils
from model_archiver import ModelArchiverConfig

CURR_FILE_PATH = Path(__file__).parent
REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent

MODEL_PY = """
import torch
import torch.nn as nn

class Foo(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x
"""

HANDLER_PY = """
import os
import torch
from ts.torch_handler.base_handler import BaseHandler

class FooHandler(BaseHandler):
def initialize(self, ctx):
super().initialize(ctx)

def preprocess(self, data):
return torch.as_tensor(int(data[0].get('body').decode('utf-8')), device=self.device)

def postprocess(self, x):
return [x.item()]
"""

MODEL_CONFIG_YAML = f"""
#frontend settings
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 4
maxBatchDelay: 100
batchSize: 4
"""


@pytest.fixture(scope="module")
def model_name():
yield "foo"


@pytest.fixture(scope="module")
def work_dir(tmp_path_factory, model_name):
return Path(tmp_path_factory.mktemp(model_name))


@pytest.fixture(scope="module", name="mar_file_path")
def create_mar_file(work_dir, model_archiver, model_name):
mar_file_path = work_dir.joinpath(model_name + ".mar")

model_config_yaml_file = work_dir / "model_config.yaml"
model_config_yaml_file.write_text(MODEL_CONFIG_YAML)

model_py_file = work_dir / "model.py"
model_py_file.write_text(MODEL_PY)

handler_py_file = work_dir / "handler.py"
handler_py_file.write_text(HANDLER_PY)

config = ModelArchiverConfig(
model_name=model_name,
version="1.0",
serialized_file=None,
model_file=model_py_file.as_posix(),
handler=handler_py_file.as_posix(),
extra_files=None,
export_path=work_dir,
requirements_file=None,
runtime="python",
force=False,
archive_format="default",
config_file=model_config_yaml_file.as_posix(),
)

with patch("archiver.ArgParser.export_model_args_parser", return_value=config):
model_archiver.generate_model_archive()

assert mar_file_path.exists()

yield mar_file_path.as_posix()

# Clean up files
mar_file_path.unlink(missing_ok=True)


def register_model(mar_file_path, model_store, params, torchserve):
shutil.copy(mar_file_path, model_store)

file_name = Path(mar_file_path).name

model_name = Path(file_name).stem

params = params + (
("model_name", model_name),
("url", file_name),
)

test_utils.reg_resp = test_utils.register_model_with_params(params)
return model_name


def test_register_model_with_batch_size(mar_file_path, model_store, torchserve):
params = (
("initial_workers", "2"),
("synchronous", "true"),
("batch_size", "2"),
)

model_name = register_model(mar_file_path, model_store, params, torchserve)

describe_resp = test_utils.describe_model(model_name, "1.0")

assert describe_resp[0]["batchSize"] == 2

test_utils.unregister_model(model_name)


def test_register_model_without_batch_size(mar_file_path, model_store, torchserve):
params = (
("initial_workers", "2"),
("synchronous", "true"),
)
model_name = register_model(mar_file_path, model_store, params, torchserve)

describe_resp = test_utils.describe_model(model_name, "1.0")

assert describe_resp[0]["batchSize"] == 4

test_utils.unregister_model(model_name)
8 changes: 8 additions & 0 deletions test/pytest/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from queue import Queue
from subprocess import PIPE, STDOUT, Popen

import orjson
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we need this? Can't we achieve the same with json?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the description, it is fastest json lib. I'm thinking if we need apply orjson in our handler for json input.

import requests

# To help discover margen modules
Expand Down Expand Up @@ -125,6 +126,13 @@ def unregister_model(model_name):
return response


def describe_model(model_name, version):
response = requests.get(
"http://localhost:8081/models/{}/{}".format(model_name, version)
)
return orjson.loads(response.content)


def delete_mar_file_from_model_store(model_store=None, model_mar=None):
model_store = (
model_store
Expand Down
Loading