Skip to content

Commit

Permalink
Refactor code to make simpler (#12)
Browse files Browse the repository at this point in the history
* Move optypes into separate file

* Rename ChainStart to Chain and StreamStart to Stream

* Add .then and pipe operations, as aliases to flat_map

* Bump version to 0.3

* Remove functions in chain, move to a more functional approach of modifying the functions

* Remove .map, .filter etc functions in Stream, moving to a more functional approach

* Rename batching to batched
  • Loading branch information
simw committed Jan 23, 2024
1 parent dde1b41 commit 8d3888d
Show file tree
Hide file tree
Showing 15 changed files with 309 additions and 225 deletions.
49 changes: 24 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import zipfile

import pyarrow.parquet as pq

from pipedata.core import StreamStart
from pipedata.core import Stream
from pipedata.ops import json_records, parquet_writer, zipped_files


Expand All @@ -46,10 +46,10 @@ with zipfile.ZipFile("test_input.json.zip", "w") as zipped:
zipped.writestr("file2.json", json.dumps(data2))

result = (
StreamStart(["test_input.json.zip"])
.flat_map(zipped_files)
.flat_map(json_records())
.flat_map(parquet_writer("test_output.parquet"))
Stream(["test_input.json.zip"])
.then(zipped_files)
.then(json_records())
.then(parquet_writer("test_output.parquet"))
.to_list()
)

Expand All @@ -63,17 +63,17 @@ Alternatively, you can construct the pipeline as a chain:
```py
import pyarrow.parquet as pq

from pipedata.core import ChainStart, StreamStart
from pipedata.core import Chain, Stream
from pipedata.ops import json_records, parquet_writer, zipped_files

# Running this after input file created in above example
chain = (
ChainStart()
.flat_map(zipped_files)
.flat_map(json_records())
.flat_map(parquet_writer("test_output_2.parquet"))
Chain()
.then(zipped_files)
.then(json_records())
.then(parquet_writer("test_output_2.parquet"))
)
result = StreamStart(["test_input.json.zip"]).flat_map(chain).to_list()
result = Stream(["test_input.json.zip"]).then(chain).to_list()
table = pq.read_table("test_output_2.parquet")
print(table.to_pydict())
#> {'col1': [1, 2, 3], 'col2': ['Hello', 'world', '!']}
Expand All @@ -86,33 +86,34 @@ The core framework provides the building blocks for chaining operations.

Running a stream:
```py
from pipedata.core import StreamStart
from pipedata.core import Stream, ops


result = (
StreamStart(range(10))
.filter(lambda x: x % 2 == 0)
.map(lambda x: x ^ 2)
.batched_map(lambda x: x, 2)
Stream(range(10))
.then(ops.filtering(lambda x: x % 2 == 0))
.then(ops.mapping(lambda x: x ^ 2))
.then(ops.batched(lambda x: x, 2))
.to_list()
)
print(result)
#> [(2, 0), (6, 4), (10,)]
```

Creating a chain and then using it:
Creating a chain and then using it, this time using the
pipe notation:
```py
import json
from pipedata.core import ChainStart, Stream, StreamStart
from pipedata.core import Chain, Stream, ops


chain = (
ChainStart()
.filter(lambda x: x % 2 == 0)
.map(lambda x: x ^ 2)
.batched_map(lambda x: sum(x), 2)
Chain()
| ops.filtering(lambda x: x % 2 == 0)
| ops.mapping(lambda x: x ^ 2)
| ops.batched(lambda x: sum(x), 2)
)
print(Stream(range(10), chain).to_list())
print(Stream(range(10)).then(chain).to_list())
#> [2, 10, 10]
print(json.dumps(chain.get_counts(), indent=4))
#> [
Expand All @@ -137,8 +138,6 @@ print(json.dumps(chain.get_counts(), indent=4))
#> "outputs": 3
#> }
#> ]
print(StreamStart(range(10)).flat_map(chain).to_list())
#> [2, 10, 10]
```

## Similar Functionality
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "pipedata"
version = "0.2.2"
version = "0.3"
description = "Framework for building pipelines for data processing"
authors = ["Simon Wicks <simw@users.noreply.github.com>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/pipedata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.2"
__version__ = "0.3"

__all__ = [
"__version__",
Expand Down
8 changes: 4 additions & 4 deletions src/pipedata/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .chain import Chain, ChainStart
from .stream import Stream, StreamStart
from .chain import Chain, ChainType
from .stream import Stream, StreamType

__all__ = [
"ChainType",
"Chain",
"ChainStart",
"StreamType",
"Stream",
"StreamStart",
]
121 changes: 13 additions & 108 deletions src/pipedata/core/chain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import functools
import itertools
from typing import (
Any,
Callable,
Expand All @@ -10,78 +8,28 @@
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
overload,
)

from .links import ChainLink

TStart = TypeVar("TStart")
TEnd = TypeVar("TEnd")
TOther = TypeVar("TOther")


def batched(iterable: Iterator[TEnd], n: Optional[int]) -> Iterator[Tuple[TEnd, ...]]:
"""Can be replaced by itertools.batched once using Python 3.12+."""
while (elements := tuple(itertools.islice(iterable, n))) != ():
yield elements


def _identity(input_iterator: Iterator[TEnd]) -> Iterator[TEnd]:
yield from input_iterator


class CountingIterator(Iterator[TStart]):
def __init__(self, iterator: Iterator[TStart]) -> None:
self._iterator = iterator
self._count = 0

def __iter__(self) -> Iterator[TStart]:
return self

def __next__(self) -> TStart:
self._count += 1
try:
return next(self._iterator)
except StopIteration as err:
self._count -= 1
raise StopIteration from err

def get_count(self) -> int:
return self._count


class ChainLink(Generic[TStart, TEnd]):
def __init__(
self,
func: Callable[[Iterator[TStart]], Iterator[TEnd]],
) -> None:
self._func = func
self._input: Optional[CountingIterator[TStart]] = None
self._output: Optional[CountingIterator[TEnd]] = None

@property
def __name__(self) -> str: # noqa: A003
return self._func.__name__

def __call__(self, input_iterator: Iterator[TStart]) -> Iterator[TEnd]:
self._input = CountingIterator(input_iterator)
self._output = CountingIterator(self._func(self._input))
return self._output

def get_counts(self) -> Tuple[int, int]:
return (
0 if self._input is None else self._input.get_count(),
0 if self._output is None else self._output.get_count(),
)


class Chain(Generic[TStart, TEnd]):
class ChainType(Generic[TStart, TEnd]):
@overload
def __init__(
self,
previous_steps: Chain[TStart, TOther],
previous_steps: ChainType[TStart, TOther],
func: Callable[[Iterator[TOther]], Iterator[TEnd]],
):
...
Expand All @@ -96,7 +44,7 @@ def __init__(

def __init__(
self,
previous_steps: Optional[Chain[TStart, TOther]],
previous_steps: Optional[ChainType[TStart, TOther]],
func: Union[
Callable[[Iterator[TOther]], Iterator[TEnd]],
Callable[[Iterator[TStart]], Iterator[TEnd]],
Expand All @@ -112,58 +60,15 @@ def __call__(self, input_iterator: Iterator[TStart]) -> Iterator[TEnd]:

return self._func(self._previous_steps(input_iterator)) # type: ignore

def flat_map(
def then(
self, func: Callable[[Iterator[TEnd]], Iterator[TOther]]
) -> Chain[TStart, TOther]:
"""
Output zero or more elements from one or more input elements.
This is a fully general operation, that can arbitrarily transform the
stream of elements. It is the most powerful operation, and all the
other operations are implemented in terms of it.
"""
return Chain(self, func)

def filter(self, func: Callable[[TEnd], bool]) -> Chain[TStart, TEnd]: # noqa: A003
"""
Remove elements from the stream that do not pass the filter function.
"""

@functools.wraps(func)
def new_action(previous_step: Iterator[TEnd]) -> Iterator[TEnd]:
return filter(func, previous_step)

return self.flat_map(new_action)
) -> ChainType[TStart, TOther]:
return ChainType(self, func)

def map( # noqa: A003
self, func: Callable[[TEnd], TOther]
) -> Chain[TStart, TOther]:
"""
Return a single transformed element from each input element.
"""

@functools.wraps(func)
def new_action(previous_step: Iterator[TEnd]) -> Iterator[TOther]:
return map(func, previous_step)

return self.flat_map(new_action)

def batched_map(
self, func: Callable[[Tuple[TEnd, ...]], TOther], n: Optional[int] = None
) -> Chain[TStart, TOther]:
"""
Return a single transformed element from (up to) n input elements.
If n is None, then apply the function to all the elements, and return
an iterator of 1 element.
"""

@functools.wraps(func)
def new_action(previous_step: Iterator[TEnd]) -> Iterator[TOther]:
for elements in batched(previous_step, n):
yield func(elements)

return self.flat_map(new_action)
def __or__(
self, func: Callable[[Iterator[TEnd]], Iterator[TOther]]
) -> ChainType[TStart, TOther]:
return self.then(func)

def get_counts(self) -> List[Dict[str, Any]]:
step_counts = []
Expand All @@ -181,6 +86,6 @@ def get_counts(self) -> List[Dict[str, Any]]:
return step_counts


class ChainStart(Chain[TOther, TOther]):
class Chain(ChainType[TOther, TOther]):
def __init__(self) -> None:
super().__init__(None, _identity)
57 changes: 57 additions & 0 deletions src/pipedata/core/links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import (
Callable,
Generic,
Iterator,
Optional,
Tuple,
TypeVar,
)

TStart = TypeVar("TStart")
TEnd = TypeVar("TEnd")
TOther = TypeVar("TOther")


class CountingIterator(Iterator[TStart]):
def __init__(self, iterator: Iterator[TStart]) -> None:
self._iterator = iterator
self._count = 0

def __iter__(self) -> Iterator[TStart]:
return self

def __next__(self) -> TStart:
self._count += 1
try:
return next(self._iterator)
except StopIteration as err:
self._count -= 1
raise StopIteration from err

def get_count(self) -> int:
return self._count


class ChainLink(Generic[TStart, TEnd]):
def __init__(
self,
func: Callable[[Iterator[TStart]], Iterator[TEnd]],
) -> None:
self._func = func
self._input: Optional[CountingIterator[TStart]] = None
self._output: Optional[CountingIterator[TEnd]] = None

@property
def __name__(self) -> str: # noqa: A003
return self._func.__name__

def __call__(self, input_iterator: Iterator[TStart]) -> Iterator[TEnd]:
self._input = CountingIterator(input_iterator)
self._output = CountingIterator(self._func(self._input))
return self._output

def get_counts(self) -> Tuple[int, int]:
return (
0 if self._input is None else self._input.get_count(),
0 if self._output is None else self._output.get_count(),
)
Loading

0 comments on commit 8d3888d

Please sign in to comment.