Skip to content

Commit

Permalink
Modularize default load and save argument handling (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman authored and idanov committed Jul 23, 2019
1 parent e11bef3 commit 9733fc6
Show file tree
Hide file tree
Showing 23 changed files with 314 additions and 161 deletions.
2 changes: 2 additions & 0 deletions kedro/contrib/io/__init__.py
Expand Up @@ -31,3 +31,5 @@
`kedro.io` module (e.g. additional ``AbstractDataSet``s and
extensions/alternative ``DataCatalog``s.
"""

from .core import DefaultArgumentsMixIn # NOQA
11 changes: 5 additions & 6 deletions kedro/contrib/io/azure/csv_blob.py
Expand Up @@ -35,10 +35,11 @@
import pandas as pd
from azure.storage.blob import BlockBlobService

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet


class CSVBlobDataSet(AbstractDataSet):
class CSVBlobDataSet(DefaultArgumentsMixIn, AbstractDataSet):
"""``CSVBlobDataSet`` loads and saves csv files in Microsoft's Azure
blob storage. It uses azure storage SDK to read and write in azure and
pandas to handle the csv file locally.
Expand All @@ -61,6 +62,8 @@ class CSVBlobDataSet(AbstractDataSet):
>>> assert data.equals(reloaded)
"""

DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -106,16 +109,12 @@ def __init__(
All defaults are preserved, but "index", which is set to False.
"""
default_save_args = {"index": False}
self._save_args = (
{**default_save_args, **save_args} if save_args else default_save_args
)
self._load_args = load_args if load_args else {}
self._filepath = filepath
self._container_name = container_name
self._credentials = credentials if credentials else {}
self._blob_to_text_args = blob_to_text_args if blob_to_text_args else {}
self._blob_from_text_args = blob_from_text_args if blob_from_text_args else {}
super().__init__(load_args, save_args)

def _load(self) -> pd.DataFrame:
blob_service = BlockBlobService(**self._credentials)
Expand Down
16 changes: 3 additions & 13 deletions kedro/contrib/io/bioinformatics/sequence_dataset.py
Expand Up @@ -35,10 +35,11 @@

from Bio import SeqIO

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet


class BioSequenceLocalDataSet(AbstractDataSet):
class BioSequenceLocalDataSet(DefaultArgumentsMixIn, AbstractDataSet):
"""``BioSequenceLocalDataSet`` loads and saves data to a sequence file.
Example:
Expand Down Expand Up @@ -95,18 +96,7 @@ def __init__(
"""
self._filepath = filepath
default_load_args = {} # type: Dict[str, Any]
default_save_args = {} # type: Dict[str, Any]
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)

def _load(self) -> List:
return list(SeqIO.parse(self._filepath, **self._load_args))
Expand Down
53 changes: 53 additions & 0 deletions kedro/contrib/io/core.py
@@ -0,0 +1,53 @@
# Copyright 2018-2019 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module extends the set of classes ``kedro.io.core`` provides."""

import copy
from typing import Any, Dict, Optional


# pylint: disable=too-few-public-methods
class DefaultArgumentsMixIn:
"""Mixin class that helps handle default load and save arguments."""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]

