Skip to content

Commit

Permalink
add validate(..., inplace=False) keyword to prevent mutation (#305)
Browse files Browse the repository at this point in the history
* add validate(..., inplace=False) keyword

Fixes #301

This diff addresses an issue where schemas mutate the original
df to be validated. By default, the validate method will create
a copy of the dataframe before coercing the type. Users can
still specify inplace=True in cases where mutating the original
dataframe doesn't matter, e.g. at the end of a method chain
where the original df doesn't need to be preserved.

* fix pylint

* fix inplace validation tests for windows

* use inplace in decorators, set line length to 79

* update travis ci line length

* revert line length

* apply new line length param in isort and black to all files

* fix test

* fix tests
  • Loading branch information
cosmicBboy committed Oct 24, 2020
1 parent 5a3ed31 commit 586ebf3
Show file tree
Hide file tree
Showing 29 changed files with 861 additions and 291 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ repos:
rev: v5.6.4
hooks:
- id: isort
args: ["--line-length=79"]

- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
args: ["--line-length=79"]

- repo: https://github.com/pycqa/pylint
rev: pylint-2.6.0
Expand Down
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ script:
# Check that requirements-dev.text is generated exclusively by environment.yml
- python ./scripts/generate_pip_deps_from_conda.py --compare
# Formatting
- isort --check-only pandera tests
- black --check pandera tests
- isort --line-length=79 --check-only pandera tests
- black --line-length=79 --check pandera tests
# Linting
- pylint pandera tests
# Type checking
Expand Down
73 changes: 52 additions & 21 deletions pandera/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
)


GroupbyObject = Union[pd.core.groupby.SeriesGroupBy, pd.core.groupby.DataFrameGroupBy]
GroupbyObject = Union[
pd.core.groupby.SeriesGroupBy, pd.core.groupby.DataFrameGroupBy
]

