diff --git a/tabpy-server/server_tests/test_config.py b/tabpy-server/server_tests/test_config.py index cee5a806..7f150a34 100644 --- a/tabpy-server/server_tests/test_config.py +++ b/tabpy-server/server_tests/test_config.py @@ -87,8 +87,8 @@ def test_config_file_present(self, mock_os, mock_path_exists, mock_os.getenv.assert_has_calls(getenv_calls, any_order=True) self.assertEqual(app.settings['port'], 1234) - self.assertEqual(app.settings['server_version'], open( - 'VERSION').read().strip()) + self.assertEqual(app.settings['server_version'], + open('VERSION').read().strip()) self.assertEqual(app.settings['upload_dir'], 'foo') self.assertEqual(app.settings['state_file_path'], 'bar') self.assertEqual(app.settings['transfer_protocol'], 'http') @@ -171,12 +171,12 @@ def test_https_cert_and_key_file_not_found(self, mock_path): "TABPY_KEY_FILE = bar") self.fp.close() - mock_path.isfile.side_effect =\ - lambda x: self.mock_isfile(x, {self.fp.name}) + mock_path.isfile.side_effect = lambda x: self.mock_isfile( + x, {self.fp.name}) assert_raises_runtime_error( - 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' - 'and TABPY_KEY_FILE must point to an existing file.', + 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE and ' + 'TABPY_KEY_FILE must point to an existing file.', TabPyApp, {self.fp.name}) @patch('tabpy_server.app.app.os.path') @@ -207,8 +207,9 @@ def test_https_key_file_not_found(self, mock_path): x, {self.fp.name, 'foo'}) assert_raises_runtime_error( - 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE must ' - 'point to an existing file.', TabPyApp, {self.fp.name}) + 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE ' + 'must point to an existing file.', + TabPyApp, {self.fp.name}) @patch('tabpy_server.app.app.os.path.isfile', return_value=True) @patch('tabpy_server.app.util.validate_cert') @@ -235,14 +236,14 @@ def __init__(self, *args, **kwargs): def test_expired_cert(self): path = os.path.join(self.resources_path, 'expired.crt') - message = 'Error using HTTPS: The certificate provided expired '\ - 'on 2018-08-18 19:47:18.' + message = ('Error using HTTPS: The certificate provided expired ' + 'on 2018-08-18 19:47:18.') assert_raises_runtime_error(message, validate_cert, {path}) def test_future_cert(self): path = os.path.join(self.resources_path, 'future.crt') - message = 'Error using HTTPS: The certificate provided is not '\ - 'valid until 3001-01-01 00:00:00.' + message = ('Error using HTTPS: The certificate provided is not valid ' + 'until 3001-01-01 00:00:00.') assert_raises_runtime_error(message, validate_cert, {path}) def test_valid_cert(self): diff --git a/tabpy-server/tabpy_server/app/util.py b/tabpy-server/tabpy_server/app/util.py index 63c4b1c5..a5b2a13e 100644 --- a/tabpy-server/tabpy_server/app/util.py +++ b/tabpy-server/tabpy_server/app/util.py @@ -36,7 +36,8 @@ def validate_cert(cert_file_path): not_before), RuntimeError) if now > not_after: log_and_raise(https_error + - 'The certificate provided expired on {}.'.format(not_after), RuntimeError) + f'The certificate provided expired on {not_after}.', + RuntimeError) def parse_pwd_file(pwd_file_name): diff --git a/tabpy-server/tabpy_server/common/endpoint_file_mgr.py b/tabpy-server/tabpy_server/common/endpoint_file_mgr.py index b5b38286..dfc3f6f4 100644 --- a/tabpy-server/tabpy_server/common/endpoint_file_mgr.py +++ b/tabpy-server/tabpy_server/common/endpoint_file_mgr.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -_name_checker = _compile('^[a-zA-Z0-9-_\ ]+$') +_name_checker = _compile(r'^[a-zA-Z0-9-_\ ]+$') def _check_endpoint_name(name): @@ -28,8 +28,10 @@ def _check_endpoint_name(name): log_and_raise("Endpoint name cannot be empty", ValueError) if not _name_checker.match(name): - log_and_raise('Endpoint name can only contain: a-z, A-Z, 0-9,' - ' underscore, hyphens and spaces.', ValueError) + log_and_raise( + 'Endpoint name can only contain: a-z, A-Z, 0-9,' + ' underscore, hyphens and spaces.', + ValueError) def grab_files(directory): @@ -48,12 +50,13 @@ def grab_files(directory): elif os.path.isfile(full_path): yield full_path + def get_local_endpoint_file_path(name, version, query_path): _check_endpoint_name(name) return os.path.join(query_path, name, str(version)) -def cleanup_endpoint_files(name, query_path, retain_versions = None): +def cleanup_endpoint_files(name, query_path, retain_versions=None): ''' Cleanup the disk space a certain endpiont uses. @@ -63,8 +66,9 @@ def cleanup_endpoint_files(name, query_path, retain_versions = None): The endpoint name retain_version : int, optional - If given, then all files for this endpoint are removed except the folder - for the given version, otherwise, all files for that endpoint are removed. + If given, then all files for this endpoint are removed except the + folder for the given version, otherwise, all files for that endpoint + are removed. ''' _check_endpoint_name(name) local_dir = os.path.join(query_path, name) @@ -84,5 +88,6 @@ def cleanup_endpoint_files(name, query_path, retain_versions = None): for file_or_dir in os.listdir(local_dir): candidate_dir = os.path.join(local_dir, file_or_dir) - if os.path.isdir(candidate_dir) and candidate_dir not in retain_folders: + if os.path.isdir(candidate_dir) and ( + candidate_dir not in retain_folders): shutil.rmtree(candidate_dir) diff --git a/tabpy-server/tabpy_server/common/messages.py b/tabpy-server/tabpy_server/common/messages.py index 76b1006e..c8c95fe1 100644 --- a/tabpy-server/tabpy_server/common/messages.py +++ b/tabpy-server/tabpy_server/common/messages.py @@ -3,6 +3,7 @@ from abc import ABCMeta from collections import namedtuple + class Msg(object): """ An abstract base class for all messages used for communicating between @@ -36,98 +37,125 @@ def from_json(str): return eval(type_str)(**d) -class LoadSuccessful(namedtuple( - 'LoadSuccessful', ['uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): +class LoadSuccessful(namedtuple('LoadSuccessful', [ + 'uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): __slots__ = () -class LoadFailed(namedtuple('LoadFailed', ['uri', 'version', 'error_msg']), Msg): + +class LoadFailed(namedtuple('LoadFailed', [ + 'uri', 'version', 'error_msg']), Msg): __slots__ = () -class LoadInProgress(namedtuple( - 'LoadInProgress', ['uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): + +class LoadInProgress(namedtuple('LoadInProgress', [ + 'uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): __slots__ = () + class Query(namedtuple('Query', ['uri', 'params']), Msg): __slots__ = () + class QuerySuccessful(namedtuple( 'QuerySuccessful', ['uri', 'version', 'response']), Msg): __slots__ = () -class LoadObject(namedtuple( - 'LoadObject', ['uri', 'url', 'version', 'is_update', 'endpoint_type']), Msg): + +class LoadObject(namedtuple('LoadObject', [ + 'uri', 'url', 'version', 'is_update', 'endpoint_type']), Msg): __slots__ = () + class DeleteObjects(namedtuple('DeleteObjects', ['uris']), Msg): __slots__ = () + # Used for testing to flush out objects class FlushObjects(namedtuple('FlushObjects', []), Msg): __slots__ = () + class ObjectsDeleted(namedtuple('ObjectsDeleted', ['uris']), Msg): __slots__ = () + class ObjectsFlushed(namedtuple( 'ObjectsFlushed', ['n_before', 'n_after']), Msg): __slots__ = () + class CountObjects(namedtuple('CountObjects', []), Msg): __slots__ = () + class ObjectCount(namedtuple('ObjectCount', ['count']), Msg): __slots__ = () + class ListObjects(namedtuple('ListObjects', []), Msg): __slots__ = () + class ObjectList(namedtuple('ObjectList', ['objects']), Msg): __slots__ = () + class UnknownURI(namedtuple('UnknownURI', ['uri']), Msg): __slots__ = () + class UnknownMessage(namedtuple('UnknownMessage', ['msg']), Msg): __slots__ = () -class DownloadSkipped(namedtuple('DownloadSkipped', ['uri', 'version', 'msg','host']), - Msg): + +class DownloadSkipped(namedtuple('DownloadSkipped', [ + 'uri', 'version', 'msg', 'host']), Msg): __slots__ = () + class QueryFailed(namedtuple('QueryFailed', ['uri', 'error']), Msg): __slots__ = () + class QueryError(namedtuple('QueryError', ['uri', 'error']), Msg): __slots__ = () + class CheckHealth(namedtuple('CheckHealth', []), Msg): __slots__ = () + class Healthy(namedtuple('Healthy', []), Msg): __slots__ = () + class Unhealthy(namedtuple('Unhealthy', []), Msg): __slots__ = () + class Ping(namedtuple('Ping', ['id']), Msg): __slots__ = () + class Pong(namedtuple('Pong', ['id']), Msg): __slots__ = () + class Listening(namedtuple('Listening', []), Msg): __slots__ = () + class EngineFailure(namedtuple('EngineFailure', ['error']), Msg): __slots__ = () + class FlushLogs(namedtuple('FlushLogs', []), Msg): __slots__ = () + class LogsFlushed(namedtuple('LogsFlushed', []), Msg): __slots__ = () -class ServiceError(namedtuple( - 'ServiceError', ['error']), Msg): - __slots__ = () +class ServiceError(namedtuple('ServiceError', ['error']), Msg): + __slots__ = () diff --git a/tabpy-server/tabpy_server/common/util.py b/tabpy-server/tabpy_server/common/util.py index 8a8d7ea7..b90d3023 100644 --- a/tabpy-server/tabpy_server/common/util.py +++ b/tabpy-server/tabpy_server/common/util.py @@ -1,10 +1,13 @@ import traceback + + def format_exception(e, context): err_msg = "%s : " % e.__class__.__name__ err_msg += "%s" % str(e) return err_msg -def format_exception_DEBUG(e, context,detail=0): + +def format_exception_DEBUG(e, context, detail=0): trace = traceback.format_exc() err_msg = "Traceback\n %s\n" % trace err_msg += "Error type : %s\n" % e.__class__.__name__ diff --git a/tabpy-server/tabpy_server/management/state.py b/tabpy-server/tabpy_server/management/state.py index 58b2ba8d..a0f5b84b 100644 --- a/tabpy-server/tabpy_server/management/state.py +++ b/tabpy-server/tabpy_server/management/state.py @@ -33,6 +33,8 @@ Lock to change the TabPy State. ''' _PS_STATE_LOCK = Lock() + + def state_lock(func): ''' Mutex for changing PS state @@ -46,6 +48,7 @@ def wrapper(self, *args, **kwargs): _PS_STATE_LOCK.release() return wrapper + def load_state_from_str(state_string): ''' Convert from String to ConfigParser @@ -61,6 +64,7 @@ def load_state_from_str(state_string): else: log_and_raise("State string is empty!", ValueError) + def save_state_to_str(config): ''' Convert from ConfigParser to String @@ -72,18 +76,20 @@ def save_state_to_str(config): string_f = StringIO() config.write(string_f) value = string_f.getvalue() - except: + except Exception: logger.error("Cannot convert config to string") finally: string_f.close() return value + def _get_root_path(state_path): if state_path[-1] != '/': return state_path + '/' else: return state_path + def get_query_object_path(state_file_path, name, version): ''' Returns the query object path @@ -99,6 +105,7 @@ def get_query_object_path(state_file_path, name, version): '/'.join([_QUERY_OBJECT_DIR, name]) return full_path + class TabPyState(object): ''' The TabPy state object that stores attributes @@ -165,22 +172,24 @@ def get_endpoints(self, name=None): if name: endpoint_info = simplejson.loads(endpoint_names) docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, name) - if sys.version_info > (3,0): - endpoint_info['docstring'] = str(bytes(docstring,"utf-8").decode('unicode_escape')) + if sys.version_info > (3, 0): + endpoint_info['docstring'] = str( + bytes(docstring, "utf-8").decode('unicode_escape')) else: endpoint_info['docstring'] = docstring.decode('string_escape') endpoints = {name: endpoint_info} else: for endpoint_name in endpoint_names: endpoint_info = simplejson.loads(self._get_config_value( - _DEPLOYMENT_SECTION_NAME, - endpoint_name)) + _DEPLOYMENT_SECTION_NAME, endpoint_name)) docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, - endpoint_name, True, '') + endpoint_name, True, '') if sys.version_info > (3, 0): - endpoint_info['docstring'] = str(bytes(docstring, "utf-8").decode('unicode_escape')) + endpoint_info['docstring'] = str( + bytes(docstring, "utf-8").decode('unicode_escape')) else: - endpoint_info['docstring'] = docstring.decode('string_escape') + endpoint_info['docstring'] = docstring.decode( + 'string_escape') endpoints[endpoint_name] = endpoint_info return endpoints @@ -212,8 +221,10 @@ def add_endpoint(self, name, description=None, ''' try: endpoints = self.get_endpoints() - if name is None or not isinstance(name, (str, unicode)) or len(name) == 0: - raise ValueError("name of the endpoint must be a valid string.") + if name is None or not isinstance( + name, (str, unicode)) or len(name) == 0: + raise ValueError( + "name of the endpoint must be a valid string.") elif name in endpoints: raise ValueError("endpoint %s already exists." % name) if description and not isinstance(description, (str, unicode)): @@ -224,7 +235,8 @@ def add_endpoint(self, name, description=None, raise ValueError("docstring must be a string.") elif not docstring: docstring = '-- no docstring found in query function --' - if not endpoint_type or not isinstance(endpoint_type, (str, unicode)): + if not endpoint_type or not isinstance( + endpoint_type, (str, unicode)): raise ValueError("endpoint type must be a string.") if dependencies and not isinstance(dependencies, list): raise ValueError("dependencies must be a list.") @@ -253,12 +265,13 @@ def add_endpoint(self, name, description=None, def _add_update_endpoints_config(self, endpoints): # save the endpoint info to config - dstring='' + dstring = '' for endpoint_name in endpoints: try: info = endpoints[endpoint_name] if sys.version_info > (3, 0): - dstring = str(bytes(info['docstring'], "utf-8").decode('unicode_escape')) + dstring = str(bytes(info['docstring'], "utf-8").decode( + 'unicode_escape')) else: dstring = info['docstring'].decode('string_escape') self._set_config_value(_QUERY_OBJECT_DOCSTRING, @@ -395,7 +408,7 @@ def delete_endpoint(self, name): # check if other endpoints are depending on this endpoint if len(deps) > 0: raise ValueError("Cannot remove endpoint %s, it is currently " - "used by %s endpoints." % (name, list(deps))) + "used by %s endpoints." % (name, list(deps))) del endpoints[name] @@ -500,8 +513,10 @@ def get_access_control_allow_origin(self): ''' _cors_origin = '' try: - logger.debug("Collecting Access-Control-Allow-Origin from state file...") - _cors_origin = self._get_config_value('Service Info', 'Access-Control-Allow-Origin') + logger.debug("Collecting Access-Control-Allow-Origin from " + "state file...") + _cors_origin = self._get_config_value( + 'Service Info', 'Access-Control-Allow-Origin') except Exception as e: logger.error(e) pass @@ -513,8 +528,9 @@ def get_access_control_allow_headers(self): ''' _cors_headers = '' try: - _cors_headers = self._get_config_value('Service Info', 'Access-Control-Allow-Headers') - except Exception as e: + _cors_headers = self._get_config_value( + 'Service Info', 'Access-Control-Allow-Headers') + except Exception: pass return _cors_headers @@ -524,8 +540,9 @@ def get_access_control_allow_methods(self): ''' _cors_methods = '' try: - _cors_methods = self._get_config_value('Service Info', 'Access-Control-Allow-Methods') - except Exception as e: + _cors_methods = self._get_config_value( + 'Service Info', 'Access-Control-Allow-Methods') + except Exception: pass return _cors_methods @@ -583,8 +600,8 @@ def _get_config_items(self, section_name): raise ValueError("State configuration not yet loaded.") return self.config.items(section_name) - def _get_config_value(self, section_name, option_name, optional = False, - default_value = None): + def _get_config_value(self, section_name, option_name, optional=False, + default_value=None): if not self.config: raise ValueError("State configuration not yet loaded.") @@ -597,7 +614,7 @@ def _get_config_value(self, section_name, option_name, optional = False, return default_value else: raise ValueError("Cannot find option name %s under section %s" - % (option_name, section_name)) + % (option_name, section_name)) def _write_state(self): ''' diff --git a/tabpy-server/tabpy_server/management/util.py b/tabpy-server/tabpy_server/management/util.py index d2cc9fb2..bdc175d5 100644 --- a/tabpy-server/tabpy_server/management/util.py +++ b/tabpy-server/tabpy_server/management/util.py @@ -4,10 +4,6 @@ from ConfigParser import ConfigParser as _ConfigParser except ImportError: from configparser import ConfigParser as _ConfigParser -try: - from StringIO import StringIO as _StringIO -except ImportError: - from io import StringIO as _StringIO from datetime import datetime, timedelta, tzinfo from tabpy_server.app.ConfigParameters import ConfigParameters from tabpy_server.app.util import log_and_raise diff --git a/tabpy-server/tabpy_server/psws/callbacks.py b/tabpy-server/tabpy_server/psws/callbacks.py index c2b4eb3f..b3e35666 100644 --- a/tabpy-server/tabpy_server/psws/callbacks.py +++ b/tabpy-server/tabpy_server/psws/callbacks.py @@ -177,4 +177,5 @@ def on_state_change(settings, tabpy_state, python_service): except Exception as e: err_msg = format_exception(e, 'on_state_change') - logger.error("Error submitting update model request: error={}".format(err_msg)) + logger.error( + "Error submitting update model request: error={}".format(err_msg)) diff --git a/tabpy-server/tabpy_server/psws/python_service.py b/tabpy-server/tabpy_server/psws/python_service.py index e4bbe1ca..4fc86e6f 100644 --- a/tabpy-server/tabpy_server/psws/python_service.py +++ b/tabpy-server/tabpy_server/psws/python_service.py @@ -1,5 +1,4 @@ import concurrent.futures -import tabpy_tools import logging import sys @@ -75,8 +74,9 @@ def __init__(self, def _load_object(self, object_uri, object_url, object_version, is_update, object_type): try: - logger.info("Loading object:, URI={}, URL={}, version={}, is_updated={}".format( - object_uri, object_url,object_version, is_update)) + logger.info("Loading object:, URI={}, URL={}, version={}, " + "is_updated={}".format( + object_uri, object_url, object_version, is_update)) if object_type == 'model': po = QueryObject.load(object_url) elif object_type == 'alias': @@ -90,7 +90,8 @@ def _load_object(self, object_uri, object_url, object_version, is_update, 'status': 'LoadSuccessful', 'last_error': None} except Exception as e: - logger.error("Unable to load QueryObject: path={}, error={}".format(object_url, str(e))) + logger.error("Unable to load QueryObject: path={}, " + "error={}".format(object_url, str(e))) self.query_objects[object_uri] = { 'version': object_version, @@ -101,46 +102,47 @@ def _load_object(self, object_uri, object_url, object_version, is_update, def load_object(self, object_uri, object_url, object_version, is_update, object_type): - try: - obj_info = self.query_objects.get(object_uri) - if obj_info and obj_info['endpoint_obj'] and ( - obj_info['version'] >= object_version): - logger.info( - "Received load message for object already loaded") - - return DownloadSkipped( - object_uri, obj_info['version'], "Object with greater " - "or equal version already loaded") + try: + obj_info = self.query_objects.get(object_uri) + if obj_info and obj_info['endpoint_obj'] and ( + obj_info['version'] >= object_version): + logger.info( + "Received load message for object already loaded") + + return DownloadSkipped( + object_uri, obj_info['version'], "Object with greater " + "or equal version already loaded") + else: + if object_uri not in self.query_objects: + self.query_objects[object_uri] = { + 'version': object_version, + 'type': object_type, + 'endpoint_obj': None, + 'status': 'LoadInProgress', + 'last_error': None} else: - if object_uri not in self.query_objects: - self.query_objects[object_uri] = { - 'version': object_version, - 'type': object_type, - 'endpoint_obj': None, - 'status': 'LoadInProgress', - 'last_error': None} - else: - self.query_objects[ - object_uri]['status'] = 'LoadInProgress' - - self.EXECUTOR.submit( - self._load_object, object_uri, object_url, - object_version, is_update, object_type) - - return LoadInProgress( - object_uri, object_url, object_version, is_update, - object_type) - except Exception as e: - logger.error("Unable to load QueryObject: path={}, error={}".format(object_url, str(e))) - - self.query_objects[object_uri] = { - 'version': object_version, - 'type': object_type, - 'endpoint_obj': None, - 'status': 'LoadFailed', - 'last_error': str(e)} - - return LoadFailed(object_uri, object_version, str(e)) + self.query_objects[ + object_uri]['status'] = 'LoadInProgress' + + self.EXECUTOR.submit( + self._load_object, object_uri, object_url, + object_version, is_update, object_type) + + return LoadInProgress( + object_uri, object_url, object_version, is_update, + object_type) + except Exception as e: + logger.error("Unable to load QueryObject: path={}, " + "error={}".format(object_url, str(e))) + + self.query_objects[object_uri] = { + 'version': object_version, + 'type': object_type, + 'endpoint_obj': None, + 'status': 'LoadFailed', + 'last_error': str(e)} + + return LoadFailed(object_uri, object_version, str(e)) def delete_objects(self, object_uris): """Delete one or more objects from the query_objects map""" @@ -155,12 +157,14 @@ def delete_objects(self, object_uris): return ObjectsDeleted([object_uris]) else: logger.warning("Received message to delete query object " - "that doesn't exist: object_uris={}".format(object_uris)) + "that doesn't exist: object_uris={}".format( + object_uris)) return ObjectsDeleted([]) else: - logger.error("Unexpected input to delete objects: input={}, info={}".format( - object_uris, - "Input should be list or str. Type: %s" % type(object_uris))) + logger.error( + "Unexpected input to delete objects: input={}, info={}".format( + object_uris, "Input should be list or str. " + "Type: %s" % type(object_uris))) return ObjectsDeleted([]) def flush_objects(self):