Skip to content

Commit

Permalink
Merge pull request #21 from petbox-dev/pipe
Browse files Browse the repository at this point in the history
Add `pipe()` method
  • Loading branch information
dsfulf committed Nov 22, 2022
2 parents 8588fd4 + d22838e commit ca4ffe4
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 38 deletions.
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ and SQL-style "group by" and join operations.
+----------------------------+-----------------------------------------------------------------------------------------------------------------------------+
| Functional Methods | `row_map <https://tafra.readthedocs.io/en/latest/api.html#tafra.base.Tafra.row_map>`_, |
| | `tuple_map <https://tafra.readthedocs.io/en/latest/api.html#tafra.base.Tafra.tuple_map>`_, |
| | `col_map <https://tafra.readthedocs.io/en/latest/api.html#tafra.base.Tafra.col_map>`_ |
| | `col_map <https://tafra.readthedocs.io/en/latest/api.html#tafra.base.Tafra.col_map>`_, |
| | `pipe <https://tafra.readthedocs.io/en/latest/api.html#tafra.base.Tafra.pipe>`_ |
+----------------------------+-----------------------------------------------------------------------------------------------------------------------------+
| Dict-like Methods | `keys <https://tafra.readthedocs.io/en/latest/api.html#tafra.base.Tafra.keys>`_, |
| | `values <https://tafra.readthedocs.io/en/latest/api.html#tafra.base.Tafra.values>`_, |
Expand Down
3 changes: 3 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Methods
tuple_map
col_map
key_map
pipe
select
copy
update
Expand Down Expand Up @@ -168,6 +169,8 @@ Methods
.. automethod:: tuple_map
.. automethod:: col_map
.. automethod:: key_map
.. automethod:: pipe
.. automethod:: __rshift__
.. automethod:: select
.. automethod:: copy
.. automethod:: update
Expand Down
14 changes: 14 additions & 0 deletions docs/versions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@ Version History
.. automodule:: tafra
:noindex:

1.0.10
------

* Add ``pipe`` and overload ``>>`` operator for Tafra objects

1.0.9
-----

* Add test files to build

1.0.8
-----

* Check rows in constructor to ensure equal data length

1.0.7
-----
Expand Down
17 changes: 13 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
import sys
import re

from tafra import __version__

try:
from setuptools import setup
except ImportError:
from distutils.core import setup


def find_version() -> str:
v = {}
with open('tafra/version.py', 'r') as f:
exec(f.read(), globals(), v)

return v['__version__']


def get_long_description() -> str:
# Fix display issues on PyPI caused by RST markup
with open('README.rst', 'r') as f:
Expand Down Expand Up @@ -60,16 +66,19 @@ def replace(s: str) -> str:

return readme + '\n\n' + version_history


__version__ = find_version()

if sys.argv[-1] == 'build':
print(f'\nBuilding version {__version__}...\n')
print(f'\nBuilding __version__ {__version__}...\n')
os.system('rm -r dist\\') # clean out dist/
os.system('python setup.py sdist bdist_wheel')
sys.exit()


