From 51d2f0acde8621988cce2d517e6453343464df24 Mon Sep 17 00:00:00 2001 From: Andy Antes Date: Thu, 30 Mar 2023 15:08:05 -0400 Subject: [PATCH 1/3] Remove duplicate setting --- tabpy/tabpy_server/app/app_parameters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tabpy/tabpy_server/app/app_parameters.py b/tabpy/tabpy_server/app/app_parameters.py index 51362e84..7dc8ddb5 100644 --- a/tabpy/tabpy_server/app/app_parameters.py +++ b/tabpy/tabpy_server/app/app_parameters.py @@ -10,7 +10,6 @@ class ConfigParameters: TABPY_TRANSFER_PROTOCOL = "TABPY_TRANSFER_PROTOCOL" TABPY_CERTIFICATE_FILE = "TABPY_CERTIFICATE_FILE" TABPY_KEY_FILE = "TABPY_KEY_FILE" - TABPY_PWD_FILE = "TABPY_PWD_FILE" TABPY_LOG_DETAILS = "TABPY_LOG_DETAILS" TABPY_STATIC_PATH = "TABPY_STATIC_PATH" TABPY_MAX_REQUEST_SIZE_MB = "TABPY_MAX_REQUEST_SIZE_MB" From 1e26c8839fa62aaeb2c6fdeb2f4b79d16a9747af Mon Sep 17 00:00:00 2001 From: Andy Antes Date: Thu, 30 Mar 2023 15:12:26 -0400 Subject: [PATCH 2/3] Make server available to evaluation_plane_handler so it can operate on flights directly. Update evaluation_plane_handler to bypass using a local client to make client-server calls. Server object construction moved to app as part of startup. --- tabpy/tabpy_server/app/app.py | 43 +++++++++++++++++- tabpy/tabpy_server/app/arrow_server.py | 39 ++-------------- .../handlers/evaluation_plane_handler.py | 44 ++++++++++--------- 3 files changed, 69 insertions(+), 57 deletions(-) diff --git a/tabpy/tabpy_server/app/app.py b/tabpy/tabpy_server/app/app.py index e448da1d..6d67d04e 100644 --- a/tabpy/tabpy_server/app/app.py +++ b/tabpy/tabpy_server/app/app.py @@ -10,6 +10,8 @@ from tabpy.tabpy import __version__ from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters from tabpy.tabpy_server.app.util import parse_pwd_file +from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory +from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler from tabpy.tabpy_server.management.state import TabPyState from tabpy.tabpy_server.management.util import _get_state_from_file from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server @@ -59,6 +61,7 @@ class TabPyApp: tabpy_state = None python_service = None credentials = {} + arrow_server = None def __init__(self, config_file): if config_file is None: @@ -75,6 +78,43 @@ def __init__(self, config_file): self._parse_config(config_file) + def _get_tls_certificates(self, config): + tls_certificates = [] + cert = config[SettingsParameters.CertificateFile] + key = config[SettingsParameters.KeyFile] + with open(cert, "rb") as cert_file: + tls_cert_chain = cert_file.read() + with open(key, "rb") as key_file: + tls_private_key = key_file.read() + tls_certificates.append((tls_cert_chain, tls_private_key)) + return tls_certificates + + def _get_arrow_server(self, config): + verify_client = None + tls_certificates = None + scheme = "grpc+tcp" + if config[SettingsParameters.TransferProtocol] == "https": + verify_client = True + scheme = "grpc+tls" + tls_certificates = self._get_tls_certificates(config) + + host = "localhost" + port = config.get(SettingsParameters.ArrowFlightPort) + location = "{}://{}:{}".format(scheme, host, port) + + auth_middleware = None + if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]: + _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE]) + auth_middleware = { + "basic": BasicAuthServerMiddlewareFactory(creds) + } + + server = pa.FlightServer(host, location, + tls_certificates=tls_certificates, + verify_client=verify_client, auth_handler=NoOpAuthHandler(), + middleware=auth_middleware) + return server + def run(self): application = self._create_tornado_web_app() max_request_size = ( @@ -115,7 +155,8 @@ def run(self): # Define a function for the thread def start_pyarrow(): - pa.start(self.settings) + self.arrow_server = self._get_arrow_server(self.settings) + pa.start(self.arrow_server) try: _thread.start_new_thread(start_pyarrow, ()) diff --git a/tabpy/tabpy_server/app/arrow_server.py b/tabpy/tabpy_server/app/arrow_server.py index d3192181..9ffe69e5 100644 --- a/tabpy/tabpy_server/app/arrow_server.py +++ b/tabpy/tabpy_server/app/arrow_server.py @@ -42,6 +42,7 @@ def __init__(self, host="localhost", location=None, self.flights = {} self.host = host self.tls_certificates = tls_certificates + self.location = location @classmethod def descriptor_to_key(self, descriptor): @@ -139,42 +140,8 @@ def _shutdown(self): time.sleep(2) self.shutdown() -def _get_tls_certificates(config): - tls_certificates = [] - cert = config[SettingsParameters.CertificateFile] - key = config[SettingsParameters.KeyFile] - with open(cert, "rb") as cert_file: - tls_cert_chain = cert_file.read() - with open(key, "rb") as key_file: - tls_private_key = key_file.read() - tls_certificates.append((tls_cert_chain, tls_private_key)) - return tls_certificates - -def start(config): - verify_client = None - tls_certificates = None - scheme = "grpc+tcp" - if config[SettingsParameters.TABPY_TRANSFER_PROTOCOL] == "https": - verify_client = True - scheme = "grpc+tls" - tls_certificates = _get_tls_certificates(config) - - host = "localhost" - port = config.get(SettingsParameters.ArrowFlightPort) - location = "{}://{}:{}".format(scheme, host, port) - - auth_middleware = None - if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]: - _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE]) - auth_middleware = { - "basic": BasicAuthServerMiddlewareFactory(creds) - } - - server = FlightServer(host, location, - tls_certificates=tls_certificates, - verify_client=verify_client, auth_handler=NoOpAuthHandler(), - middleware=auth_middleware) - logger.info(f"Serving on {location}") +def start(server): + logger.info(f"Serving on {server.location}") server.serve() diff --git a/tabpy/tabpy_server/handlers/evaluation_plane_handler.py b/tabpy/tabpy_server/handlers/evaluation_plane_handler.py index 86867271..412ae2e4 100644 --- a/tabpy/tabpy_server/handlers/evaluation_plane_handler.py +++ b/tabpy/tabpy_server/handlers/evaluation_plane_handler.py @@ -2,7 +2,7 @@ import pyarrow import uuid -from tabpy.tabpy_server.handlers import BaseHandler, arrow_client +from tabpy.tabpy_server.handlers import BaseHandler import json import simplejson import logging @@ -10,9 +10,7 @@ import requests from tornado import gen from datetime import timedelta -from tabpy.tabpy_server.handlers.basic_auth_client_middleware_factory import BasicAuthClientMiddlewareFactory from tabpy.tabpy_server.handlers.util import AuthErrorStates -from tabpy.tabpy_server.app.app_parameters import SettingsParameters class RestrictedTabPy: def __init__(self, protocol, port, logger, timeout, headers): @@ -59,6 +57,7 @@ class EvaluationPlaneHandler(BaseHandler): def initialize(self, executor, app): super(EvaluationPlaneHandler, self).initialize(app) + self.arrow_server = app.arrow_server self.executor = executor self._error_message_timeout = ( f"User defined script timed out. " @@ -79,6 +78,7 @@ def _post_impl(self): arguments_str = "" if "dataPath" in body: # arrow flight scenario + print("arrow flight scenario") arrow_data = self.get_arrow_data(body["dataPath"]) if arrow_data is not None: arguments = {"_arg1": arrow_data} @@ -139,27 +139,31 @@ def _post_impl(self): else: self.write("null") self.finish() - - def _get_flight_client(self): - # TODO: handle TLS - scheme = "grpc+tcp" - host = "localhost" - port = self.settings[SettingsParameters.ArrowFlightPort] - middleware = None - if "authentication" in self.settings[SettingsParameters.ApiVersions]["v1"]["features"]: - middleware = [ - BasicAuthClientMiddlewareFactory(self.username, self.password) - ] - connection_args = {} - return pyarrow.flight.FlightClient(location=f"{scheme}://{host}:{port}", middleware=middleware, **connection_args) def get_arrow_data(self, filename): - client = self._get_flight_client() - return arrow_client.get_flight_by_path(filename, client, client_factory=self._get_flight_client) + descriptor = pyarrow.flight.FlightDescriptor.for_path(filename) + info = self.arrow_server.get_flight_info(None, descriptor) + for endpoint in info.endpoints: + print('Ticket:', endpoint.ticket) + for location in endpoint.locations: + print(location) + key = (descriptor.descriptor_type.value, descriptor.command, + tuple(descriptor.path or tuple())) + df = self.arrow_server.flights.pop(key).to_pandas() + return df + print('no data found for get') + return '' def upload_arrow_data(self, data, filename, metadata): - client = self._get_flight_client() - return arrow_client.upload_data(client, data, filename, metadata) + my_table = pyarrow.table(data) + if metadata is not None: + my_table.schema.with_metadata(metadata) + print('Table rows=', str(len(my_table))) + print("Uploading", data.head()) + descriptor = pyarrow.flight.FlightDescriptor.for_path(filename) + key = (descriptor.descriptor_type.value, descriptor.command, + tuple(descriptor.path or tuple())) + self.arrow_server.flights[key] = my_table @gen.coroutine def post(self): From f359eb161b9642a22dffb5bc142d6814540de12d Mon Sep 17 00:00:00 2001 From: Andy Antes Date: Thu, 30 Mar 2023 15:12:47 -0400 Subject: [PATCH 3/3] Remove unused client code. --- tabpy/tabpy_server/handlers/arrow_client.py | 200 ------------------ .../basic_auth_client_middleware_factory.py | 24 --- 2 files changed, 224 deletions(-) delete mode 100644 tabpy/tabpy_server/handlers/arrow_client.py delete mode 100644 tabpy/tabpy_server/handlers/basic_auth_client_middleware_factory.py diff --git a/tabpy/tabpy_server/handlers/arrow_client.py b/tabpy/tabpy_server/handlers/arrow_client.py deleted file mode 100644 index 9777bba6..00000000 --- a/tabpy/tabpy_server/handlers/arrow_client.py +++ /dev/null @@ -1,200 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""An example Flight CLI client.""" - -import argparse -import sys - -import pyarrow -import pyarrow.flight -import pyarrow.csv as csv - - -def list_flights(args, client, connection_args={}): - print('Flights\n=======') - for flight in client.list_flights(): - descriptor = flight.descriptor - if descriptor.descriptor_type == pyarrow.flight.DescriptorType.PATH: - print("Path:", descriptor.path) - elif descriptor.descriptor_type == pyarrow.flight.DescriptorType.CMD: - print("Command:", descriptor.command) - else: - print("Unknown descriptor type") - - print("Total records:", end=" ") - if flight.total_records >= 0: - print(flight.total_records) - else: - print("Unknown") - - print("Total bytes:", end=" ") - if flight.total_bytes >= 0: - print(flight.total_bytes) - else: - print("Unknown") - - print("Number of endpoints:", len(flight.endpoints)) - print("Schema:") - print(flight.schema) - print('---') - - print('\nActions\n=======') - for action in client.list_actions(): - print("Type:", action.type) - print("Description:", action.description) - print('---') - - -def do_action(args, client, connection_args={}): - try: - buf = pyarrow.allocate_buffer(0) - action = pyarrow.flight.Action(args.action_type, buf) - print('Running action', args.action_type) - for result in client.do_action(action): - print("Got result", result.body.to_pybytes()) - except pyarrow.lib.ArrowIOError as e: - print("Error calling action:", e) - - -def push_data(args, client, connection_args={}): - print('File Name:', args.file) - my_table = csv.read_csv(args.file) - print('Table rows=', str(len(my_table))) - df = my_table.to_pandas() - print(df.head()) - writer, _ = client.do_put( - pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema) - writer.write_table(my_table) - writer.close() - - -def upload_data(client, data, filename, metadata=None): - my_table = pyarrow.table(data) - if metadata is not None: - my_table.schema.with_metadata(metadata) - print('Table rows=', str(len(my_table))) - print("Uploading", data.head()) - writer, _ = client.do_put( - pyarrow.flight.FlightDescriptor.for_path(filename), my_table.schema) - writer.write_table(my_table) - writer.close() - - -def get_flight_by_path(path, client, client_factory): - descriptor = pyarrow.flight.FlightDescriptor.for_path(path) - - info = client.get_flight_info(descriptor) - for endpoint in info.endpoints: - print('Ticket:', endpoint.ticket) - for location in endpoint.locations: - print(location) - get_client = client_factory() - reader = get_client.do_get(endpoint.ticket) - df = reader.read_pandas() - print(df) - return df - print("no data found for get") - return '' - -def _add_common_arguments(parser): - parser.add_argument('--tls', action='store_true', - help='Enable transport-level security') - parser.add_argument('--tls-roots', default=None, - help='Path to trusted TLS certificate(s)') - parser.add_argument("--mtls", nargs=2, default=None, - metavar=('CERTFILE', 'KEYFILE'), - help="Enable transport-level security") - parser.add_argument('host', type=str, - help="Address or hostname to connect to") - - -def main(): - parser = argparse.ArgumentParser() - subcommands = parser.add_subparsers() - - cmd_list = subcommands.add_parser('list') - cmd_list.set_defaults(action='list') - _add_common_arguments(cmd_list) - cmd_list.add_argument('-l', '--list', action='store_true', - help="Print more details.") - - cmd_do = subcommands.add_parser('do') - cmd_do.set_defaults(action='do') - _add_common_arguments(cmd_do) - cmd_do.add_argument('action_type', type=str, - help="The action type to run.") - - cmd_put = subcommands.add_parser('put') - cmd_put.set_defaults(action='put') - _add_common_arguments(cmd_put) - cmd_put.add_argument('file', type=str, - help="CSV file to upload.") - - cmd_get = subcommands.add_parser('get') - cmd_get.set_defaults(action='get') - _add_common_arguments(cmd_get) - cmd_get_descriptor = cmd_get.add_mutually_exclusive_group(required=True) - cmd_get_descriptor.add_argument('-p', '--path', type=str, action='append', - help="The path for the descriptor.") - cmd_get_descriptor.add_argument('-c', '--command', type=str, - help="The command for the descriptor.") - - args = parser.parse_args() - if not hasattr(args, 'action'): - parser.print_help() - sys.exit(1) - - commands = { - 'list': list_flights, - 'do': do_action, - 'get': get_flight_by_path, - 'put': push_data, - } - host, port = args.host.split(':') - port = int(port) - scheme = "grpc+tcp" - connection_args = {} - if args.tls: - scheme = "grpc+tls" - if args.tls_roots: - with open(args.tls_roots, "rb") as root_certs: - connection_args["tls_root_certs"] = root_certs.read() - if args.mtls: - with open(args.mtls[0], "rb") as cert_file: - tls_cert_chain = cert_file.read() - with open(args.mtls[1], "rb") as key_file: - tls_private_key = key_file.read() - connection_args["cert_chain"] = tls_cert_chain - connection_args["private_key"] = tls_private_key - client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}", - **connection_args) - while True: - try: - action = pyarrow.flight.Action("healthcheck", b"") - options = pyarrow.flight.FlightCallOptions(timeout=1) - list(client.do_action(action, options=options)) - break - except pyarrow.ArrowIOError as e: - if "Deadline" in str(e): - print("Server is not ready, waiting...") - commands[args.action](args, client, connection_args) - - - -if __name__ == '__main__': - main() diff --git a/tabpy/tabpy_server/handlers/basic_auth_client_middleware_factory.py b/tabpy/tabpy_server/handlers/basic_auth_client_middleware_factory.py deleted file mode 100644 index 19d5d2c8..00000000 --- a/tabpy/tabpy_server/handlers/basic_auth_client_middleware_factory.py +++ /dev/null @@ -1,24 +0,0 @@ -import base64 -from requests.auth import HTTPBasicAuth -from pyarrow.flight import ClientMiddlewareFactory, ClientMiddleware - -class BasicAuthClientMiddleware(ClientMiddleware): - def __init__(self, username, password): - self.username = username - self.password = password - - def sending_headers(self): - headers = {} - creds = f'{self.username}:{self.password}' - encoded_creds = base64.b64encode(creds.encode()).decode() - headers['authorization'] = f'Basic {encoded_creds}' - return headers - -class BasicAuthClientMiddlewareFactory(ClientMiddlewareFactory): - - def __init__(self, username, password): - self.username = username - self.password = password - - def start_call(self, instance): - return BasicAuthClientMiddleware(self.username, self.password)