Skip to content

Commit

Permalink
User must now manually register /alive and /update-model-cache
Browse files Browse the repository at this point in the history
  • Loading branch information
dnouri committed Apr 9, 2018
1 parent fc482a8 commit aa119ab
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 15 deletions.
28 changes: 26 additions & 2 deletions docs/user/web-service.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ method:

- In multi-server or multi-process environments, you must take care of
updating existing model caches (e.g. when running
:class:`palladium.persistence.CachedUpdatePersister`) by hand. This
:class:`~palladium.persistence.CachedUpdatePersister`) by hand. This
can be done by calling the */update-model-cache* endpoint for each
server process.

Expand Down Expand Up @@ -242,4 +242,28 @@ the same way that */refit* does, that is, by returning an id and
storing information about the job inside of ``process_metadata``.
*/update-model-cache* will update the cache of any caching model
persisters, such as
:class:`palladium.persistence.CachedUpdatePersister`.
:class:`~palladium.persistence.CachedUpdatePersister`.

The */refit* and */update-model-cache* endpoints aren't registered by
default with the Flask app. To register the two endpoints, you can
either call the Flask app's ``add_url_rules`` directly or use the
convenience function :func:`palladium.server.add_url_rule` instead
inside of your configuration file. An example of registering the two
endpoints is this:

.. code-block:: python
'flask_add_url_rules': [
{
'__factory__': 'palladium.server.add_url_rule',
'rule': '/refit',
'view_func': 'palladium.server.refit',
'methods': ['POST'],
},
{
'__factory__': 'palladium.server.add_url_rule',
'rule': '/update-model-cache',
'view_func': 'palladium.server.update_model_cache',
'methods': ['POST'],
},
],
8 changes: 7 additions & 1 deletion palladium/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .util import PluggableDecorator
from .util import process_store
from .util import run_job
from .util import resolve_dotted_name

app = Flask(__name__)

Expand Down Expand Up @@ -372,7 +373,6 @@ def stream_cmd(argv=sys.argv[1:]): # pragma: no cover
stream.listen(sys.stdin, sys.stdout, sys.stderr)


@app.route('/refit', methods=['POST'])
@PluggableDecorator('refit_decorators')
@args_from_config
def refit():
Expand Down Expand Up @@ -401,3 +401,9 @@ def update_model_cache(model_persister):
return make_ujson_response({'job_id': job_id}, status_code=200)
else:
return make_ujson_response({}, status_code=503)


def add_url_rule(rule, endpoint=None, view_func=None, app=app, **options):
if isinstance(view_func, str):
view_func = resolve_dotted_name(view_func)
app.add_url_rule(rule, endpoint=endpoint, view_func=view_func, **options)
53 changes: 45 additions & 8 deletions palladium/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import math
from threading import Thread
from time import sleep
from unittest.mock import call
from unittest.mock import Mock
from unittest.mock import patch
Expand Down Expand Up @@ -563,20 +564,27 @@ def test_predict_params(self, config, stream):


class TestRefitFunctional:
@pytest.fixture
def refit(self):
from palladium.server import refit
return refit

@pytest.fixture
def jobs(self, process_store):
jobs = process_store['process_metadata'].setdefault('jobs', {})
yield jobs
jobs.clear()

def test_it(self, config, jobs, flask_client):
def test_it(self, refit, config, jobs, flask_app):
dsl, model, model_persister = Mock(), Mock(), Mock()
X, y = Mock(), Mock()
dsl.return_value = X, y
config['dataset_loader_train'] = dsl
config['model'] = model
config['model_persister'] = model_persister
resp = flask_client.post('refit')
with flask_app.test_request_context(method='POST'):
resp = refit()
sleep(0.005)
resp_json = json.loads(resp.get_data(as_text=True))
job = jobs[resp_json['job_id']]
assert job['status'] == 'finished'
Expand All @@ -592,32 +600,61 @@ def test_it(self, config, jobs, flask_client):
{'persist_if_better_than': 0.234},
),
])
def test_pass_args(self, flask_client, args, args_expected):
def test_pass_args(self, refit, flask_app, args, args_expected):
with patch('palladium.server.fit') as fit:
fit.__name__ = 'mock'
flask_client.post('refit', data=args)
with flask_app.test_request_context(method='POST', data=args):
refit()
sleep(0.005)
assert fit.call_args == call(**args_expected)


class TestUpdateModelCacheFunctional:
@pytest.fixture
def update_model_cache(self):
from palladium.server import update_model_cache
return update_model_cache

@pytest.fixture
def jobs(self, process_store):
jobs = process_store['process_metadata'].setdefault('jobs', {})
yield jobs
jobs.clear()

def test_success(self, config, jobs, flask_client):
def test_success(self, update_model_cache, config, jobs, flask_app):
model_persister = Mock()
config['model_persister'] = model_persister
resp = flask_client.post('update-model-cache')
with flask_app.test_request_context(method='POST'):
resp = update_model_cache()
sleep(0.005)
resp_json = json.loads(resp.get_data(as_text=True))
job = jobs[resp_json['job_id']]
assert job['status'] == 'finished'
assert job['info'] == repr(model_persister.update_cache())

def test_unavailable(self, config, jobs, flask_client):
def test_unavailable(self, update_model_cache, config, jobs, flask_app):
model_persister = Mock()
del model_persister.update_cache
config['model_persister'] = model_persister
resp = flask_client.post('update-model-cache')
with flask_app.test_request_context(method='POST'):
resp = update_model_cache()
assert resp.status_code == 503


def _test_add_url_rule_func():
return b'A OK'


class TestAddUrlRule:
@pytest.fixture
def add_url_rule(self):
from palladium.server import add_url_rule
return add_url_rule

def test_it(self, add_url_rule, flask_client):
add_url_rule(
'/okay',
view_func='palladium.tests.test_server._test_add_url_rule_func',
)
resp = flask_client.get('/okay')
assert resp.data == b'A OK'
8 changes: 4 additions & 4 deletions palladium/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def test_mtime_no_entry(self, store):
def test_mtime_setitem(self, store):
dt0 = datetime.now()
store['somekey'] = '1'
sleep(0.001) # make sure that we're not too fast
sleep(0.005) # make sure that we're not too fast
dt1 = datetime.now()
assert dt0 < store.mtime['somekey'] < dt1
store['somekey'] = '2'
sleep(0.001) # make sure that we're not too fast
sleep(0.005) # make sure that we're not too fast
dt2 = datetime.now()
assert dt1 < store.mtime['somekey'] < dt2

Expand Down Expand Up @@ -362,7 +362,7 @@ def myfunc(add):
results = []
for i in range(3):
results.append(run_job(myfunc, add=i))
sleep(0.01)
sleep(0.005)
assert result == 3
assert len(jobs) == len(results) == 3
assert set(jobs.keys()) == set(r[1] for r in results)
Expand All @@ -378,7 +378,7 @@ def myfunc(divisor):
num_threads_before = len(threading.enumerate())
for i in range(3):
run_job(myfunc, divisor=i)
sleep(0.01)
sleep(0.005)
num_threads_after = len(threading.enumerate())

assert num_threads_before == num_threads_after
Expand Down

0 comments on commit aa119ab

Please sign in to comment.