From 2364eb05293a309e221e319ddc93f4959d9737a3 Mon Sep 17 00:00:00 2001 From: Oleksandr Golovatyi Date: Thu, 28 Mar 2019 14:08:26 -0700 Subject: [PATCH 1/6] Cherry-pick flake8 code improvements from master --- tabpy-server/server_tests/test_config.py | 141 +++ .../tabpy_server/common/endpoint_file_mgr.py | 18 +- tabpy-server/tabpy_server/common/messages.py | 52 +- tabpy-server/tabpy_server/common/util.py | 5 +- tabpy-server/tabpy_server/management/state.py | 63 +- tabpy-server/tabpy_server/management/util.py | 16 + tabpy-server/tabpy_server/psws/callbacks.py | 3 +- .../tabpy_server/psws/python_service.py | 98 +- tabpy-server/tabpy_server/tabpy.py | 980 ++++++++++++++++++ 9 files changed, 1287 insertions(+), 89 deletions(-) diff --git a/tabpy-server/server_tests/test_config.py b/tabpy-server/server_tests/test_config.py index cee5a806..0618043c 100644 --- a/tabpy-server/server_tests/test_config.py +++ b/tabpy-server/server_tests/test_config.py @@ -1,14 +1,18 @@ import os import unittest +<<<<<<< HEAD from argparse import Namespace from tempfile import NamedTemporaryFile from tabpy_server.app.util import validate_cert from tabpy_server.app.app import TabPyApp +======= +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting from unittest.mock import patch, call + def assert_raises_runtime_error(message, fn, args={}): try: fn(*args) @@ -17,6 +21,7 @@ def assert_raises_runtime_error(message, fn, args={}): assert err.args[0] == message +<<<<<<< HEAD class TestConfigEnvironmentCalls(unittest.TestCase): @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', return_value=Namespace(config=None)) @@ -31,6 +36,44 @@ def test_no_config_file(self, mock_os, mock_file_exists, mock_management_util, mock_tabpy_state, mock_parse_arguments): TabPyApp(None) +======= +def append_logger_settings_to_config_file(config_file): + config_file.write("[loggers]\n" + "keys=root\n" + "[handlers]\n" + "keys=rotatingFileHandler\n" + "[formatters]\n" + "keys=rootFormatter\n" + "[logger_root]\n" + "level=ERROR\n" + "handlers=rotatingFileHandler\n" + "qualname=root\n" + "propagete=0\n" + "[handler_rotatingFileHandler]\n" + "class=handlers.RotatingFileHandler\n" + "level=ERROR\n" + "formatter=rootFormatter\n" + "args=('tabpy_server_tests_log.log', 'w', 1000000, 5)\n" + "[formatter_rootFormatter]\n" + "format=%(asctime)s [%(levelname)s] (%(filename)s:" + "%(module)s:%(lineno)d): %(message)s\n" + "datefmt=%Y-%m-%d,%H:%M:%S\n".encode()) + + +class TestConfigEnvironmentCalls(unittest.TestCase): + + @patch('tabpy_server.tabpy.TabPyState') + @patch('tabpy_server.tabpy._get_state_from_file') + @patch('tabpy_server.tabpy.shutil') + @patch('tabpy_server.tabpy.PythonServiceHandler') + @patch('tabpy_server.tabpy.os.path.exists', return_value=True) + @patch('tabpy_server.tabpy.os.path.isfile', return_value=False) + @patch('tabpy_server.tabpy.os') + def test_no_config_file(self, mock_os, mock_file_exists, mock_path_exists, + mock_psws, mock_shutil, mock_management_util, + mock_tabpy_state): + get_config(None) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting getenv_calls = [call('TABPY_PORT', 9004), call('TABPY_QUERY_OBJECT_PATH', '/tmp/query_objects'), @@ -43,6 +86,7 @@ def test_no_config_file(self, mock_os, mock_file_exists, self.assertTrue(len(mock_management_util.mock_calls) > 0) mock_os.makedirs.assert_not_called() +<<<<<<< HEAD @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', return_value=Namespace(config=None)) @patch('tabpy_server.app.app.TabPyState') @@ -57,10 +101,25 @@ def test_no_state_ini_file_or_state_dir(self, mock_os, mock_file_exists, mock_tabpy_state, mock_parse_arguments): TabPyApp(None) +======= + @patch('tabpy_server.tabpy.TabPyState') + @patch('tabpy_server.tabpy._get_state_from_file') + @patch('tabpy_server.tabpy.shutil') + @patch('tabpy_server.tabpy.PythonServiceHandler') + @patch('tabpy_server.tabpy.os.path.exists', return_value=False) + @patch('tabpy_server.tabpy.os.path.isfile', return_value=False) + @patch('tabpy_server.tabpy.os') + def test_no_state_ini_file_or_state_dir(self, mock_os, mock_file_exists, + mock_path_exists, mock_psws, + mock_shutil, mock_management_util, + mock_tabpy_state): + get_config(None) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting self.assertEqual(len(mock_os.makedirs.mock_calls), 1) class TestPartialConfigFile(unittest.TestCase): +<<<<<<< HEAD @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments') @patch('tabpy_server.app.app.TabPyState') @patch('tabpy_server.app.app._get_state_from_file') @@ -69,6 +128,18 @@ class TestPartialConfigFile(unittest.TestCase): @patch('tabpy_server.app.app.os') def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, mock_management_util, +======= + + @patch('tabpy_server.tabpy.parse_arguments') + @patch('tabpy_server.tabpy.TabPyState') + @patch('tabpy_server.tabpy._get_state_from_file') + @patch('tabpy_server.tabpy.shutil') + @patch('tabpy_server.tabpy.PythonServiceHandler') + @patch('tabpy_server.tabpy.os.path.exists', return_value=True) + @patch('tabpy_server.tabpy.os') + def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, + mock_shutil, mock_management_util, +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting mock_tabpy_state, mock_parse_arguments): config_file = NamedTemporaryFile(delete=False) @@ -77,7 +148,12 @@ def test_config_file_present(self, mock_os, mock_path_exists, "TABPY_STATE_PATH = bar\n".encode()) config_file.close() +<<<<<<< HEAD mock_parse_arguments.return_value = Namespace(config=config_file.name) +======= + mock_parse_arguments.return_value = Namespace( + config=config_file.name, port=None) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting mock_os.getenv.side_effect = [1234] mock_os.path.realpath.return_value = 'bar' @@ -86,6 +162,7 @@ def test_config_file_present(self, mock_os, mock_path_exists, getenv_calls = [call('TABPY_PORT', 9004)] mock_os.getenv.assert_has_calls(getenv_calls, any_order=True) +<<<<<<< HEAD self.assertEqual(app.settings['port'], 1234) self.assertEqual(app.settings['server_version'], open( 'VERSION').read().strip()) @@ -94,6 +171,17 @@ def test_config_file_present(self, mock_os, mock_path_exists, self.assertEqual(app.settings['transfer_protocol'], 'http') self.assertTrue('certificate_file' not in app.settings) self.assertTrue('key_file' not in app.settings) +======= + self.assertEqual(settings['port'], 1234) + self.assertEqual(settings['server_version'], + open('VERSION').read().strip()) + self.assertEquals(settings['bind_ip'], '0.0.0.0') + self.assertEquals(settings['upload_dir'], 'foo') + self.assertEquals(settings['state_file_path'], 'bar') + self.assertEqual(settings['transfer_protocol'], 'http') + self.assertTrue('certificate_file' not in settings) + self.assertTrue('key_file' not in settings) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting os.remove(config_file.name) @@ -117,7 +205,18 @@ def __init__(self, *args, **kwargs): def setUp(self): os.chdir(self.tabpy_cwd) +<<<<<<< HEAD self.fp = NamedTemporaryFile(mode='w+t', delete=False) +======= + self.fp = NamedTemporaryFile(delete=False) + self.config_name = self.fp.name + append_logger_settings_to_config_file(self.fp) + + patcher = patch('tabpy_server.tabpy.parse_arguments', + return_value=Namespace(config=self.fp.name, port=None)) + patcher.start() + self.addCleanup(patcher.stop) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting def tearDown(self): os.chdir(self.cwd) @@ -139,8 +238,12 @@ def test_https_without_cert_and_key(self): assert_raises_runtime_error( 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' +<<<<<<< HEAD 'and TABPY_KEY_FILE must be set.', TabPyApp, {self.fp.name}) +======= + 'and TABPY_KEY_FILE must be set.', get_config, {self.config_name}) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting def test_https_without_cert(self): self.fp.write( @@ -151,7 +254,11 @@ def test_https_without_cert(self): assert_raises_runtime_error('Error using HTTPS: The parameter(s) ' 'TABPY_CERTIFICATE_FILE must be set.', +<<<<<<< HEAD TabPyApp, {self.fp.name}) +======= + get_config, [self.config_name]) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting def test_https_without_key(self): self.fp.write("[TabPy]\n" @@ -161,7 +268,11 @@ def test_https_without_key(self): assert_raises_runtime_error('Error using HTTPS: The parameter(s) ' 'TABPY_KEY_FILE must be set.', +<<<<<<< HEAD TabPyApp, {self.fp.name}) +======= + get_config, [self.config_name]) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting @patch('tabpy_server.app.app.os.path') def test_https_cert_and_key_file_not_found(self, mock_path): @@ -171,6 +282,7 @@ def test_https_cert_and_key_file_not_found(self, mock_path): "TABPY_KEY_FILE = bar") self.fp.close() +<<<<<<< HEAD mock_path.isfile.side_effect =\ lambda x: self.mock_isfile(x, {self.fp.name}) @@ -178,6 +290,15 @@ def test_https_cert_and_key_file_not_found(self, mock_path): 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' 'and TABPY_KEY_FILE must point to an existing file.', TabPyApp, {self.fp.name}) +======= + mock_path.isfile.side_effect = lambda x: self.mock_isfile( + x, {self.fp.name}) + + assert_raises_runtime_error( + 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE and ' + 'TABPY_KEY_FILE must point to an existing file.', + get_config, {self.config_name}) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting @patch('tabpy_server.app.app.os.path') def test_https_cert_file_not_found(self, mock_path): @@ -191,9 +312,14 @@ def test_https_cert_file_not_found(self, mock_path): x, {self.fp.name, 'bar'}) assert_raises_runtime_error( +<<<<<<< HEAD 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' 'must point to an existing file.', TabPyApp, {self.fp.name}) +======= + 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE must ' + 'point to an existing file.', get_config, {self.config_name}) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting @patch('tabpy_server.app.app.os.path') def test_https_key_file_not_found(self, mock_path): @@ -207,8 +333,13 @@ def test_https_key_file_not_found(self, mock_path): x, {self.fp.name, 'foo'}) assert_raises_runtime_error( +<<<<<<< HEAD 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE must ' 'point to an existing file.', TabPyApp, {self.fp.name}) +======= + 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE must point to ' + 'an existing file.', get_config, {self.config_name}) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting @patch('tabpy_server.app.app.os.path.isfile', return_value=True) @patch('tabpy_server.app.util.validate_cert') @@ -235,14 +366,24 @@ def __init__(self, *args, **kwargs): def test_expired_cert(self): path = os.path.join(self.resources_path, 'expired.crt') +<<<<<<< HEAD message = 'Error using HTTPS: The certificate provided expired '\ 'on 2018-08-18 19:47:18.' +======= + message = ('Error using HTTPS: The certificate provided expired ' + 'on 2018-08-18 19:47:18.') +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting assert_raises_runtime_error(message, validate_cert, {path}) def test_future_cert(self): path = os.path.join(self.resources_path, 'future.crt') +<<<<<<< HEAD message = 'Error using HTTPS: The certificate provided is not '\ 'valid until 3001-01-01 00:00:00.' +======= + message = ('Error using HTTPS: The certificate provided is not valid ' + 'until 3001-01-01 00:00:00.') +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting assert_raises_runtime_error(message, validate_cert, {path}) def test_valid_cert(self): diff --git a/tabpy-server/tabpy_server/common/endpoint_file_mgr.py b/tabpy-server/tabpy_server/common/endpoint_file_mgr.py index b5b38286..b49bac26 100644 --- a/tabpy-server/tabpy_server/common/endpoint_file_mgr.py +++ b/tabpy-server/tabpy_server/common/endpoint_file_mgr.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -_name_checker = _compile('^[a-zA-Z0-9-_\ ]+$') +_name_checker = _compile(r'^[a-zA-Z0-9-_\ ]+$') def _check_endpoint_name(name): @@ -28,8 +28,13 @@ def _check_endpoint_name(name): log_and_raise("Endpoint name cannot be empty", ValueError) if not _name_checker.match(name): +<<<<<<< HEAD log_and_raise('Endpoint name can only contain: a-z, A-Z, 0-9,' ' underscore, hyphens and spaces.', ValueError) +======= + raise ValueError('Endpoint name can only contain: a-z, A-Z, 0-9,' + ' underscore, hyphens and spaces.') +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting def grab_files(directory): @@ -48,12 +53,13 @@ def grab_files(directory): elif os.path.isfile(full_path): yield full_path + def get_local_endpoint_file_path(name, version, query_path): _check_endpoint_name(name) return os.path.join(query_path, name, str(version)) -def cleanup_endpoint_files(name, query_path, retain_versions = None): +def cleanup_endpoint_files(name, query_path, retain_versions=None): ''' Cleanup the disk space a certain endpiont uses. @@ -63,8 +69,9 @@ def cleanup_endpoint_files(name, query_path, retain_versions = None): The endpoint name retain_version : int, optional - If given, then all files for this endpoint are removed except the folder - for the given version, otherwise, all files for that endpoint are removed. + If given, then all files for this endpoint are removed except the + folder for the given version, otherwise, all files for that endpoint + are removed. ''' _check_endpoint_name(name) local_dir = os.path.join(query_path, name) @@ -84,5 +91,6 @@ def cleanup_endpoint_files(name, query_path, retain_versions = None): for file_or_dir in os.listdir(local_dir): candidate_dir = os.path.join(local_dir, file_or_dir) - if os.path.isdir(candidate_dir) and candidate_dir not in retain_folders: + if os.path.isdir(candidate_dir) and ( + candidate_dir not in retain_folders): shutil.rmtree(candidate_dir) diff --git a/tabpy-server/tabpy_server/common/messages.py b/tabpy-server/tabpy_server/common/messages.py index 76b1006e..c8c95fe1 100644 --- a/tabpy-server/tabpy_server/common/messages.py +++ b/tabpy-server/tabpy_server/common/messages.py @@ -3,6 +3,7 @@ from abc import ABCMeta from collections import namedtuple + class Msg(object): """ An abstract base class for all messages used for communicating between @@ -36,98 +37,125 @@ def from_json(str): return eval(type_str)(**d) -class LoadSuccessful(namedtuple( - 'LoadSuccessful', ['uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): +class LoadSuccessful(namedtuple('LoadSuccessful', [ + 'uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): __slots__ = () -class LoadFailed(namedtuple('LoadFailed', ['uri', 'version', 'error_msg']), Msg): + +class LoadFailed(namedtuple('LoadFailed', [ + 'uri', 'version', 'error_msg']), Msg): __slots__ = () -class LoadInProgress(namedtuple( - 'LoadInProgress', ['uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): + +class LoadInProgress(namedtuple('LoadInProgress', [ + 'uri', 'path', 'version', 'is_update', 'endpoint_type']), Msg): __slots__ = () + class Query(namedtuple('Query', ['uri', 'params']), Msg): __slots__ = () + class QuerySuccessful(namedtuple( 'QuerySuccessful', ['uri', 'version', 'response']), Msg): __slots__ = () -class LoadObject(namedtuple( - 'LoadObject', ['uri', 'url', 'version', 'is_update', 'endpoint_type']), Msg): + +class LoadObject(namedtuple('LoadObject', [ + 'uri', 'url', 'version', 'is_update', 'endpoint_type']), Msg): __slots__ = () + class DeleteObjects(namedtuple('DeleteObjects', ['uris']), Msg): __slots__ = () + # Used for testing to flush out objects class FlushObjects(namedtuple('FlushObjects', []), Msg): __slots__ = () + class ObjectsDeleted(namedtuple('ObjectsDeleted', ['uris']), Msg): __slots__ = () + class ObjectsFlushed(namedtuple( 'ObjectsFlushed', ['n_before', 'n_after']), Msg): __slots__ = () + class CountObjects(namedtuple('CountObjects', []), Msg): __slots__ = () + class ObjectCount(namedtuple('ObjectCount', ['count']), Msg): __slots__ = () + class ListObjects(namedtuple('ListObjects', []), Msg): __slots__ = () + class ObjectList(namedtuple('ObjectList', ['objects']), Msg): __slots__ = () + class UnknownURI(namedtuple('UnknownURI', ['uri']), Msg): __slots__ = () + class UnknownMessage(namedtuple('UnknownMessage', ['msg']), Msg): __slots__ = () -class DownloadSkipped(namedtuple('DownloadSkipped', ['uri', 'version', 'msg','host']), - Msg): + +class DownloadSkipped(namedtuple('DownloadSkipped', [ + 'uri', 'version', 'msg', 'host']), Msg): __slots__ = () + class QueryFailed(namedtuple('QueryFailed', ['uri', 'error']), Msg): __slots__ = () + class QueryError(namedtuple('QueryError', ['uri', 'error']), Msg): __slots__ = () + class CheckHealth(namedtuple('CheckHealth', []), Msg): __slots__ = () + class Healthy(namedtuple('Healthy', []), Msg): __slots__ = () + class Unhealthy(namedtuple('Unhealthy', []), Msg): __slots__ = () + class Ping(namedtuple('Ping', ['id']), Msg): __slots__ = () + class Pong(namedtuple('Pong', ['id']), Msg): __slots__ = () + class Listening(namedtuple('Listening', []), Msg): __slots__ = () + class EngineFailure(namedtuple('EngineFailure', ['error']), Msg): __slots__ = () + class FlushLogs(namedtuple('FlushLogs', []), Msg): __slots__ = () + class LogsFlushed(namedtuple('LogsFlushed', []), Msg): __slots__ = () -class ServiceError(namedtuple( - 'ServiceError', ['error']), Msg): - __slots__ = () +class ServiceError(namedtuple('ServiceError', ['error']), Msg): + __slots__ = () diff --git a/tabpy-server/tabpy_server/common/util.py b/tabpy-server/tabpy_server/common/util.py index 8a8d7ea7..b90d3023 100644 --- a/tabpy-server/tabpy_server/common/util.py +++ b/tabpy-server/tabpy_server/common/util.py @@ -1,10 +1,13 @@ import traceback + + def format_exception(e, context): err_msg = "%s : " % e.__class__.__name__ err_msg += "%s" % str(e) return err_msg -def format_exception_DEBUG(e, context,detail=0): + +def format_exception_DEBUG(e, context, detail=0): trace = traceback.format_exc() err_msg = "Traceback\n %s\n" % trace err_msg += "Error type : %s\n" % e.__class__.__name__ diff --git a/tabpy-server/tabpy_server/management/state.py b/tabpy-server/tabpy_server/management/state.py index 58b2ba8d..a0f5b84b 100644 --- a/tabpy-server/tabpy_server/management/state.py +++ b/tabpy-server/tabpy_server/management/state.py @@ -33,6 +33,8 @@ Lock to change the TabPy State. ''' _PS_STATE_LOCK = Lock() + + def state_lock(func): ''' Mutex for changing PS state @@ -46,6 +48,7 @@ def wrapper(self, *args, **kwargs): _PS_STATE_LOCK.release() return wrapper + def load_state_from_str(state_string): ''' Convert from String to ConfigParser @@ -61,6 +64,7 @@ def load_state_from_str(state_string): else: log_and_raise("State string is empty!", ValueError) + def save_state_to_str(config): ''' Convert from ConfigParser to String @@ -72,18 +76,20 @@ def save_state_to_str(config): string_f = StringIO() config.write(string_f) value = string_f.getvalue() - except: + except Exception: logger.error("Cannot convert config to string") finally: string_f.close() return value + def _get_root_path(state_path): if state_path[-1] != '/': return state_path + '/' else: return state_path + def get_query_object_path(state_file_path, name, version): ''' Returns the query object path @@ -99,6 +105,7 @@ def get_query_object_path(state_file_path, name, version): '/'.join([_QUERY_OBJECT_DIR, name]) return full_path + class TabPyState(object): ''' The TabPy state object that stores attributes @@ -165,22 +172,24 @@ def get_endpoints(self, name=None): if name: endpoint_info = simplejson.loads(endpoint_names) docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, name) - if sys.version_info > (3,0): - endpoint_info['docstring'] = str(bytes(docstring,"utf-8").decode('unicode_escape')) + if sys.version_info > (3, 0): + endpoint_info['docstring'] = str( + bytes(docstring, "utf-8").decode('unicode_escape')) else: endpoint_info['docstring'] = docstring.decode('string_escape') endpoints = {name: endpoint_info} else: for endpoint_name in endpoint_names: endpoint_info = simplejson.loads(self._get_config_value( - _DEPLOYMENT_SECTION_NAME, - endpoint_name)) + _DEPLOYMENT_SECTION_NAME, endpoint_name)) docstring = self._get_config_value(_QUERY_OBJECT_DOCSTRING, - endpoint_name, True, '') + endpoint_name, True, '') if sys.version_info > (3, 0): - endpoint_info['docstring'] = str(bytes(docstring, "utf-8").decode('unicode_escape')) + endpoint_info['docstring'] = str( + bytes(docstring, "utf-8").decode('unicode_escape')) else: - endpoint_info['docstring'] = docstring.decode('string_escape') + endpoint_info['docstring'] = docstring.decode( + 'string_escape') endpoints[endpoint_name] = endpoint_info return endpoints @@ -212,8 +221,10 @@ def add_endpoint(self, name, description=None, ''' try: endpoints = self.get_endpoints() - if name is None or not isinstance(name, (str, unicode)) or len(name) == 0: - raise ValueError("name of the endpoint must be a valid string.") + if name is None or not isinstance( + name, (str, unicode)) or len(name) == 0: + raise ValueError( + "name of the endpoint must be a valid string.") elif name in endpoints: raise ValueError("endpoint %s already exists." % name) if description and not isinstance(description, (str, unicode)): @@ -224,7 +235,8 @@ def add_endpoint(self, name, description=None, raise ValueError("docstring must be a string.") elif not docstring: docstring = '-- no docstring found in query function --' - if not endpoint_type or not isinstance(endpoint_type, (str, unicode)): + if not endpoint_type or not isinstance( + endpoint_type, (str, unicode)): raise ValueError("endpoint type must be a string.") if dependencies and not isinstance(dependencies, list): raise ValueError("dependencies must be a list.") @@ -253,12 +265,13 @@ def add_endpoint(self, name, description=None, def _add_update_endpoints_config(self, endpoints): # save the endpoint info to config - dstring='' + dstring = '' for endpoint_name in endpoints: try: info = endpoints[endpoint_name] if sys.version_info > (3, 0): - dstring = str(bytes(info['docstring'], "utf-8").decode('unicode_escape')) + dstring = str(bytes(info['docstring'], "utf-8").decode( + 'unicode_escape')) else: dstring = info['docstring'].decode('string_escape') self._set_config_value(_QUERY_OBJECT_DOCSTRING, @@ -395,7 +408,7 @@ def delete_endpoint(self, name): # check if other endpoints are depending on this endpoint if len(deps) > 0: raise ValueError("Cannot remove endpoint %s, it is currently " - "used by %s endpoints." % (name, list(deps))) + "used by %s endpoints." % (name, list(deps))) del endpoints[name] @@ -500,8 +513,10 @@ def get_access_control_allow_origin(self): ''' _cors_origin = '' try: - logger.debug("Collecting Access-Control-Allow-Origin from state file...") - _cors_origin = self._get_config_value('Service Info', 'Access-Control-Allow-Origin') + logger.debug("Collecting Access-Control-Allow-Origin from " + "state file...") + _cors_origin = self._get_config_value( + 'Service Info', 'Access-Control-Allow-Origin') except Exception as e: logger.error(e) pass @@ -513,8 +528,9 @@ def get_access_control_allow_headers(self): ''' _cors_headers = '' try: - _cors_headers = self._get_config_value('Service Info', 'Access-Control-Allow-Headers') - except Exception as e: + _cors_headers = self._get_config_value( + 'Service Info', 'Access-Control-Allow-Headers') + except Exception: pass return _cors_headers @@ -524,8 +540,9 @@ def get_access_control_allow_methods(self): ''' _cors_methods = '' try: - _cors_methods = self._get_config_value('Service Info', 'Access-Control-Allow-Methods') - except Exception as e: + _cors_methods = self._get_config_value( + 'Service Info', 'Access-Control-Allow-Methods') + except Exception: pass return _cors_methods @@ -583,8 +600,8 @@ def _get_config_items(self, section_name): raise ValueError("State configuration not yet loaded.") return self.config.items(section_name) - def _get_config_value(self, section_name, option_name, optional = False, - default_value = None): + def _get_config_value(self, section_name, option_name, optional=False, + default_value=None): if not self.config: raise ValueError("State configuration not yet loaded.") @@ -597,7 +614,7 @@ def _get_config_value(self, section_name, option_name, optional = False, return default_value else: raise ValueError("Cannot find option name %s under section %s" - % (option_name, section_name)) + % (option_name, section_name)) def _write_state(self): ''' diff --git a/tabpy-server/tabpy_server/management/util.py b/tabpy-server/tabpy_server/management/util.py index d2cc9fb2..f63cbefa 100644 --- a/tabpy-server/tabpy_server/management/util.py +++ b/tabpy-server/tabpy_server/management/util.py @@ -4,10 +4,13 @@ from ConfigParser import ConfigParser as _ConfigParser except ImportError: from configparser import ConfigParser as _ConfigParser +<<<<<<< HEAD try: from StringIO import StringIO as _StringIO except ImportError: from io import StringIO as _StringIO +======= +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting from datetime import datetime, timedelta, tzinfo from tabpy_server.app.ConfigParameters import ConfigParameters from tabpy_server.app.util import log_and_raise @@ -20,10 +23,14 @@ def write_state_config(state, settings): if 'state_file_path' in settings: state_path = settings['state_file_path'] else: +<<<<<<< HEAD log_and_raise( '{} is not set'.format( ConfigParameters.TABPY_STATE_PATH), ValueError) +======= + raise ValueError('TABPY_STATE_PATH is not set') +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting logger.debug("State path is {}".format(state_path)) state_key = os.path.join(state_path, 'state.ini') @@ -47,9 +54,14 @@ def _get_state_from_file(state_path): config.read(tmp_state_file) if not config.has_section('Service Info'): +<<<<<<< HEAD log_and_raise( "Config error: Expected 'Service Info' section in %s" % (tmp_state_file,), ValueError) +======= + raise ValueError("Config error: Expected 'Service Info' " + "section in %s" % (tmp_state_file,)) +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting return config @@ -96,4 +108,8 @@ def _dt_to_utc_timestamp(t): elif not t.tzinfo: return mktime(t.timetuple()) else: +<<<<<<< HEAD log_and_raise('Only local time and UTC time is supported', ValueError) +======= + raise ValueError('Only local time and UTC time is supported') +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting diff --git a/tabpy-server/tabpy_server/psws/callbacks.py b/tabpy-server/tabpy_server/psws/callbacks.py index c2b4eb3f..b3e35666 100644 --- a/tabpy-server/tabpy_server/psws/callbacks.py +++ b/tabpy-server/tabpy_server/psws/callbacks.py @@ -177,4 +177,5 @@ def on_state_change(settings, tabpy_state, python_service): except Exception as e: err_msg = format_exception(e, 'on_state_change') - logger.error("Error submitting update model request: error={}".format(err_msg)) + logger.error( + "Error submitting update model request: error={}".format(err_msg)) diff --git a/tabpy-server/tabpy_server/psws/python_service.py b/tabpy-server/tabpy_server/psws/python_service.py index e4bbe1ca..4fc86e6f 100644 --- a/tabpy-server/tabpy_server/psws/python_service.py +++ b/tabpy-server/tabpy_server/psws/python_service.py @@ -1,5 +1,4 @@ import concurrent.futures -import tabpy_tools import logging import sys @@ -75,8 +74,9 @@ def __init__(self, def _load_object(self, object_uri, object_url, object_version, is_update, object_type): try: - logger.info("Loading object:, URI={}, URL={}, version={}, is_updated={}".format( - object_uri, object_url,object_version, is_update)) + logger.info("Loading object:, URI={}, URL={}, version={}, " + "is_updated={}".format( + object_uri, object_url, object_version, is_update)) if object_type == 'model': po = QueryObject.load(object_url) elif object_type == 'alias': @@ -90,7 +90,8 @@ def _load_object(self, object_uri, object_url, object_version, is_update, 'status': 'LoadSuccessful', 'last_error': None} except Exception as e: - logger.error("Unable to load QueryObject: path={}, error={}".format(object_url, str(e))) + logger.error("Unable to load QueryObject: path={}, " + "error={}".format(object_url, str(e))) self.query_objects[object_uri] = { 'version': object_version, @@ -101,46 +102,47 @@ def _load_object(self, object_uri, object_url, object_version, is_update, def load_object(self, object_uri, object_url, object_version, is_update, object_type): - try: - obj_info = self.query_objects.get(object_uri) - if obj_info and obj_info['endpoint_obj'] and ( - obj_info['version'] >= object_version): - logger.info( - "Received load message for object already loaded") - - return DownloadSkipped( - object_uri, obj_info['version'], "Object with greater " - "or equal version already loaded") + try: + obj_info = self.query_objects.get(object_uri) + if obj_info and obj_info['endpoint_obj'] and ( + obj_info['version'] >= object_version): + logger.info( + "Received load message for object already loaded") + + return DownloadSkipped( + object_uri, obj_info['version'], "Object with greater " + "or equal version already loaded") + else: + if object_uri not in self.query_objects: + self.query_objects[object_uri] = { + 'version': object_version, + 'type': object_type, + 'endpoint_obj': None, + 'status': 'LoadInProgress', + 'last_error': None} else: - if object_uri not in self.query_objects: - self.query_objects[object_uri] = { - 'version': object_version, - 'type': object_type, - 'endpoint_obj': None, - 'status': 'LoadInProgress', - 'last_error': None} - else: - self.query_objects[ - object_uri]['status'] = 'LoadInProgress' - - self.EXECUTOR.submit( - self._load_object, object_uri, object_url, - object_version, is_update, object_type) - - return LoadInProgress( - object_uri, object_url, object_version, is_update, - object_type) - except Exception as e: - logger.error("Unable to load QueryObject: path={}, error={}".format(object_url, str(e))) - - self.query_objects[object_uri] = { - 'version': object_version, - 'type': object_type, - 'endpoint_obj': None, - 'status': 'LoadFailed', - 'last_error': str(e)} - - return LoadFailed(object_uri, object_version, str(e)) + self.query_objects[ + object_uri]['status'] = 'LoadInProgress' + + self.EXECUTOR.submit( + self._load_object, object_uri, object_url, + object_version, is_update, object_type) + + return LoadInProgress( + object_uri, object_url, object_version, is_update, + object_type) + except Exception as e: + logger.error("Unable to load QueryObject: path={}, " + "error={}".format(object_url, str(e))) + + self.query_objects[object_uri] = { + 'version': object_version, + 'type': object_type, + 'endpoint_obj': None, + 'status': 'LoadFailed', + 'last_error': str(e)} + + return LoadFailed(object_uri, object_version, str(e)) def delete_objects(self, object_uris): """Delete one or more objects from the query_objects map""" @@ -155,12 +157,14 @@ def delete_objects(self, object_uris): return ObjectsDeleted([object_uris]) else: logger.warning("Received message to delete query object " - "that doesn't exist: object_uris={}".format(object_uris)) + "that doesn't exist: object_uris={}".format( + object_uris)) return ObjectsDeleted([]) else: - logger.error("Unexpected input to delete objects: input={}, info={}".format( - object_uris, - "Input should be list or str. Type: %s" % type(object_uris))) + logger.error( + "Unexpected input to delete objects: input={}, info={}".format( + object_uris, "Input should be list or str. " + "Type: %s" % type(object_uris))) return ObjectsDeleted([]) def flush_objects(self): diff --git a/tabpy-server/tabpy_server/tabpy.py b/tabpy-server/tabpy_server/tabpy.py index f7936c5a..90b3b5ae 100644 --- a/tabpy-server/tabpy_server/tabpy.py +++ b/tabpy-server/tabpy_server/tabpy.py @@ -1,3 +1,4 @@ +<<<<<<< HEAD from tabpy_server import __version__ from tabpy_server.app.app import TabPyApp @@ -5,6 +6,985 @@ def main(): app = TabPyApp() app.run() +======= +from argparse import ArgumentParser +import concurrent.futures +import configparser +from datetime import datetime +from hashlib import md5 +import logging +import logging.config +import multiprocessing +from OpenSSL import crypto +import os +from re import compile as _compile +import requests +import shutil +import simplejson +import sys +from tabpy_server import __version__ +from tabpy_server.psws.python_service import PythonService +from tabpy_server.psws.python_service import PythonServiceHandler +from tabpy_server.common.util import format_exception +from tabpy_server.common.messages import ( + Query, QuerySuccessful, QueryError, UnknownURI) +from tabpy_server.psws.callbacks import ( + init_ps_server, init_model_evaluator, on_state_change) +from tabpy_server.management.util import _get_state_from_file +from tabpy_server.management.state import TabPyState, get_query_object_path +import time +import tornado +import tornado.options +import tornado.web +import tornado.ioloop +from tornado import gen +from tornado_json.constants import TORNADO_MAJOR +from uuid import uuid4 as random_uuid +import urllib +import uuid + + +STAGING_THREAD = concurrent.futures.ThreadPoolExecutor(max_workers=3) +_QUERY_OBJECT_STAGING_FOLDER = 'staging' + +if sys.version_info.major == 3: + unicode = str + + +def parse_arguments(): + ''' + Parse input arguments and return the parsed arguments. Expected arguments: + * --port : int + ''' + parser = ArgumentParser(description='Run Python27 Service.') + parser.add_argument('--port', type=int, + help='Listening port for this service.') + parser.add_argument('--config', help='Path to a config file.') + return parser.parse_args() + + +cli_args = parse_arguments() +config_file = (cli_args.config if cli_args.config is not None else + os.path.join(os.path.dirname(__file__), 'common', + 'default.conf')) +loggingConfigured = False + + +if os.path.isfile(config_file): + try: + logging.config.fileConfig(config_file, disable_existing_loggers=False) + loggingConfigured = True + except Exception: + pass + + +if not loggingConfigured: + logging.basicConfig(level=logging.DEBUG) + + +logger = logging.getLogger(__name__) + + +def copy_from_local(localpath, remotepath, is_dir=False): + if is_dir: + if not os.path.exists(remotepath): + # remote folder does not exist + shutil.copytree(localpath, remotepath) + else: + # remote folder exists, copy each file + src_files = os.listdir(localpath) + for file_name in src_files: + full_file_name = os.path.join(localpath, file_name) + if os.path.isdir(full_file_name): + # copy folder recursively + full_remote_path = os.path.join(remotepath, file_name) + shutil.copytree(full_file_name, full_remote_path) + else: + # copy each file + shutil.copy(full_file_name, remotepath) + else: + shutil.copy(localpath, remotepath) + + +def _sanitize_request_data(data): + if not isinstance(data, dict): + raise RuntimeError("Expect input data to be a dictionary") + + if "method" in data: + return {"data": data.get("data"), "method": data.get("method")} + elif "data" in data: + return data.get("data") + else: + raise RuntimeError("Expect input data is a dictionary with at least a " + "key called 'data'") + + +def _get_uuid(): + """Generate a unique identifier string""" + return str(uuid.uuid4()) + + +class BaseHandler(tornado.web.RequestHandler): + KEYS_TO_SANITIZE = ("api key", "api_key", "admin key", "admin_key") + + def initialize(self): + self.tabpy = self.settings['tabpy'] + # set content type to application/json + self.set_header("Content-Type", "application/json") + self.port = self.settings['port'] + self.py_handler = self.settings['py_handler'] + + def error_out(self, code, log_message, info=None): + self.set_status(code) + self.write(simplejson.dumps( + {'message': log_message, 'info': info or {}})) + + # We want to duplicate error message in console for + # loggers are misconfigured or causing the failure + # themselves + print(info) + logger.error('message: {}, info: {}'.format(log_message, info)) + self.finish() + + def options(self): + # add CORS headers if TabPy has a cors_origin specified + self._add_CORS_header() + self.write({}) + + def _add_CORS_header(self): + """ + Add CORS header if the TabPy has attribute _cors_origin + and _cors_origin is not an empty string. + """ + origin = self.tabpy.get_access_control_allow_origin() + if len(origin) > 0: + self.set_header("Access-Control-Allow-Origin", origin) + logger.debug("Access-Control-Allow-Origin:{}".format(origin)) + + headers = self.tabpy.get_access_control_allow_headers() + if len(headers) > 0: + self.set_header("Access-Control-Allow-Headers", headers) + logger.debug("Access-Control-Allow-Headers:{}".format(headers)) + + methods = self.tabpy.get_access_control_allow_methods() + if len(methods) > 0: + self.set_header("Access-Control-Allow-Methods", methods) + logger.debug("Access-Control-Allow-Methods:{}".format(methods)) + + def _sanitize_request_data(self, data, keys=KEYS_TO_SANITIZE): + """Remove keys so that we can log safely""" + for key in keys: + data.pop(key, None) + + +class MainHandler(BaseHandler): + + def get(self): + self._add_CORS_header() + self.render('/static/index.html') + + +class ManagementHandler(MainHandler): + def initialize(self): + super(ManagementHandler, self).initialize() + self.port = self.settings['port'] + + def _get_protocol(self): + return 'http://' + + @gen.coroutine + def _add_or_update_endpoint(self, action, name, version, request_data): + ''' + Add or update an endpoint + ''' + logging.debug("Adding/updating model {}...".format(name)) + _name_checker = _compile(r'^[a-zA-Z0-9-_\ ]+$') + if not isinstance(name, (str, unicode)): + raise TypeError("Endpoint name must be a string or unicode") + + if not _name_checker.match(name): + raise gen.Return('endpoint name can only contain: a-z, A-Z, 0-9,' + ' underscore, hyphens and spaces.') + + if self.settings.get('add_or_updating_endpoint'): + raise RuntimeError("Another endpoint update is already in progress" + ", please wait a while and try again") + + request_uuid = random_uuid() + self.settings['add_or_updating_endpoint'] = request_uuid + try: + description = (request_data['description'] if 'description' in + request_data else None) + if 'docstring' in request_data: + if sys.version_info > (3, 0): + docstring = str(bytes(request_data['docstring'], + "utf-8").decode('unicode_escape')) + else: + docstring = request_data['docstring'].decode( + 'string_escape') + else: + docstring = None + endpoint_type = (request_data['type'] if 'type' in request_data + else None) + methods = (request_data['methods'] if 'methods' in request_data + else []) + dependencies = (request_data['dependencies'] if 'dependencies' in + request_data else None) + target = (request_data['target'] if 'target' in request_data + else None) + schema = (request_data['schema'] if 'schema' in request_data + else None) + + src_path = (request_data['src_path'] if 'src_path' in request_data + else None) + target_path = get_query_object_path( + self.settings['state_file_path'], name, version) + _path_checker = _compile(r'^[\\a-zA-Z0-9-_\ /]+$') + # copy from staging + if src_path: + if not isinstance(request_data['src_path'], (str, unicode)): + raise gen.Return("src_path must be a string.") + if not _path_checker.match(src_path): + raise gen.Return('Endpoint name can only contain: a-z, A-' + 'Z, 0-9,underscore, hyphens and spaces.') + + yield self._copy_po_future(src_path, target_path) + elif endpoint_type != 'alias': + raise gen.Return("src_path is required to add/update an " + "endpoint.") + + # alias special logic: + if endpoint_type == 'alias': + if not target: + raise gen.Return('Target is required for alias endpoint.') + dependencies = [target] + + # update local config + try: + if action == 'add': + self.tabpy.add_endpoint( + name=name, + description=description, + docstring=docstring, + endpoint_type=endpoint_type, + methods=methods, + dependencies=dependencies, + target=target, + schema=schema) + else: + self.tabpy.update_endpoint( + name=name, + description=description, + docstring=docstring, + endpoint_type=endpoint_type, + methods=methods, + dependencies=dependencies, + target=target, + schema=schema, + version=version) + + except Exception as e: + raise gen.Return("Error when changing TabPy state: %s" % e) + + on_state_change(self.settings) + + finally: + self.settings['add_or_updating_endpoint'] = None + + @gen.coroutine + def _copy_po_future(self, src_path, target_path): + future = STAGING_THREAD.submit(copy_from_local, src_path, + target_path, is_dir=True) + ret = yield future + raise gen.Return(ret) + + +class ServiceInfoHandler(ManagementHandler): + + def get(self): + self._add_CORS_header() + info = {} + info['state_path'] = self.settings['state_file_path'] + info['name'] = self.tabpy.name + info['description'] = self.tabpy.get_description() + info['server_version'] = self.settings['server_version'] + info['creation_time'] = self.tabpy.creation_time + self.write(simplejson.dumps(info)) + + +class StatusHandler(BaseHandler): + + def get(self): + self._add_CORS_header() + + logger.debug("Obtaining service status") + status_dict = {} + for k, v in self.py_handler.ps.query_objects.items(): + status_dict[k] = { + 'version': v['version'], + 'type': v['type'], + 'status': v['status'], + 'last_error': v['last_error']} + + logger.debug("Found models: {}".format(status_dict)) + self.write(simplejson.dumps(status_dict)) + self.finish() + return + + +class UploadDestinationHandler(ManagementHandler): + + def get(self): + path = self.settings['state_file_path'] + path = os.path.join(path, _QUERY_OBJECT_STAGING_FOLDER) + self.write({"path": path}) + + +class EndpointsHandler(ManagementHandler): + + def get(self): + self._add_CORS_header() + self.write(simplejson.dumps(self.tabpy.get_endpoints())) + + @tornado.web.asynchronous + @gen.coroutine + def post(self): + try: + if not self.request.body: + self.error_out(400, "Input body cannot be empty") + self.finish() + return + + try: + request_data = simplejson.loads( + self.request.body.decode('utf-8')) + except Exception: + self.error_out(400, "Failed to decode input body") + self.finish() + return + + if 'name' not in request_data: + self.error_out(400, + "name is required to add an endpoint.") + self.finish() + return + + name = request_data['name'] + + # check if endpoint already exist + if name in self.tabpy.get_endpoints(): + self.error_out(400, "endpoint %s already exists." % name) + self.finish() + return + + logger.debug("Adding endpoint '{}'".format(name)) + err_msg = yield self._add_or_update_endpoint('add', name, 1, + request_data) + if err_msg: + self.error_out(400, err_msg) + else: + logger.debug("Endopoint {} successfully added".format(name)) + self.set_status(201) + self.write(self.tabpy.get_endpoints(name)) + self.finish() + return + + except Exception as e: + err_msg = format_exception(e, '/add_endpoint') + self.error_out(500, "error adding endpoint", err_msg) + self.finish() + return + + +class EndpointHandler(ManagementHandler): + + def get(self, endpoint_name): + self._add_CORS_header() + if not endpoint_name: + self.write(simplejson.dumps(self.tabpy.get_endpoints())) + else: + if endpoint_name in self.tabpy.get_endpoints(): + self.write(simplejson.dumps( + self.tabpy.get_endpoints()[endpoint_name])) + else: + self.error_out(404, 'Unknown endpoint', + info='Endpoint %s is not found' % endpoint_name) + + @tornado.web.asynchronous + @gen.coroutine + def put(self, name): + try: + if not self.request.body: + self.error_out(400, "Input body cannot be empty") + self.finish() + return + try: + request_data = simplejson.loads( + self.request.body.decode('utf-8')) + except Exception: + self.error_out(400, "Failed to decode input body") + self.finish() + return + + # check if endpoint exists + endpoints = self.tabpy.get_endpoints(name) + if len(endpoints) == 0: + self.error_out(404, + "endpoint %s does not exist." % name) + self.finish() + return + + new_version = int(endpoints[name]['version']) + 1 + logger.info('Endpoint info: %s' % request_data) + err_msg = yield self._add_or_update_endpoint( + 'update', name, new_version, request_data) + if err_msg: + self.error_out(400, err_msg) + self.finish() + else: + self.write(self.tabpy.get_endpoints(name)) + self.finish() + + except Exception as e: + err_msg = format_exception(e, 'update_endpoint') + self.error_out(500, err_msg) + self.finish() + + @tornado.web.asynchronous + @gen.coroutine + def delete(self, name): + try: + endpoints = self.tabpy.get_endpoints(name) + if len(endpoints) == 0: + self.error_out(404, + "endpoint %s does not exist." % name) + self.finish() + return + + # update state + try: + endpoint_info = self.tabpy.delete_endpoint(name) + except Exception as e: + self.error_out(400, + "Error when removing endpoint: %s" % e.message) + self.finish() + return + + # delete files + if endpoint_info['type'] != 'alias': + delete_path = get_query_object_path( + self.settings['state_file_path'], name, None) + try: + yield self._delete_po_future(delete_path) + except Exception as e: + self.error_out(400, + "Error while deleting: %s" % e) + self.finish() + return + + self.set_status(204) + self.finish() + + except Exception as e: + err_msg = format_exception(e, 'delete endpoint') + self.error_out(500, err_msg) + self.finish() + + on_state_change(self.settings) + + @gen.coroutine + def _delete_po_future(self, delete_path): + future = STAGING_THREAD.submit(shutil.rmtree, delete_path) + ret = yield future + raise gen.Return(ret) + + +class EvaluationPlaneHandler(BaseHandler): + ''' + EvaluationPlaneHandler is responsible for running arbitrary python scripts. + ''' + + def initialize(self, executor): + super(EvaluationPlaneHandler, self).initialize() + self.executor = executor + + @tornado.web.asynchronous + @gen.coroutine + def post(self): + self._add_CORS_header() + try: + body = simplejson.loads(self.request.body.decode('utf-8')) + if 'script' not in body: + self.error_out(400, 'Script is empty.') + return + + # Transforming user script into a proper function. + user_code = body['script'] + arguments = None + arguments_str = '' + if 'data' in body: + arguments = body['data'] + + if arguments is not None: + if not isinstance(arguments, dict): + self.error_out(400, 'Script parameters need to be ' + 'provided as a dictionary.') + return + else: + arguments_expected = [] + for i in range(1, len(arguments.keys()) + 1): + arguments_expected.append('_arg' + str(i)) + if sorted(arguments_expected) == sorted(arguments.keys()): + arguments_str = ', ' + ', '.join(arguments.keys()) + else: + self.error_out(400, 'Variables names should follow ' + 'the format _arg1, _arg2, _argN') + return + + function_to_evaluate = ('def _user_script(tabpy' + + arguments_str + '):\n') + for u in user_code.splitlines(): + function_to_evaluate += ' ' + u + '\n' + + logger.info( + "function to evaluate=%s" % function_to_evaluate) + + result = yield self.call_subprocess(function_to_evaluate, + arguments) + if result is None: + self.error_out(400, 'Error running script. No return value') + else: + self.write(simplejson.dumps(result)) + self.finish() + + except Exception as e: + err_msg = "%s : " % e.__class__.__name__ + err_msg += "%s" % str(e) + if err_msg != "KeyError : 'response'": + err_msg = format_exception(e, 'POST /evaluate') + self.error_out(500, 'Error processing script', info=err_msg) + else: + self.error_out( + 404, 'Error processing script', + info=("The endpoint you're trying to query did not respond" + ". Please make sure the endpoint exists and the " + "correct set of arguments are provided.")) + + @gen.coroutine + def call_subprocess(self, function_to_evaluate, arguments): + # Exec does not run the function, so it does not block. + if sys.version_info > (3, 0): + exec(function_to_evaluate, globals()) + else: + exec(function_to_evaluate) + + +class RestrictedTabPy: + def __init__(self, port): + self.port = port + + def query(self, name, *args, **kwargs): + url = 'http://localhost:%d/query/%s' % (self.port, name) + internal_data = {'data': args or kwargs} + data = simplejson.dumps(internal_data) + headers = {'content-type': 'application/json'} + response = requests.post(url=url, data=data, headers=headers, + timeout=30) + + return response.json() + + +class QueryPlaneHandler(BaseHandler): + + def _query(self, po_name, data, uid, qry): + """ + Parameters + ---------- + po_name : str + The name of the query object to query + + data : dict + The deserialized request body + + uid: str + A unique identifier for the request + + qry: str + The incoming query object. This object maintains + raw incoming request, which is different from the sanitied data + + Returns + ------- + out : (result type, dict, int) + A triple containing a result type, the result message + as a dictionary, and the time in seconds that it took to complete + the request. + """ + start_time = time.time() + response = self.py_handler.ps.query(po_name, data, uid) + gls_time = time.time() - start_time + + if isinstance(response, QuerySuccessful): + response_json = response.to_json() + self.set_header("Etag", '"%s"' % md5(response_json.encode( + 'utf-8')).hexdigest()) + return (QuerySuccessful, response.for_json(), gls_time) + else: + logger.error("Failed query, response: {}".format(response)) + return (type(response), response.for_json(), gls_time) + + # handle HTTP Options requests to support CORS + # don't check API key (client does not send or receive data for OPTIONS, + # it just allows the client to subsequently make a POST request) + def options(self, pred_name): + # add CORS headers if TabPy has a cors_origin specified + self._add_CORS_header() + self.write({}) + + def _handle_result(self, po_name, data, qry, uid): + + (response_type, response, gls_time) = \ + self._query(po_name, data, uid, qry) + + if response_type == QuerySuccessful: + result_dict = { + 'response': response['response'], + 'version': response['version'], + 'model': po_name, + 'uuid': uid + } + self.write(result_dict) + self.finish() + return (gls_time, response['response']) + else: + if response_type == UnknownURI: + self.error_out(404, 'UnknownURI', + info="No query object has been registered" + " with the name '%s'" % po_name) + elif response_type == QueryError: + self.error_out(400, 'QueryError', info=response) + else: + self.error_out(500, 'Error querying GLS', info=response) + + return (None, None) + + def _process_query(self, endpoint_name, start): + try: + self._add_CORS_header() + + if not self.request.body: + self.request.body = {} + + # extract request data explicitly for caching purpose + request_json = self.request.body.decode('utf-8') + + # Sanitize input data + data = _sanitize_request_data(simplejson.loads(request_json)) + except Exception as e: + err_msg = format_exception(e, "Invalid Input Data") + self.error_out(400, err_msg) + return + + try: + (po_name, all_endpoint_names) = self._get_actual_model( + endpoint_name) + + # po_name is None if self.py_handler.ps.query_objects.get( + # endpoint_name) is None + if not po_name: + self.error_out(404, 'UnknownURI', + info="Endpoint '%s' does not exist" + % endpoint_name) + return + + po_obj = self.py_handler.ps.query_objects.get(po_name) + + if not po_obj: + self.error_out(404, 'UnknownURI', + info="Endpoint '%s' does not exist" % po_name) + return + + if po_name != endpoint_name: + logger.info( + "Querying actual model: po_name={}".format(po_name)) + + uid = _get_uuid() + + # record query w/ request ID in query log + qry = Query(po_name, request_json) + gls_time = 0 + # send a query to PythonService and return + (gls_time, result) = self._handle_result(po_name, data, qry, uid) + + # if error occurred, GLS time is None. + if not gls_time: + return + + except Exception as e: + err_msg = format_exception(e, 'process query') + self.error_out(500, 'Error processing query', info=err_msg) + return + + def _get_actual_model(self, endpoint_name): + # Find the actual query to run from given endpoint + all_endpoint_names = [] + + while True: + endpoint_info = self.py_handler.ps.query_objects.get(endpoint_name) + if not endpoint_info: + return [None, None] + + all_endpoint_names.append(endpoint_name) + + endpoint_type = endpoint_info.get('type', 'model') + + if endpoint_type == 'alias': + endpoint_name = endpoint_info['endpoint_obj'] + elif endpoint_type == 'model': + break + else: + self.error_out(500, 'Unknown endpoint type', + info="Endpoint type '%s' does not exist" + % endpoint_type) + return + + return (endpoint_name, all_endpoint_names) + + @tornado.web.asynchronous + def get(self, endpoint_name): + start = time.time() + if sys.version_info > (3, 0): + endpoint_name = urllib.parse.unquote(endpoint_name) + else: + endpoint_name = urllib.unquote(endpoint_name) + logger.debug("GET /query/{}".format(endpoint_name)) + self._process_query(endpoint_name, start) + + @tornado.web.asynchronous + def post(self, endpoint_name): + start = time.time() + if sys.version_info > (3, 0): + endpoint_name = urllib.parse.unquote(endpoint_name) + else: + endpoint_name = urllib.unquote(endpoint_name) + logger.debug("POST /query/{}".format(endpoint_name)) + self._process_query(endpoint_name, start) + + +def get_config(config_file): + """Provide consistent mechanism for pulling in configuration. + + Attempt to retain backward compatibility for existing implementations by + grabbing port setting from CLI first. + + Take settings in the following order: + + 1. CLI arguments, if present - port only - may be able to deprecate + 2. common.config file, and + 3. OS environment variables (for ease of setting defaults if not present) + 4. current defaults if a setting is not present in any location + + Additionally provide similar configuration capabilities in between + common.config and environment variables. + For consistency use the same variable name in the config file as in the os + environment. + For naming standards use all capitals and start with 'TABPY_' + """ + parser = configparser.ConfigParser() + + if os.path.isfile(config_file): + with open(config_file) as f: + parser.read_string(f.read()) + else: + logger.warning("Unable to find config file at '{}', " + "using default settings.".format(config_file)) + + settings = {} + for section in parser.sections(): + if section == "TabPy": + for key, val in parser.items(section): + settings[key] = val + break + + def set_parameter(settings_key, + config_key, + default_val=None, + check_env_var=False): + if config_key is not None and parser.has_option('TabPy', config_key): + settings[settings_key] = parser.get('TabPy', config_key) + elif check_env_var: + settings[settings_key] = os.getenv(config_key, default_val) + elif default_val is not None: + settings[settings_key] = default_val + + if cli_args is not None and cli_args.port is not None: + settings['port'] = cli_args.port + else: + set_parameter( + 'port', 'TABPY_PORT', default_val=9004, check_env_var=True) + try: + settings['port'] = int(settings['port']) + except ValueError: + logger.warning('Error during config validation, invalid port: {}. ' + 'Using default port 9004'.format(settings['port'])) + settings['port'] = 9004 + + set_parameter('server_version', None, default_val=__version__) + set_parameter( + 'bind_ip', 'TABPY_BIND_IP', default_val='0.0.0.0', check_env_var=True) + + set_parameter('upload_dir', 'TABPY_QUERY_OBJECT_PATH', + default_val='/tmp/query_objects', check_env_var=True) + if not os.path.exists(settings['upload_dir']): + os.makedirs(settings['upload_dir']) + + set_parameter('state_file_path', 'TABPY_STATE_PATH', + default_val='./', check_env_var=True) + settings['state_file_path'] = os.path.realpath( + os.path.normpath( + os.path.expanduser(settings['state_file_path']))) + + # set and validate transfer protocol + set_parameter('transfer_protocol', 'TABPY_TRANSFER_PROTOCOL', + default_val='http') + settings['transfer_protocol'] = settings['transfer_protocol'].lower() + + set_parameter('certificate_file', 'TABPY_CERTIFICATE_FILE') + set_parameter('key_file', 'TABPY_KEY_FILE') + validate_transfer_protocol_settings(settings) + + # if state.ini does not exist try and create it - remove last dependence + # on batch/shell script + state_file_path = settings['state_file_path'] + logger.info("Loading state from state file {}".format( + os.path.join(state_file_path, "state.ini"))) + tabpy_state = _get_state_from_file(state_file_path) + settings['tabpy'] = TabPyState(config=tabpy_state, settings=settings) + + settings['py_handler'] = PythonServiceHandler(PythonService()) + settings['compress_response'] = True if TORNADO_MAJOR >= 4 else "gzip" + settings['static_path'] = os.path.join(os.path.dirname(__file__), "static") + + # Set subdirectory from config if applicable + subdirectory = "" + if tabpy_state.has_option("Service Info", "Subdirectory"): + subdirectory = "/" + tabpy_state.get("Service Info", "Subdirectory") + + return settings, subdirectory + + +def validate_transfer_protocol_settings(settings): + if 'transfer_protocol' not in settings: + logger.error('Missing transfer protocol information.') + raise RuntimeError('Missing transfer protocol information.') + + protocol = settings['transfer_protocol'] + + if protocol == 'http': + return + + if protocol != 'https': + err = 'Unsupported transfer protocol: {}.'.format(protocol) + logger.fatal(err) + raise RuntimeError(err) + + validate_cert_key_state('The parameter(s) {} must be set.', + 'certificate_file' in settings, + 'key_file' in settings) + cert = settings['certificate_file'] + + validate_cert_key_state( + 'The parameter(s) {} must point to an existing file.', + os.path.isfile(cert), os.path.isfile(settings['key_file'])) + validate_cert(cert) + return + + +def validate_cert_key_state(msg, cert_valid, key_valid): + cert_param, key_param = 'TABPY_CERTIFICATE_FILE', 'TABPY_KEY_FILE' + cert_and_key_param = '{} and {}'.format(cert_param, key_param) + https_error = 'Error using HTTPS: ' + err = None + if not cert_valid and not key_valid: + err = https_error + msg.format(cert_and_key_param) + elif not cert_valid: + err = https_error + msg.format(cert_param) + elif not key_valid: + err = https_error + msg.format(key_param) + if err is not None: + logger.fatal(err) + raise RuntimeError(err) + + +def validate_cert(cert_file_path): + with open(cert_file_path, 'r') as f: + cert_buf = f.read() + + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_buf) + + date_format, encoding = '%Y%m%d%H%M%SZ', 'ascii' + not_before = datetime.strptime( + cert.get_notBefore().decode(encoding), date_format) + not_after = datetime.strptime( + cert.get_notAfter().decode(encoding), date_format) + now = datetime.now() + + https_error = 'Error using HTTPS: ' + if now < not_before: + raise RuntimeError(https_error + 'The certificate provided is not ' + 'valid until {}.'.format(not_before)) + if now > not_after: + raise RuntimeError(https_error + 'The certificate provided expired ' + 'on {}.'.format(not_after)) + + +def main(): + settings, subdirectory = get_config(config_file) + + logger.info('Initializing TabPy...') + tornado.ioloop.IOLoop.instance().run_sync(lambda: init_ps_server(settings)) + logger.info('Done initializing TabPy.') + + executor = concurrent.futures.ThreadPoolExecutor( + max_workers=multiprocessing.cpu_count()) + + # initialize Tornado application + application = tornado.web.Application([ + # skip MainHandler to use StaticFileHandler .* page requests and + # default to index.html + # (r"/", MainHandler), + (subdirectory + r'/query/([^/]+)', QueryPlaneHandler), + (subdirectory + r'/status', StatusHandler), + (subdirectory + r'/info', ServiceInfoHandler), + (subdirectory + r'/endpoints', EndpointsHandler), + (subdirectory + r'/endpoints/([^/]+)?', EndpointHandler), + (subdirectory + r'/evaluate', EvaluationPlaneHandler, + dict(executor=executor)), + (subdirectory + r'/configurations/endpoint_upload_destination', + UploadDestinationHandler), + (subdirectory + r'/(.*)', tornado.web.StaticFileHandler, + dict(path=settings['static_path'], default_filename="index.html")), + ], debug=False, **settings) + + settings = application.settings + + init_model_evaluator(settings) + + if settings['transfer_protocol'] == 'http': + application.listen(settings['port'], address=settings['bind_ip']) + elif settings['transfer_protocol'] == 'https': + application.listen(settings['port'], address=settings['bind_ip'], + ssl_options={ + 'certfile': settings['certificate_file'], + 'keyfile': settings['key_file'] + }) + else: + raise RuntimeError('Unsupported transfer protocol.') + + logger.info('Web service listening on {} port {}'.format( + settings['bind_ip'], str(settings['port']))) + tornado.ioloop.IOLoop.instance().start() +>>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting if __name__ == '__main__': From bea03bb376d7c3c6549acfd84fa2e7fcd54cb228 Mon Sep 17 00:00:00 2001 From: ogolovatyi Date: Thu, 28 Mar 2019 14:28:01 -0700 Subject: [PATCH 2/6] Fix merging conflicts --- tabpy-server/server_tests/test_config.py | 112 ------------------ .../tabpy_server/common/endpoint_file_mgr.py | 11 +- tabpy-server/tabpy_server/management/util.py | 20 ---- 3 files changed, 4 insertions(+), 139 deletions(-) diff --git a/tabpy-server/server_tests/test_config.py b/tabpy-server/server_tests/test_config.py index 0618043c..539b5eed 100644 --- a/tabpy-server/server_tests/test_config.py +++ b/tabpy-server/server_tests/test_config.py @@ -1,18 +1,14 @@ import os import unittest -<<<<<<< HEAD from argparse import Namespace from tempfile import NamedTemporaryFile from tabpy_server.app.util import validate_cert from tabpy_server.app.app import TabPyApp -======= ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting from unittest.mock import patch, call - def assert_raises_runtime_error(message, fn, args={}): try: fn(*args) @@ -21,22 +17,6 @@ def assert_raises_runtime_error(message, fn, args={}): assert err.args[0] == message -<<<<<<< HEAD -class TestConfigEnvironmentCalls(unittest.TestCase): - @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace(config=None)) - @patch('tabpy_server.app.app.TabPyState') - @patch('tabpy_server.app.app._get_state_from_file') - @patch('tabpy_server.app.app.PythonServiceHandler') - @patch('tabpy_server.app.app.os.path.exists', return_value=True) - @patch('tabpy_server.app.app.os.path.isfile', return_value=False) - @patch('tabpy_server.app.app.os') - def test_no_config_file(self, mock_os, mock_file_exists, - mock_path_exists, mock_psws, - mock_management_util, mock_tabpy_state, - mock_parse_arguments): - TabPyApp(None) -======= def append_logger_settings_to_config_file(config_file): config_file.write("[loggers]\n" "keys=root\n" @@ -73,7 +53,6 @@ def test_no_config_file(self, mock_os, mock_file_exists, mock_path_exists, mock_psws, mock_shutil, mock_management_util, mock_tabpy_state): get_config(None) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting getenv_calls = [call('TABPY_PORT', 9004), call('TABPY_QUERY_OBJECT_PATH', '/tmp/query_objects'), @@ -86,22 +65,6 @@ def test_no_config_file(self, mock_os, mock_file_exists, mock_path_exists, self.assertTrue(len(mock_management_util.mock_calls) > 0) mock_os.makedirs.assert_not_called() -<<<<<<< HEAD - @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace(config=None)) - @patch('tabpy_server.app.app.TabPyState') - @patch('tabpy_server.app.app._get_state_from_file') - @patch('tabpy_server.app.app.PythonServiceHandler') - @patch('tabpy_server.app.app.os.path.exists', return_value=False) - @patch('tabpy_server.app.app.os.path.isfile', return_value=False) - @patch('tabpy_server.app.app.os') - def test_no_state_ini_file_or_state_dir(self, mock_os, mock_file_exists, - mock_path_exists, mock_psws, - mock_management_util, - mock_tabpy_state, - mock_parse_arguments): - TabPyApp(None) -======= @patch('tabpy_server.tabpy.TabPyState') @patch('tabpy_server.tabpy._get_state_from_file') @patch('tabpy_server.tabpy.shutil') @@ -114,22 +77,10 @@ def test_no_state_ini_file_or_state_dir(self, mock_os, mock_file_exists, mock_shutil, mock_management_util, mock_tabpy_state): get_config(None) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting self.assertEqual(len(mock_os.makedirs.mock_calls), 1) class TestPartialConfigFile(unittest.TestCase): -<<<<<<< HEAD - @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments') - @patch('tabpy_server.app.app.TabPyState') - @patch('tabpy_server.app.app._get_state_from_file') - @patch('tabpy_server.app.app.PythonServiceHandler') - @patch('tabpy_server.app.app.os.path.exists', return_value=True) - @patch('tabpy_server.app.app.os') - def test_config_file_present(self, mock_os, mock_path_exists, - mock_psws, mock_management_util, -======= - @patch('tabpy_server.tabpy.parse_arguments') @patch('tabpy_server.tabpy.TabPyState') @patch('tabpy_server.tabpy._get_state_from_file') @@ -139,7 +90,6 @@ def test_config_file_present(self, mock_os, mock_path_exists, @patch('tabpy_server.tabpy.os') def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, mock_shutil, mock_management_util, ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting mock_tabpy_state, mock_parse_arguments): config_file = NamedTemporaryFile(delete=False) @@ -148,12 +98,8 @@ def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, "TABPY_STATE_PATH = bar\n".encode()) config_file.close() -<<<<<<< HEAD - mock_parse_arguments.return_value = Namespace(config=config_file.name) -======= mock_parse_arguments.return_value = Namespace( config=config_file.name, port=None) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting mock_os.getenv.side_effect = [1234] mock_os.path.realpath.return_value = 'bar' @@ -162,16 +108,6 @@ def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, getenv_calls = [call('TABPY_PORT', 9004)] mock_os.getenv.assert_has_calls(getenv_calls, any_order=True) -<<<<<<< HEAD - self.assertEqual(app.settings['port'], 1234) - self.assertEqual(app.settings['server_version'], open( - 'VERSION').read().strip()) - self.assertEqual(app.settings['upload_dir'], 'foo') - self.assertEqual(app.settings['state_file_path'], 'bar') - self.assertEqual(app.settings['transfer_protocol'], 'http') - self.assertTrue('certificate_file' not in app.settings) - self.assertTrue('key_file' not in app.settings) -======= self.assertEqual(settings['port'], 1234) self.assertEqual(settings['server_version'], open('VERSION').read().strip()) @@ -181,7 +117,6 @@ def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, self.assertEqual(settings['transfer_protocol'], 'http') self.assertTrue('certificate_file' not in settings) self.assertTrue('key_file' not in settings) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting os.remove(config_file.name) @@ -205,9 +140,6 @@ def __init__(self, *args, **kwargs): def setUp(self): os.chdir(self.tabpy_cwd) -<<<<<<< HEAD - self.fp = NamedTemporaryFile(mode='w+t', delete=False) -======= self.fp = NamedTemporaryFile(delete=False) self.config_name = self.fp.name append_logger_settings_to_config_file(self.fp) @@ -216,7 +148,6 @@ def setUp(self): return_value=Namespace(config=self.fp.name, port=None)) patcher.start() self.addCleanup(patcher.stop) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting def tearDown(self): os.chdir(self.cwd) @@ -238,12 +169,8 @@ def test_https_without_cert_and_key(self): assert_raises_runtime_error( 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' -<<<<<<< HEAD 'and TABPY_KEY_FILE must be set.', TabPyApp, {self.fp.name}) -======= - 'and TABPY_KEY_FILE must be set.', get_config, {self.config_name}) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting def test_https_without_cert(self): self.fp.write( @@ -254,11 +181,7 @@ def test_https_without_cert(self): assert_raises_runtime_error('Error using HTTPS: The parameter(s) ' 'TABPY_CERTIFICATE_FILE must be set.', -<<<<<<< HEAD TabPyApp, {self.fp.name}) -======= - get_config, [self.config_name]) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting def test_https_without_key(self): self.fp.write("[TabPy]\n" @@ -268,11 +191,7 @@ def test_https_without_key(self): assert_raises_runtime_error('Error using HTTPS: The parameter(s) ' 'TABPY_KEY_FILE must be set.', -<<<<<<< HEAD TabPyApp, {self.fp.name}) -======= - get_config, [self.config_name]) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting @patch('tabpy_server.app.app.os.path') def test_https_cert_and_key_file_not_found(self, mock_path): @@ -282,15 +201,6 @@ def test_https_cert_and_key_file_not_found(self, mock_path): "TABPY_KEY_FILE = bar") self.fp.close() -<<<<<<< HEAD - mock_path.isfile.side_effect =\ - lambda x: self.mock_isfile(x, {self.fp.name}) - - assert_raises_runtime_error( - 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' - 'and TABPY_KEY_FILE must point to an existing file.', - TabPyApp, {self.fp.name}) -======= mock_path.isfile.side_effect = lambda x: self.mock_isfile( x, {self.fp.name}) @@ -298,7 +208,6 @@ def test_https_cert_and_key_file_not_found(self, mock_path): 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE and ' 'TABPY_KEY_FILE must point to an existing file.', get_config, {self.config_name}) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting @patch('tabpy_server.app.app.os.path') def test_https_cert_file_not_found(self, mock_path): @@ -312,14 +221,8 @@ def test_https_cert_file_not_found(self, mock_path): x, {self.fp.name, 'bar'}) assert_raises_runtime_error( -<<<<<<< HEAD - 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' - 'must point to an existing file.', - TabPyApp, {self.fp.name}) -======= 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE must ' 'point to an existing file.', get_config, {self.config_name}) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting @patch('tabpy_server.app.app.os.path') def test_https_key_file_not_found(self, mock_path): @@ -333,13 +236,8 @@ def test_https_key_file_not_found(self, mock_path): x, {self.fp.name, 'foo'}) assert_raises_runtime_error( -<<<<<<< HEAD - 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE must ' - 'point to an existing file.', TabPyApp, {self.fp.name}) -======= 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE must point to ' 'an existing file.', get_config, {self.config_name}) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting @patch('tabpy_server.app.app.os.path.isfile', return_value=True) @patch('tabpy_server.app.util.validate_cert') @@ -366,24 +264,14 @@ def __init__(self, *args, **kwargs): def test_expired_cert(self): path = os.path.join(self.resources_path, 'expired.crt') -<<<<<<< HEAD - message = 'Error using HTTPS: The certificate provided expired '\ - 'on 2018-08-18 19:47:18.' -======= message = ('Error using HTTPS: The certificate provided expired ' 'on 2018-08-18 19:47:18.') ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting assert_raises_runtime_error(message, validate_cert, {path}) def test_future_cert(self): path = os.path.join(self.resources_path, 'future.crt') -<<<<<<< HEAD - message = 'Error using HTTPS: The certificate provided is not '\ - 'valid until 3001-01-01 00:00:00.' -======= message = ('Error using HTTPS: The certificate provided is not valid ' 'until 3001-01-01 00:00:00.') ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting assert_raises_runtime_error(message, validate_cert, {path}) def test_valid_cert(self): diff --git a/tabpy-server/tabpy_server/common/endpoint_file_mgr.py b/tabpy-server/tabpy_server/common/endpoint_file_mgr.py index b49bac26..dfc3f6f4 100644 --- a/tabpy-server/tabpy_server/common/endpoint_file_mgr.py +++ b/tabpy-server/tabpy_server/common/endpoint_file_mgr.py @@ -28,13 +28,10 @@ def _check_endpoint_name(name): log_and_raise("Endpoint name cannot be empty", ValueError) if not _name_checker.match(name): -<<<<<<< HEAD - log_and_raise('Endpoint name can only contain: a-z, A-Z, 0-9,' - ' underscore, hyphens and spaces.', ValueError) -======= - raise ValueError('Endpoint name can only contain: a-z, A-Z, 0-9,' - ' underscore, hyphens and spaces.') ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting + log_and_raise( + 'Endpoint name can only contain: a-z, A-Z, 0-9,' + ' underscore, hyphens and spaces.', + ValueError) def grab_files(directory): diff --git a/tabpy-server/tabpy_server/management/util.py b/tabpy-server/tabpy_server/management/util.py index f63cbefa..bdc175d5 100644 --- a/tabpy-server/tabpy_server/management/util.py +++ b/tabpy-server/tabpy_server/management/util.py @@ -4,13 +4,6 @@ from ConfigParser import ConfigParser as _ConfigParser except ImportError: from configparser import ConfigParser as _ConfigParser -<<<<<<< HEAD -try: - from StringIO import StringIO as _StringIO -except ImportError: - from io import StringIO as _StringIO -======= ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting from datetime import datetime, timedelta, tzinfo from tabpy_server.app.ConfigParameters import ConfigParameters from tabpy_server.app.util import log_and_raise @@ -23,14 +16,10 @@ def write_state_config(state, settings): if 'state_file_path' in settings: state_path = settings['state_file_path'] else: -<<<<<<< HEAD log_and_raise( '{} is not set'.format( ConfigParameters.TABPY_STATE_PATH), ValueError) -======= - raise ValueError('TABPY_STATE_PATH is not set') ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting logger.debug("State path is {}".format(state_path)) state_key = os.path.join(state_path, 'state.ini') @@ -54,14 +43,9 @@ def _get_state_from_file(state_path): config.read(tmp_state_file) if not config.has_section('Service Info'): -<<<<<<< HEAD log_and_raise( "Config error: Expected 'Service Info' section in %s" % (tmp_state_file,), ValueError) -======= - raise ValueError("Config error: Expected 'Service Info' " - "section in %s" % (tmp_state_file,)) ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting return config @@ -108,8 +92,4 @@ def _dt_to_utc_timestamp(t): elif not t.tzinfo: return mktime(t.timetuple()) else: -<<<<<<< HEAD log_and_raise('Only local time and UTC time is supported', ValueError) -======= - raise ValueError('Only local time and UTC time is supported') ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting From 699fae0f141355fb05a8ea29f46a388f7d661aa9 Mon Sep 17 00:00:00 2001 From: ogolovatyi Date: Thu, 28 Mar 2019 14:39:48 -0700 Subject: [PATCH 3/6] Fix merging conflicts --- tabpy-server/server_tests/test_config.py | 101 +-- tabpy-server/tabpy_server/tabpy.py | 980 ----------------------- 2 files changed, 36 insertions(+), 1045 deletions(-) diff --git a/tabpy-server/server_tests/test_config.py b/tabpy-server/server_tests/test_config.py index 539b5eed..503e848c 100644 --- a/tabpy-server/server_tests/test_config.py +++ b/tabpy-server/server_tests/test_config.py @@ -17,31 +17,9 @@ def assert_raises_runtime_error(message, fn, args={}): assert err.args[0] == message -def append_logger_settings_to_config_file(config_file): - config_file.write("[loggers]\n" - "keys=root\n" - "[handlers]\n" - "keys=rotatingFileHandler\n" - "[formatters]\n" - "keys=rootFormatter\n" - "[logger_root]\n" - "level=ERROR\n" - "handlers=rotatingFileHandler\n" - "qualname=root\n" - "propagete=0\n" - "[handler_rotatingFileHandler]\n" - "class=handlers.RotatingFileHandler\n" - "level=ERROR\n" - "formatter=rootFormatter\n" - "args=('tabpy_server_tests_log.log', 'w', 1000000, 5)\n" - "[formatter_rootFormatter]\n" - "format=%(asctime)s [%(levelname)s] (%(filename)s:" - "%(module)s:%(lineno)d): %(message)s\n" - "datefmt=%Y-%m-%d,%H:%M:%S\n".encode()) - - class TestConfigEnvironmentCalls(unittest.TestCase): - + @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', + return_value=Namespace(config=None)) @patch('tabpy_server.tabpy.TabPyState') @patch('tabpy_server.tabpy._get_state_from_file') @patch('tabpy_server.tabpy.shutil') @@ -49,10 +27,11 @@ class TestConfigEnvironmentCalls(unittest.TestCase): @patch('tabpy_server.tabpy.os.path.exists', return_value=True) @patch('tabpy_server.tabpy.os.path.isfile', return_value=False) @patch('tabpy_server.tabpy.os') - def test_no_config_file(self, mock_os, mock_file_exists, mock_path_exists, - mock_psws, mock_shutil, mock_management_util, - mock_tabpy_state): - get_config(None) + def test_no_config_file(self, mock_os, mock_file_exists, + mock_path_exists, mock_psws, + mock_management_util, mock_tabpy_state, + mock_parse_arguments): + TabPyApp(None) getenv_calls = [call('TABPY_PORT', 9004), call('TABPY_QUERY_OBJECT_PATH', '/tmp/query_objects'), @@ -65,31 +44,32 @@ def test_no_config_file(self, mock_os, mock_file_exists, mock_path_exists, self.assertTrue(len(mock_management_util.mock_calls) > 0) mock_os.makedirs.assert_not_called() - @patch('tabpy_server.tabpy.TabPyState') - @patch('tabpy_server.tabpy._get_state_from_file') - @patch('tabpy_server.tabpy.shutil') - @patch('tabpy_server.tabpy.PythonServiceHandler') - @patch('tabpy_server.tabpy.os.path.exists', return_value=False) - @patch('tabpy_server.tabpy.os.path.isfile', return_value=False) - @patch('tabpy_server.tabpy.os') + @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', + return_value=Namespace(config=None)) + @patch('tabpy_server.app.app.TabPyState') + @patch('tabpy_server.app.app._get_state_from_file') + @patch('tabpy_server.app.app.PythonServiceHandler') + @patch('tabpy_server.app.app.os.path.exists', return_value=False) + @patch('tabpy_server.app.app.os.path.isfile', return_value=False) + @patch('tabpy_server.app.app.os') def test_no_state_ini_file_or_state_dir(self, mock_os, mock_file_exists, mock_path_exists, mock_psws, - mock_shutil, mock_management_util, - mock_tabpy_state): - get_config(None) + mock_management_util, + mock_tabpy_state, + mock_parse_arguments): + TabPyApp(None) self.assertEqual(len(mock_os.makedirs.mock_calls), 1) class TestPartialConfigFile(unittest.TestCase): - @patch('tabpy_server.tabpy.parse_arguments') - @patch('tabpy_server.tabpy.TabPyState') - @patch('tabpy_server.tabpy._get_state_from_file') - @patch('tabpy_server.tabpy.shutil') - @patch('tabpy_server.tabpy.PythonServiceHandler') - @patch('tabpy_server.tabpy.os.path.exists', return_value=True) - @patch('tabpy_server.tabpy.os') - def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, - mock_shutil, mock_management_util, + @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments') + @patch('tabpy_server.app.app.TabPyState') + @patch('tabpy_server.app.app._get_state_from_file') + @patch('tabpy_server.app.app.PythonServiceHandler') + @patch('tabpy_server.app.app.os.path.exists', return_value=True) + @patch('tabpy_server.app.app.os') + def test_config_file_present(self, mock_os, mock_path_exists, + mock_psws, mock_management_util, mock_tabpy_state, mock_parse_arguments): config_file = NamedTemporaryFile(delete=False) @@ -98,8 +78,7 @@ def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, "TABPY_STATE_PATH = bar\n".encode()) config_file.close() - mock_parse_arguments.return_value = Namespace( - config=config_file.name, port=None) + mock_parse_arguments.return_value = Namespace(config=config_file.name) mock_os.getenv.side_effect = [1234] mock_os.path.realpath.return_value = 'bar' @@ -108,15 +87,14 @@ def test_config_file_present(self, mock_os, mock_path_exists, mock_psws, getenv_calls = [call('TABPY_PORT', 9004)] mock_os.getenv.assert_has_calls(getenv_calls, any_order=True) - self.assertEqual(settings['port'], 1234) - self.assertEqual(settings['server_version'], + self.assertEqual(app.settings['port'], 1234) + self.assertEqual(app.settings['server_version'], open('VERSION').read().strip()) - self.assertEquals(settings['bind_ip'], '0.0.0.0') - self.assertEquals(settings['upload_dir'], 'foo') - self.assertEquals(settings['state_file_path'], 'bar') - self.assertEqual(settings['transfer_protocol'], 'http') - self.assertTrue('certificate_file' not in settings) - self.assertTrue('key_file' not in settings) + self.assertEqual(app.settings['upload_dir'], 'foo') + self.assertEqual(app.settings['state_file_path'], 'bar') + self.assertEqual(app.settings['transfer_protocol'], 'http') + self.assertTrue('certificate_file' not in app.settings) + self.assertTrue('key_file' not in app.settings) os.remove(config_file.name) @@ -140,14 +118,7 @@ def __init__(self, *args, **kwargs): def setUp(self): os.chdir(self.tabpy_cwd) - self.fp = NamedTemporaryFile(delete=False) - self.config_name = self.fp.name - append_logger_settings_to_config_file(self.fp) - - patcher = patch('tabpy_server.tabpy.parse_arguments', - return_value=Namespace(config=self.fp.name, port=None)) - patcher.start() - self.addCleanup(patcher.stop) + self.fp = NamedTemporaryFile(mode='w+t', delete=False) def tearDown(self): os.chdir(self.cwd) diff --git a/tabpy-server/tabpy_server/tabpy.py b/tabpy-server/tabpy_server/tabpy.py index 90b3b5ae..f7936c5a 100644 --- a/tabpy-server/tabpy_server/tabpy.py +++ b/tabpy-server/tabpy_server/tabpy.py @@ -1,4 +1,3 @@ -<<<<<<< HEAD from tabpy_server import __version__ from tabpy_server.app.app import TabPyApp @@ -6,985 +5,6 @@ def main(): app = TabPyApp() app.run() -======= -from argparse import ArgumentParser -import concurrent.futures -import configparser -from datetime import datetime -from hashlib import md5 -import logging -import logging.config -import multiprocessing -from OpenSSL import crypto -import os -from re import compile as _compile -import requests -import shutil -import simplejson -import sys -from tabpy_server import __version__ -from tabpy_server.psws.python_service import PythonService -from tabpy_server.psws.python_service import PythonServiceHandler -from tabpy_server.common.util import format_exception -from tabpy_server.common.messages import ( - Query, QuerySuccessful, QueryError, UnknownURI) -from tabpy_server.psws.callbacks import ( - init_ps_server, init_model_evaluator, on_state_change) -from tabpy_server.management.util import _get_state_from_file -from tabpy_server.management.state import TabPyState, get_query_object_path -import time -import tornado -import tornado.options -import tornado.web -import tornado.ioloop -from tornado import gen -from tornado_json.constants import TORNADO_MAJOR -from uuid import uuid4 as random_uuid -import urllib -import uuid - - -STAGING_THREAD = concurrent.futures.ThreadPoolExecutor(max_workers=3) -_QUERY_OBJECT_STAGING_FOLDER = 'staging' - -if sys.version_info.major == 3: - unicode = str - - -def parse_arguments(): - ''' - Parse input arguments and return the parsed arguments. Expected arguments: - * --port : int - ''' - parser = ArgumentParser(description='Run Python27 Service.') - parser.add_argument('--port', type=int, - help='Listening port for this service.') - parser.add_argument('--config', help='Path to a config file.') - return parser.parse_args() - - -cli_args = parse_arguments() -config_file = (cli_args.config if cli_args.config is not None else - os.path.join(os.path.dirname(__file__), 'common', - 'default.conf')) -loggingConfigured = False - - -if os.path.isfile(config_file): - try: - logging.config.fileConfig(config_file, disable_existing_loggers=False) - loggingConfigured = True - except Exception: - pass - - -if not loggingConfigured: - logging.basicConfig(level=logging.DEBUG) - - -logger = logging.getLogger(__name__) - - -def copy_from_local(localpath, remotepath, is_dir=False): - if is_dir: - if not os.path.exists(remotepath): - # remote folder does not exist - shutil.copytree(localpath, remotepath) - else: - # remote folder exists, copy each file - src_files = os.listdir(localpath) - for file_name in src_files: - full_file_name = os.path.join(localpath, file_name) - if os.path.isdir(full_file_name): - # copy folder recursively - full_remote_path = os.path.join(remotepath, file_name) - shutil.copytree(full_file_name, full_remote_path) - else: - # copy each file - shutil.copy(full_file_name, remotepath) - else: - shutil.copy(localpath, remotepath) - - -def _sanitize_request_data(data): - if not isinstance(data, dict): - raise RuntimeError("Expect input data to be a dictionary") - - if "method" in data: - return {"data": data.get("data"), "method": data.get("method")} - elif "data" in data: - return data.get("data") - else: - raise RuntimeError("Expect input data is a dictionary with at least a " - "key called 'data'") - - -def _get_uuid(): - """Generate a unique identifier string""" - return str(uuid.uuid4()) - - -class BaseHandler(tornado.web.RequestHandler): - KEYS_TO_SANITIZE = ("api key", "api_key", "admin key", "admin_key") - - def initialize(self): - self.tabpy = self.settings['tabpy'] - # set content type to application/json - self.set_header("Content-Type", "application/json") - self.port = self.settings['port'] - self.py_handler = self.settings['py_handler'] - - def error_out(self, code, log_message, info=None): - self.set_status(code) - self.write(simplejson.dumps( - {'message': log_message, 'info': info or {}})) - - # We want to duplicate error message in console for - # loggers are misconfigured or causing the failure - # themselves - print(info) - logger.error('message: {}, info: {}'.format(log_message, info)) - self.finish() - - def options(self): - # add CORS headers if TabPy has a cors_origin specified - self._add_CORS_header() - self.write({}) - - def _add_CORS_header(self): - """ - Add CORS header if the TabPy has attribute _cors_origin - and _cors_origin is not an empty string. - """ - origin = self.tabpy.get_access_control_allow_origin() - if len(origin) > 0: - self.set_header("Access-Control-Allow-Origin", origin) - logger.debug("Access-Control-Allow-Origin:{}".format(origin)) - - headers = self.tabpy.get_access_control_allow_headers() - if len(headers) > 0: - self.set_header("Access-Control-Allow-Headers", headers) - logger.debug("Access-Control-Allow-Headers:{}".format(headers)) - - methods = self.tabpy.get_access_control_allow_methods() - if len(methods) > 0: - self.set_header("Access-Control-Allow-Methods", methods) - logger.debug("Access-Control-Allow-Methods:{}".format(methods)) - - def _sanitize_request_data(self, data, keys=KEYS_TO_SANITIZE): - """Remove keys so that we can log safely""" - for key in keys: - data.pop(key, None) - - -class MainHandler(BaseHandler): - - def get(self): - self._add_CORS_header() - self.render('/static/index.html') - - -class ManagementHandler(MainHandler): - def initialize(self): - super(ManagementHandler, self).initialize() - self.port = self.settings['port'] - - def _get_protocol(self): - return 'http://' - - @gen.coroutine - def _add_or_update_endpoint(self, action, name, version, request_data): - ''' - Add or update an endpoint - ''' - logging.debug("Adding/updating model {}...".format(name)) - _name_checker = _compile(r'^[a-zA-Z0-9-_\ ]+$') - if not isinstance(name, (str, unicode)): - raise TypeError("Endpoint name must be a string or unicode") - - if not _name_checker.match(name): - raise gen.Return('endpoint name can only contain: a-z, A-Z, 0-9,' - ' underscore, hyphens and spaces.') - - if self.settings.get('add_or_updating_endpoint'): - raise RuntimeError("Another endpoint update is already in progress" - ", please wait a while and try again") - - request_uuid = random_uuid() - self.settings['add_or_updating_endpoint'] = request_uuid - try: - description = (request_data['description'] if 'description' in - request_data else None) - if 'docstring' in request_data: - if sys.version_info > (3, 0): - docstring = str(bytes(request_data['docstring'], - "utf-8").decode('unicode_escape')) - else: - docstring = request_data['docstring'].decode( - 'string_escape') - else: - docstring = None - endpoint_type = (request_data['type'] if 'type' in request_data - else None) - methods = (request_data['methods'] if 'methods' in request_data - else []) - dependencies = (request_data['dependencies'] if 'dependencies' in - request_data else None) - target = (request_data['target'] if 'target' in request_data - else None) - schema = (request_data['schema'] if 'schema' in request_data - else None) - - src_path = (request_data['src_path'] if 'src_path' in request_data - else None) - target_path = get_query_object_path( - self.settings['state_file_path'], name, version) - _path_checker = _compile(r'^[\\a-zA-Z0-9-_\ /]+$') - # copy from staging - if src_path: - if not isinstance(request_data['src_path'], (str, unicode)): - raise gen.Return("src_path must be a string.") - if not _path_checker.match(src_path): - raise gen.Return('Endpoint name can only contain: a-z, A-' - 'Z, 0-9,underscore, hyphens and spaces.') - - yield self._copy_po_future(src_path, target_path) - elif endpoint_type != 'alias': - raise gen.Return("src_path is required to add/update an " - "endpoint.") - - # alias special logic: - if endpoint_type == 'alias': - if not target: - raise gen.Return('Target is required for alias endpoint.') - dependencies = [target] - - # update local config - try: - if action == 'add': - self.tabpy.add_endpoint( - name=name, - description=description, - docstring=docstring, - endpoint_type=endpoint_type, - methods=methods, - dependencies=dependencies, - target=target, - schema=schema) - else: - self.tabpy.update_endpoint( - name=name, - description=description, - docstring=docstring, - endpoint_type=endpoint_type, - methods=methods, - dependencies=dependencies, - target=target, - schema=schema, - version=version) - - except Exception as e: - raise gen.Return("Error when changing TabPy state: %s" % e) - - on_state_change(self.settings) - - finally: - self.settings['add_or_updating_endpoint'] = None - - @gen.coroutine - def _copy_po_future(self, src_path, target_path): - future = STAGING_THREAD.submit(copy_from_local, src_path, - target_path, is_dir=True) - ret = yield future - raise gen.Return(ret) - - -class ServiceInfoHandler(ManagementHandler): - - def get(self): - self._add_CORS_header() - info = {} - info['state_path'] = self.settings['state_file_path'] - info['name'] = self.tabpy.name - info['description'] = self.tabpy.get_description() - info['server_version'] = self.settings['server_version'] - info['creation_time'] = self.tabpy.creation_time - self.write(simplejson.dumps(info)) - - -class StatusHandler(BaseHandler): - - def get(self): - self._add_CORS_header() - - logger.debug("Obtaining service status") - status_dict = {} - for k, v in self.py_handler.ps.query_objects.items(): - status_dict[k] = { - 'version': v['version'], - 'type': v['type'], - 'status': v['status'], - 'last_error': v['last_error']} - - logger.debug("Found models: {}".format(status_dict)) - self.write(simplejson.dumps(status_dict)) - self.finish() - return - - -class UploadDestinationHandler(ManagementHandler): - - def get(self): - path = self.settings['state_file_path'] - path = os.path.join(path, _QUERY_OBJECT_STAGING_FOLDER) - self.write({"path": path}) - - -class EndpointsHandler(ManagementHandler): - - def get(self): - self._add_CORS_header() - self.write(simplejson.dumps(self.tabpy.get_endpoints())) - - @tornado.web.asynchronous - @gen.coroutine - def post(self): - try: - if not self.request.body: - self.error_out(400, "Input body cannot be empty") - self.finish() - return - - try: - request_data = simplejson.loads( - self.request.body.decode('utf-8')) - except Exception: - self.error_out(400, "Failed to decode input body") - self.finish() - return - - if 'name' not in request_data: - self.error_out(400, - "name is required to add an endpoint.") - self.finish() - return - - name = request_data['name'] - - # check if endpoint already exist - if name in self.tabpy.get_endpoints(): - self.error_out(400, "endpoint %s already exists." % name) - self.finish() - return - - logger.debug("Adding endpoint '{}'".format(name)) - err_msg = yield self._add_or_update_endpoint('add', name, 1, - request_data) - if err_msg: - self.error_out(400, err_msg) - else: - logger.debug("Endopoint {} successfully added".format(name)) - self.set_status(201) - self.write(self.tabpy.get_endpoints(name)) - self.finish() - return - - except Exception as e: - err_msg = format_exception(e, '/add_endpoint') - self.error_out(500, "error adding endpoint", err_msg) - self.finish() - return - - -class EndpointHandler(ManagementHandler): - - def get(self, endpoint_name): - self._add_CORS_header() - if not endpoint_name: - self.write(simplejson.dumps(self.tabpy.get_endpoints())) - else: - if endpoint_name in self.tabpy.get_endpoints(): - self.write(simplejson.dumps( - self.tabpy.get_endpoints()[endpoint_name])) - else: - self.error_out(404, 'Unknown endpoint', - info='Endpoint %s is not found' % endpoint_name) - - @tornado.web.asynchronous - @gen.coroutine - def put(self, name): - try: - if not self.request.body: - self.error_out(400, "Input body cannot be empty") - self.finish() - return - try: - request_data = simplejson.loads( - self.request.body.decode('utf-8')) - except Exception: - self.error_out(400, "Failed to decode input body") - self.finish() - return - - # check if endpoint exists - endpoints = self.tabpy.get_endpoints(name) - if len(endpoints) == 0: - self.error_out(404, - "endpoint %s does not exist." % name) - self.finish() - return - - new_version = int(endpoints[name]['version']) + 1 - logger.info('Endpoint info: %s' % request_data) - err_msg = yield self._add_or_update_endpoint( - 'update', name, new_version, request_data) - if err_msg: - self.error_out(400, err_msg) - self.finish() - else: - self.write(self.tabpy.get_endpoints(name)) - self.finish() - - except Exception as e: - err_msg = format_exception(e, 'update_endpoint') - self.error_out(500, err_msg) - self.finish() - - @tornado.web.asynchronous - @gen.coroutine - def delete(self, name): - try: - endpoints = self.tabpy.get_endpoints(name) - if len(endpoints) == 0: - self.error_out(404, - "endpoint %s does not exist." % name) - self.finish() - return - - # update state - try: - endpoint_info = self.tabpy.delete_endpoint(name) - except Exception as e: - self.error_out(400, - "Error when removing endpoint: %s" % e.message) - self.finish() - return - - # delete files - if endpoint_info['type'] != 'alias': - delete_path = get_query_object_path( - self.settings['state_file_path'], name, None) - try: - yield self._delete_po_future(delete_path) - except Exception as e: - self.error_out(400, - "Error while deleting: %s" % e) - self.finish() - return - - self.set_status(204) - self.finish() - - except Exception as e: - err_msg = format_exception(e, 'delete endpoint') - self.error_out(500, err_msg) - self.finish() - - on_state_change(self.settings) - - @gen.coroutine - def _delete_po_future(self, delete_path): - future = STAGING_THREAD.submit(shutil.rmtree, delete_path) - ret = yield future - raise gen.Return(ret) - - -class EvaluationPlaneHandler(BaseHandler): - ''' - EvaluationPlaneHandler is responsible for running arbitrary python scripts. - ''' - - def initialize(self, executor): - super(EvaluationPlaneHandler, self).initialize() - self.executor = executor - - @tornado.web.asynchronous - @gen.coroutine - def post(self): - self._add_CORS_header() - try: - body = simplejson.loads(self.request.body.decode('utf-8')) - if 'script' not in body: - self.error_out(400, 'Script is empty.') - return - - # Transforming user script into a proper function. - user_code = body['script'] - arguments = None - arguments_str = '' - if 'data' in body: - arguments = body['data'] - - if arguments is not None: - if not isinstance(arguments, dict): - self.error_out(400, 'Script parameters need to be ' - 'provided as a dictionary.') - return - else: - arguments_expected = [] - for i in range(1, len(arguments.keys()) + 1): - arguments_expected.append('_arg' + str(i)) - if sorted(arguments_expected) == sorted(arguments.keys()): - arguments_str = ', ' + ', '.join(arguments.keys()) - else: - self.error_out(400, 'Variables names should follow ' - 'the format _arg1, _arg2, _argN') - return - - function_to_evaluate = ('def _user_script(tabpy' - + arguments_str + '):\n') - for u in user_code.splitlines(): - function_to_evaluate += ' ' + u + '\n' - - logger.info( - "function to evaluate=%s" % function_to_evaluate) - - result = yield self.call_subprocess(function_to_evaluate, - arguments) - if result is None: - self.error_out(400, 'Error running script. No return value') - else: - self.write(simplejson.dumps(result)) - self.finish() - - except Exception as e: - err_msg = "%s : " % e.__class__.__name__ - err_msg += "%s" % str(e) - if err_msg != "KeyError : 'response'": - err_msg = format_exception(e, 'POST /evaluate') - self.error_out(500, 'Error processing script', info=err_msg) - else: - self.error_out( - 404, 'Error processing script', - info=("The endpoint you're trying to query did not respond" - ". Please make sure the endpoint exists and the " - "correct set of arguments are provided.")) - - @gen.coroutine - def call_subprocess(self, function_to_evaluate, arguments): - # Exec does not run the function, so it does not block. - if sys.version_info > (3, 0): - exec(function_to_evaluate, globals()) - else: - exec(function_to_evaluate) - - -class RestrictedTabPy: - def __init__(self, port): - self.port = port - - def query(self, name, *args, **kwargs): - url = 'http://localhost:%d/query/%s' % (self.port, name) - internal_data = {'data': args or kwargs} - data = simplejson.dumps(internal_data) - headers = {'content-type': 'application/json'} - response = requests.post(url=url, data=data, headers=headers, - timeout=30) - - return response.json() - - -class QueryPlaneHandler(BaseHandler): - - def _query(self, po_name, data, uid, qry): - """ - Parameters - ---------- - po_name : str - The name of the query object to query - - data : dict - The deserialized request body - - uid: str - A unique identifier for the request - - qry: str - The incoming query object. This object maintains - raw incoming request, which is different from the sanitied data - - Returns - ------- - out : (result type, dict, int) - A triple containing a result type, the result message - as a dictionary, and the time in seconds that it took to complete - the request. - """ - start_time = time.time() - response = self.py_handler.ps.query(po_name, data, uid) - gls_time = time.time() - start_time - - if isinstance(response, QuerySuccessful): - response_json = response.to_json() - self.set_header("Etag", '"%s"' % md5(response_json.encode( - 'utf-8')).hexdigest()) - return (QuerySuccessful, response.for_json(), gls_time) - else: - logger.error("Failed query, response: {}".format(response)) - return (type(response), response.for_json(), gls_time) - - # handle HTTP Options requests to support CORS - # don't check API key (client does not send or receive data for OPTIONS, - # it just allows the client to subsequently make a POST request) - def options(self, pred_name): - # add CORS headers if TabPy has a cors_origin specified - self._add_CORS_header() - self.write({}) - - def _handle_result(self, po_name, data, qry, uid): - - (response_type, response, gls_time) = \ - self._query(po_name, data, uid, qry) - - if response_type == QuerySuccessful: - result_dict = { - 'response': response['response'], - 'version': response['version'], - 'model': po_name, - 'uuid': uid - } - self.write(result_dict) - self.finish() - return (gls_time, response['response']) - else: - if response_type == UnknownURI: - self.error_out(404, 'UnknownURI', - info="No query object has been registered" - " with the name '%s'" % po_name) - elif response_type == QueryError: - self.error_out(400, 'QueryError', info=response) - else: - self.error_out(500, 'Error querying GLS', info=response) - - return (None, None) - - def _process_query(self, endpoint_name, start): - try: - self._add_CORS_header() - - if not self.request.body: - self.request.body = {} - - # extract request data explicitly for caching purpose - request_json = self.request.body.decode('utf-8') - - # Sanitize input data - data = _sanitize_request_data(simplejson.loads(request_json)) - except Exception as e: - err_msg = format_exception(e, "Invalid Input Data") - self.error_out(400, err_msg) - return - - try: - (po_name, all_endpoint_names) = self._get_actual_model( - endpoint_name) - - # po_name is None if self.py_handler.ps.query_objects.get( - # endpoint_name) is None - if not po_name: - self.error_out(404, 'UnknownURI', - info="Endpoint '%s' does not exist" - % endpoint_name) - return - - po_obj = self.py_handler.ps.query_objects.get(po_name) - - if not po_obj: - self.error_out(404, 'UnknownURI', - info="Endpoint '%s' does not exist" % po_name) - return - - if po_name != endpoint_name: - logger.info( - "Querying actual model: po_name={}".format(po_name)) - - uid = _get_uuid() - - # record query w/ request ID in query log - qry = Query(po_name, request_json) - gls_time = 0 - # send a query to PythonService and return - (gls_time, result) = self._handle_result(po_name, data, qry, uid) - - # if error occurred, GLS time is None. - if not gls_time: - return - - except Exception as e: - err_msg = format_exception(e, 'process query') - self.error_out(500, 'Error processing query', info=err_msg) - return - - def _get_actual_model(self, endpoint_name): - # Find the actual query to run from given endpoint - all_endpoint_names = [] - - while True: - endpoint_info = self.py_handler.ps.query_objects.get(endpoint_name) - if not endpoint_info: - return [None, None] - - all_endpoint_names.append(endpoint_name) - - endpoint_type = endpoint_info.get('type', 'model') - - if endpoint_type == 'alias': - endpoint_name = endpoint_info['endpoint_obj'] - elif endpoint_type == 'model': - break - else: - self.error_out(500, 'Unknown endpoint type', - info="Endpoint type '%s' does not exist" - % endpoint_type) - return - - return (endpoint_name, all_endpoint_names) - - @tornado.web.asynchronous - def get(self, endpoint_name): - start = time.time() - if sys.version_info > (3, 0): - endpoint_name = urllib.parse.unquote(endpoint_name) - else: - endpoint_name = urllib.unquote(endpoint_name) - logger.debug("GET /query/{}".format(endpoint_name)) - self._process_query(endpoint_name, start) - - @tornado.web.asynchronous - def post(self, endpoint_name): - start = time.time() - if sys.version_info > (3, 0): - endpoint_name = urllib.parse.unquote(endpoint_name) - else: - endpoint_name = urllib.unquote(endpoint_name) - logger.debug("POST /query/{}".format(endpoint_name)) - self._process_query(endpoint_name, start) - - -def get_config(config_file): - """Provide consistent mechanism for pulling in configuration. - - Attempt to retain backward compatibility for existing implementations by - grabbing port setting from CLI first. - - Take settings in the following order: - - 1. CLI arguments, if present - port only - may be able to deprecate - 2. common.config file, and - 3. OS environment variables (for ease of setting defaults if not present) - 4. current defaults if a setting is not present in any location - - Additionally provide similar configuration capabilities in between - common.config and environment variables. - For consistency use the same variable name in the config file as in the os - environment. - For naming standards use all capitals and start with 'TABPY_' - """ - parser = configparser.ConfigParser() - - if os.path.isfile(config_file): - with open(config_file) as f: - parser.read_string(f.read()) - else: - logger.warning("Unable to find config file at '{}', " - "using default settings.".format(config_file)) - - settings = {} - for section in parser.sections(): - if section == "TabPy": - for key, val in parser.items(section): - settings[key] = val - break - - def set_parameter(settings_key, - config_key, - default_val=None, - check_env_var=False): - if config_key is not None and parser.has_option('TabPy', config_key): - settings[settings_key] = parser.get('TabPy', config_key) - elif check_env_var: - settings[settings_key] = os.getenv(config_key, default_val) - elif default_val is not None: - settings[settings_key] = default_val - - if cli_args is not None and cli_args.port is not None: - settings['port'] = cli_args.port - else: - set_parameter( - 'port', 'TABPY_PORT', default_val=9004, check_env_var=True) - try: - settings['port'] = int(settings['port']) - except ValueError: - logger.warning('Error during config validation, invalid port: {}. ' - 'Using default port 9004'.format(settings['port'])) - settings['port'] = 9004 - - set_parameter('server_version', None, default_val=__version__) - set_parameter( - 'bind_ip', 'TABPY_BIND_IP', default_val='0.0.0.0', check_env_var=True) - - set_parameter('upload_dir', 'TABPY_QUERY_OBJECT_PATH', - default_val='/tmp/query_objects', check_env_var=True) - if not os.path.exists(settings['upload_dir']): - os.makedirs(settings['upload_dir']) - - set_parameter('state_file_path', 'TABPY_STATE_PATH', - default_val='./', check_env_var=True) - settings['state_file_path'] = os.path.realpath( - os.path.normpath( - os.path.expanduser(settings['state_file_path']))) - - # set and validate transfer protocol - set_parameter('transfer_protocol', 'TABPY_TRANSFER_PROTOCOL', - default_val='http') - settings['transfer_protocol'] = settings['transfer_protocol'].lower() - - set_parameter('certificate_file', 'TABPY_CERTIFICATE_FILE') - set_parameter('key_file', 'TABPY_KEY_FILE') - validate_transfer_protocol_settings(settings) - - # if state.ini does not exist try and create it - remove last dependence - # on batch/shell script - state_file_path = settings['state_file_path'] - logger.info("Loading state from state file {}".format( - os.path.join(state_file_path, "state.ini"))) - tabpy_state = _get_state_from_file(state_file_path) - settings['tabpy'] = TabPyState(config=tabpy_state, settings=settings) - - settings['py_handler'] = PythonServiceHandler(PythonService()) - settings['compress_response'] = True if TORNADO_MAJOR >= 4 else "gzip" - settings['static_path'] = os.path.join(os.path.dirname(__file__), "static") - - # Set subdirectory from config if applicable - subdirectory = "" - if tabpy_state.has_option("Service Info", "Subdirectory"): - subdirectory = "/" + tabpy_state.get("Service Info", "Subdirectory") - - return settings, subdirectory - - -def validate_transfer_protocol_settings(settings): - if 'transfer_protocol' not in settings: - logger.error('Missing transfer protocol information.') - raise RuntimeError('Missing transfer protocol information.') - - protocol = settings['transfer_protocol'] - - if protocol == 'http': - return - - if protocol != 'https': - err = 'Unsupported transfer protocol: {}.'.format(protocol) - logger.fatal(err) - raise RuntimeError(err) - - validate_cert_key_state('The parameter(s) {} must be set.', - 'certificate_file' in settings, - 'key_file' in settings) - cert = settings['certificate_file'] - - validate_cert_key_state( - 'The parameter(s) {} must point to an existing file.', - os.path.isfile(cert), os.path.isfile(settings['key_file'])) - validate_cert(cert) - return - - -def validate_cert_key_state(msg, cert_valid, key_valid): - cert_param, key_param = 'TABPY_CERTIFICATE_FILE', 'TABPY_KEY_FILE' - cert_and_key_param = '{} and {}'.format(cert_param, key_param) - https_error = 'Error using HTTPS: ' - err = None - if not cert_valid and not key_valid: - err = https_error + msg.format(cert_and_key_param) - elif not cert_valid: - err = https_error + msg.format(cert_param) - elif not key_valid: - err = https_error + msg.format(key_param) - if err is not None: - logger.fatal(err) - raise RuntimeError(err) - - -def validate_cert(cert_file_path): - with open(cert_file_path, 'r') as f: - cert_buf = f.read() - - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_buf) - - date_format, encoding = '%Y%m%d%H%M%SZ', 'ascii' - not_before = datetime.strptime( - cert.get_notBefore().decode(encoding), date_format) - not_after = datetime.strptime( - cert.get_notAfter().decode(encoding), date_format) - now = datetime.now() - - https_error = 'Error using HTTPS: ' - if now < not_before: - raise RuntimeError(https_error + 'The certificate provided is not ' - 'valid until {}.'.format(not_before)) - if now > not_after: - raise RuntimeError(https_error + 'The certificate provided expired ' - 'on {}.'.format(not_after)) - - -def main(): - settings, subdirectory = get_config(config_file) - - logger.info('Initializing TabPy...') - tornado.ioloop.IOLoop.instance().run_sync(lambda: init_ps_server(settings)) - logger.info('Done initializing TabPy.') - - executor = concurrent.futures.ThreadPoolExecutor( - max_workers=multiprocessing.cpu_count()) - - # initialize Tornado application - application = tornado.web.Application([ - # skip MainHandler to use StaticFileHandler .* page requests and - # default to index.html - # (r"/", MainHandler), - (subdirectory + r'/query/([^/]+)', QueryPlaneHandler), - (subdirectory + r'/status', StatusHandler), - (subdirectory + r'/info', ServiceInfoHandler), - (subdirectory + r'/endpoints', EndpointsHandler), - (subdirectory + r'/endpoints/([^/]+)?', EndpointHandler), - (subdirectory + r'/evaluate', EvaluationPlaneHandler, - dict(executor=executor)), - (subdirectory + r'/configurations/endpoint_upload_destination', - UploadDestinationHandler), - (subdirectory + r'/(.*)', tornado.web.StaticFileHandler, - dict(path=settings['static_path'], default_filename="index.html")), - ], debug=False, **settings) - - settings = application.settings - - init_model_evaluator(settings) - - if settings['transfer_protocol'] == 'http': - application.listen(settings['port'], address=settings['bind_ip']) - elif settings['transfer_protocol'] == 'https': - application.listen(settings['port'], address=settings['bind_ip'], - ssl_options={ - 'certfile': settings['certificate_file'], - 'keyfile': settings['key_file'] - }) - else: - raise RuntimeError('Unsupported transfer protocol.') - - logger.info('Web service listening on {} port {}'.format( - settings['bind_ip'], str(settings['port']))) - tornado.ioloop.IOLoop.instance().start() ->>>>>>> fab86dd... Merge pull request #206 from WillAyd/server-linting if __name__ == '__main__': From 8ede838e93b08ae573db5ee396e17b823ef9290d Mon Sep 17 00:00:00 2001 From: ogolovatyi Date: Thu, 28 Mar 2019 14:45:45 -0700 Subject: [PATCH 4/6] Fix merging conflicts --- tabpy-server/server_tests/test_config.py | 27 ++++++++++++------------ 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tabpy-server/server_tests/test_config.py b/tabpy-server/server_tests/test_config.py index 503e848c..27b38857 100644 --- a/tabpy-server/server_tests/test_config.py +++ b/tabpy-server/server_tests/test_config.py @@ -19,14 +19,13 @@ def assert_raises_runtime_error(message, fn, args={}): class TestConfigEnvironmentCalls(unittest.TestCase): @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', - return_value=Namespace(config=None)) - @patch('tabpy_server.tabpy.TabPyState') - @patch('tabpy_server.tabpy._get_state_from_file') - @patch('tabpy_server.tabpy.shutil') - @patch('tabpy_server.tabpy.PythonServiceHandler') - @patch('tabpy_server.tabpy.os.path.exists', return_value=True) - @patch('tabpy_server.tabpy.os.path.isfile', return_value=False) - @patch('tabpy_server.tabpy.os') + return_value=Namespace(config=None)) + @patch('tabpy_server.tabpy.app.app.TabPyState') + @patch('tabpy_server.tabpy.app.app._get_state_from_file') + @patch('tabpy_server.tabpy.app.app.PythonServiceHandler') + @patch('tabpy_server.tabpy.app.app.os.path.exists', return_value=True) + @patch('tabpy_server.tabpy.app.app.os.path.isfile', return_value=False) + @patch('tabpy_server.tabpy.app.app.os') def test_no_config_file(self, mock_os, mock_file_exists, mock_path_exists, mock_psws, mock_management_util, mock_tabpy_state, @@ -178,7 +177,7 @@ def test_https_cert_and_key_file_not_found(self, mock_path): assert_raises_runtime_error( 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE and ' 'TABPY_KEY_FILE must point to an existing file.', - get_config, {self.config_name}) + TabPyApp, {self.config_name}) @patch('tabpy_server.app.app.os.path') def test_https_cert_file_not_found(self, mock_path): @@ -192,8 +191,9 @@ def test_https_cert_file_not_found(self, mock_path): x, {self.fp.name, 'bar'}) assert_raises_runtime_error( - 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE must ' - 'point to an existing file.', get_config, {self.config_name}) + 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' + 'must point to an existing file.', + TabPyApp, {self.fp.name}) @patch('tabpy_server.app.app.os.path') def test_https_key_file_not_found(self, mock_path): @@ -207,8 +207,9 @@ def test_https_key_file_not_found(self, mock_path): x, {self.fp.name, 'foo'}) assert_raises_runtime_error( - 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE must point to ' - 'an existing file.', get_config, {self.config_name}) + 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' + 'and TABPY_KEY_FILE must point to an existing file.', + TabPyApp, {self.fp.name}) @patch('tabpy_server.app.app.os.path.isfile', return_value=True) @patch('tabpy_server.app.util.validate_cert') From 79a65a44a41e67a2f0ee4cc0f03325758b96ab99 Mon Sep 17 00:00:00 2001 From: ogolovatyi Date: Thu, 28 Mar 2019 14:49:41 -0700 Subject: [PATCH 5/6] Fix unit tests --- tabpy-server/server_tests/test_config.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tabpy-server/server_tests/test_config.py b/tabpy-server/server_tests/test_config.py index 27b38857..d0025f3e 100644 --- a/tabpy-server/server_tests/test_config.py +++ b/tabpy-server/server_tests/test_config.py @@ -18,14 +18,14 @@ def assert_raises_runtime_error(message, fn, args={}): class TestConfigEnvironmentCalls(unittest.TestCase): - @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', + @patch('tabpy_server.app.app.TabPyApp._parse_cli_arguments', return_value=Namespace(config=None)) - @patch('tabpy_server.tabpy.app.app.TabPyState') - @patch('tabpy_server.tabpy.app.app._get_state_from_file') - @patch('tabpy_server.tabpy.app.app.PythonServiceHandler') - @patch('tabpy_server.tabpy.app.app.os.path.exists', return_value=True) - @patch('tabpy_server.tabpy.app.app.os.path.isfile', return_value=False) - @patch('tabpy_server.tabpy.app.app.os') + @patch('tabpy_server.app.app.TabPyState') + @patch('tabpy_server.app.app._get_state_from_file') + @patch('tabpy_server.app.app.PythonServiceHandler') + @patch('tabpy_server.app.app.os.path.exists', return_value=True) + @patch('tabpy_server.app.app.os.path.isfile', return_value=False) + @patch('tabpy_server.app.app.os') def test_no_config_file(self, mock_os, mock_file_exists, mock_path_exists, mock_psws, mock_management_util, mock_tabpy_state, @@ -177,7 +177,7 @@ def test_https_cert_and_key_file_not_found(self, mock_path): assert_raises_runtime_error( 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE and ' 'TABPY_KEY_FILE must point to an existing file.', - TabPyApp, {self.config_name}) + TabPyApp, {self.fp.name}) @patch('tabpy_server.app.app.os.path') def test_https_cert_file_not_found(self, mock_path): @@ -207,8 +207,8 @@ def test_https_key_file_not_found(self, mock_path): x, {self.fp.name, 'foo'}) assert_raises_runtime_error( - 'Error using HTTPS: The parameter(s) TABPY_CERTIFICATE_FILE ' - 'and TABPY_KEY_FILE must point to an existing file.', + 'Error using HTTPS: The parameter(s) TABPY_KEY_FILE ' + 'must point to an existing file.', TabPyApp, {self.fp.name}) @patch('tabpy_server.app.app.os.path.isfile', return_value=True) From 99410f897bb1701f53f73adea4a68e88bcefd206 Mon Sep 17 00:00:00 2001 From: ogolovatyi Date: Thu, 28 Mar 2019 15:30:01 -0700 Subject: [PATCH 6/6] More code style fixes --- tabpy-server/server_tests/test_config.py | 2 +- tabpy-server/tabpy_server/app/util.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tabpy-server/server_tests/test_config.py b/tabpy-server/server_tests/test_config.py index d0025f3e..7f150a34 100644 --- a/tabpy-server/server_tests/test_config.py +++ b/tabpy-server/server_tests/test_config.py @@ -87,7 +87,7 @@ def test_config_file_present(self, mock_os, mock_path_exists, mock_os.getenv.assert_has_calls(getenv_calls, any_order=True) self.assertEqual(app.settings['port'], 1234) - self.assertEqual(app.settings['server_version'], + self.assertEqual(app.settings['server_version'], open('VERSION').read().strip()) self.assertEqual(app.settings['upload_dir'], 'foo') self.assertEqual(app.settings['state_file_path'], 'bar') diff --git a/tabpy-server/tabpy_server/app/util.py b/tabpy-server/tabpy_server/app/util.py index 63c4b1c5..a5b2a13e 100644 --- a/tabpy-server/tabpy_server/app/util.py +++ b/tabpy-server/tabpy_server/app/util.py @@ -36,7 +36,8 @@ def validate_cert(cert_file_path): not_before), RuntimeError) if now > not_after: log_and_raise(https_error + - 'The certificate provided expired on {}.'.format(not_after), RuntimeError) + f'The certificate provided expired on {not_after}.', + RuntimeError) def parse_pwd_file(pwd_file_name):