Skip to content

Commit

Permalink
refactor(test): rewrite wrapper test to be more complete
Browse files Browse the repository at this point in the history
It now validates that the protocol version is passed properly as well
as failure if an invalid protocol is given.
  • Loading branch information
sgaist committed Feb 27, 2024
1 parent 8710263 commit effb601
Showing 1 changed file with 41 additions and 14 deletions.
55 changes: 41 additions & 14 deletions test/pytest/test_kserve_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,17 @@ def generate_protobuf_code():
subprocess.run(" ".join(cmd), shell=True, check=True)


@pytest.fixture
def kserve_wrapper(protobuf):
@pytest.mark.parametrize(
"protocol_version,expected,is_valid",
[
(None, "v1", True),
("v1", "v1", True),
("v2", "v2", True),
("grpc-v2", "grpc-v2", True),
("invalid", "unused", False),
],
)
def test_kserve_wrapper(protobuf, protocol_version, expected, is_valid):
kserve_config = str(CONFIG_DIR / "kserve_config.properties")
with open(kserve_config, "wt") as config:
config.write(KSERVE_CONFIG_CONTENT)
Expand All @@ -138,28 +147,46 @@ def kserve_wrapper(protobuf):
env = os.environ.copy()
env["CONFIG_PATH"] = kserve_config

if protocol_version is not None:
env["PROTOCOL_VERSION"] = protocol_version

error_occured = None

p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=STDOUT, env=env)
for line in p.stdout:
print(line.decode("utf8").strip())
cleaned_line = str(line).strip()
cleaned_line = str(line.decode()).strip()
if "Started server process" in cleaned_line:
wrapper_pid = int(
cleaned_line[cleaned_line.rfind("[") + 1 : cleaned_line.rfind("]")]
)
if "Protocol version is" in cleaned_line:
actual_version = cleaned_line
if "ValueError" in cleaned_line:
error_occured = cleaned_line
if "Application startup complete" in cleaned_line:
break
yield

os.kill(wrapper_pid, signal.SIGINT)
if error_occured is None:
json_data = {
"inputs": [
{"name": "uuid", "shape": -1, "datatype": "BYTES", "data": ["test"]}
]
}

response = requests.post(
f"http://127.0.0.1:8080/v1/models/test:predict", json=json_data, timeout=120
)

def test_inference(kserve_wrapper):
json_data = {
"inputs": [{"name": "uuid", "shape": -1, "datatype": "BYTES", "data": ["test"]}]
}
os.kill(wrapper_pid, signal.SIGINT)

response = requests.post(
f"http://127.0.0.1:8080/v1/models/test:predict", json=json_data, timeout=120
)
assert actual_version.endswith(expected)
assert response.status_code == 200

assert response.status_code == 200
else:
if not is_valid:
assert (
error_occured
== f"ValueError: '{protocol_version}' is not a valid PredictorProtocol"
)
else:
assert False, f"Unexpected error: {error_occured}"

0 comments on commit effb601

Please sign in to comment.