-
Notifications
You must be signed in to change notification settings - Fork 834
/
Copy pathtest_utils.py
39 lines (32 loc) · 1.22 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import pytest
from requests_mock import Mocker
from alibiexplainer.utils import (
SELDON_PREDICTOR_URL_FORMAT,
SELDON_SKIP_LOGGING_HEADER,
TENSORFLOW_PREDICTOR_URL_FORMAT,
Protocol,
construct_predict_fn,
)
@pytest.mark.parametrize("protocol", [Protocol.seldon_http, Protocol.tensorflow_http])
def test_construct_predict_fn(protocol: str, requests_mock: Mocker):
predictor_host = "fake-endpoint.com"
model_name = "foo"
predictor_endpoint = SELDON_PREDICTOR_URL_FORMAT.format(predictor_host)
if protocol == Protocol.tensorflow_http:
predictor_endpoint = TENSORFLOW_PREDICTOR_URL_FORMAT.format(
predictor_host, model_name
)
res_value = [[7]]
requests_mock.post(
predictor_endpoint,
json={"data": {"ndarray": res_value}, "predictions": res_value},
)
predict_fn = construct_predict_fn(
predictor_host, model_name=model_name, protocol=protocol
)
res = predict_fn(arr=[[0, 1, 2, 3]])
assert res == np.array(res_value)
assert requests_mock.call_count == 1
assert SELDON_SKIP_LOGGING_HEADER in requests_mock.last_request.headers
assert requests_mock.last_request.headers[SELDON_SKIP_LOGGING_HEADER] == "true"