diff --git a/app.py b/app.py index 76bd933..6202733 100644 --- a/app.py +++ b/app.py @@ -11,9 +11,14 @@ import hmac import logging import os +import re import signal +from struct import ( + Struct, +) import sys import urllib.parse +from xml.sax.saxutils import escape as escape_xml from flask import ( Flask, @@ -55,17 +60,33 @@ def proxy(path): logger.debug('Attempt to proxy: %s', request) url = endpoint_url + path - body_hash = hashlib.sha256(b'').hexdigest() + parsed_url = urllib.parse.urlsplit(url) + method, body, params, parse_response = \ + ( + 'POST', + aws_select_post_body(request.args['query_sql']), + (('select', ''), ('select-type', '2')), + aws_select_parse_result, + ) if 'query_sql' in request.args else \ + ( + 'GET', + b'', + {}, + lambda x, _: x, + ) + + body_hash = hashlib.sha256(body).hexdigest() pre_auth_headers = tuple(( (key, request.headers[key]) for key in proxied_request_headers if key in request.headers )) - parsed_url = urllib.parse.urlsplit(url) + encoded_params = urllib.parse.urlencode(params) request_headers = aws_sigv4_headers( - pre_auth_headers, 's3', parsed_url.netloc, 'GET', parsed_url.path, (), body_hash, + pre_auth_headers, 's3', parsed_url.netloc, method, parsed_url.path, params, body_hash, ) + response = http.request(method, f'{url}?{encoded_params}', headers=dict( + request_headers), body=body, preload_content=False) - response = http.request('GET', url, headers=dict(request_headers), preload_content=False) response_headers = tuple(( (key, response.headers[key]) for key in proxied_response_headers if key in response.headers @@ -84,7 +105,7 @@ def body_empty(): pass downstream_response = \ - Response(response.stream(65536, decode_content=False), + Response(parse_response(response.stream(65536, decode_content=False), 65536), status=response.status, headers=response_headers) if allow_proxy else \ Response(body_empty(), status=500) downstream_response.call_on_close(response.release_conn) @@ -144,6 +165,215 @@ def sign(key, msg): (b'x-amz-content-sha256', body_hash.encode('ascii')), ) + pre_auth_headers + def aws_select_post_body(sql): + sql_xml_escaped = escape_xml(sql) + return \ + f''' + + {sql_xml_escaped} + SQL + + + Document + + + + + , + + + + '''.encode('utf-8') + + def aws_select_parse_result(input_iterable, output_chunk_size): + # Returns a iterator that yields payload data in fixed size chunks. It does not depend + # on the input_stream yielding chunks of any particular size, and internal copying or + # concatanation of chunks is avoided + + class NoMoreBytes(Exception): + pass + + def get_byte_readers(_input_iterable): + chunk = b'' + offset = 0 + it = iter(_input_iterable) + + def _read_multiple_chunks(amt): + nonlocal chunk + nonlocal offset + + # Yield anything we already have + if chunk: + to_yield = min(amt, len(chunk) - offset) + yield chunk[offset:offset + to_yield] + amt -= to_yield + offset += to_yield % len(chunk) + + # Yield the rest as it comes in + while amt: + try: + chunk = next(it) + except StopIteration: + raise NoMoreBytes() + to_yield = min(amt, len(chunk)) + yield chunk[:to_yield] + amt -= to_yield + offset = to_yield % len(chunk) + chunk = chunk if offset else b'' + + def _read_single_chunk(amt): + raw = b''.join(chunk for chunk in _read_multiple_chunks(amt)) + if raw: + return raw + raise NoMoreBytes() + + return _read_multiple_chunks, _read_single_chunk + + ################################ + # Extract records from the bytes + + def yield_messages(_read_multiple_chunks, _read_single_chunk): + # Yields a series of messages. Each is a dict of headers together with a generator that + # itself yields the bytes of the payload of the message. The payload generator must + # be read by calling code before the next iteration of this generator + prelude_struct = Struct('!III') + byte_struct = Struct('!B') + header_value_struct = Struct('!H') + + while True: + try: + total_length, header_length, _ = prelude_struct.unpack(_read_single_chunk(12)) + except NoMoreBytes: + return + payload_length = total_length - header_length - 16 + + # Read headers. Any given header type can only appear once, so a dict + # type => value is fine + headers = {} + while header_length: + header_key_length, = byte_struct.unpack(_read_single_chunk(1)) + header_key = _read_single_chunk(header_key_length).decode('utf-8') + _ = _read_single_chunk(1) # Header value type is ignored for S3 + header_value_length, = header_value_struct.unpack(_read_single_chunk(2)) + header_value = _read_single_chunk(header_value_length).decode('utf-8') + header_length -= (1 + header_key_length + 1 + 2 + header_value_length) + headers[header_key] = header_value + + def payload(): + for chunk in _read_multiple_chunks(payload_length): + yield chunk + + # Ignore final CRC + final_crc_length = 4 + for _ in _read_multiple_chunks(final_crc_length): + pass + + yield headers, payload() + + def yield_records(_messages): + for headers, payload in _messages: + if headers[':message-type'] == 'event' and headers[':event-type'] == 'Records': + yield from payload + else: + for _ in payload: + pass + + def yield_as_json(_records): + yield b'{"rows":[' + + # Slightly faffy to remove the trailing "," from S3 Select output + try: + last = next(_records) + except StopIteration: + pass + else: + for val in _records: + yield last + last = val + + yield last[:len(last) - 1] + + yield b']}' + + def yield_as_utf_8(_as_json): + # The output from S3 Select [at least from minio] appears to include unicode escape + # sequences, even for characters like > and &. A plain .decode('unicode-escape') isn't + # enough to convert them, since an excape sequence can be truncated if it crosses into + # the next chunk, and in fact even using .decode('unicode-escape') where you're sure + # there is no truncated unicode escape sequence breaks non-ASCII UTF-8 data, since it + # appears to treat them as Latin-1. So we have to do our own search and and replace. + + def even_slashes_before(_chunk, index): + count = 0 + index -= 1 + while index >= 0 and _chunk[index:index + 1] == b'\\': + count += 1 + index -= 1 + return count % 2 == 0 + + def split_trailing_escape(_chunk): + # \, \u, \uX, \uXX, \uXXX, with an even number of \ before are trailing escapes + if _chunk[-1:] == b'\\': + if even_slashes_before(_chunk, len(_chunk) - 1): + return _chunk[:-1], _chunk[-1:] + elif _chunk[-2:] == b'\\u': + if even_slashes_before(_chunk, len(_chunk) - 2): + return _chunk[:-2], _chunk[-2:] + elif _chunk[-3:-1] == b'\\u': + if even_slashes_before(_chunk, len(_chunk) - 3): + return _chunk[:-3], _chunk[-3:] + elif _chunk[-4:-2] == b'\\u': + if even_slashes_before(_chunk, len(_chunk) - 4): + return _chunk[:-4], _chunk[-4:] + elif _chunk[-5:-3] == b'\\u': + if even_slashes_before(_chunk, len(_chunk) - 5): + return _chunk[:-5], _chunk[-5:] + return _chunk, b'' + + def unicode_escapes_to_utf_8(_chunk): + def to_utf_8(match): + group = match.group() + if even_slashes_before(_chunk, match.span()[0]): + return group.decode('unicode-escape').encode('utf-8') + return group + + return re.sub(b'\\\\u[0-9a-fA-F]{4}', to_utf_8, _chunk) + + trailing_escape = b'' + for chunk in as_json: + chunk, trailing_escape = split_trailing_escape(trailing_escape + chunk) + + if chunk: + yield unicode_escapes_to_utf_8(chunk) + + def yield_output(_as_json_utf_8, _output_chunk_size): + # Typically web servers send an HTTP chunk for every yield of the body generator, which + # can result in quite small chunks so more packets/bytes over the wire. We avoid this. + chunks = [] + num_bytes = 0 + for chunk in _as_json_utf_8: + chunks.append(chunk) + num_bytes += len(chunk) + if num_bytes < _output_chunk_size: + continue + chunk = b''.join(chunks) + output, chunk = chunk[:_output_chunk_size], chunk[_output_chunk_size:] + yield output + num_bytes = len(chunk) + chunks = [chunk] if chunk else [] + + if chunks: + yield b''.join(chunks) + + read_multiple_chunks, read_single_chunk = get_byte_readers(input_iterable) + messages = yield_messages(read_multiple_chunks, read_single_chunk) + records = yield_records(messages) + as_json = yield_as_json(records) + as_json_utf_8 = yield_as_utf_8(as_json) + output = yield_output(as_json_utf_8, output_chunk_size) + + return output + app = Flask('app') app.add_url_rule('/', view_func=proxy) diff --git a/test.py b/test.py index 11dab95..da72b98 100644 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ ) import hashlib import hmac +import json import os import time import socket @@ -292,6 +293,126 @@ def test_healthcheck(self): self.assertEqual(resp_1.status_code, 200) self.assertEqual(resp_1.content, b'OK') + def test_select_all(self): + wait_until_started, stop_application = create_application(8080) + self.addCleanup(stop_application) + wait_until_started() + + key = str(uuid.uuid4()) + '/' + str(uuid.uuid4()) + content = json.dumps({ + 'topLevel': ( + [{'a': '>&', 'd': 'e'}] * 100000 + + [{'a': 'c'}] * 1 + + [{'a': '🍰', 'd': 'f'}] * 100000 + ) + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + put_object(key, content) + + params = { + 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]' + } + expected_content = json.dumps({ + 'rows': ( + [{'a': '>&', 'd': 'e'}] * 100000 + + [{'a': 'c'}] * 1 + + [{'a': '🍰', 'd': 'f'}] * 100000 + ), + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + + with \ + requests.Session() as session, \ + session.get(f'http://127.0.0.1:8080/{key}', params=params) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, expected_content) + self.assertEqual(len(response.history), 0) + + def test_select_newlines(self): + wait_until_started, stop_application = create_application(8080) + self.addCleanup(stop_application) + wait_until_started() + + key = str(uuid.uuid4()) + '/' + str(uuid.uuid4()) + content = json.dumps({ + 'topLevel': ( + [{'a': '\n' * 10000}] * 100 + ) + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + put_object(key, content) + + params = { + 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]' + } + expected_content = json.dumps({ + 'rows': ( + [{'a': '\n' * 10000}] * 100 + ), + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + + with \ + requests.Session() as session, \ + session.get(f'http://127.0.0.1:8080/{key}', params=params) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, expected_content) + self.assertEqual(len(response.history), 0) + + def test_select_strings_that_are_almost_unicode_escapes(self): + wait_until_started, stop_application = create_application(8080) + self.addCleanup(stop_application) + wait_until_started() + + key = str(uuid.uuid4()) + '/' + str(uuid.uuid4()) + content = json.dumps({ + 'topLevel': ( + [{'a': '\\u003e🍰\\u0026>&\\u003e\\u0026>\\u0026\\u002\\\\u0026\\n' * 10000}] * 10 + ) + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + put_object(key, content) + + params = { + 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]' + } + expected_content = json.dumps({ + 'rows': ( + [{'a': '\\u003e🍰\\u0026>&\\u003e\\u0026>\\u0026\\u002\\\\u0026\\n' * 10000}] * 10 + ), + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + + with \ + requests.Session() as session, \ + session.get(f'http://127.0.0.1:8080/{key}', params=params) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, expected_content) + self.assertEqual(len(response.history), 0) + + def test_select_subset(self): + wait_until_started, stop_application = create_application(8080) + self.addCleanup(stop_application) + wait_until_started() + + key = str(uuid.uuid4()) + '/' + str(uuid.uuid4()) + content = json.dumps({ + 'topLevel': ( + [{'a': '>&', 'd': 'e'}] * 100000 + + [{'a': 'c'}] * 1 + + [{'a': '🍰', 'd': 'f'}] * 100000 + ) + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + put_object(key, content) + + params = { + 'query_sql': "SELECT * FROM S3Object[*].topLevel[*] AS t WHERE t.a = '>&' OR t.a='🍰'" + } + expected_content = json.dumps({ + 'rows': [{'a': '>&', 'd': 'e'}] * 100000 + [{'a': '🍰', 'd': 'f'}] * 100000, + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + + with \ + requests.Session() as session, \ + session.get(f'http://127.0.0.1:8080/{key}', params=params) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, expected_content) + self.assertEqual(len(response.history), 0) + def create_application( port, max_attempts=100,