diff --git a/app.py b/app.py
index 76bd933..6202733 100644
--- a/app.py
+++ b/app.py
@@ -11,9 +11,14 @@
import hmac
import logging
import os
+import re
import signal
+from struct import (
+ Struct,
+)
import sys
import urllib.parse
+from xml.sax.saxutils import escape as escape_xml
from flask import (
Flask,
@@ -55,17 +60,33 @@ def proxy(path):
logger.debug('Attempt to proxy: %s', request)
url = endpoint_url + path
- body_hash = hashlib.sha256(b'').hexdigest()
+ parsed_url = urllib.parse.urlsplit(url)
+ method, body, params, parse_response = \
+ (
+ 'POST',
+ aws_select_post_body(request.args['query_sql']),
+ (('select', ''), ('select-type', '2')),
+ aws_select_parse_result,
+ ) if 'query_sql' in request.args else \
+ (
+ 'GET',
+ b'',
+ {},
+ lambda x, _: x,
+ )
+
+ body_hash = hashlib.sha256(body).hexdigest()
pre_auth_headers = tuple((
(key, request.headers[key])
for key in proxied_request_headers if key in request.headers
))
- parsed_url = urllib.parse.urlsplit(url)
+ encoded_params = urllib.parse.urlencode(params)
request_headers = aws_sigv4_headers(
- pre_auth_headers, 's3', parsed_url.netloc, 'GET', parsed_url.path, (), body_hash,
+ pre_auth_headers, 's3', parsed_url.netloc, method, parsed_url.path, params, body_hash,
)
+ response = http.request(method, f'{url}?{encoded_params}', headers=dict(
+ request_headers), body=body, preload_content=False)
- response = http.request('GET', url, headers=dict(request_headers), preload_content=False)
response_headers = tuple((
(key, response.headers[key])
for key in proxied_response_headers if key in response.headers
@@ -84,7 +105,7 @@ def body_empty():
pass
downstream_response = \
- Response(response.stream(65536, decode_content=False),
+ Response(parse_response(response.stream(65536, decode_content=False), 65536),
status=response.status, headers=response_headers) if allow_proxy else \
Response(body_empty(), status=500)
downstream_response.call_on_close(response.release_conn)
@@ -144,6 +165,215 @@ def sign(key, msg):
(b'x-amz-content-sha256', body_hash.encode('ascii')),
) + pre_auth_headers
+ def aws_select_post_body(sql):
+ sql_xml_escaped = escape_xml(sql)
+ return \
+ f'''
+
+ {sql_xml_escaped}
+ SQL
+
+
+ Document
+
+
+
+
+ ,
+
+
+
+ '''.encode('utf-8')
+
+ def aws_select_parse_result(input_iterable, output_chunk_size):
+ # Returns a iterator that yields payload data in fixed size chunks. It does not depend
+ # on the input_stream yielding chunks of any particular size, and internal copying or
+ # concatanation of chunks is avoided
+
+ class NoMoreBytes(Exception):
+ pass
+
+ def get_byte_readers(_input_iterable):
+ chunk = b''
+ offset = 0
+ it = iter(_input_iterable)
+
+ def _read_multiple_chunks(amt):
+ nonlocal chunk
+ nonlocal offset
+
+ # Yield anything we already have
+ if chunk:
+ to_yield = min(amt, len(chunk) - offset)
+ yield chunk[offset:offset + to_yield]
+ amt -= to_yield
+ offset += to_yield % len(chunk)
+
+ # Yield the rest as it comes in
+ while amt:
+ try:
+ chunk = next(it)
+ except StopIteration:
+ raise NoMoreBytes()
+ to_yield = min(amt, len(chunk))
+ yield chunk[:to_yield]
+ amt -= to_yield
+ offset = to_yield % len(chunk)
+ chunk = chunk if offset else b''
+
+ def _read_single_chunk(amt):
+ raw = b''.join(chunk for chunk in _read_multiple_chunks(amt))
+ if raw:
+ return raw
+ raise NoMoreBytes()
+
+ return _read_multiple_chunks, _read_single_chunk
+
+ ################################
+ # Extract records from the bytes
+
+ def yield_messages(_read_multiple_chunks, _read_single_chunk):
+ # Yields a series of messages. Each is a dict of headers together with a generator that
+ # itself yields the bytes of the payload of the message. The payload generator must
+ # be read by calling code before the next iteration of this generator
+ prelude_struct = Struct('!III')
+ byte_struct = Struct('!B')
+ header_value_struct = Struct('!H')
+
+ while True:
+ try:
+ total_length, header_length, _ = prelude_struct.unpack(_read_single_chunk(12))
+ except NoMoreBytes:
+ return
+ payload_length = total_length - header_length - 16
+
+ # Read headers. Any given header type can only appear once, so a dict
+ # type => value is fine
+ headers = {}
+ while header_length:
+ header_key_length, = byte_struct.unpack(_read_single_chunk(1))
+ header_key = _read_single_chunk(header_key_length).decode('utf-8')
+ _ = _read_single_chunk(1) # Header value type is ignored for S3
+ header_value_length, = header_value_struct.unpack(_read_single_chunk(2))
+ header_value = _read_single_chunk(header_value_length).decode('utf-8')
+ header_length -= (1 + header_key_length + 1 + 2 + header_value_length)
+ headers[header_key] = header_value
+
+ def payload():
+ for chunk in _read_multiple_chunks(payload_length):
+ yield chunk
+
+ # Ignore final CRC
+ final_crc_length = 4
+ for _ in _read_multiple_chunks(final_crc_length):
+ pass
+
+ yield headers, payload()
+
+ def yield_records(_messages):
+ for headers, payload in _messages:
+ if headers[':message-type'] == 'event' and headers[':event-type'] == 'Records':
+ yield from payload
+ else:
+ for _ in payload:
+ pass
+
+ def yield_as_json(_records):
+ yield b'{"rows":['
+
+ # Slightly faffy to remove the trailing "," from S3 Select output
+ try:
+ last = next(_records)
+ except StopIteration:
+ pass
+ else:
+ for val in _records:
+ yield last
+ last = val
+
+ yield last[:len(last) - 1]
+
+ yield b']}'
+
+ def yield_as_utf_8(_as_json):
+ # The output from S3 Select [at least from minio] appears to include unicode escape
+ # sequences, even for characters like > and &. A plain .decode('unicode-escape') isn't
+ # enough to convert them, since an excape sequence can be truncated if it crosses into
+ # the next chunk, and in fact even using .decode('unicode-escape') where you're sure
+ # there is no truncated unicode escape sequence breaks non-ASCII UTF-8 data, since it
+ # appears to treat them as Latin-1. So we have to do our own search and and replace.
+
+ def even_slashes_before(_chunk, index):
+ count = 0
+ index -= 1
+ while index >= 0 and _chunk[index:index + 1] == b'\\':
+ count += 1
+ index -= 1
+ return count % 2 == 0
+
+ def split_trailing_escape(_chunk):
+ # \, \u, \uX, \uXX, \uXXX, with an even number of \ before are trailing escapes
+ if _chunk[-1:] == b'\\':
+ if even_slashes_before(_chunk, len(_chunk) - 1):
+ return _chunk[:-1], _chunk[-1:]
+ elif _chunk[-2:] == b'\\u':
+ if even_slashes_before(_chunk, len(_chunk) - 2):
+ return _chunk[:-2], _chunk[-2:]
+ elif _chunk[-3:-1] == b'\\u':
+ if even_slashes_before(_chunk, len(_chunk) - 3):
+ return _chunk[:-3], _chunk[-3:]
+ elif _chunk[-4:-2] == b'\\u':
+ if even_slashes_before(_chunk, len(_chunk) - 4):
+ return _chunk[:-4], _chunk[-4:]
+ elif _chunk[-5:-3] == b'\\u':
+ if even_slashes_before(_chunk, len(_chunk) - 5):
+ return _chunk[:-5], _chunk[-5:]
+ return _chunk, b''
+
+ def unicode_escapes_to_utf_8(_chunk):
+ def to_utf_8(match):
+ group = match.group()
+ if even_slashes_before(_chunk, match.span()[0]):
+ return group.decode('unicode-escape').encode('utf-8')
+ return group
+
+ return re.sub(b'\\\\u[0-9a-fA-F]{4}', to_utf_8, _chunk)
+
+ trailing_escape = b''
+ for chunk in as_json:
+ chunk, trailing_escape = split_trailing_escape(trailing_escape + chunk)
+
+ if chunk:
+ yield unicode_escapes_to_utf_8(chunk)
+
+ def yield_output(_as_json_utf_8, _output_chunk_size):
+ # Typically web servers send an HTTP chunk for every yield of the body generator, which
+ # can result in quite small chunks so more packets/bytes over the wire. We avoid this.
+ chunks = []
+ num_bytes = 0
+ for chunk in _as_json_utf_8:
+ chunks.append(chunk)
+ num_bytes += len(chunk)
+ if num_bytes < _output_chunk_size:
+ continue
+ chunk = b''.join(chunks)
+ output, chunk = chunk[:_output_chunk_size], chunk[_output_chunk_size:]
+ yield output
+ num_bytes = len(chunk)
+ chunks = [chunk] if chunk else []
+
+ if chunks:
+ yield b''.join(chunks)
+
+ read_multiple_chunks, read_single_chunk = get_byte_readers(input_iterable)
+ messages = yield_messages(read_multiple_chunks, read_single_chunk)
+ records = yield_records(messages)
+ as_json = yield_as_json(records)
+ as_json_utf_8 = yield_as_utf_8(as_json)
+ output = yield_output(as_json_utf_8, output_chunk_size)
+
+ return output
+
app = Flask('app')
app.add_url_rule('/', view_func=proxy)
diff --git a/test.py b/test.py
index 11dab95..da72b98 100644
--- a/test.py
+++ b/test.py
@@ -3,6 +3,7 @@
)
import hashlib
import hmac
+import json
import os
import time
import socket
@@ -292,6 +293,126 @@ def test_healthcheck(self):
self.assertEqual(resp_1.status_code, 200)
self.assertEqual(resp_1.content, b'OK')
+ def test_select_all(self):
+ wait_until_started, stop_application = create_application(8080)
+ self.addCleanup(stop_application)
+ wait_until_started()
+
+ key = str(uuid.uuid4()) + '/' + str(uuid.uuid4())
+ content = json.dumps({
+ 'topLevel': (
+ [{'a': '>&', 'd': 'e'}] * 100000
+ + [{'a': 'c'}] * 1
+ + [{'a': '🍰', 'd': 'f'}] * 100000
+ )
+ }, separators=(',', ':'), ensure_ascii=False).encode('utf-8')
+ put_object(key, content)
+
+ params = {
+ 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]'
+ }
+ expected_content = json.dumps({
+ 'rows': (
+ [{'a': '>&', 'd': 'e'}] * 100000
+ + [{'a': 'c'}] * 1
+ + [{'a': '🍰', 'd': 'f'}] * 100000
+ ),
+ }, separators=(',', ':'), ensure_ascii=False).encode('utf-8')
+
+ with \
+ requests.Session() as session, \
+ session.get(f'http://127.0.0.1:8080/{key}', params=params) as response:
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, expected_content)
+ self.assertEqual(len(response.history), 0)
+
+ def test_select_newlines(self):
+ wait_until_started, stop_application = create_application(8080)
+ self.addCleanup(stop_application)
+ wait_until_started()
+
+ key = str(uuid.uuid4()) + '/' + str(uuid.uuid4())
+ content = json.dumps({
+ 'topLevel': (
+ [{'a': '\n' * 10000}] * 100
+ )
+ }, separators=(',', ':'), ensure_ascii=False).encode('utf-8')
+ put_object(key, content)
+
+ params = {
+ 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]'
+ }
+ expected_content = json.dumps({
+ 'rows': (
+ [{'a': '\n' * 10000}] * 100
+ ),
+ }, separators=(',', ':'), ensure_ascii=False).encode('utf-8')
+
+ with \
+ requests.Session() as session, \
+ session.get(f'http://127.0.0.1:8080/{key}', params=params) as response:
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, expected_content)
+ self.assertEqual(len(response.history), 0)
+
+ def test_select_strings_that_are_almost_unicode_escapes(self):
+ wait_until_started, stop_application = create_application(8080)
+ self.addCleanup(stop_application)
+ wait_until_started()
+
+ key = str(uuid.uuid4()) + '/' + str(uuid.uuid4())
+ content = json.dumps({
+ 'topLevel': (
+ [{'a': '\\u003e🍰\\u0026>&\\u003e\\u0026>\\u0026\\u002\\\\u0026\\n' * 10000}] * 10
+ )
+ }, separators=(',', ':'), ensure_ascii=False).encode('utf-8')
+ put_object(key, content)
+
+ params = {
+ 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]'
+ }
+ expected_content = json.dumps({
+ 'rows': (
+ [{'a': '\\u003e🍰\\u0026>&\\u003e\\u0026>\\u0026\\u002\\\\u0026\\n' * 10000}] * 10
+ ),
+ }, separators=(',', ':'), ensure_ascii=False).encode('utf-8')
+
+ with \
+ requests.Session() as session, \
+ session.get(f'http://127.0.0.1:8080/{key}', params=params) as response:
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, expected_content)
+ self.assertEqual(len(response.history), 0)
+
+ def test_select_subset(self):
+ wait_until_started, stop_application = create_application(8080)
+ self.addCleanup(stop_application)
+ wait_until_started()
+
+ key = str(uuid.uuid4()) + '/' + str(uuid.uuid4())
+ content = json.dumps({
+ 'topLevel': (
+ [{'a': '>&', 'd': 'e'}] * 100000
+ + [{'a': 'c'}] * 1
+ + [{'a': '🍰', 'd': 'f'}] * 100000
+ )
+ }, separators=(',', ':'), ensure_ascii=False).encode('utf-8')
+ put_object(key, content)
+
+ params = {
+ 'query_sql': "SELECT * FROM S3Object[*].topLevel[*] AS t WHERE t.a = '>&' OR t.a='🍰'"
+ }
+ expected_content = json.dumps({
+ 'rows': [{'a': '>&', 'd': 'e'}] * 100000 + [{'a': '🍰', 'd': 'f'}] * 100000,
+ }, separators=(',', ':'), ensure_ascii=False).encode('utf-8')
+
+ with \
+ requests.Session() as session, \
+ session.get(f'http://127.0.0.1:8080/{key}', params=params) as response:
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, expected_content)
+ self.assertEqual(len(response.history), 0)
+
def create_application(
port, max_attempts=100,