New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EHN Add parameter as_frame to fetch_covtype #17491
Changes from 14 commits
322ea45
c5f17eb
9ab6f5c
24df8fa
5a34ebf
04feec5
2413b3b
f8ff415
0134c73
7e06f8f
8fc854e
6729f9a
7b77c03
f9ee162
5ecc820
75c5e8c
335b4bf
85718f3
93053f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
import joblib | ||
|
||
from . import get_data_home | ||
from ._base import _convert_data_dataframe | ||
from ._base import _fetch_remote | ||
from ._base import RemoteFileMetadata | ||
from ..utils import Bunch | ||
|
@@ -41,10 +42,27 @@ | |
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Column names reference: | ||
# https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.info | ||
FEATURE_NAMES = ["Elevation", | ||
"Aspect", | ||
"Slope", | ||
"Horizontal_Distance_To_Hydrology", | ||
"Vertical_Distance_To_Hydrology", | ||
"Horizontal_Distance_To_Roadways", | ||
"Hillshade_9am", | ||
"Hillshade_Noon", | ||
"Hillshade_3pm", | ||
"Horizontal_Distance_To_Fire_Points"] | ||
FEATURE_NAMES += [f"Wilderness_Area_{i}" for i in range(4)] | ||
FEATURE_NAMES += [f"Soil_Type_{i}" for i in range(40)] | ||
TARGET_NAMES = ["Cover_Type"] | ||
|
||
|
||
@_deprecate_positional_args | ||
def fetch_covtype(*, data_home=None, download_if_missing=True, | ||
random_state=None, shuffle=False, return_X_y=False): | ||
random_state=None, shuffle=False, return_X_y=False, | ||
as_frame=False): | ||
"""Load the covertype dataset (classification). | ||
|
||
Download it if necessary. | ||
|
@@ -82,6 +100,15 @@ def fetch_covtype(*, data_home=None, download_if_missing=True, | |
|
||
.. versionadded:: 0.20 | ||
|
||
as_frame : bool, default=False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the part that's missing in the other PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, thanks |
||
If True, the data is a pandas DataFrame including columns with | ||
appropriate dtypes (numeric). The target is a pandas DataFrame or | ||
Series depending on the number of target columns. If `return_X_y` is | ||
True, then (`data`, `target`) will be pandas DataFrames or Series as | ||
described below. | ||
|
||
.. versionadded:: 0.24 | ||
|
||
Returns | ||
------- | ||
dataset : :class:`~sklearn.utils.Bunch` | ||
|
@@ -93,12 +120,19 @@ def fetch_covtype(*, data_home=None, download_if_missing=True, | |
Each value corresponds to one of | ||
the 7 forest covertypes with values | ||
ranging between 1 to 7. | ||
frame : dataframe of shape (581012, 53) | ||
Only present when `as_frame=True`. Contains `data` and `target`. | ||
DESCR : str | ||
Description of the forest covertype dataset. | ||
feature_names : list | ||
The names of the dataset columns | ||
target_names: list | ||
The names of the target columns | ||
|
||
(data, target) : tuple if ``return_X_y`` is True | ||
|
||
.. versionadded:: 0.20 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be changed, I think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changing it back |
||
|
||
""" | ||
|
||
data_home = get_data_home(data_home=data_home) | ||
|
@@ -142,7 +176,19 @@ def fetch_covtype(*, data_home=None, download_if_missing=True, | |
with open(join(module_path, 'descr', 'covtype.rst')) as rst_file: | ||
fdescr = rst_file.read() | ||
|
||
frame = None | ||
if as_frame: | ||
frame, X, y = _convert_data_dataframe(caller_name="fetch_covtype", | ||
data=X, | ||
target=y, | ||
feature_names=FEATURE_NAMES, | ||
target_names=TARGET_NAMES) | ||
if return_X_y: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adjusted accordingly, thanks for pointing it out. |
||
return X, y | ||
|
||
return Bunch(data=X, target=y, DESCR=fdescr) | ||
return Bunch(data=X, | ||
target=y, | ||
frame=frame, | ||
target_names=TARGET_NAMES, | ||
feature_names=FEATURE_NAMES, | ||
DESCR=fdescr) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
"""Test the covtype loader, if the data is available, | ||
or if specifically requested via environment variable | ||
(e.g. for travis cron job).""" | ||
|
||
from sklearn.datasets.tests.test_common import check_return_X_y | ||
from functools import partial | ||
import pytest | ||
from sklearn.datasets.tests.test_common import check_return_X_y | ||
|
||
|
||
def test_fetch(fetch_covtype_fxt): | ||
|
@@ -23,3 +23,24 @@ def test_fetch(fetch_covtype_fxt): | |
# test return_X_y option | ||
fetch_func = partial(fetch_covtype_fxt) | ||
check_return_X_y(data1, fetch_func) | ||
|
||
|
||
def test_fetch_asframe(fetch_covtype_fxt): | ||
bunch = fetch_covtype_fxt(as_frame=True) | ||
assert hasattr(bunch, 'frame') | ||
frame = bunch.frame | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assert below will be useful because the test will already crash here if frame does not exist There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DONE. |
||
assert frame.shape == (581012, 55) | ||
|
||
column_names = set(frame.columns) | ||
|
||
# enumerated names are added correctly | ||
assert set(f"Wilderness_Area_{i}" for i in range(4)) < column_names | ||
assert set(f"Soil_Type_{i}" for i in range(40)) < column_names | ||
|
||
|
||
def test_pandas_dependency_message(fetch_covtype_fxt, | ||
hide_available_pandas): | ||
expected_msg = ('fetch_covtype with as_frame=True' | ||
' requires pandas') | ||
with pytest.raises(ImportError, match=expected_msg): | ||
fetch_covtype_fxt(as_frame=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will add it back, thanks for pointing it out.