Skip to content

Commit

Permalink
In PredictService, allow handling data as a 1d array instead of 2d
Browse files Browse the repository at this point in the history
When working with text, scikit-learn and others will sometimes expect
the input to be a 1d array of strings rather than a 2d array.  Setting
the new *unwrap_sample* to true will use this representation.
  • Loading branch information
dnouri committed May 29, 2018
1 parent 1909827 commit 8df98ee
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 4 deletions.
25 changes: 21 additions & 4 deletions palladium/server.py
Expand Up @@ -59,9 +59,15 @@ class PredictService:
}

def __init__(
self, mapping, params=(), entry_point='/predict',
decorator_list_name='predict_decorators',
predict_proba=False, **kwargs):
self,
mapping,
params=(),
entry_point='/predict',
decorator_list_name='predict_decorators',
predict_proba=False,
unwrap_sample=False,
**kwargs
):
"""
:param mapping:
A list of query parameters and their type that should be
Expand All @@ -85,12 +91,19 @@ def __init__(
Instead of returning a single class (the default), when
*predict_proba* is set to true, the result will instead
contain a list of class probabilities.
:param unwrap_sample:
When working with text, scikit-learn and others will
sometimes expect the input to be a 1d array of strings
rather than a 2d array. Setting *unwrap_sample* to true
will use this representation.
"""
self.mapping = mapping
self.params = params
self.entry_point = entry_point
self.decorator_list_name = decorator_list_name
self.predict_proba = predict_proba
self.unwrap_sample = unwrap_sample
vars(self).update(kwargs)

def initialize_component(self, config):
Expand Down Expand Up @@ -132,7 +145,11 @@ def sample_from_data(self, model, data):
for key, type_name in self.mapping:
value_type = self.types[type_name]
values.append(value_type(data[key]))
return np.array(values, dtype=object)
if self.unwrap_sample:
assert len(values) == 1
return np.array(values[0])
else:
return np.array(values, dtype=object)

def params_from_data(self, model, data):
"""Retrieve additional parameters (keyword arguments) for
Expand Down
61 changes: 61 additions & 0 deletions palladium/tests/test_server.py
Expand Up @@ -171,6 +171,67 @@ def test_sample_from_data(self, PredictService):
assert sample[0] == 'myflower'
assert sample[1] == 3

def test_unwrap_sample_get(self, PredictService, flask_app):
predict_service = PredictService(
mapping=[('text', 'str')],
unwrap_sample=True,
)
model = Mock()
model.predict.return_value = np.array([1])
with flask_app.test_request_context():
request = Mock(
args=dict([
('text', 'Hi this is text'),
]),
method='GET',
)
resp = predict_service(model, request)

assert model.predict.call_args[0][0].ndim == 1
model.predict.assert_called_with(np.array(['Hi this is text']))
resp_data = json.loads(resp.get_data(as_text=True))
assert resp.status_code == 200
assert resp_data == {
"metadata": {
"status": "OK",
"error_code": 0,
},
"result": 1,
}

def test_unwrap_sample_post(self, PredictService, flask_app):
predict_service = PredictService(
mapping=[('text', 'str')],
unwrap_sample=True,
)
model = Mock()
model.predict.return_value = np.array([1, 2])
with flask_app.test_request_context():
request = Mock(
json=[
{'text': 'First piece of text'},
{'text': 'Second piece of text'},
],
method='POST',
mimetype='application/json',
)
resp = predict_service(model, request)

assert model.predict.call_args[0][0].ndim == 1
assert (
model.predict.call_args[0] ==
np.array(['First piece of text', 'Second piece of text'])
).all()
resp_data = json.loads(resp.get_data(as_text=True))
assert resp.status_code == 200
assert resp_data == {
"metadata": {
"status": "OK",
"error_code": 0,
},
"result": [1, 2],
}

def test_probas(self, PredictService, flask_app):
model = Mock()
model.predict_proba.return_value = np.array([[0.1, 0.5, math.pi]])
Expand Down

0 comments on commit 8df98ee

Please sign in to comment.