In [1]:
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.flight as flight
import numpy as np
import pandas as pd
import time
import threading

# Implement a Flight server in Python

This server has a few goals

* Clients can send ("put") datasets, to be kept in memory by the server
* Clients can request a list of cached datasets ("list-tables")
* Clients can request ("get") a cached table

Note that this server is very simple and does not show some of the more sophisticated "query planning" capabilities of Arrow Flight, nor does it show parallel or multi-part access. My goal is to show you that

* It's easy to write a Flight service in Python
* The performance of Flight is **very, very good**

In [2]:
class DemoServer(flight.FlightServerBase):
    
    def __init__(self):
        self._cache = {}
    
    def list_actions(self, context):
        return [flight.ActionType('list-tables', 'List stored tables'),
                flight.ActionType('drop-table', 'Drop a stored table')]

    # -----------------------------------------------------------------
    # Implement actions
    
    def do_action(self, context, action):
        handlers = {
            'list-tables': self._list_tables,
            'drop-table': self._drop_table
        }        
        handler = handlers.get(action.type)
        if not handler:
            raise NotImplementedError   
        return handlers[action.type](action)
        
    def _drop_table(self, action):
        del self._cache[action.body]
        
    def _list_tables(self, action):
        return iter([flight.Result(cache_key) 
                     for cache_key in sorted(self._cache.keys())])

    # -----------------------------------------------------------------
    # Implement puts
    
    def do_put(self, context, descriptor, reader, writer):
        self._cache[descriptor.command] = reader.read_all()
        
    # -----------------------------------------------------------------
    # Implement gets

    def do_get(self, context, ticket):
        table = self._cache[ticket.ticket]
        return flight.RecordBatchStream(table)

Some helper utilities, you can ignore this part

In [3]:
import contextlib
import socket
def find_free_port():
    # Find a free port
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    with contextlib.closing(sock) as sock:
        sock.bind(('', 0))
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        port = sock.getsockname()[1]
    return port

def wait_for_available(client):
    deadline = time.time() + 5.0
    while True:
        try:
            list(client.list_flights())
        except Exception as e:
            if 'Connect Failed' in str(e):
                if time.time() < deadline:
                    time.sleep(0.025)
                    continue
                else:
                    raise
        break

## Start server in background, connect client

In [4]:
port = 1337
location = flight.Location.for_grpc_tcp("localhost", find_free_port())
location

server = DemoServer()
server.init(location)

thread = threading.Thread(target=lambda: server.run(), daemon=True)
thread.start()

client = flight.FlightClient.connect(location)
wait_for_available(client)

### Ask server for supported actions

In [5]:
client.list_actions()

[ActionType(type='list-tables', description='List stored tables'),
 ActionType(type='drop-table', description='Drop a stored table')]

### Implement convenience functions for invoking server's RPC methods

In [6]:
# Call "list-tables" RPC and return results as Python list
def list_tables(client):
    action = flight.Action('list-tables', b'')
    return [x.body.to_pybytes().decode('utf8') for x in client.do_action(action)]    

# Send a pyarrow.Table to the server to be cached
def cache_table_in_server(name, table):
    desc = flight.FlightDescriptor.for_command(name.encode('utf8'))
    put_writer, put_meta_reader = client.do_put(desc, table.schema)
    put_writer.write(table)
    put_writer.close()
    
# Request a pyarrow.Table by name
def get_table(name):
    reader = client.do_get(flight.Ticket(name.encode('utf8')))
    return reader.read_all()

list_tables(client)

[]

In [7]:
table = pa.table([pa.array([1,2,3,4,5])], names=['f0'])
cache_table_in_server('table1', table)

In [8]:
list_tables(client)

['table1']

In [9]:
cache_table_in_server('table2', table)
cache_table_in_server('table3', table)
cache_table_in_server('table4', table)

In [10]:
list_tables(client)

['table1', 'table2', 'table3', 'table4']

In [11]:
get_table('table1')

pyarrow.Table
f0: int64

### Now let's make a much bigger table and test performance

In [12]:
fec_table = pq.read_table('fec-2012.parquet')

In [13]:
fec_table = pa.concat_tables([fec_table] * 10)

In [14]:
# How big is it?
out = pa.BufferOutputStream()
with pa.ipc.RecordBatchStreamWriter(out, fec_table.schema) as writer:
    writer.write(fec_table)
len(out.getvalue())

1780273284

In [15]:
print(f'Table is {1780273284 / (1 << 30)} gigabytes')

Table is 1.658008698374033 gigabytes


In [16]:
%%time
cache_table_in_server('fec_table', fec_table)

CPU times: user 402 ms, sys: 1.02 s, total: 1.42 s
Wall time: 1.12 s


In [17]:
list_tables(client)

['fec_table', 'table1', 'table2', 'table3', 'table4']

In [18]:
%%time 

fec_table_received = get_table('fec_table')

CPU times: user 356 ms, sys: 988 ms, total: 1.34 s
Wall time: 1.09 s


### ~1.5 gigabytes/sec end-to-end over TCP, not bad