Skip to content

Commit

Permalink
Merge pull request #51 from dnouri/bugfix/get-config-multithreading
Browse files Browse the repository at this point in the history
Use a lock to make get_config thread-safe
  • Loading branch information
alattner committed Sep 22, 2017
2 parents 851ba20 + e390cbf commit 5c5f088
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 5 deletions.
42 changes: 42 additions & 0 deletions palladium/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from logging.config import dictConfig
import os
import sys
import threading


PALLADIUM_CONFIG_ERROR = """
Expand All @@ -12,6 +13,30 @@
"""


PALLADIUM_RECURSIVE_GET_CONFIG_ERROR = """
You're trying to call `get_config` from code that's already called by
`get_config`. Thus, there's unfortunately no way we can guarantee
that the part of the config that you're interested in is already
properly resolved and initialized.
Please consider either implementing the `initialize_component(config)`
method in your class, which gets passed the initialized configuration
as an argument, and use that to do final initialization of your model
using the configuration. Or pass in the configuration that you want
to access as part of your components config. So if possible, instead
of:
{'mycomponent': {...}, 'mydependency': {...}}
write this:
{'mycomponent': {'mydependency': {...}, ...}}
Use of @args_from_config for functions that get called by
configuration code is prohibited for the same reason.
"""


class Config(dict):
"""A dictionary that represents the app's configuration.
Expand Down Expand Up @@ -198,7 +223,24 @@ def process_config(
return config_final


_get_config_lock = threading.Lock()
_get_config_lock_owner = None


def get_config(**extra):
global _get_config_lock_owner
if _get_config_lock_owner == threading.get_ident():
raise ValueError(PALLADIUM_RECURSIVE_GET_CONFIG_ERROR)
with _get_config_lock:
_get_config_lock_owner = threading.get_ident()
try:
config = _get_config(**extra)
finally:
_get_config_lock_owner = None
return config


def _get_config(**extra):
if not _config.initialized:
_config.update(extra)
_config.initialized = True
Expand Down
6 changes: 3 additions & 3 deletions palladium/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(

def initialize_component(self, config):
create_predict_function(
self.entry_point, self, self.decorator_list_name)
self.entry_point, self, self.decorator_list_name, config)

def __call__(self, model, request):
try:
Expand Down Expand Up @@ -244,7 +244,7 @@ def alive(alive=None):


def create_predict_function(
route, predict_service, decorator_list_name):
route, predict_service, decorator_list_name, config):
"""Creates a predict function and registers it to
the Flask app using the route decorator.
Expand All @@ -262,7 +262,7 @@ def create_predict_function(
A predict service function that will be used to process
predict requests.
"""
model_persister = get_config().get('model_persister')
model_persister = config.get('model_persister')

@app.route(route, methods=['GET', 'POST'], endpoint=route)
@PluggableDecorator(decorator_list_name)
Expand Down
61 changes: 60 additions & 1 deletion palladium/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from functools import reduce
import operator
import os
import threading
import time
from unittest.mock import patch

import pytest
Expand All @@ -23,6 +27,17 @@ def __eq__(self, other):
])


class BlockingDummy:
def __init__(self):
time.sleep(0.1)


class BadDummy:
def __init__(self):
from palladium.config import get_config
self.cfg = get_config().copy()


def test_config_class_keyerror():
from palladium.config import Config
with pytest.raises(KeyError) as e:
Expand Down Expand Up @@ -57,7 +72,13 @@ def get_config(self):
@pytest.fixture
def config1_fname(self, tmpdir):
path = tmpdir.join('config1.py')
path.write("{'env': environ['ENV1'], 'here': here}")
path.write("""{
'env': environ['ENV1'],
'here': here,
'blocking': {
'__factory__': 'palladium.tests.test_config.BlockingDummy',
}
}""")
return str(path)

@pytest.fixture
Expand All @@ -66,6 +87,16 @@ def config2_fname(self, tmpdir):
path.write("{'env': environ['ENV2']}")
return str(path)

@pytest.fixture
def config3_fname(self, tmpdir):
path = tmpdir.join('config3.py')
path.write("""{
'bad': {
'__factory__': 'palladium.tests.test_config.BadDummy'
}
}""")
return str(path)

def test_extras(self, get_config):
assert get_config(foo='bar')['foo'] == 'bar'

Expand All @@ -86,6 +117,34 @@ def test_multiple_files(self, get_config, config1_fname, config2_fname,
assert config['env'] == 'two'
assert config['here'] == os.path.dirname(config1_fname)

def test_multithreaded(self, get_config, config1_fname, monkeypatch):
monkeypatch.setitem(os.environ, 'PALLADIUM_CONFIG', config1_fname)
monkeypatch.setitem(os.environ, 'ENV1', 'one')

cfg = {}

def get_me_config():
cfg[threading.get_ident()] = get_config().copy()

threads = [threading.Thread(target=get_me_config) for i in range(2)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

assert reduce(operator.eq, cfg.values())

def test_recursive_call_of_get_config(
self,
get_config,
config3_fname,
monkeypatch,
):
monkeypatch.setitem(os.environ, 'PALLADIUM_CONFIG', config3_fname)
with pytest.raises(ValueError) as exc:
get_config()
assert "You're trying to call `get_config` from code" in str(exc.value)


class TestProcessConfig:
@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion palladium/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_predict_functional(self, config, flask_app_test, flask_client):
with flask_app_test.test_request_context():
from palladium.server import create_predict_function
create_predict_function(
'/predict', predict_service, 'predict_decorators')
'/predict', predict_service, 'predict_decorators', config)
predict_service.return_value = make_ujson_response(
'a', status_code=200)

Expand Down

0 comments on commit 5c5f088

Please sign in to comment.