Skip to content

Commit

Permalink
renames format_mapping to extension_mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
edublancas committed Aug 16, 2021
1 parent 3991573 commit 160ccf7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 47 deletions.
59 changes: 30 additions & 29 deletions src/ploomber/io/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,29 +78,30 @@ def _df2parquet(obj, product):
}


def _format_mapping_validate(format_mapping, fn):
if format_mapping is not None:
if not isinstance(format_mapping, Mapping):
raise TypeError(f'Invalid format_mapping {format_mapping!r} for '
f'decorated function {fn.__name__!r}. Expected '
'it to be a dictionary but got a '
f'{type(format_mapping).__name__}')
def _extension_mapping_validate(extension_mapping, fn):
if extension_mapping is not None:
if not isinstance(extension_mapping, Mapping):
raise TypeError(
f'Invalid extension_mapping {extension_mapping!r} for '
f'decorated function {fn.__name__!r}. Expected '
'it to be a dictionary but got a '
f'{type(extension_mapping).__name__}')

invalid_keys = {
k
for k in format_mapping.keys() if not k.startswith('.')
for k in extension_mapping.keys() if not k.startswith('.')
}

if invalid_keys:
raise ValueError(
f'Invalid format_mapping {format_mapping!r} for '
f'Invalid extension_mapping {extension_mapping!r} for '
f'decorated function {fn.__name__!r}. Expected '
'keys to start with a dot (e.g., ".csv"). Invalid '
f'keys found: {invalid_keys!r}')


def _build_format_mapping_final(format_mapping, defaults, fn,
defaults_provided, name):
def _build_extension_mapping_final(extension_mapping, defaults, fn,
defaults_provided, name):
defaults_keys = set(defaults_provided)

