Skip to content

Commit

Permalink
Relax reconstructor checks. Add test for simple lists.
Browse files Browse the repository at this point in the history
  • Loading branch information
chkoar committed Feb 3, 2020
1 parent 92dab47 commit 35d7af9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
19 changes: 5 additions & 14 deletions imblearn/utils/_validation.py
Expand Up @@ -49,24 +49,15 @@ def _gets_props(self, array):

def _transfrom(self, array, props):
type_ = props["type"].lower()
msg = "Could not convert to {}".format(type_)
if type_ == "list":
ret = array.tolist()
elif type_ == "dataframe":
try:
import pandas as pd
ret = pd.DataFrame(array, columns=props["columns"])
ret = ret.astype(props["dtypes"])
except Exception:
warnings.warn(msg)
import pandas as pd
ret = pd.DataFrame(array, columns=props["columns"])
ret = ret.astype(props["dtypes"])
elif type_ == "series":
try:
import pandas as pd
ret = pd.Series(array,
dtype=props["dtypes"],
name=props["name"])
except Exception:
warnings.warn(msg)
import pandas as pd
ret = pd.Series(array, dtype=props["dtypes"], name=props["name"])
else:
ret = array
return ret
Expand Down
32 changes: 31 additions & 1 deletion imblearn/utils/estimator_checks.py
Expand Up @@ -258,7 +258,7 @@ def check_samplers_pandas(name, Sampler):
X_res_df, y_res_df = sampler.fit_resample(X_df, y_df)
X_res, y_res = sampler.fit_resample(X, y)

# check that we return the same type for dataframes or seires types
# check that we return the same type for dataframes or series types
assert isinstance(X_res_df, pd.DataFrame)
assert isinstance(y_res_df, pd.DataFrame)
assert isinstance(y_res_s, pd.Series)
Expand All @@ -272,6 +272,36 @@ def check_samplers_pandas(name, Sampler):
assert_allclose(y_res_s.to_numpy(), y_res)


def check_samplers_list(name, Sampler):
# Check that the can samplers handle simple lists
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X_list = X.tolist()
y_list = y.tolist()
sampler = Sampler()
if isinstance(Sampler(), NearMiss):
samplers = [Sampler(version=version) for version in (1, 2, 3)]

else:
samplers = [Sampler()]

for sampler in samplers:
set_random_state(sampler)
X_res, y_res = sampler.fit_resample(X, y)
X_res_list, y_res_list = sampler.fit_resample(X_list, y_list)

assert isinstance(X_res_list, list)
assert isinstance(y_res_list, list)

assert_allclose(X_res, X_res_list)
assert_allclose(y_res, y_res_list)


def check_samplers_multiclass_ova(name, Sampler):
# Check that multiclass target lead to the same results than OVA encoding
X, y = make_classification(
Expand Down

0 comments on commit 35d7af9

Please sign in to comment.