diff --git a/palladium/server.py b/palladium/server.py index ac44776..1257b6c 100644 --- a/palladium/server.py +++ b/palladium/server.py @@ -59,9 +59,15 @@ class PredictService: } def __init__( - self, mapping, params=(), entry_point='/predict', - decorator_list_name='predict_decorators', - predict_proba=False, **kwargs): + self, + mapping, + params=(), + entry_point='/predict', + decorator_list_name='predict_decorators', + predict_proba=False, + unwrap_sample=False, + **kwargs + ): """ :param mapping: A list of query parameters and their type that should be @@ -85,12 +91,19 @@ def __init__( Instead of returning a single class (the default), when *predict_proba* is set to true, the result will instead contain a list of class probabilities. + + :param unwrap_sample: + When working with text, scikit-learn and others will + sometimes expect the input to be a 1d array of strings + rather than a 2d array. Setting *unwrap_sample* to true + will use this representation. """ self.mapping = mapping self.params = params self.entry_point = entry_point self.decorator_list_name = decorator_list_name self.predict_proba = predict_proba + self.unwrap_sample = unwrap_sample vars(self).update(kwargs) def initialize_component(self, config): @@ -132,7 +145,11 @@ def sample_from_data(self, model, data): for key, type_name in self.mapping: value_type = self.types[type_name] values.append(value_type(data[key])) - return np.array(values, dtype=object) + if self.unwrap_sample: + assert len(values) == 1 + return np.array(values[0]) + else: + return np.array(values, dtype=object) def params_from_data(self, model, data): """Retrieve additional parameters (keyword arguments) for diff --git a/palladium/tests/test_server.py b/palladium/tests/test_server.py index 85fb8b5..783ec79 100644 --- a/palladium/tests/test_server.py +++ b/palladium/tests/test_server.py @@ -171,6 +171,67 @@ def test_sample_from_data(self, PredictService): assert sample[0] == 'myflower' assert sample[1] == 3 + def test_unwrap_sample_get(self, PredictService, flask_app): + predict_service = PredictService( + mapping=[('text', 'str')], + unwrap_sample=True, + ) + model = Mock() + model.predict.return_value = np.array([1]) + with flask_app.test_request_context(): + request = Mock( + args=dict([ + ('text', 'Hi this is text'), + ]), + method='GET', + ) + resp = predict_service(model, request) + + assert model.predict.call_args[0][0].ndim == 1 + model.predict.assert_called_with(np.array(['Hi this is text'])) + resp_data = json.loads(resp.get_data(as_text=True)) + assert resp.status_code == 200 + assert resp_data == { + "metadata": { + "status": "OK", + "error_code": 0, + }, + "result": 1, + } + + def test_unwrap_sample_post(self, PredictService, flask_app): + predict_service = PredictService( + mapping=[('text', 'str')], + unwrap_sample=True, + ) + model = Mock() + model.predict.return_value = np.array([1, 2]) + with flask_app.test_request_context(): + request = Mock( + json=[ + {'text': 'First piece of text'}, + {'text': 'Second piece of text'}, + ], + method='POST', + mimetype='application/json', + ) + resp = predict_service(model, request) + + assert model.predict.call_args[0][0].ndim == 1 + assert ( + model.predict.call_args[0] == + np.array(['First piece of text', 'Second piece of text']) + ).all() + resp_data = json.loads(resp.get_data(as_text=True)) + assert resp.status_code == 200 + assert resp_data == { + "metadata": { + "status": "OK", + "error_code": 0, + }, + "result": [1, 2], + } + def test_probas(self, PredictService, flask_app): model = Mock() model.predict_proba.return_value = np.array([[0.1, 0.5, math.pi]])