def __init__(
self,
load_args: Optional[Dict[str, Any]] = None,
save_args: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
6 changes: 3 additions & 3 deletions kedro/contrib/io/pyspark/spark_data_set.py
Expand Up @@ -36,10 +36,11 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.utils import AnalysisException

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet


class SparkDataSet(AbstractDataSet):
class SparkDataSet(DefaultArgumentsMixIn, AbstractDataSet):
"""``SparkDataSet`` loads and saves Spark data frames.
Example:
Expand Down Expand Up @@ -106,8 +107,7 @@ def __init__(

self._filepath = filepath
self._file_format = file_format
self._load_args = load_args if load_args is not None else {}
self._save_args = save_args if save_args is not None else {}
super().__init__(load_args, save_args)

@staticmethod
def _get_spark():
Expand Down
6 changes: 3 additions & 3 deletions kedro/contrib/io/pyspark/spark_jdbc.py
Expand Up @@ -31,12 +31,13 @@

from pyspark.sql import DataFrame, SparkSession

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet, DataSetError

__all__ = ["SparkJDBCDataSet"]


class SparkJDBCDataSet(AbstractDataSet):
class SparkJDBCDataSet(DefaultArgumentsMixIn, AbstractDataSet):
"""``SparkJDBCDataSet`` loads data from a database table accessible
via JDBC URL url and connection properties and saves the content of
a PySpark DataFrame to an external database table via JDBC. It uses
Expand Down Expand Up @@ -140,8 +141,7 @@ def __init__(

self._url = url
self._table = table
self._load_args = load_args if load_args is not None else {}
self._save_args = save_args if save_args is not None else {}
super().__init__(load_args, save_args)

# Update properties in load_args and save_args with credentials.
if credentials is not None:
Expand Down
24 changes: 12 additions & 12 deletions kedro/io/csv_local.py
Expand Up @@ -30,6 +30,7 @@
underlying functionality is supported by pandas, so it supports all
allowed pandas options for loading and saving csv files.
"""
import copy
from pathlib import Path
from typing import Any, Dict

Expand Down Expand Up @@ -61,6 +62,9 @@ class CSVLocalDataSet(AbstractVersionedDataSet):
"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any]

def __init__(
self,
filepath: str,
Expand All @@ -87,18 +91,14 @@ def __init__(
attribute is None, save version will be autogenerated.
"""
super().__init__(Path(filepath), version)
default_save_args = {"index": False} # type: Dict[str, Any]
default_load_args = {} # type: Dict[str, Any]
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> pd.DataFrame:
load_path = Path(self._get_load_path())
Expand Down
21 changes: 14 additions & 7 deletions kedro/io/csv_s3.py
Expand Up @@ -29,7 +29,7 @@
"""``CSVS3DataSet`` loads and saves data to a file in S3. It uses s3fs
to read and write from S3 and pandas to handle the csv file.
"""
from copy import deepcopy
import copy
from pathlib import PurePosixPath
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -62,6 +62,9 @@ class CSVS3DataSet(AbstractVersionedDataSet):
>>> assert data.equals(reloaded)
"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any]

# pylint: disable=too-many-arguments
def __init__(
self,
Expand Down Expand Up @@ -94,21 +97,25 @@ def __init__(
attribute is None, save version will be autogenerated.
"""
_credentials = deepcopy(credentials) or {}
_credentials = copy.deepcopy(credentials) or {}
_s3 = S3FileSystem(client_kwargs=_credentials)
super().__init__(
PurePosixPath("{}/{}".format(bucket_name, filepath)),
version,
exists_function=_s3.exists,
glob_function=_s3.glob,
)
default_save_args = {"index": False}
self._save_args = (
{**default_save_args, **save_args} if save_args else default_save_args
)
self._load_args = load_args if load_args else {}
self._bucket_name = bucket_name
self._credentials = _credentials

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

self._s3 = _s3

def _describe(self) -> Dict[str, Any]:
Expand Down
25 changes: 12 additions & 13 deletions kedro/io/excel_local.py
Expand Up @@ -30,6 +30,7 @@
underlying functionality is supported by pandas, so it supports all
allowed pandas options for loading and saving Excel files.
"""
import copy
from pathlib import Path
from typing import Any, Dict, Union

Expand Down Expand Up @@ -61,6 +62,9 @@ class ExcelLocalDataSet(AbstractVersionedDataSet):
"""

DEFAULT_LOAD_ARGS = {"engine": "xlrd"}
DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -105,21 +109,16 @@ def __init__(
"""
super().__init__(Path(filepath), version)
default_save_args = {"index": False}
default_load_args = {"engine": "xlrd"}

self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
self._engine = engine

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]:
load_path = Path(self._get_load_path())
return pd.read_excel(load_path, **self._load_args)
Expand Down
24 changes: 12 additions & 12 deletions kedro/io/hdf_local.py
Expand Up @@ -30,6 +30,7 @@
underlying functionality is supported by pandas, so it supports all
allowed pandas options for loading and saving hdf files.
"""
import copy
from pathlib import Path
from typing import Any, Dict

Expand Down Expand Up @@ -63,6 +64,9 @@ class HDFLocalDataSet(AbstractVersionedDataSet):
"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]

# pylint: disable=too-many-arguments
def __init__(
self,
Expand Down Expand Up @@ -93,19 +97,15 @@ def __init__(
"""
super().__init__(Path(filepath), version)
default_load_args = {} # type: Dict[str, Any]
default_save_args = {} # type: Dict[str, Any]
self._key = key
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_load_args, **save_args}
if save_args is not None
else default_save_args
)

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> pd.DataFrame:
load_path = Path(self._get_load_path())
Expand Down

0 comments on commit 9733fc6

Please sign in to comment.