Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test template #3140

Merged
merged 4 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,21 @@ Your contributions will fall into two categories:

Once you finish implementing a feature or bug-fix, please send a Pull Request to https://github.com/pytorch/serve.

For more non-technical guidance about how to contribute to PyTorch, see the Contributing Guide.
New features should always be covered by at least one integration test.
For guidance please have a look at our [current suite of pytest tests](https://github.com/pytorch/serve/tree/master/test/pytest) and orient yourself on a test that covers a similar use case as your new feature.
A simplified version of an example test can be found in the [mnist template test](https://github.com/pytorch/serve/blob/master/test/pytest/test_mnist_template.py) which shows how to create a mar file on the fly and register it with TorchServe from within a test.
You can run most tests by simply executing:
```bash
pytest test/pytest/test_mnist_template.py
```
To have a look at the TorchServe and/or test output add `-s` like this:
```bash
pytest -s test/pytest/test_mnist_template.py
```
To run only a subset or a single test from a file use `-k` like this:
```bash
pytest -k test/pytest/test_mnist_template.py
```

### Install TorchServe for development

Expand All @@ -50,7 +64,7 @@ Ensure that you have `python3` installed, and the user has access to the site-pa

Run the following script from the top of the source directory.

NOTE: This script force reinstalls `torchserve`, `torch-model-archiver` and `torch-workflow-archiver` if existing installations are found
NOTE: This script force re-installs `torchserve`, `torch-model-archiver` and `torch-workflow-archiver` if existing installations are found

#### For Debian Based Systems/ MacOS

Expand Down
120 changes: 120 additions & 0 deletions test/pytest/test_mnist_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import shutil
from pathlib import Path
from unittest.mock import patch

import pytest
import requests
import test_utils
from model_archiver import ModelArchiverConfig

CURR_FILE_PATH = Path(__file__).parent
REPO_ROOT_DIR = CURR_FILE_PATH.parents[1]
config_file = REPO_ROOT_DIR / "test/resources/config_token.properties"
data_file_zero = REPO_ROOT_DIR / "test/pytest/test_data/0.png"
handler_py_file = REPO_ROOT_DIR / "examples/image_classifier/mnist/mnist_handler.py"
model_py_file = REPO_ROOT_DIR / "examples/image_classifier/mnist/mnist.py"
model_pt_file = REPO_ROOT_DIR / "examples/image_classifier/mnist/mnist_cnn.pt"


HANDLER_PY = """
import torch
from ts.torch_handler.base_handler import BaseHandler

class customHandler(BaseHandler):

def initialize(self, context):
super().initialize(context)
"""

MODEL_CONFIG_YAML = """
#frontend settings
# TorchServe frontend parameters
minWorkers: 1
batchSize: 1
maxWorkers: 1
"""


@pytest.fixture(scope="module")
def model_name():
yield "some_model"


@pytest.fixture(scope="module")
def work_dir(tmp_path_factory, model_name):
return Path(tmp_path_factory.mktemp(model_name))


@pytest.fixture(scope="module", name="mar_file_path")
def create_mar_file(work_dir, model_archiver, model_name):
mar_file_path = work_dir.joinpath(model_name + ".mar")

model_config_yaml_file = work_dir / "model_config.yaml"
model_config_yaml_file.write_text(MODEL_CONFIG_YAML)

config = ModelArchiverConfig(
model_name=model_name,
version="1.0",
serialized_file=model_pt_file.as_posix(),
model_file=model_py_file.as_posix(),
handler=handler_py_file.as_posix(),
extra_files=None,
export_path=work_dir,
requirements_file=None,
runtime="python",
force=False,
archive_format="default",
config_file=model_config_yaml_file.as_posix(),
)

with patch("archiver.ArgParser.export_model_args_parser", return_value=config):
model_archiver.generate_model_archive()

assert mar_file_path.exists()

yield mar_file_path.as_posix()

# Clean up files

mar_file_path.unlink(missing_ok=True)

# Clean up files


@pytest.fixture(scope="module", name="model_name")
def register_model(mar_file_path, model_store, torchserve):
"""
Register the model in torchserve
"""
shutil.copy(mar_file_path, model_store)

file_name = Path(mar_file_path).name

model_name = Path(file_name).stem

params = (
("model_name", model_name),
("url", file_name),
("initial_workers", "1"),
("synchronous", "true"),
("batch_size", "1"),
)

test_utils.reg_resp = test_utils.register_model_with_params(params)

yield model_name

test_utils.unregister_model(model_name)


def test_mnist_template(model_name):
response = requests.get(f"http://localhost:8081/models/{model_name}")
assert response.status_code == 200, "Describe Failed"

with open(data_file_zero, "rb") as f:
response = requests.post(
f"http://localhost:8080/predictions/{model_name}",
data=f,
)

assert response.content.decode("utf-8") == "0", "Wrong prediction"
1 change: 1 addition & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1238,3 +1238,4 @@ lora
vllm
sql
TimeUnit
Aopen
Loading