SeriesCheckObj = Union[pd.Series, Dict[str, pd.Series]]
DataFrameCheckObj = Union[pd.DataFrame, Dict[str, pd.DataFrame]]
Expand Down Expand Up @@ -177,7 +179,9 @@ def __init__(
"""

if element_wise and groupby is not None:
raise errors.SchemaInitError("Cannot use groupby when element_wise=True.")
raise errors.SchemaInitError(
"Cannot use groupby when element_wise=True."
)
self._check_fn = check_fn
self._check_kwargs = check_kwargs
self.element_wise = element_wise
Expand Down Expand Up @@ -236,11 +240,15 @@ def _format_groupby_input(
"key. Valid group keys: %s" % (invalid_groups, group_keys)
)
return {
group_key: group for group_key, group in groupby_obj if group_key in groups
group_key: group
for group_key, group in groupby_obj
if group_key in groups
}

def _prepare_series_input(
self, series: pd.Series, dataframe_context: Optional[pd.DataFrame] = None
self,
series: pd.Series,
dataframe_context: Optional[pd.DataFrame] = None,
) -> SeriesCheckObj:
"""Prepare input for Column check.
Expand All @@ -260,13 +268,15 @@ def _prepare_series_input(
).groupby(self.groupby)[series.name]
return self._format_groupby_input(groupby_obj, self.groups)
if callable(self.groupby):
groupby_obj = self.groupby(pd.concat([series, dataframe_context], axis=1))[
series.name
]
groupby_obj = self.groupby(
pd.concat([series, dataframe_context], axis=1)
)[series.name]
return self._format_groupby_input(groupby_obj, self.groups)
raise TypeError("Type %s not recognized for `groupby` argument.")

def _prepare_dataframe_input(self, dataframe: pd.DataFrame) -> DataFrameCheckObj:
def _prepare_dataframe_input(
self, dataframe: pd.DataFrame
) -> DataFrameCheckObj:
"""Prepare input for DataFrameSchema check.
:param dataframe: dataframe to validate.
Expand All @@ -279,7 +289,9 @@ def _prepare_dataframe_input(self, dataframe: pd.DataFrame) -> DataFrameCheckObj
return self._format_groupby_input(groupby_obj, self.groups)

def _handle_na(
self, df_or_series: Union[pd.DataFrame, pd.Series], column: Optional[str] = None
self,
df_or_series: Union[pd.DataFrame, pd.Series],
column: Optional[str] = None,
):
"""Handle nan values before passing object to check function."""
if not self.ignore_na:
Expand All @@ -294,14 +306,19 @@ def _handle_na(
for col in self.groupby:
# raise schema definition error if column is not in the
# validated dataframe
if isinstance(df_or_series, pd.DataFrame) and col not in df_or_series:
if (
isinstance(df_or_series, pd.DataFrame)
and col not in df_or_series
):
raise errors.SchemaDefinitionError(
"`groupby` column '%s' not found" % col
)
drop_na_columns.extend(self.groupby)

if drop_na_columns:
return df_or_series.loc[df_or_series[drop_na_columns].dropna().index]
return df_or_series.loc[
df_or_series[drop_na_columns].dropna().index
]
return df_or_series.dropna()

def __call__(
Expand Down Expand Up @@ -335,7 +352,9 @@ def __call__(

column_dataframe_context = None
if column is not None and isinstance(df_or_series, pd.DataFrame):
column_dataframe_context = df_or_series.drop(column, axis="columns")
column_dataframe_context = df_or_series.drop(
column, axis="columns"
)
df_or_series = df_or_series[column].copy()

# prepare check object
Expand Down Expand Up @@ -389,7 +408,8 @@ def __call__(
)
else:
raise TypeError(
"output type of check_fn not recognized: %s" % type(check_output)
"output type of check_fn not recognized: %s"
% type(check_output)
)

check_passed = (
Expand All @@ -400,7 +420,9 @@ def __call__(
else check_output
)

return CheckResult(check_output, check_passed, check_obj, failure_cases)
return CheckResult(
check_output, check_passed, check_obj, failure_cases
)

def __eq__(self, other):
are_fn_objects_equal = (
Expand Down Expand Up @@ -609,7 +631,9 @@ def _less_or_equal(series: pd.Series) -> pd.Series:
le = less_than_or_equal_to

@classmethod
@register_check_statistics(["min_value", "max_value", "include_min", "include_max"])
@register_check_statistics(
["min_value", "max_value", "include_min", "include_max"]
)
def in_range(
cls, min_value, max_value, include_min=True, include_max=True, **kwargs
) -> "Check":
Expand Down Expand Up @@ -683,7 +707,8 @@ def isin(cls, allowed_values: Iterable, **kwargs) -> "Check":
allowed_values = frozenset(allowed_values)
except TypeError as exc:
raise ValueError(
"Argument allowed_values must be iterable. Got %s" % allowed_values
"Argument allowed_values must be iterable. Got %s"
% allowed_values
) from exc

def _isin(series: pd.Series) -> pd.Series:
Expand Down Expand Up @@ -722,7 +747,8 @@ def notin(cls, forbidden_values: Iterable, **kwargs) -> "Check":
forbidden_values = frozenset(forbidden_values)
except TypeError as exc:
raise ValueError(
"Argument forbidden_values must be iterable. Got %s" % forbidden_values
"Argument forbidden_values must be iterable. Got %s"
% forbidden_values
) from exc

def _notin(series: pd.Series) -> pd.Series:
Expand Down Expand Up @@ -753,7 +779,8 @@ def str_matches(cls, pattern: str, **kwargs) -> "Check":
regex = re.compile(pattern)
except TypeError as exc:
raise ValueError(
'pattern="%s" cannot be compiled as regular expression' % pattern
'pattern="%s" cannot be compiled as regular expression'
% pattern
) from exc

def _match(series: pd.Series) -> pd.Series:
Expand Down Expand Up @@ -786,7 +813,8 @@ def str_contains(cls, pattern: str, **kwargs) -> "Check":
regex = re.compile(pattern)
except TypeError as exc:
raise ValueError(
'pattern="%s" cannot be compiled as regular expression' % pattern
'pattern="%s" cannot be compiled as regular expression'
% pattern
) from exc

def _contains(series: pd.Series) -> pd.Series:
Expand Down Expand Up @@ -859,7 +887,8 @@ def str_length(
"""
if min_value is None and max_value is None:
raise ValueError(
"At least a minimum or a maximum need to be specified. Got " "None."
"At least a minimum or a maximum need to be specified. Got "
"None."
)
if max_value is None:

Expand All @@ -877,7 +906,9 @@ def _str_length(series: pd.Series) -> pd.Series:

def _str_length(series: pd.Series) -> pd.Series:
"""Check for both, minimum and maximum string length"""
return (series.str.len() <= max_value) & (series.str.len() >= min_value)
return (series.str.len() <= max_value) & (
series.str.len() >= min_value
)

return cls(
_str_length,
Expand Down

0 comments on commit 586ebf3

Please sign in to comment.