setup(
name='tafra',
version=__version__,
__version__=__version__,
description='Tafra: innards of a dataframe',
long_description=get_long_description(),
long_description_content_type="text/x-rst",
Expand Down
3 changes: 2 additions & 1 deletion tafra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
-----
Created on April 25, 2020
"""
__version__ = '1.0.9'

from .version import __version__

from .base import Tafra, object_formatter
from .group import GroupBy, Transform, IterateBy, InnerJoin, LeftJoin
Expand Down
75 changes: 49 additions & 26 deletions tafra/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,19 @@

from typing import (Any, Callable, Dict, Mapping, List, Tuple, Optional, Union as _Union, Sequence,
Sized, Iterable, Iterator, Type, KeysView, ValuesView, ItemsView,
IO)
IO, Concatenate, ParamSpec)
from typing import cast
from io import TextIOWrapper

from .formatter import ObjectFormatter
from .csvreader import CSVReader


object_formatter = ObjectFormatter()
P = ParamSpec('P')


# default object formats
object_formatter = ObjectFormatter()
object_formatter['Decimal'] = lambda x: x.astype(float)


Expand Down Expand Up @@ -333,6 +335,14 @@ def __getitem__(
def __setitem__(self, item: str, value: _Union[np.ndarray, Sequence[Any], Any]) -> None:
self._ensure_valid(item, value, set_item=True)

def __repr__(self) -> str:
if not hasattr(self, '_rows'):
return f'Tafra(data={self._data}, dtypes={self._dtypes}, rows=n/a)'
return f'Tafra(data={self._data}, dtypes={self._dtypes}, rows={self._rows})'

def __str__(self) -> str:
return self.__repr__()

def __len__(self) -> int:
assert self._data is not None, \
'Interal error: Cannot construct a Tafra with no data.'
Expand All @@ -341,6 +351,9 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator['Tafra']:
return (self._iindex(i) for i in range(self._rows))

def __rshift__(self, other: Callable[['Tafra'], 'Tafra']) -> 'Tafra':
return self.pipe(other)

def iterrows(self) -> Iterator['Tafra']:
"""
Yield rows as :class:`Tafra`. Use :meth:`itertuples` for better performance.
Expand Down Expand Up @@ -386,12 +399,6 @@ def itercols(self) -> Iterator[Tuple[str, np.ndarray]]:
"""
return map(tuple, self.data.items()) # type: ignore

def __str__(self) -> str:
return self.__repr__()

def __repr__(self) -> str:
return f'Tafra(data={self._data}, dtypes={self._dtypes}, rows={self._rows})'

def _update_rows(self) -> None:
"""
Updates :attr:`_rows`. User should call this if they have directly assigned to
Expand Down Expand Up @@ -756,7 +763,7 @@ def _ensure_valid(self, column: str, value: _Union[np.ndarray, Sequence[Any], An
value = sq_value

assert value.ndim >= 1, \
'Interal error: `Tafra` only supports assigning ndim >= 1.'
'Interal error: `Tafra` only supports assigning ndim == 1.'

if check_rows and len(value) != rows:
raise ValueError(
Expand Down Expand Up @@ -1270,8 +1277,7 @@ def tuple_map(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Iterat
name = kwargs.pop('name', 'Tafra')
return (fn(tf, *args, **kwargs) for tf in self.itertuples(name))

def col_map(self, fn: Callable[..., Any], keys: bool = True,
*args: Any, **kwargs: Any) -> Iterator[Any]:
def col_map(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Iterator[Any]:
"""
Map a function over columns. To apply to specific columns, use :meth:`select`
first. The function must operate on :class:`Tuple[str, np.ndarray]`.
Expand All @@ -1281,9 +1287,6 @@ def col_map(self, fn: Callable[..., Any], keys: bool = True,
fn: Callable[..., Any]
The function to map.
keys: bool = True
Return a tuple
*args: Any
Additional positional arguments to ``fn``.
Expand All @@ -1298,7 +1301,7 @@ def col_map(self, fn: Callable[..., Any], keys: bool = True,

return (fn(value, *args, **kwargs) for column, value in self.itercols())

def key_map(self, fn: Callable[..., Any], keys: bool = True,
def key_map(self, fn: Callable[..., Any],
*args: Any, **kwargs: Any) -> Iterator[Tuple[str, Any]]:
"""
Map a function over columns like :meth:col_map, but return :class:`Tuple` of the
Expand All @@ -1310,9 +1313,6 @@ def key_map(self, fn: Callable[..., Any], keys: bool = True,
fn: Callable[..., Any]
The function to map.
keys: bool = True
Return a tuple
*args: Any
Additional positional arguments to ``fn``.
Expand All @@ -1324,23 +1324,31 @@ def key_map(self, fn: Callable[..., Any], keys: bool = True,
iter_tf: Iterator[Any]
An iterator to map the function.
"""
return ((column, fn(value, *args, **kwargs))
for column, value in self.itercols())
return ((column, fn(value, *args, **kwargs)) for column, value in self.itercols())

def head(self, n: int = 5) -> 'Tafra':
def pipe(self, fn: Callable[Concatenate['Tafra', P], 'Tafra'],
*args: Any, **kwargs: Any) -> 'Tafra':
"""
Display the head of the :class:`Tafra`.
Apply a function to the :class:`Tafra` and return the resulting :class:`Tafra`. Primarily
used to build a tranformer pipeline.
Parameters
----------
n: int = 5
The number of rows to display.
fn: Callable[[], 'Tafra']
The function to apply.
*args: Any
Additional positional arguments to ``fn``.
**kwargs: Any
Additional keyword arguments to ``fn``.
Returns
-------
None: None
tafra: Tafra
A new :class:`Tafra` result of the function.
"""
return self._slice(slice(n))
return fn(self, *args, **kwargs)

def select(self, columns: Iterable[str]) -> 'Tafra':
"""
Expand Down Expand Up @@ -1368,6 +1376,21 @@ def select(self, columns: Iterable[str]) -> 'Tafra':
validate=False
)

def head(self, n: int = 5) -> 'Tafra':
"""
Display the head of the :class:`Tafra`.
Parameters
----------
n: int = 5
The number of rows to display.
Returns
-------
None: None
"""
return self._slice(slice(n))

def keys(self) -> KeysView[str]:
"""
Return the keys of :attr:`data`, i.e. like :meth:`dict.keys()`.
Expand Down
1 change: 1 addition & 0 deletions tafra/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = '1.0.10'
35 changes: 29 additions & 6 deletions test/test_tafra.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,13 +542,13 @@ def test_update_dtypes() -> None:
t.update_dtypes_inplace({'x': float})
check_tafra(t)
assert t['x'].dtype == 'float'
assert isinstance(t['x'][0], np.float)
assert isinstance(t['x'][0], np.float64)

t = build_tafra()
_ = t.update_dtypes({'x': float})
check_tafra(_)
assert _['x'].dtype == 'float'
assert isinstance(_['x'][0], np.float)
assert isinstance(_['x'][0], np.float64)

def test_rename() -> None:
t = build_tafra()
Expand Down Expand Up @@ -704,11 +704,34 @@ def test_invalid_agg() -> None:

def test_map() -> None:
t = build_tafra()
_ = list(t.row_map(np.repeat, 6))
_ = list(t.tuple_map(np.repeat, 6))
_ = list(t.col_map(np.repeat, repeats=6))

def repeat(tf: Tafra, repeats: int) -> Tafra:
return [tf for _ in range(repeats)]

_ = list(t.row_map(repeat, 6))
_ = list(t.tuple_map(repeat, 6))
_ = list(t.col_map(repeat, repeats=6))
_ = Tafra(t.key_map(np.repeat, repeats=6))

def test_pipe() -> None:
def fn1(t: Tafra) -> Tafra:
return t[t['y'] == 'one']
def fn2(t: Tafra) -> Tafra:
return t[t['z'] == 0]

t = build_tafra()
check_tafra(t.pipe(fn1))
check_tafra(t >> fn1)
check_tafra(t.pipe(fn1).pipe(fn2))
check_tafra(t >> fn1 >> fn2)

def fn3(t: Tafra, i: int) -> Tafra:
return t[t['x'] == i]

check_tafra(t.pipe(fn3, 1))
check_tafra(t.pipe(fn3, i=1))
check_tafra(t >> (lambda t: fn3(t, i=1)))

def test_union() -> None:
t = build_tafra()
t2 = build_tafra()
Expand Down Expand Up @@ -1178,7 +1201,7 @@ def write_reread(t: Tafra) -> None:
check_tafra(t)

# force dtypes on missing columns
t = Tafra.read_csv('test/ex6.csv', missing=None, dtypes={'dp_prime': np.float, 'dp_prime_te': np.float32})
t = Tafra.read_csv('test/ex6.csv', missing=None, dtypes={'dp_prime': np.float64, 'dp_prime_te': np.float32})
assert t.dtypes['dp'] == 'float64'
assert t.dtypes['dp_prime'] == 'float64'
assert t.dtypes['dp_prime_te'] == 'float32'
Expand Down

0 comments on commit ca4ffe4

Please sign in to comment.