diff --git a/splunklib/searchcommands/search_command.py b/splunklib/searchcommands/search_command.py index 8ec82bba0..faed2eaff 100644 --- a/splunklib/searchcommands/search_command.py +++ b/splunklib/searchcommands/search_command.py @@ -851,7 +851,8 @@ def _execute(self, ifile, process): @staticmethod def _as_binary_stream(ifile): - if six.PY2: + naught = ifile.read(0) + if isinstance(naught, bytes): return ifile try: diff --git a/tests/searchcommands/chunked_data_stream.py b/tests/searchcommands/chunked_data_stream.py new file mode 100644 index 000000000..ae5363eff --- /dev/null +++ b/tests/searchcommands/chunked_data_stream.py @@ -0,0 +1,100 @@ +import collections +import csv +import io +import json + +import splunklib.searchcommands.internals +from splunklib import six + + +class Chunk(object): + def __init__(self, version, meta, data): + self.version = six.ensure_str(version) + self.meta = json.loads(meta) + dialect = splunklib.searchcommands.internals.CsvDialect + self.data = csv.DictReader(io.StringIO(data.decode("utf-8")), + dialect=dialect) + + +class ChunkedDataStreamIter(collections.Iterator): + def __init__(self, chunk_stream): + self.chunk_stream = chunk_stream + + def __next__(self): + return self.next() + + def next(self): + try: + return self.chunk_stream.read_chunk() + except EOFError: + raise StopIteration + + +class ChunkedDataStream(collections.Iterable): + def __iter__(self): + return ChunkedDataStreamIter(self) + + def __init__(self, stream): + empty = stream.read(0) + assert isinstance(empty, bytes) + self.stream = stream + + def read_chunk(self): + header = self.stream.readline() + + while len(header) > 0 and header.strip() == b'': + header = self.stream.readline() # Skip empty lines + if len(header) == 0: + raise EOFError + + version, meta, data = header.rstrip().split(b',') + metabytes = self.stream.read(int(meta)) + databytes = self.stream.read(int(data)) + return Chunk(version, metabytes, databytes) + + +def build_chunk(keyval, data=None): + metadata = six.ensure_binary(json.dumps(keyval), 'utf-8') + data_output = _build_data_csv(data) + return b"chunked 1.0,%d,%d\n%s%s" % (len(metadata), len(data_output), metadata, data_output) + + +def build_empty_searchinfo(): + return { + 'earliest_time': 0, + 'latest_time': 0, + 'search': "", + 'dispatch_dir': "", + 'sid': "", + 'args': [], + 'splunk_version': "42.3.4", + } + + +def build_getinfo_chunk(): + return build_chunk({ + 'action': 'getinfo', + 'preview': False, + 'searchinfo': build_empty_searchinfo()}) + + +def build_data_chunk(data, finished=True): + return build_chunk({'action': 'execute', 'finished': finished}, data) + + +def _build_data_csv(data): + if data is None: + return b'' + if isinstance(data, bytes): + return data + csvout = splunklib.six.StringIO() + + headers = set() + for datum in data: + headers.update(datum.keys()) + writer = csv.DictWriter(csvout, headers, + dialect=splunklib.searchcommands.internals.CsvDialect) + writer.writeheader() + for datum in data: + writer.writerow(datum) + return six.ensure_binary(csvout.getvalue()) diff --git a/tests/searchcommands/test_generator_command.py b/tests/searchcommands/test_generator_command.py new file mode 100644 index 000000000..4af61a5d2 --- /dev/null +++ b/tests/searchcommands/test_generator_command.py @@ -0,0 +1,44 @@ +import io +import time + +from . import chunked_data_stream as chunky + +from splunklib.searchcommands import Configuration, GeneratingCommand + + +def test_simple_generator(): + @Configuration() + class GeneratorTest(GeneratingCommand): + def generate(self): + for num in range(1, 10): + yield {'_time': time.time(), 'event_index': num} + generator = GeneratorTest() + in_stream = io.BytesIO() + in_stream.write(chunky.build_getinfo_chunk()) + in_stream.write(chunky.build_chunk({'action': 'execute'})) + in_stream.seek(0) + out_stream = io.BytesIO() + generator._process_protocol_v2([], in_stream, out_stream) + out_stream.seek(0) + + ds = chunky.ChunkedDataStream(out_stream) + is_first_chunk = True + finished_seen = False + expected = set(map(lambda i: str(i), range(1, 10))) + seen = set() + for chunk in ds: + if is_first_chunk: + assert chunk.meta["generating"] is True + assert chunk.meta["type"] == "stateful" + is_first_chunk = False + finished_seen = chunk.meta.get("finished", False) + for row in chunk.data: + seen.add(row["event_index"]) + print(out_stream.getvalue()) + print(expected) + print(seen) + assert expected.issubset(seen) + assert finished_seen + + + diff --git a/tests/searchcommands/test_reporting_command.py b/tests/searchcommands/test_reporting_command.py new file mode 100644 index 000000000..e5add818c --- /dev/null +++ b/tests/searchcommands/test_reporting_command.py @@ -0,0 +1,34 @@ +import io + +import splunklib.searchcommands as searchcommands +from . import chunked_data_stream as chunky + + +def test_simple_reporting_command(): + @searchcommands.Configuration() + class TestReportingCommand(searchcommands.ReportingCommand): + def reduce(self, records): + value = 0 + for record in records: + value += int(record["value"]) + yield {'sum': value} + + cmd = TestReportingCommand() + ifile = io.BytesIO() + data = list() + for i in range(0, 10): + data.append({"value": str(i)}) + ifile.write(chunky.build_getinfo_chunk()) + ifile.write(chunky.build_data_chunk(data)) + ifile.seek(0) + ofile = io.BytesIO() + cmd._process_protocol_v2([], ifile, ofile) + ofile.seek(0) + chunk_stream = chunky.ChunkedDataStream(ofile) + getinfo_response = chunk_stream.read_chunk() + assert getinfo_response.meta['type'] == 'reporting' + data_chunk = chunk_stream.read_chunk() + assert data_chunk.meta['finished'] is True # Should only be one row + data = list(data_chunk.data) + assert len(data) == 1 + assert int(data[0]['sum']) == sum(range(0, 10)) diff --git a/tests/searchcommands/test_streaming_command.py b/tests/searchcommands/test_streaming_command.py new file mode 100644 index 000000000..dcc00b53e --- /dev/null +++ b/tests/searchcommands/test_streaming_command.py @@ -0,0 +1,29 @@ +import io + +from . import chunked_data_stream as chunky +from splunklib.searchcommands import StreamingCommand, Configuration + + +def test_simple_streaming_command(): + @Configuration() + class TestStreamingCommand(StreamingCommand): + + def stream(self, records): + for record in records: + record["out_index"] = record["in_index"] + yield record + + cmd = TestStreamingCommand() + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + data = list() + for i in range(0, 10): + data.append({"in_index": str(i)}) + ifile.write(chunky.build_data_chunk(data, finished=True)) + ifile.seek(0) + ofile = io.BytesIO() + cmd._process_protocol_v2([], ifile, ofile) + ofile.seek(0) + output = chunky.ChunkedDataStream(ofile) + getinfo_response = output.read_chunk() + assert getinfo_response.meta["type"] == "streaming"