From 382675b201d76cbbf4e61202d02eb3a869bb882e Mon Sep 17 00:00:00 2001 From: Michal Charemza Date: Sun, 2 Aug 2020 11:41:03 +0100 Subject: [PATCH] feat: S3 Select happy path Not using boto3 deliberately. This does mean we have low-level code, but... - To actually output JSON, we do have to do a little bit of faffing with bytes anyway, since S3 Select does not output valid JSON: it outputs JSON objects contatanated together with a delimeter. - S3 Select [at least, minios's implementation used for testing], appears to output unicode escape sequences for certain characters that don't need it, i.e. '\u0026' instead of just '&'. Not using boto3 means we can address issues like this in as a performant way as possible [even if we don't do much optimisation now, we are free to later]. - boto3 does not always support all of AWS, specifically with S3. For example https://github.com/boto/botocore/pull/996 has been open for 4 years (to the day!) at this point, so we should not be thwarted or have to workaround any limitation of boto3: I suspect its architecture is not optimized for low-level/streaming access, which is exactly the sort of thing that is suspected to be useful in this project. - Am pro keeping the option to use asyncio open, at least for now in this early stage of the project, and boto3 doesn't appear to support it out of the box. In its current form, to move to asyncio shouldn't be a massive project since the dependencies are a web server and a web client [both of which would likely be aiohttp]. --- app.py | 240 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- test.py | 121 ++++++++++++++++++++++++++++ 2 files changed, 356 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 76bd933..6202733 100644 --- a/app.py +++ b/app.py @@ -11,9 +11,14 @@ import hmac import logging import os +import re import signal +from struct import ( + Struct, +) import sys import urllib.parse +from xml.sax.saxutils import escape as escape_xml from flask import ( Flask, @@ -55,17 +60,33 @@ def proxy(path): logger.debug('Attempt to proxy: %s', request) url = endpoint_url + path - body_hash = hashlib.sha256(b'').hexdigest() + parsed_url = urllib.parse.urlsplit(url) + method, body, params, parse_response = \ + ( + 'POST', + aws_select_post_body(request.args['query_sql']), + (('select', ''), ('select-type', '2')), + aws_select_parse_result, + ) if 'query_sql' in request.args else \ + ( + 'GET', + b'', + {}, + lambda x, _: x, + ) + + body_hash = hashlib.sha256(body).hexdigest() pre_auth_headers = tuple(( (key, request.headers[key]) for key in proxied_request_headers if key in request.headers )) - parsed_url = urllib.parse.urlsplit(url) + encoded_params = urllib.parse.urlencode(params) request_headers = aws_sigv4_headers( - pre_auth_headers, 's3', parsed_url.netloc, 'GET', parsed_url.path, (), body_hash, + pre_auth_headers, 's3', parsed_url.netloc, method, parsed_url.path, params, body_hash, ) + response = http.request(method, f'{url}?{encoded_params}', headers=dict( + request_headers), body=body, preload_content=False) - response = http.request('GET', url, headers=dict(request_headers), preload_content=False) response_headers = tuple(( (key, response.headers[key]) for key in proxied_response_headers if key in response.headers @@ -84,7 +105,7 @@ def body_empty(): pass downstream_response = \ - Response(response.stream(65536, decode_content=False), + Response(parse_response(response.stream(65536, decode_content=False), 65536), status=response.status, headers=response_headers) if allow_proxy else \ Response(body_empty(), status=500) downstream_response.call_on_close(response.release_conn) @@ -144,6 +165,215 @@ def sign(key, msg): (b'x-amz-content-sha256', body_hash.encode('ascii')), ) + pre_auth_headers + def aws_select_post_body(sql): + sql_xml_escaped = escape_xml(sql) + return \ + f''' + + {sql_xml_escaped} + SQL + + + Document + + + + + , + + + + '''.encode('utf-8') + + def aws_select_parse_result(input_iterable, output_chunk_size): + # Returns a iterator that yields payload data in fixed size chunks. It does not depend + # on the input_stream yielding chunks of any particular size, and internal copying or + # concatanation of chunks is avoided + + class NoMoreBytes(Exception): + pass + + def get_byte_readers(_input_iterable): + chunk = b'' + offset = 0 + it = iter(_input_iterable) + + def _read_multiple_chunks(amt): + nonlocal chunk + nonlocal offset + + # Yield anything we already have + if chunk: + to_yield = min(amt, len(chunk) - offset) + yield chunk[offset:offset + to_yield] + amt -= to_yield + offset += to_yield % len(chunk) + + # Yield the rest as it comes in + while amt: + try: + chunk = next(it) + except StopIteration: + raise NoMoreBytes() + to_yield = min(amt, len(chunk)) + yield chunk[:to_yield] + amt -= to_yield + offset = to_yield % len(chunk) + chunk = chunk if offset else b'' + + def _read_single_chunk(amt): + raw = b''.join(chunk for chunk in _read_multiple_chunks(amt)) + if raw: + return raw + raise NoMoreBytes() + + return _read_multiple_chunks, _read_single_chunk + + ################################ + # Extract records from the bytes + + def yield_messages(_read_multiple_chunks, _read_single_chunk): + # Yields a series of messages. Each is a dict of headers together with a generator that + # itself yields the bytes of the payload of the message. The payload generator must + # be read by calling code before the next iteration of this generator + prelude_struct = Struct('!III') + byte_struct = Struct('!B') + header_value_struct = Struct('!H') + + while True: + try: + total_length, header_length, _ = prelude_struct.unpack(_read_single_chunk(12)) + except NoMoreBytes: + return + payload_length = total_length - header_length - 16 + + # Read headers. Any given header type can only appear once, so a dict + # type => value is fine + headers = {} + while header_length: + header_key_length, = byte_struct.unpack(_read_single_chunk(1)) + header_key = _read_single_chunk(header_key_length).decode('utf-8') + _ = _read_single_chunk(1) # Header value type is ignored for S3 + header_value_length, = header_value_struct.unpack(_read_single_chunk(2)) + header_value = _read_single_chunk(header_value_length).decode('utf-8') + header_length -= (1 + header_key_length + 1 + 2 + header_value_length) + headers[header_key] = header_value + + def payload(): + for chunk in _read_multiple_chunks(payload_length): + yield chunk + + # Ignore final CRC + final_crc_length = 4 + for _ in _read_multiple_chunks(final_crc_length): + pass + + yield headers, payload() + + def yield_records(_messages): + for headers, payload in _messages: + if headers[':message-type'] == 'event' and headers[':event-type'] == 'Records': + yield from payload + else: + for _ in payload: + pass + + def yield_as_json(_records): + yield b'{"rows":[' + + # Slightly faffy to remove the trailing "," from S3 Select output + try: + last = next(_records) + except StopIteration: + pass + else: + for val in _records: + yield last + last = val + + yield last[:len(last) - 1] + + yield b']}' + + def yield_as_utf_8(_as_json): + # The output from S3 Select [at least from minio] appears to include unicode escape + # sequences, even for characters like > and &. A plain .decode('unicode-escape') isn't + # enough to convert them, since an excape sequence can be truncated if it crosses into + # the next chunk, and in fact even using .decode('unicode-escape') where you're sure + # there is no truncated unicode escape sequence breaks non-ASCII UTF-8 data, since it + # appears to treat them as Latin-1. So we have to do our own search and and replace. + + def even_slashes_before(_chunk, index): + count = 0 + index -= 1 + while index >= 0 and _chunk[index:index + 1] == b'\\': + count += 1 + index -= 1 + return count % 2 == 0 + + def split_trailing_escape(_chunk): + # \, \u, \uX, \uXX, \uXXX, with an even number of \ before are trailing escapes + if _chunk[-1:] == b'\\': + if even_slashes_before(_chunk, len(_chunk) - 1): + return _chunk[:-1], _chunk[-1:] + elif _chunk[-2:] == b'\\u': + if even_slashes_before(_chunk, len(_chunk) - 2): + return _chunk[:-2], _chunk[-2:] + elif _chunk[-3:-1] == b'\\u': + if even_slashes_before(_chunk, len(_chunk) - 3): + return _chunk[:-3], _chunk[-3:] + elif _chunk[-4:-2] == b'\\u': + if even_slashes_before(_chunk, len(_chunk) - 4): + return _chunk[:-4], _chunk[-4:] + elif _chunk[-5:-3] == b'\\u': + if even_slashes_before(_chunk, len(_chunk) - 5): + return _chunk[:-5], _chunk[-5:] + return _chunk, b'' + + def unicode_escapes_to_utf_8(_chunk): + def to_utf_8(match): + group = match.group() + if even_slashes_before(_chunk, match.span()[0]): + return group.decode('unicode-escape').encode('utf-8') + return group + + return re.sub(b'\\\\u[0-9a-fA-F]{4}', to_utf_8, _chunk) + + trailing_escape = b'' + for chunk in as_json: + chunk, trailing_escape = split_trailing_escape(trailing_escape + chunk) + + if chunk: + yield unicode_escapes_to_utf_8(chunk) + + def yield_output(_as_json_utf_8, _output_chunk_size): + # Typically web servers send an HTTP chunk for every yield of the body generator, which + # can result in quite small chunks so more packets/bytes over the wire. We avoid this. + chunks = [] + num_bytes = 0 + for chunk in _as_json_utf_8: + chunks.append(chunk) + num_bytes += len(chunk) + if num_bytes < _output_chunk_size: + continue + chunk = b''.join(chunks) + output, chunk = chunk[:_output_chunk_size], chunk[_output_chunk_size:] + yield output + num_bytes = len(chunk) + chunks = [chunk] if chunk else [] + + if chunks: + yield b''.join(chunks) + + read_multiple_chunks, read_single_chunk = get_byte_readers(input_iterable) + messages = yield_messages(read_multiple_chunks, read_single_chunk) + records = yield_records(messages) + as_json = yield_as_json(records) + as_json_utf_8 = yield_as_utf_8(as_json) + output = yield_output(as_json_utf_8, output_chunk_size) + + return output + app = Flask('app') app.add_url_rule('/', view_func=proxy) diff --git a/test.py b/test.py index 11dab95..da72b98 100644 --- a/test.py +++ b/test.py @@ -3,6 +3,7 @@ ) import hashlib import hmac +import json import os import time import socket @@ -292,6 +293,126 @@ def test_healthcheck(self): self.assertEqual(resp_1.status_code, 200) self.assertEqual(resp_1.content, b'OK') + def test_select_all(self): + wait_until_started, stop_application = create_application(8080) + self.addCleanup(stop_application) + wait_until_started() + + key = str(uuid.uuid4()) + '/' + str(uuid.uuid4()) + content = json.dumps({ + 'topLevel': ( + [{'a': '>&', 'd': 'e'}] * 100000 + + [{'a': 'c'}] * 1 + + [{'a': '🍰', 'd': 'f'}] * 100000 + ) + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + put_object(key, content) + + params = { + 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]' + } + expected_content = json.dumps({ + 'rows': ( + [{'a': '>&', 'd': 'e'}] * 100000 + + [{'a': 'c'}] * 1 + + [{'a': '🍰', 'd': 'f'}] * 100000 + ), + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + + with \ + requests.Session() as session, \ + session.get(f'http://127.0.0.1:8080/{key}', params=params) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, expected_content) + self.assertEqual(len(response.history), 0) + + def test_select_newlines(self): + wait_until_started, stop_application = create_application(8080) + self.addCleanup(stop_application) + wait_until_started() + + key = str(uuid.uuid4()) + '/' + str(uuid.uuid4()) + content = json.dumps({ + 'topLevel': ( + [{'a': '\n' * 10000}] * 100 + ) + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + put_object(key, content) + + params = { + 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]' + } + expected_content = json.dumps({ + 'rows': ( + [{'a': '\n' * 10000}] * 100 + ), + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + + with \ + requests.Session() as session, \ + session.get(f'http://127.0.0.1:8080/{key}', params=params) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, expected_content) + self.assertEqual(len(response.history), 0) + + def test_select_strings_that_are_almost_unicode_escapes(self): + wait_until_started, stop_application = create_application(8080) + self.addCleanup(stop_application) + wait_until_started() + + key = str(uuid.uuid4()) + '/' + str(uuid.uuid4()) + content = json.dumps({ + 'topLevel': ( + [{'a': '\\u003e🍰\\u0026>&\\u003e\\u0026>\\u0026\\u002\\\\u0026\\n' * 10000}] * 10 + ) + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + put_object(key, content) + + params = { + 'query_sql': 'SELECT * FROM S3Object[*].topLevel[*]' + } + expected_content = json.dumps({ + 'rows': ( + [{'a': '\\u003e🍰\\u0026>&\\u003e\\u0026>\\u0026\\u002\\\\u0026\\n' * 10000}] * 10 + ), + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + + with \ + requests.Session() as session, \ + session.get(f'http://127.0.0.1:8080/{key}', params=params) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, expected_content) + self.assertEqual(len(response.history), 0) + + def test_select_subset(self): + wait_until_started, stop_application = create_application(8080) + self.addCleanup(stop_application) + wait_until_started() + + key = str(uuid.uuid4()) + '/' + str(uuid.uuid4()) + content = json.dumps({ + 'topLevel': ( + [{'a': '>&', 'd': 'e'}] * 100000 + + [{'a': 'c'}] * 1 + + [{'a': '🍰', 'd': 'f'}] * 100000 + ) + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + put_object(key, content) + + params = { + 'query_sql': "SELECT * FROM S3Object[*].topLevel[*] AS t WHERE t.a = '>&' OR t.a='🍰'" + } + expected_content = json.dumps({ + 'rows': [{'a': '>&', 'd': 'e'}] * 100000 + [{'a': '🍰', 'd': 'f'}] * 100000, + }, separators=(',', ':'), ensure_ascii=False).encode('utf-8') + + with \ + requests.Session() as session, \ + session.get(f'http://127.0.0.1:8080/{key}', params=params) as response: + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, expected_content) + self.assertEqual(len(response.history), 0) + def create_application( port, max_attempts=100,