Skip to content

Commit

Permalink
New command pld-stream allows making predictions via stdin and stdout.
Browse files Browse the repository at this point in the history
  • Loading branch information
dnouri committed Jun 24, 2015
1 parent c9a7890 commit c7c15d9
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 0 deletions.
55 changes: 55 additions & 0 deletions palladium/server.py
Expand Up @@ -258,3 +258,58 @@ def devserver_cmd(argv=sys.argv[1:]): # pragma: no cover
port=int(arguments['--port']),
debug=int(arguments['--debug']),
)


class PredictStream:
"""A class that helps make predictions through stdin and stdout.
"""
def __init__(self):
self.model = get_config()['model_persister'].read()
self.predict_service = get_config()['predict_service']

def process_line(self, line):
predict_service = self.predict_service
datas = ujson.loads(line)
samples = [predict_service.sample_from_data(self.model, data)
for data in datas]
samples = np.array(samples)
params = predict_service.params_from_data(self.model, datas[0])
return predict_service.predict(self.model, samples, **params)

def listen(self, io_in, io_out, io_err):
"""Listens to provided io stream and writes predictions
to output. In case of errors, the error stream will be used.
"""
for line in io_in:
if line.strip().lower() == 'exit':
break

try:
y_pred = self.process_line(line)
except Exception as e:
io_out.write('[]\n')
io_err.write(
"Error while processing input row: {}"
"{}: {}\n".format(line, type(e), e))
io_err.flush()
else:
io_out.write(ujson.dumps(y_pred.tolist()))
io_out.write('\n')
io_out.flush()


def stream_cmd(argv=sys.argv[1:]): # pragma: no cover
__doc__ = """
Start the streaming server, which listens to stdin, processes line
by line, and returns predictions.
Usage:
pld-stream [options]
Options:
-h --help Show this screen.
"""
docopt(__doc__, argv=argv)
initialize_config()
stream = PredictStream()
stream.listen(sys.stdin, sys.stdout, sys.stderr)
137 changes: 137 additions & 0 deletions palladium/tests/test_server.py
@@ -1,13 +1,16 @@
from datetime import datetime
import io
import json
import math
from threading import Thread
from unittest.mock import Mock
from unittest.mock import patch

import dateutil.parser
from flask import request
import numpy as np
import pytest
import ujson
from werkzeug.exceptions import BadRequest


Expand Down Expand Up @@ -317,3 +320,137 @@ def test_missing_process_state(self, config, process_store, flask_client):

assert resp_data['model']['metadata'] == {'hello': 'is it me'}
assert resp_data['data'] == 'N/A'


class TestPredictStream:
@pytest.fixture
def PredictStream(self):
from palladium.server import PredictStream
return PredictStream

@pytest.fixture
def stream(self, config, PredictStream):
config['model_persister'] = Mock()
predict_service = config['predict_service'] = Mock()
predict_service.sample_from_data.side_effect = (
lambda model, data: data)
predict_service.params_from_data.side_effect = (
lambda model, data: data)
return PredictStream()

def test_listen_direct_exit(self, stream):
io_in = io.StringIO()
io_out = io.StringIO()
io_err = io.StringIO()

stream_thread = Thread(
target=stream.listen(io_in, io_out, io_err))
stream_thread.start()
io_in.write('EXIT\n')
stream_thread.join()
io_out.seek(0)
io_err.seek(0)
assert len(io_out.read()) == 0
assert len(io_err.read()) == 0
assert stream.predict_service.predict.call_count == 0

def test_listen(self, stream):
io_in = io.StringIO()
io_out = io.StringIO()
io_err = io.StringIO()
lines = [
'[{"id": 1, "color": "blue", "length": 1.0}]\n',
'[{"id": 1, "color": "{\\"a\\": 1, \\"b\\": 2}", "length": 1.0}]\n',
'[{"id": 1, "color": "blue", "length": 1.0}, {"id": 2, "color": "{\\"a\\": 1, \\"b\\": 2}", "length": 1.0}]\n',
]
for line in lines:
io_in.write(line)

io_in.write('EXIT\n')
io_in.seek(0)
predict = stream.predict_service.predict
predict.side_effect = (
lambda model, samples, **params:
np.array([{'result': 1}] * len(samples))
)
stream_thread = Thread(
target=stream.listen(io_in, io_out, io_err))
stream_thread.start()
stream_thread.join()
io_out.seek(0)
io_err.seek(0)
assert len(io_err.read()) == 0
assert io_out.read() == (
('[{"result":1}]\n' * 2) + ('[{"result":1},{"result":1}]\n'))
assert predict.call_count == 3
# check if the correct arguments are passed to predict call
assert predict.call_args_list[0][0][1] == np.array([
{'id': 1, 'color': 'blue', 'length': 1.0}])
assert predict.call_args_list[1][0][1] == np.array([
{'id': 1, 'color': '{"a": 1, "b": 2}', 'length': 1.0}])
assert (predict.call_args_list[2][0][1] == np.array([
{'id': 1, 'color': 'blue', 'length': 1.0},
{'id': 2, 'color': '{"a": 1, "b": 2}', 'length': 1.0},
])).all()

# check if string representation of attribute can be converted to json
assert ujson.loads(predict.call_args_list[1][0][1][0]['color']) == {
"a": 1, "b": 2}

def test_predict_error(self, stream):
from palladium.interfaces import PredictError

io_in = io.StringIO()
io_out = io.StringIO()
io_err = io.StringIO()

line = '[{"hey": "1"}]\n'
io_in.write(line)
io_in.write('EXIT\n')
io_in.seek(0)
stream.predict_service.predict.side_effect = PredictError('error')

stream_thread = Thread(
target=stream.listen(io_in, io_out, io_err))
stream_thread.start()
stream_thread.join()

io_out.seek(0)
io_err.seek(0)
assert io_out.read() == '[]\n'
assert io_err.read() == (
"Error while processing input row: {}"
"<class 'palladium.interfaces.PredictError'>: "
"error (-1)\n".format(line))
assert stream.predict_service.predict.call_count == 1

def test_predict_params(self, config, stream):
from palladium.server import PredictService
line = '[{"length": 1.0, "width": 1.0, "turbo": "true"}]'

model = Mock()
model.predict.return_value = np.array([[{'class': 'a'}]])
model.turbo = False
model.magic = False
stream.model = model

mapping = [
('length', 'float'),
('width', 'float'),
]
params = [
('turbo', 'bool'), # will be set by request args
('magic', 'bool'), # default value will be used
]
stream.predict_service = PredictService(
mapping=mapping,
params=params,
)

expected = [{'class': 'a'}]
result = stream.process_line(line)
assert result == expected
assert model.predict.call_count == 1
assert (model.predict.call_args[0][0] == np.array([[1.0, 1.0]])).all()
assert model.predict.call_args[1]['turbo'] is True
assert model.predict.call_args[1]['magic'] is False
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -67,6 +67,7 @@
'pld-fit = palladium.fit:fit_cmd',
'pld-grid-search = palladium.fit:grid_search_cmd',
'pld-list = palladium.eval:list_cmd',
'pld-stream = palladium.server:stream_cmd',
'pld-test = palladium.eval:test_cmd',
'pld-upgrade = palladium.util:upgrade_cmd',
'pld-version = palladium.util:version_cmd',
Expand Down

0 comments on commit c7c15d9

Please sign in to comment.