Skip to content

Commit

Permalink
bugfix: samples attr empty implies one sample ttest (#214)
Browse files Browse the repository at this point in the history
* bugfix: samples attr empty implies one sample ttest

* add tests for hypothesis bugfix
  • Loading branch information
cosmicBboy committed May 30, 2020
1 parent f5051fa commit 0b0d6cc
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 14 deletions.
6 changes: 4 additions & 2 deletions pandera/error_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ def __init__(self, lazy: bool) -> None:
self._lazy = lazy
self._collected_errors = [] # type: ignore

def collect_error(self, reason_code: str, schema_error: SchemaError):
def collect_error(
self, reason_code: str, schema_error: SchemaError,
original_exc: BaseException = None):
"""Collect schema error, raising exception if lazy is False.
:param reason_code: string representing reason for error
:param schema_error: ``SchemaError`` object.
"""
if not self._lazy:
raise schema_error
raise schema_error from original_exc

# delete data of validated object from SchemaError object to prevent
# storing copies of the validated DataFrame/Series for every
Expand Down
35 changes: 24 additions & 11 deletions pandera/hypotheses.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
@property
def is_one_sample_test(self):
"""Return True if hypothesis is a one-sample test."""
return len(self.samples) == 1
return len(self.samples) <= 1

def _prepare_series_input(
self,
Expand Down Expand Up @@ -202,16 +202,16 @@ def _relationships(self, relationship: Union[str, Callable]):
)
return relationship

def _hypothesis_check(self, check_obj: Dict[str, pd.Series]):
def _hypothesis_check(
self, check_obj: Union[pd.Series, Dict[str, pd.Series]]
):
"""Create a function fn which is checked via the Check parent class.
:param dict check_obj: a dictionary of pd.Series to be used by
`_hypothesis_check` and `_vectorized_check`
"""
if self.is_one_sample_test:
# one-sample case where no groupby argument supplied, apply to
# entire column
if isinstance(check_obj, pd.Series):
return self.relationship(*self.test(check_obj))
return self.relationship(
*self.test(*[check_obj.get(s) for s in self.samples]))
Expand Down Expand Up @@ -330,18 +330,31 @@ def two_sample_ttest(
@classmethod
def one_sample_ttest(
cls,
sample: str,
popmean: float,
relationship: str,
sample: Optional[str] = None,
groupby: Union[str, List[str], Callable, None] = None,
relationship: str = "equal",
alpha: float = DEFAULT_ALPHA,
raise_warning=False,
):
"""Calculate a t-test for the mean of one sample.
:param sample: The sample group to test. For `Column` and
`SeriesSchema` hypotheses, refers to the `groupby` level in the
`Column`. For `DataFrameSchema` hypotheses, refers to column in
the `DataFrame`.
`SeriesSchema` hypotheses, this refers to the `groupby` level that
is used to subset the `Column` being checked. For `DataFrameSchema`
hypotheses, refers to column in the `DataFrame`.
:param groupby: If a string or list of strings is provided, then these
columns are used to group the Column Series by `groupby`. If a
callable is passed, the expected signature is
DataFrame -> DataFrameGroupby. The function has access to the
entire dataframe, but the Column.name is selected from this
DataFrameGroupby object so that a SeriesGroupBy object is passed
into `fn`.
Specifying this argument changes the `fn` signature to:
dict[str|tuple[str], Series] -> bool|pd.Series[bool]
Where specific groups can be obtained from the input dict.
:param popmean: population mean to compare `sample` to.
:param relationship: Represents what relationship conditions are
imposed on the hypothesis test. Available relationships
Expand Down Expand Up @@ -369,7 +382,6 @@ def one_sample_ttest(
... "height_in_feet": pa.Column(
... pa.Float, [
... pa.Hypothesis.one_sample_ttest(
... sample="height_in_feet",
... popmean=5,
... relationship="greater_than",
... alpha=0.1),
Expand All @@ -396,6 +408,7 @@ def one_sample_ttest(
return cls(
test=stats.ttest_1samp,
samples=sample,
groupby=groupby,
relationship=relationship,
test_kwargs={"popmean": popmean},
relationship_kwargs={"alpha": alpha},
Expand Down
3 changes: 2 additions & 1 deletion pandera/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,8 @@ def validate(
failure_cases=scalar_failure_case(err_str),
check=check,
check_index=check_index
)
),
original_exc=err
)

if lazy and error_handler.collected_errors:
Expand Down
37 changes: 37 additions & 0 deletions tests/test_hypotheses.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,40 @@ def test_two_sample_ttest_hypothesis_relationships():
]),
"sex": Column(String)
})


def test_one_sample_hypothesis():
"""Check one sample ttest."""
schema = DataFrameSchema({
"height_in_feet": Column(
Float, [
Hypothesis.one_sample_ttest(
popmean=5,
relationship="greater_than",
alpha=0.1),
]
),
})

subset_schema = DataFrameSchema({
"group": Column(String),
"height_in_feet": Column(
Float, [
Hypothesis.one_sample_ttest(
sample="A",
groupby="group",
popmean=5,
relationship="greater_than",
alpha=0.1),
]
),
})

df = (
pd.DataFrame({
"height_in_feet": [8.1, 7, 6.5, 6.7, 5.1],
"group": ["A", "A", "B", "B", "A"],
})
)
schema.validate(df)
subset_schema.validate(df)

0 comments on commit 0b0d6cc

Please sign in to comment.