if defaults:
Expand All @@ -112,13 +113,13 @@ def _build_format_mapping_final(format_mapping, defaults, fn,

passed_defaults = set(defaults)

if format_mapping:
overlap = passed_defaults & set(format_mapping)
if extension_mapping:
overlap = passed_defaults & set(extension_mapping)
if overlap:
raise ValueError(
f'Error when adding @{name} decorator '
f'to function {fn.__name__!r}: '
'Keys in \'format_mapping\' and \'defaults\' must not '
'Keys in \'extension_mapping\' and \'defaults\' must not '
f'overlap (overlapping keys: {overlap})')

unexpected_defaults = passed_defaults - defaults_keys
Expand All @@ -135,27 +136,27 @@ def _build_format_mapping_final(format_mapping, defaults, fn,
k: v
for k, v in defaults_provided.items() if k in defaults
}
format_mapping_final = {**defaults_map, **(format_mapping or {})}
extension_mapping_final = {**defaults_map, **(extension_mapping or {})}
else:
format_mapping_final = format_mapping
extension_mapping_final = extension_mapping

_format_mapping_validate(format_mapping_final, fn)
_extension_mapping_validate(extension_mapping_final, fn)

return format_mapping_final
return extension_mapping_final


def serializer(format_mapping=None, *, fallback=False, defaults=None):
def serializer(extension_mapping=None, *, fallback=False, defaults=None):
"""Decorator for serializing functions
Parameters
----------
format_mapping : dict, default=None
extension_mapping : dict, default=None
An extension -> function mapping. Calling the decorated function with a
File of a given extension will use the one in the mapping if it exists,
e.g., {'.csv': to_csv, '.json': to_json}.
fallback : bool or str, default=False
Determines what method to use if format_mapping does not match the
Determines what method to use if extension_mapping does not match the
product to serialize. Valid values are True (uses the pickle module),
'joblib', and 'cloudpickle'. If you use any of the last two, the
corresponding moduel must be installed. If this is enabled, the
Expand All @@ -168,12 +169,12 @@ def serializer(format_mapping=None, *, fallback=False, defaults=None):
to .txt, the returned object must be a string, for .json it must be
a json serializable object (e.g., a list or a dict), for .csv and
.parquet it must be a pandas.DataFrame. If using .parquet, a parquet
library must be installed (e.g., pyarrow). If format_mapping
library must be installed (e.g., pyarrow). If extension_mapping
and defaults contain overlapping keys, an error is raised
"""
def _serializer(fn):
format_mapping_final = _build_format_mapping_final(
format_mapping, defaults, fn, _DEFAULTS, 'serializer')
extension_mapping_final = _build_extension_mapping_final(
extension_mapping, defaults, fn, _DEFAULTS, 'serializer')

try:
serializer_fallback = _EXTERNAL[fallback]
Expand Down Expand Up @@ -206,10 +207,10 @@ def wrapper(obj, product):

for key, value in obj.items():
_serialize_product(value, product[key],
format_mapping_final, fallback,
extension_mapping_final, fallback,
serializer_fallback, fn)
else:
_serialize_product(obj, product, format_mapping_final,
_serialize_product(obj, product, extension_mapping_final,
fallback, serializer_fallback, fn)

return wrapper
Expand Down Expand Up @@ -249,12 +250,12 @@ def _validate_obj(obj, product):
raise ValueError(error)


def _serialize_product(obj, product, format_mapping, fallback,
def _serialize_product(obj, product, extension_mapping, fallback,
serializer_fallback, fn):
suffix = Path(product).suffix

if format_mapping and suffix in format_mapping:
format_mapping[suffix](obj, product)
if extension_mapping and suffix in extension_mapping:
extension_mapping[suffix](obj, product)
elif fallback:
with open(product, 'wb') as f:
serializer_fallback(obj, f)
Expand Down
26 changes: 13 additions & 13 deletions src/ploomber/io/unserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
pd = None

from ploomber.products import MetaProduct
from ploomber.io.serialize import _build_format_mapping_final
from ploomber.io.serialize import _build_extension_mapping_final

_EXTERNAL = {
False: None,
Expand Down Expand Up @@ -62,18 +62,18 @@ def _parquet2df(product):
}


def unserializer(format_mapping=None, *, fallback=False, defaults=None):
def unserializer(extension_mapping=None, *, fallback=False, defaults=None):
"""Decorator for unserializing functions
Parameters
----------
format_mapping : dict, default=None
extension_mapping : dict, default=None
An extension -> function mapping. Calling the decorated function with a
File of a given extension will use the one in the mapping if it exists,
e.g., {'.csv': from_csv, '.json': from_json}.
fallback : bool or str, default=False
Determines what method to use if format_mapping does not match the
Determines what method to use if extension_mapping does not match the
product to unserialize. Valid values are True (uses the pickle module),
'joblib', and 'cloudpickle'. If you use any of the last two, the
corresponding moduel must be installed. If this is enabled, the
Expand All @@ -86,12 +86,12 @@ def unserializer(format_mapping=None, *, fallback=False, defaults=None):
Unserializing .txt, returns a string, for .json returns any
JSON-unserializable object (e.g., a list or a dict), .csv and
.parquet return a pandas.DataFrame. If using .parquet, a parquet
library must be installed (e.g., pyarrow). If format_mapping
library must be installed (e.g., pyarrow). If extension_mapping
and defaults contain overlapping keys, an error is raises
"""
def _unserializer(fn):
format_mapping_final = _build_format_mapping_final(
format_mapping, defaults, fn, _DEFAULTS, 'unserializer')
extension_mapping_final = _build_extension_mapping_final(
extension_mapping, defaults, fn, _DEFAULTS, 'unserializer')

try:
unserializer_fallback = _EXTERNAL[fallback]
Expand Down Expand Up @@ -122,13 +122,13 @@ def wrapper(product):
if isinstance(product, MetaProduct):
return {
key:
_unserialize_product(value, format_mapping_final, fallback,
unserializer_fallback, fn)
_unserialize_product(value, extension_mapping_final,
fallback, unserializer_fallback, fn)
for key, value in product.products.products.items()
}

else:
return _unserialize_product(product, format_mapping_final,
return _unserialize_product(product, extension_mapping_final,
fallback, unserializer_fallback,
fn)

Expand All @@ -145,12 +145,12 @@ def unserializer_pickle(product):
raise RuntimeError('Error when unserializing with pickle module')


def _unserialize_product(product, format_mapping, fallback,
def _unserialize_product(product, extension_mapping, fallback,
unserializer_fallback, fn):
suffix = Path(product).suffix

if format_mapping and suffix in format_mapping:
obj = format_mapping[suffix](product)
if extension_mapping and suffix in extension_mapping:
obj = extension_mapping[suffix](product)
elif fallback:
with open(product, 'rb') as f:
obj = unserializer_fallback(f)
Expand Down
10 changes: 5 additions & 5 deletions tests/io_mod/test_un_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def my_fn():
[unserialize.unserializer, 'unserializer'],
[serialize.serializer, 'serializer'],
])
def test_un_serializer_error_on_format_mapping_and_defaults_overlap(
def test_un_serializer_error_on_extension_mapping_and_defaults_overlap(
decorator_factory, name):
decorator = decorator_factory({
'.txt': None,
Expand All @@ -518,26 +518,26 @@ def my_fn():
[serialize.serializer, serializer_undecorated],
[unserialize.unserializer, unserializer_undecorated],
])
def test_validates_format_mapping_type(decorator_factory, fn):
def test_validates_extension_mapping_type(decorator_factory, fn):
decorator = decorator_factory(['.csv'])

with pytest.raises(TypeError) as excinfo:
decorator(fn)

assert 'Invalid format_mapping' in str(excinfo.value)
assert 'Invalid extension_mapping' in str(excinfo.value)


@pytest.mark.parametrize('decorator_factory, fn', [
[serialize.serializer, serializer_undecorated],
[unserialize.unserializer, unserializer_undecorated],
])
def test_validates_format_mapping_keys(decorator_factory, fn):
def test_validates_extension_mapping_keys(decorator_factory, fn):
decorator = decorator_factory({'csv': None})

with pytest.raises(ValueError) as excinfo:
decorator(fn)

assert 'Invalid format_mapping' in str(excinfo.value)
assert 'Invalid extension_mapping' in str(excinfo.value)


@pytest.mark.parametrize('decorator_factory, fn', [
Expand Down

0 comments on commit 160ccf7

Please sign in to comment.