diff --git a/pystreamapi/__stream.py b/pystreamapi/__stream.py index c16a21b..290c7e3 100644 --- a/pystreamapi/__stream.py +++ b/pystreamapi/__stream.py @@ -87,7 +87,7 @@ def concat(*streams: "BaseStream[_K]"): :param streams: The streams to concatenate :return: The concatenated stream """ - return streams[0].__class__(itertools.chain(*list(streams))) + return streams[0].__class__(itertools.chain(*iter(streams))) @staticmethod def iterate(seed: _K, func: Callable[[_K], _K]) -> BaseStream[_K]: diff --git a/pystreamapi/_itertools/tools.py b/pystreamapi/_itertools/tools.py index 37bf413..547a8ba 100644 --- a/pystreamapi/_itertools/tools.py +++ b/pystreamapi/_itertools/tools.py @@ -1,8 +1,10 @@ # pylint: disable=protected-access +from typing import Iterable + from pystreamapi._streams.error.__error import ErrorHandler, _sentinel -def dropwhile(predicate, iterable, handler: ErrorHandler=None): +def dropwhile(predicate, iterable, handler: ErrorHandler = None): """ Drop items from the iterable while predicate(item) is true. Afterward, return every element until the iterable is exhausted. @@ -22,7 +24,7 @@ def dropwhile(predicate, iterable, handler: ErrorHandler=None): _initial_missing = object() -def reduce(function, sequence, initial=_initial_missing, handler: ErrorHandler=None): +def reduce(function, sequence, initial=_initial_missing, handler: ErrorHandler = None): """ Apply a function of two arguments cumulatively to the items of a sequence or iterable, from left to right, to reduce the iterable to a single @@ -37,8 +39,7 @@ def reduce(function, sequence, initial=_initial_missing, handler: ErrorHandler=N try: value = next(it) except StopIteration: - raise TypeError( - "reduce() of empty iterable with no initial value") from None + raise TypeError("reduce() of empty iterable with no initial value") from None else: value = initial @@ -51,3 +52,38 @@ def reduce(function, sequence, initial=_initial_missing, handler: ErrorHandler=N value = function(value, element) return value + + +def peek(iterable: Iterable, mapper): + """ + Generator wrapper that applies a function to every item of the iterable + and yields the item unchanged. + """ + for item in iterable: + mapper(item) + yield item + + +def distinct(iterable: Iterable): + """Generator wrapper that returns unique elements from the iterable.""" + seen = set() + for item in iterable: + if item not in seen: + seen.add(item) + yield item + + +def limit(source: Iterable, max_nr: int): + """Generator wrapper that returns the first n elements of the iterable.""" + iterator = iter(source) + for _ in range(max_nr): + try: + yield next(iterator) + except StopIteration: + break + + +def flat_map(iterable: Iterable): + """Generator wrapper that flattens the Stream iterable.""" + for stream in iterable: + yield from stream.to_list() diff --git a/pystreamapi/_streams/__base_stream.py b/pystreamapi/_streams/__base_stream.py index 9fc1db1..4d8215a 100644 --- a/pystreamapi/_streams/__base_stream.py +++ b/pystreamapi/_streams/__base_stream.py @@ -1,5 +1,6 @@ # pylint: disable=protected-access from __future__ import annotations + import functools import itertools from abc import abstractmethod @@ -8,7 +9,7 @@ from typing import Iterable, Callable, Any, TypeVar, Iterator, TYPE_CHECKING, Union from pystreamapi.__optional import Optional -from pystreamapi._itertools.tools import dropwhile +from pystreamapi._itertools.tools import dropwhile, distinct, limit from pystreamapi._lazy.process import Process from pystreamapi._lazy.queue import ProcessQueue from pystreamapi._streams.error.__error import ErrorHandler @@ -85,8 +86,7 @@ def _verify_open(self): def __iter__(self) -> Iterator[K]: return iter(self._source) - @classmethod - def concat(cls, *streams: "BaseStream[K]"): + def concat(self, *streams: "BaseStream[K]") -> BaseStream[K]: """ Creates a lazily concatenated stream whose elements are all the elements of the first stream followed by all the elements of the other streams. @@ -94,7 +94,11 @@ def concat(cls, *streams: "BaseStream[K]"): :param streams: The streams to concatenate :return: The concatenated stream """ - return cls(itertools.chain(*list(streams))) + self._queue.execute_all() + for stream in streams: + stream._queue.execute_all() + self._source = itertools.chain(self._source, *[stream._source for stream in streams]) + return self @_operation def distinct(self) -> 'BaseStream[K]': @@ -104,7 +108,7 @@ def distinct(self) -> 'BaseStream[K]': def __distinct(self): """Removes duplicate elements from the stream.""" - self._source = list(set(self._source)) + self._source = distinct(self._source) @_operation def drop_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]': @@ -119,7 +123,7 @@ def drop_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]': def __drop_while(self, predicate: Callable[[Any], bool]): """Drops elements from the stream while the predicate is true.""" - self._source = list(dropwhile(predicate, self._source, self)) + self._source = dropwhile(predicate, self._source, self) def error_level(self, level: ErrorLevel, *exceptions)\ -> Union["BaseStream[K]", NumericBaseStream]: @@ -160,7 +164,7 @@ def flat_map(self, predicate: Callable[[K], Iterable[_V]]) -> 'BaseStream[_V]': return self @abstractmethod - def _flat_map(self, predicate: Callable[[K], Iterable[_V]]): + def _flat_map(self, mapper: Callable[[K], Iterable[_V]]): """Implementation of flat_map. Should be implemented by subclasses.""" @_operation @@ -196,7 +200,7 @@ def limit(self, max_size: int) -> 'BaseStream[K]': def __limit(self, max_size: int): """Limits the stream to the first n elements.""" - self._source = itertools.islice(self._source, max_size) + self._source = limit(self._source, max_size) @_operation def map(self, mapper: Callable[[K], _V]) -> 'BaseStream[_V]': @@ -283,6 +287,7 @@ def reversed(self) -> 'BaseStream[K]': """ Returns a stream consisting of the elements of this stream, with their order being reversed. + This does not work on infinite generators. """ self._queue.append(Process(self.__reversed)) return self @@ -314,7 +319,7 @@ def skip(self, n: int) -> 'BaseStream[K]': def __skip(self, n: int): """Skips the first n elements of the stream.""" - self._source = self._source[n:] + self._source = itertools.islice(self._source, n, None) @_operation def sorted(self, comparator: Callable[[K], int] = None) -> 'BaseStream[K]': @@ -345,7 +350,7 @@ def take_while(self, predicate: Callable[[K], bool]) -> 'BaseStream[K]': def __take_while(self, predicate: Callable[[Any], bool]): """Takes elements from the stream while the predicate is true.""" - self._source = list(itertools.takewhile(predicate, self._source)) + self._source = itertools.takewhile(predicate, self._source) @abstractmethod @terminal @@ -363,7 +368,13 @@ def any_match(self, predicate: Callable[[K], bool]): :param predicate: The callable predicate """ - return any(self._itr(self._source, predicate)) + def _one_wrapper(iterable, mapper): + """Generator wrapper for any_match.""" + for i in iterable: + yield self._one(mapper, item=i) + + self._source = _one_wrapper(self._source, predicate) + return any(self._source) @terminal def count(self): @@ -413,6 +424,7 @@ def none_match(self, predicate: Callable[[K], bool]): @terminal def min(self): """Returns the minimum element of this stream.""" + self._source = list(self._source) if len(self._source) > 0: return Optional.of(min(self._source)) return Optional.empty() @@ -420,6 +432,7 @@ def min(self): @terminal def max(self): """Returns the maximum element of this stream.""" + self._source = list(self._source) if len(self._source) > 0: return Optional.of(max(self._source)) return Optional.empty() diff --git a/pystreamapi/_streams/__parallel_stream.py b/pystreamapi/_streams/__parallel_stream.py index 833c9bd..5aa238f 100644 --- a/pystreamapi/_streams/__parallel_stream.py +++ b/pystreamapi/_streams/__parallel_stream.py @@ -34,14 +34,15 @@ def _filter(self, predicate: Callable[[Any], bool]): @terminal def find_any(self): - if len(self._source) > 0: - return Optional.of(self._source[0]) - return Optional.empty() + try: + return Optional.of(next(iter(self._source))) + except StopIteration: + return Optional.empty() - def _flat_map(self, predicate: Callable[[Any], stream.BaseStream]): + def _flat_map(self, mapper: Callable[[Any], stream.BaseStream]): new_src = [] for element in Parallel(n_jobs=-1, prefer="threads", handler=self)( - delayed(self.__mapper(predicate))(element) for element in self._source): + delayed(self.__mapper(mapper))(element) for element in self._source): new_src.extend(element.to_list()) self._source = new_src diff --git a/pystreamapi/_streams/__sequential_stream.py b/pystreamapi/_streams/__sequential_stream.py index 1273992..db92915 100644 --- a/pystreamapi/_streams/__sequential_stream.py +++ b/pystreamapi/_streams/__sequential_stream.py @@ -3,8 +3,8 @@ import pystreamapi._streams.__base_stream as stream from pystreamapi.__optional import Optional +from pystreamapi._itertools.tools import reduce, flat_map, peek from pystreamapi._streams.error.__error import _sentinel -from pystreamapi._itertools.tools import reduce _identity_missing = object() @@ -21,15 +21,13 @@ def _filter(self, predicate: Callable[[Any], bool]): @stream.terminal def find_any(self): - if len(self._source) > 0: - return Optional.of(self._source[0]) - return Optional.empty() + try: + return Optional.of(next(iter(self._source))) + except StopIteration: + return Optional.empty() - def _flat_map(self, predicate: Callable[[Any], stream.BaseStream]): - new_src = [] - for element in self._itr(self._source, mapper=predicate): - new_src.extend(element.to_list()) - self._source = new_src + def _flat_map(self, mapper: Callable[[Any], stream.BaseStream]): + self._source = flat_map(self._itr(self._source, mapper=mapper)) def _group_to_dict(self, key_mapper: Callable[[Any], Any]): groups = defaultdict(list) @@ -43,13 +41,14 @@ def _group_to_dict(self, key_mapper: Callable[[Any], Any]): @stream.terminal def for_each(self, action: Callable): - self._peek(action) + for item in self._source: + self._one(mapper=action, item=item) def _map(self, mapper: Callable[[Any], Any]): self._source = self._itr(self._source, mapper=mapper) def _peek(self, action: Callable): - self._itr(self._source, mapper=action) + self._source = peek(self._source, lambda x: self._one(mapper=action, item=x)) @stream.terminal def reduce(self, predicate: Callable, identity=_identity_missing, depends_on_state=False): diff --git a/pystreamapi/_streams/error/__error.py b/pystreamapi/_streams/error/__error.py index e88c13c..c05c311 100644 --- a/pystreamapi/_streams/error/__error.py +++ b/pystreamapi/_streams/error/__error.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from typing import Iterable from pystreamapi._streams.error.__levels import ErrorLevel from pystreamapi._streams.error.__sentinel import Sentinel @@ -37,20 +38,18 @@ def _get_error_level(self): """Get the error level""" return self.__error_level - def _itr(self, src, mapper=nothing, condition=true_condition) -> list: + def _itr(self, src, mapper=nothing, condition=true_condition) -> Iterable: """Iterate over the source and apply the mapper and condition""" - new_src = [] for i in src: try: if condition(i): - new_src.append(mapper(i)) + yield mapper(i) except self.__exceptions_to_ignore as e: if self.__error_level == ErrorLevel.RAISE: raise e if self.__error_level == ErrorLevel.IGNORE: continue self.__log(e) - return new_src def _one(self, mapper=nothing, condition=true_condition, item=None): """ diff --git a/tests/_streams/error/test_error_handler.py b/tests/_streams/error/test_error_handler.py index 22e4bc4..aa42302 100644 --- a/tests/_streams/error/test_error_handler.py +++ b/tests/_streams/error/test_error_handler.py @@ -16,42 +16,42 @@ def setUp(self) -> None: def test_iterate_raise(self): self.handler._error_level(ErrorLevel.RAISE) - self.assertRaises(ValueError, lambda: self.handler._itr([1, 2, 3, 4, 5, "a"], int)) + self.assertRaises(ValueError, lambda: list(self.handler._itr([1, 2, 3, 4, 5, "a"], int))) def test_iterate_raise_with_condition(self): self.handler._error_level(ErrorLevel.RAISE) - self.assertRaises(ValueError, lambda: self.handler._itr( - [1, 2, 3, 4, 5, "a"], int, lambda x: x != "")) + self.assertRaises(ValueError, lambda: list(self.handler._itr( + [1, 2, 3, 4, 5, "a"], int, lambda x: x != ""))) def test_iterate_ignore(self): self.handler._error_level(ErrorLevel.IGNORE) - self.assertEqual(self.handler._itr([1, 2, 3, 4, 5, "a"], int), [1, 2, 3, 4, 5]) + self.assertEqual(list(self.handler._itr([1, 2, 3, 4, 5, "a"], int)), [1, 2, 3, 4, 5]) def test_iterate_ignore_with_condition(self): self.handler._error_level(ErrorLevel.IGNORE) - self.assertEqual(self.handler._itr( - [1, 2, 3, 4, 5, "a"], int, lambda x: x != ""), [1, 2, 3, 4, 5]) + self.assertEqual(list(self.handler._itr( + [1, 2, 3, 4, 5, "a"], int, lambda x: x != "")), [1, 2, 3, 4, 5]) def test_iterate_ignore_specific_exceptions(self): self.handler._error_level(ErrorLevel.IGNORE, ValueError, AttributeError) - self.assertEqual(self.handler._itr( - ["b", 2, 3, 4, 5, "a"], mapper=lambda x: x.split()), [["b"], ["a"]]) + self.assertEqual(list(self.handler._itr( + ["b", 2, 3, 4, 5, "a"], mapper=lambda x: x.split())), [["b"], ["a"]]) def test_iterate_ignore_specific_exception_raise_another(self): self.handler._error_level(ErrorLevel.IGNORE, ValueError) - self.assertRaises(AttributeError, lambda: self.handler._itr( - ["b", 2, 3, 4, 5, "a"], mapper=lambda x: x.split())) + self.assertRaises(AttributeError, lambda: list(self.handler._itr( + ["b", 2, 3, 4, 5, "a"], mapper=lambda x: x.split()))) def test_iterate_warn(self): self.handler._error_level(ErrorLevel.WARN) - self.assertEqual(self.handler._itr([1, 2, 3, 4, 5, "a"], int), [1, 2, 3, 4, 5]) + self.assertEqual(list(self.handler._itr([1, 2, 3, 4, 5, "a"], int)), [1, 2, 3, 4, 5]) def test_iterate_warn_with_condition(self): self.handler._error_level(ErrorLevel.WARN) - self.assertEqual(self.handler._itr( - [1, 2, 3, 4, 5, "a"], int, lambda x: x != ""), [1, 2, 3, 4, 5]) + self.assertEqual(list(self.handler._itr( + [1, 2, 3, 4, 5, "a"], int, lambda x: x != "")), [1, 2, 3, 4, 5]) def test_one_raise(self): self.handler._error_level(ErrorLevel.RAISE) diff --git a/tests/_streams/test_base_stream.py b/tests/_streams/test_base_stream.py index 35fc2de..1fae535 100644 --- a/tests/_streams/test_base_stream.py +++ b/tests/_streams/test_base_stream.py @@ -1,3 +1,4 @@ +import itertools import unittest from pystreamapi.__optional import Optional @@ -26,6 +27,66 @@ def test_concat_unsorted(self): result = Stream.concat(Stream.of([9, 6, 1]), Stream.of([3, 5, 99])) self.assertListEqual(result.to_list(), [9, 6, 1, 3, 5, 99]) + def test_concat_generator(self): + def finite_generator(): + index = 1 + while index < 5: + yield index + index += 1 + + result = Stream.concat(Stream.of([1, 2, 3]), + Stream.of(finite_generator())).to_list() + self.assertListEqual(result, [1, 2, 3, 1, 2, 3, 4]) + + def test_concat_infinite_generator(self): + result = Stream.concat(Stream.of([1, 2, 3]), + Stream.of(itertools.count()).limit(2)).limit(5).to_list() + self.assertListEqual(result, [1, 2, 3, 0, 1]) + + def test_concat_two_generators_limited(self): + result = Stream.concat(Stream.of(itertools.count()).limit(2), + Stream.of(itertools.count()).limit(2)).to_list() + self.assertListEqual(result, [0, 1, 0, 1]) + + def test_concat_two_generators(self): + result = Stream.concat(Stream.of(itertools.count()), + Stream.of(itertools.count())).limit(4).to_list() + self.assertListEqual(result, [0, 1, 2, 3]) + + def test_concat_ten_lists(self): + result = Stream.concat( + Stream.of([1, 2, 3]), + Stream.of([4, 5, 6]), + Stream.of([7, 8, 9]), + Stream.of([10, 11, 12]), + Stream.of([13, 14, 15]), + Stream.of([16, 17, 18]), + Stream.of([19, 20, 21]), + Stream.of([22, 23, 24]), + Stream.of([25, 26, 27]), + Stream.of([28, 29, 30]) + ).to_list() + self.assertListEqual(result, list(range(1, 31))) + + def test_concat_after_initialization(self): + stream1 = Stream.of([1, 2, 3]) + stream2 = Stream.of([4, 5, 6]) + stream3 = Stream.of([7, 8, 9]) + result = stream1.concat(stream2, stream3).to_list() + self.assertListEqual(result, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + + def test_concat_after_initialization_generators(self): + stream1 = Stream.of([1, 2, 3]) + stream2 = Stream.of(itertools.count()).limit(2) + result = stream1.concat(stream2).to_list() + self.assertListEqual(result, [1, 2, 3, 0, 1]) + + def test_concat_after_initialization_infinite_generators(self): + stream1 = Stream.of(itertools.count()) + stream2 = Stream.of(itertools.count()) + result = stream1.concat(stream2).limit(4).to_list() + self.assertListEqual(result, [0, 1, 2, 3]) + def test_iterate(self): result = Stream.iterate(1, lambda x: x + 1).limit(3).to_list() self.assertListEqual(result, [1, 2, 3]) @@ -58,13 +119,6 @@ def test_of_noneable_valid(self): result = Stream.of_noneable([1, 2, 3]).to_list() self.assertListEqual(result, [1, 2, 3]) - def test_concat_after_initialization(self): - stream1 = Stream.of([1, 2, 3]) - stream2 = Stream.of([4, 5, 6]) - stream3 = Stream.of([7, 8, 9]) - result = stream1.concat(stream2, stream3).to_list() - self.assertListEqual(result, [4, 5, 6, 7, 8, 9]) - def test_sort_unsorted(self): result = Stream.of([3, 2, 9, 1]).sorted().to_list() self.assertListEqual(result, [1, 2, 3, 9]) @@ -123,6 +177,10 @@ def test_skip_empty(self): result = Stream.of([]).skip(2).to_list() self.assertListEqual(result, []) + def test_skip_infinite_generator(self): + result = Stream.of(itertools.count()).skip(2).limit(3).to_list() + self.assertListEqual(result, [2, 3, 4]) + def test_distinct(self): result = Stream.of([1, 2, 3, 9, 1, 2, 3, 9]).distinct().to_list() self.assertListEqual(result, [1, 2, 3, 9]) @@ -131,6 +189,14 @@ def test_distinct_empty(self): result = Stream.of([]).distinct().to_list() self.assertListEqual(result, []) + def test_distinct_infinite_generator_unique(self): + result = Stream.of(itertools.count()).distinct().limit(5).to_list() + self.assertListEqual(result, [0, 1, 2, 3, 4]) + + def test_distinct_infinite_generator_not_unique(self): + result = Stream.of(itertools.cycle([1, 2, 2, 3])).distinct().limit(3).to_list() + self.assertListEqual(result, [1, 2, 3]) + def test_drop_while(self): result = Stream.of([1, 2, 3, 9]).drop_while(lambda x: x < 3).to_list() self.assertListEqual(result, [3, 9]) @@ -147,6 +213,10 @@ def test_take_while_empty(self): result = Stream.of([]).take_while(lambda x: x < 3).to_list() self.assertListEqual(result, []) + def test_take_while_infinite_generator(self): + result = Stream.of(itertools.count()).take_while(lambda x: x < 4).limit(4).to_list() + self.assertListEqual(result, [0, 1, 2, 3]) + def test_count(self): result = Stream.of([1, 2, "3", None]).count() self.assertEqual(result, 4) @@ -159,6 +229,10 @@ def test_any_match_empty(self): result = Stream.of([]).any_match(lambda x: x > 3) self.assertFalse(result) + def test_any_match_infinite_generator(self): + result = Stream.of(itertools.count()).any_match(lambda x: x > 3) + self.assertTrue(result) + def test_none_match(self): result = Stream.of([1, 2, 3, 9]).none_match(lambda x: x > 3) self.assertFalse(result) diff --git a/tests/_streams/test_stream_implementation.py b/tests/_streams/test_stream_implementation.py index 17c1e5f..76961cb 100644 --- a/tests/_streams/test_stream_implementation.py +++ b/tests/_streams/test_stream_implementation.py @@ -1,3 +1,4 @@ +import itertools import unittest from parameterized import parameterized_class @@ -11,6 +12,14 @@ from pystreamapi._streams.numeric.__sequential_numeric_stream import SequentialNumericStream +def throwing_generator(): + i = 0 + while True: + yield i + i = i + 1 + if i > 1000: + raise RecursionError("Infinite generator consumed wrong") + @parameterized_class("stream", [ [SequentialStream], [ParallelStream], @@ -68,9 +77,15 @@ def test_convert_to_numeric_stream_is_already_numeric(self): self.assertIsInstance(result, NumericBaseStream) def test_flat_map(self): - result = self.stream([1, 2, 3, 9]).flat_map(lambda x: self.stream([x, x])).to_list() + result = (self.stream([1, 2, 3, 9]) + .flat_map(lambda x: self.stream([x, x])).to_list()) self.assertListEqual(result, [1, 1, 2, 2, 3, 3, 9, 9]) + def test_flat_map_infinite_generator(self): + result = (self.stream(throwing_generator()) + .flat_map(lambda x: self.stream([x, x*2])).limit(6).to_list()) + self.assertListEqual(result, [0, 0, 1, 2, 2, 4]) + def test_filter_not_none(self): result = self.stream([1, 2, "3", None]).filter(lambda x: x is not None).to_list() self.assertListEqual(result, [1, 2, "3"]) @@ -121,6 +136,10 @@ def test_find_any_empty(self): result = self.stream([]).find_any() self.assertEqual(result, Optional.empty()) + def test_find_any_infinite_generator(self): + result = self.stream(itertools.count()).find_any() + self.assertEqual(result, Optional.of(0)) + def test_limit(self): result = self.stream([1, 2, 3, 9]).limit(2).to_list() self.assertListEqual(result, [1, 2]) @@ -184,6 +203,11 @@ def test_to_dict_empty(self): result = self.stream([]).to_dict(lambda x: x) self.assertDictEqual(result, {}) + def test_handling_of_generator(self): + result = (self.stream(throwing_generator()) + .map(lambda x: x * 2).filter(lambda x: x < 10).limit(5).to_list()) + self.assertListEqual(result, [0, 2, 4, 6, 8]) + if __name__ == '__main__': unittest.main()