Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion tabpy/tabpy_server/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from tabpy.tabpy import __version__
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
from tabpy.tabpy_server.app.util import parse_pwd_file
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
from tabpy.tabpy_server.management.state import TabPyState
from tabpy.tabpy_server.management.util import _get_state_from_file
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
Expand Down Expand Up @@ -59,6 +61,7 @@ class TabPyApp:
tabpy_state = None
python_service = None
credentials = {}
arrow_server = None

def __init__(self, config_file):
if config_file is None:
Expand All @@ -75,6 +78,43 @@ def __init__(self, config_file):

self._parse_config(config_file)

def _get_tls_certificates(self, config):
tls_certificates = []
cert = config[SettingsParameters.CertificateFile]
key = config[SettingsParameters.KeyFile]
with open(cert, "rb") as cert_file:
tls_cert_chain = cert_file.read()
with open(key, "rb") as key_file:
tls_private_key = key_file.read()
tls_certificates.append((tls_cert_chain, tls_private_key))
return tls_certificates

def _get_arrow_server(self, config):
verify_client = None
tls_certificates = None
scheme = "grpc+tcp"
if config[SettingsParameters.TransferProtocol] == "https":
verify_client = True
scheme = "grpc+tls"
tls_certificates = self._get_tls_certificates(config)

host = "localhost"
port = config.get(SettingsParameters.ArrowFlightPort)
location = "{}://{}:{}".format(scheme, host, port)

auth_middleware = None
if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
_, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
auth_middleware = {
"basic": BasicAuthServerMiddlewareFactory(creds)
}

server = pa.FlightServer(host, location,
tls_certificates=tls_certificates,
verify_client=verify_client, auth_handler=NoOpAuthHandler(),
middleware=auth_middleware)
return server

def run(self):
application = self._create_tornado_web_app()
max_request_size = (
Expand Down Expand Up @@ -115,7 +155,8 @@ def run(self):

# Define a function for the thread
def start_pyarrow():
pa.start(self.settings)
self.arrow_server = self._get_arrow_server(self.settings)
pa.start(self.arrow_server)

try:
_thread.start_new_thread(start_pyarrow, ())
Expand Down
1 change: 0 additions & 1 deletion tabpy/tabpy_server/app/app_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class ConfigParameters:
TABPY_TRANSFER_PROTOCOL = "TABPY_TRANSFER_PROTOCOL"
TABPY_CERTIFICATE_FILE = "TABPY_CERTIFICATE_FILE"
TABPY_KEY_FILE = "TABPY_KEY_FILE"
TABPY_PWD_FILE = "TABPY_PWD_FILE"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a duplicate parameter. You can see its copy at the top of the list.

TABPY_LOG_DETAILS = "TABPY_LOG_DETAILS"
TABPY_STATIC_PATH = "TABPY_STATIC_PATH"
TABPY_MAX_REQUEST_SIZE_MB = "TABPY_MAX_REQUEST_SIZE_MB"
Expand Down
39 changes: 3 additions & 36 deletions tabpy/tabpy_server/app/arrow_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self, host="localhost", location=None,
self.flights = {}
self.host = host
self.tls_certificates = tls_certificates
self.location = location

@classmethod
def descriptor_to_key(self, descriptor):
Expand Down Expand Up @@ -139,42 +140,8 @@ def _shutdown(self):
time.sleep(2)
self.shutdown()

def _get_tls_certificates(config):
tls_certificates = []
cert = config[SettingsParameters.CertificateFile]
key = config[SettingsParameters.KeyFile]
with open(cert, "rb") as cert_file:
tls_cert_chain = cert_file.read()
with open(key, "rb") as key_file:
tls_private_key = key_file.read()
tls_certificates.append((tls_cert_chain, tls_private_key))
return tls_certificates

def start(config):
verify_client = None
tls_certificates = None
scheme = "grpc+tcp"
if config[SettingsParameters.TABPY_TRANSFER_PROTOCOL] == "https":
verify_client = True
scheme = "grpc+tls"
tls_certificates = _get_tls_certificates(config)

host = "localhost"
port = config.get(SettingsParameters.ArrowFlightPort)
location = "{}://{}:{}".format(scheme, host, port)

auth_middleware = None
if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
_, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
auth_middleware = {
"basic": BasicAuthServerMiddlewareFactory(creds)
}

server = FlightServer(host, location,
tls_certificates=tls_certificates,
verify_client=verify_client, auth_handler=NoOpAuthHandler(),
middleware=auth_middleware)
logger.info(f"Serving on {location}")
def start(server):
logger.info(f"Serving on {server.location}")
server.serve()


Expand Down
200 changes: 0 additions & 200 deletions tabpy/tabpy_server/handlers/arrow_client.py

This file was deleted.

This file was deleted.

Loading