From c7c15d9c76caf2db76ea0c64a932e98bed7adcef Mon Sep 17 00:00:00 2001 From: Daniel Nouri Date: Wed, 24 Jun 2015 11:39:53 +0200 Subject: [PATCH] New command pld-stream allows making predictions via stdin and stdout. --- palladium/server.py | 55 +++++++++++++ palladium/tests/test_server.py | 137 +++++++++++++++++++++++++++++++++ setup.py | 1 + 3 files changed, 193 insertions(+) diff --git a/palladium/server.py b/palladium/server.py index d70f715..91c3b3f 100644 --- a/palladium/server.py +++ b/palladium/server.py @@ -258,3 +258,58 @@ def devserver_cmd(argv=sys.argv[1:]): # pragma: no cover port=int(arguments['--port']), debug=int(arguments['--debug']), ) + + +class PredictStream: + """A class that helps make predictions through stdin and stdout. + """ + def __init__(self): + self.model = get_config()['model_persister'].read() + self.predict_service = get_config()['predict_service'] + + def process_line(self, line): + predict_service = self.predict_service + datas = ujson.loads(line) + samples = [predict_service.sample_from_data(self.model, data) + for data in datas] + samples = np.array(samples) + params = predict_service.params_from_data(self.model, datas[0]) + return predict_service.predict(self.model, samples, **params) + + def listen(self, io_in, io_out, io_err): + """Listens to provided io stream and writes predictions + to output. In case of errors, the error stream will be used. + """ + for line in io_in: + if line.strip().lower() == 'exit': + break + + try: + y_pred = self.process_line(line) + except Exception as e: + io_out.write('[]\n') + io_err.write( + "Error while processing input row: {}" + "{}: {}\n".format(line, type(e), e)) + io_err.flush() + else: + io_out.write(ujson.dumps(y_pred.tolist())) + io_out.write('\n') + io_out.flush() + + +def stream_cmd(argv=sys.argv[1:]): # pragma: no cover + __doc__ = """ +Start the streaming server, which listens to stdin, processes line +by line, and returns predictions. + +Usage: + pld-stream [options] + +Options: + -h --help Show this screen. +""" + docopt(__doc__, argv=argv) + initialize_config() + stream = PredictStream() + stream.listen(sys.stdin, sys.stdout, sys.stderr) diff --git a/palladium/tests/test_server.py b/palladium/tests/test_server.py index c6ce7b7..a7b257c 100644 --- a/palladium/tests/test_server.py +++ b/palladium/tests/test_server.py @@ -1,6 +1,8 @@ from datetime import datetime +import io import json import math +from threading import Thread from unittest.mock import Mock from unittest.mock import patch @@ -8,6 +10,7 @@ from flask import request import numpy as np import pytest +import ujson from werkzeug.exceptions import BadRequest @@ -317,3 +320,137 @@ def test_missing_process_state(self, config, process_store, flask_client): assert resp_data['model']['metadata'] == {'hello': 'is it me'} assert resp_data['data'] == 'N/A' + + +class TestPredictStream: + @pytest.fixture + def PredictStream(self): + from palladium.server import PredictStream + return PredictStream + + @pytest.fixture + def stream(self, config, PredictStream): + config['model_persister'] = Mock() + predict_service = config['predict_service'] = Mock() + predict_service.sample_from_data.side_effect = ( + lambda model, data: data) + predict_service.params_from_data.side_effect = ( + lambda model, data: data) + return PredictStream() + + def test_listen_direct_exit(self, stream): + io_in = io.StringIO() + io_out = io.StringIO() + io_err = io.StringIO() + + stream_thread = Thread( + target=stream.listen(io_in, io_out, io_err)) + stream_thread.start() + io_in.write('EXIT\n') + stream_thread.join() + io_out.seek(0) + io_err.seek(0) + assert len(io_out.read()) == 0 + assert len(io_err.read()) == 0 + assert stream.predict_service.predict.call_count == 0 + + def test_listen(self, stream): + io_in = io.StringIO() + io_out = io.StringIO() + io_err = io.StringIO() + lines = [ + '[{"id": 1, "color": "blue", "length": 1.0}]\n', + '[{"id": 1, "color": "{\\"a\\": 1, \\"b\\": 2}", "length": 1.0}]\n', + '[{"id": 1, "color": "blue", "length": 1.0}, {"id": 2, "color": "{\\"a\\": 1, \\"b\\": 2}", "length": 1.0}]\n', + ] + for line in lines: + io_in.write(line) + + io_in.write('EXIT\n') + io_in.seek(0) + predict = stream.predict_service.predict + predict.side_effect = ( + lambda model, samples, **params: + np.array([{'result': 1}] * len(samples)) + ) + stream_thread = Thread( + target=stream.listen(io_in, io_out, io_err)) + stream_thread.start() + stream_thread.join() + io_out.seek(0) + io_err.seek(0) + assert len(io_err.read()) == 0 + assert io_out.read() == ( + ('[{"result":1}]\n' * 2) + ('[{"result":1},{"result":1}]\n')) + assert predict.call_count == 3 + # check if the correct arguments are passed to predict call + assert predict.call_args_list[0][0][1] == np.array([ + {'id': 1, 'color': 'blue', 'length': 1.0}]) + assert predict.call_args_list[1][0][1] == np.array([ + {'id': 1, 'color': '{"a": 1, "b": 2}', 'length': 1.0}]) + assert (predict.call_args_list[2][0][1] == np.array([ + {'id': 1, 'color': 'blue', 'length': 1.0}, + {'id': 2, 'color': '{"a": 1, "b": 2}', 'length': 1.0}, + ])).all() + + # check if string representation of attribute can be converted to json + assert ujson.loads(predict.call_args_list[1][0][1][0]['color']) == { + "a": 1, "b": 2} + + def test_predict_error(self, stream): + from palladium.interfaces import PredictError + + io_in = io.StringIO() + io_out = io.StringIO() + io_err = io.StringIO() + + line = '[{"hey": "1"}]\n' + io_in.write(line) + io_in.write('EXIT\n') + io_in.seek(0) + stream.predict_service.predict.side_effect = PredictError('error') + + stream_thread = Thread( + target=stream.listen(io_in, io_out, io_err)) + stream_thread.start() + stream_thread.join() + + io_out.seek(0) + io_err.seek(0) + assert io_out.read() == '[]\n' + assert io_err.read() == ( + "Error while processing input row: {}" + ": " + "error (-1)\n".format(line)) + assert stream.predict_service.predict.call_count == 1 + + def test_predict_params(self, config, stream): + from palladium.server import PredictService + line = '[{"length": 1.0, "width": 1.0, "turbo": "true"}]' + + model = Mock() + model.predict.return_value = np.array([[{'class': 'a'}]]) + model.turbo = False + model.magic = False + stream.model = model + + mapping = [ + ('length', 'float'), + ('width', 'float'), + ] + params = [ + ('turbo', 'bool'), # will be set by request args + ('magic', 'bool'), # default value will be used + ] + stream.predict_service = PredictService( + mapping=mapping, + params=params, + ) + + expected = [{'class': 'a'}] + result = stream.process_line(line) + assert result == expected + assert model.predict.call_count == 1 + assert (model.predict.call_args[0][0] == np.array([[1.0, 1.0]])).all() + assert model.predict.call_args[1]['turbo'] is True + assert model.predict.call_args[1]['magic'] is False diff --git a/setup.py b/setup.py index af408bf..360e322 100644 --- a/setup.py +++ b/setup.py @@ -67,6 +67,7 @@ 'pld-fit = palladium.fit:fit_cmd', 'pld-grid-search = palladium.fit:grid_search_cmd', 'pld-list = palladium.eval:list_cmd', + 'pld-stream = palladium.server:stream_cmd', 'pld-test = palladium.eval:test_cmd', 'pld-upgrade = palladium.util:upgrade_cmd', 'pld-version = palladium.util:version_cmd',