Skip to content

Commit

Permalink
EHN Add pandas dataframe support to fetch_openml (#13902)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored and glemaitre committed Jul 12, 2019
1 parent da66111 commit cf3e303
Show file tree
Hide file tree
Showing 9 changed files with 635 additions and 65 deletions.
1 change: 0 additions & 1 deletion .circleci/config.yml
Expand Up @@ -12,7 +12,6 @@ jobs:
- PYTHON_VERSION: 3.5
- NUMPY_VERSION: 1.11.0
- SCIPY_VERSION: 0.17.0
- PANDAS_VERSION: 0.18.0
- MATPLOTLIB_VERSION: 1.5.1
- SCIKIT_IMAGE_VERSION: 0.12.3
steps:
Expand Down
1 change: 1 addition & 0 deletions azure-pipelines.yml
Expand Up @@ -20,6 +20,7 @@ jobs:
INSTALL_MKL: 'false'
NUMPY_VERSION: '1.11.0'
SCIPY_VERSION: '0.17.0'
PANDAS_VERSION: '*'
CYTHON_VERSION: '*'
PILLOW_VERSION: '4.0.0'
MATPLOTLIB_VERSION: '1.5.1'
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v0.22.rst
Expand Up @@ -43,6 +43,12 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.
:mod:`sklearn.datasets`
.......................

- |Feature| :func:`datasets.fetch_openml` now supports heterogeneous data using pandas
by setting `as_frame=True`. :pr:`13902` by `Thomas Fan`_.

:mod:`sklearn.decomposition`
............................

Expand Down
15 changes: 8 additions & 7 deletions examples/compose/plot_column_transformer_mixed_types.py
Expand Up @@ -24,10 +24,10 @@
#
# License: BSD 3 clause

import pandas as pd
import numpy as np

from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
Expand All @@ -37,9 +37,13 @@
np.random.seed(0)

# Read data from Titanic dataset.
titanic_url = ('https://raw.githubusercontent.com/amueller/'
'scipy-2017-sklearn/091d371/notebooks/datasets/titanic3.csv')
data = pd.read_csv(titanic_url)
titantic = fetch_openml(data_id=40945, as_frame=True)
X = titantic.data
y = titantic.target

# Alternatively X and y can be obtained directly from the frame attribute:
# X = titantic.frame.drop('survived', axis=1)
# y = titantic.frame['survived']

# We will train our classifier with the following features:
# Numeric Features:
Expand Down Expand Up @@ -71,9 +75,6 @@
clf = Pipeline(steps=[('preprocessor', preprocessor),
('classifier', LogisticRegression())])

X = data.drop('survived', axis=1)
y = data['survived']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

clf.fit(X_train, y_train)
Expand Down
212 changes: 158 additions & 54 deletions sklearn/datasets/openml.py
Expand Up @@ -8,6 +8,7 @@
from functools import wraps
import itertools
from collections.abc import Generator
from collections import OrderedDict

from urllib.request import urlopen, Request

Expand All @@ -18,6 +19,9 @@
from .base import get_data_home
from urllib.error import HTTPError
from ..utils import Bunch
from ..utils import get_chunk_n_rows
from ..utils import _chunk_generator
from ..utils import check_pandas_support # noqa

__all__ = ['fetch_openml']

Expand Down Expand Up @@ -263,6 +267,69 @@ def _convert_arff_data(arff_data, col_slice_x, col_slice_y, shape=None):
raise ValueError('Unexpected Data Type obtained from arff.')


def _feature_to_dtype(feature):
"""Map feature to dtype for pandas DataFrame
"""
if feature['data_type'] == 'string':
return object
elif feature['data_type'] == 'nominal':
return 'category'
# only numeric, integer, real are left
elif (feature['number_of_missing_values'] != '0' or
feature['data_type'] in ['numeric', 'real']):
# cast to floats when there are any missing values
return np.float64
elif feature['data_type'] == 'integer':
return np.int64
raise ValueError('Unsupported feature: {}'.format(feature))


def _convert_arff_data_dataframe(arrf, columns, features_dict):
"""Convert the ARFF object into a pandas DataFrame.
Parameters
----------
arrf : dict
As obtained from liac-arff object.
columns : list
Columns from dataframe to return.
features_dict : dict
Maps feature name to feature info from openml.
Returns
-------
dataframe : pandas DataFrame
"""
pd = check_pandas_support('fetch_openml with as_frame=True')

attributes = OrderedDict(arrf['attributes'])
arrf_columns = list(attributes)

# calculate chunksize
first_row = next(arrf['data'])
first_df = pd.DataFrame([first_row], columns=arrf_columns)

row_bytes = first_df.memory_usage(deep=True).sum()
chunksize = get_chunk_n_rows(row_bytes)

# read arrf data with chunks
columns_to_keep = [col for col in arrf_columns if col in columns]
dfs = []
dfs.append(first_df[columns_to_keep])
for data in _chunk_generator(arrf['data'], chunksize):
dfs.append(pd.DataFrame(data, columns=arrf_columns)[columns_to_keep])
df = pd.concat(dfs)

for column in columns_to_keep:
dtype = _feature_to_dtype(features_dict[column])
if dtype == 'category':
dtype = pd.api.types.CategoricalDtype(attributes[column])
df[column] = df[column].astype(dtype, copy=False)
return df


def _get_data_info_by_name(name, version, data_home):
"""
Utilizes the openml dataset listing api to find a dataset by
Expand Down Expand Up @@ -436,7 +503,8 @@ def _valid_data_column_names(features_list, target_columns):


def fetch_openml(name=None, version='active', data_id=None, data_home=None,
target_column='default-target', cache=True, return_X_y=False):
target_column='default-target', cache=True, return_X_y=False,
as_frame=False):
"""Fetch dataset from openml by name or dataset id.
Datasets are uniquely identified by either an integer ID or by a
Expand Down Expand Up @@ -489,26 +557,39 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
If True, returns ``(data, target)`` instead of a Bunch object. See
below for more information about the `data` and `target` objects.
as_frame : boolean, default=False
If True, the data is a pandas DataFrame including columns with
appropriate dtypes (numeric, string or categorical). The target is
a pandas DataFrame or Series depending on the number of target_columns.
The Bunch will contain a ``frame`` attribute with the target and the
data. If ``return_X_y`` is True, then ``(data, target)`` will be pandas
DataFrames or Series as describe above.
Returns
-------
data : Bunch
Dictionary-like object, with attributes:
data : np.array or scipy.sparse.csr_matrix of floats
data : np.array, scipy.sparse.csr_matrix of floats, or pandas DataFrame
The feature matrix. Categorical features are encoded as ordinals.
target : np.array
target : np.array, pandas Series or DataFrame
The regression target or classification labels, if applicable.
Dtype is float if numeric, and object if categorical.
Dtype is float if numeric, and object if categorical. If
``as_frame`` is True, ``target`` is a pandas object.
DESCR : str
The full description of the dataset
feature_names : list
The names of the dataset columns
categories : dict
categories : dict or None
Maps each categorical feature name to a list of values, such
that the value encoded as i is ith in the list.
that the value encoded as i is ith in the list. If ``as_frame``
is True, this is None.
details : dict
More metadata from OpenML
frame : pandas DataFrame
Only present when `as_frame=True`. DataFrame with ``data`` and
``target``.
(data, target) : tuple if ``return_X_y`` is True
Expand Down Expand Up @@ -568,41 +649,52 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
warn("OpenML raised a warning on the dataset. It might be "
"unusable. Warning: {}".format(data_description['warning']))

return_sparse = False
if data_description['format'].lower() == 'sparse_arff':
return_sparse = True

if as_frame and return_sparse:
raise ValueError('Cannot return dataframe with sparse data')

# download data features, meta-info about column types
features_list = _get_data_features(data_id, data_home)

for feature in features_list:
if 'true' in (feature['is_ignore'], feature['is_row_identifier']):
continue
if feature['data_type'] == 'string':
raise ValueError('STRING attributes are not yet supported')
if not as_frame:
for feature in features_list:
if 'true' in (feature['is_ignore'], feature['is_row_identifier']):
continue
if feature['data_type'] == 'string':
raise ValueError('STRING attributes are not supported for '
'array representation. Try as_frame=True')

if target_column == "default-target":
# determines the default target based on the data feature results
# (which is currently more reliable than the data description;
# see issue: https://github.com/openml/OpenML/issues/768)
target_column = [feature['name'] for feature in features_list
if feature['is_target'] == 'true']
target_columns = [feature['name'] for feature in features_list
if feature['is_target'] == 'true']
elif isinstance(target_column, str):
# for code-simplicity, make target_column by default a list
target_column = [target_column]
target_columns = [target_column]
elif target_column is None:
target_column = []
elif not isinstance(target_column, list):
target_columns = []
elif isinstance(target_column, list):
target_columns = target_column
else:
raise TypeError("Did not recognize type of target_column"
"Should be str, list or None. Got: "
"{}".format(type(target_column)))
data_columns = _valid_data_column_names(features_list,
target_column)
target_columns)

# prepare which columns and data types should be returned for the X and y
features_dict = {feature['name']: feature for feature in features_list}

# XXX: col_slice_y should be all nominal or all numeric
_verify_target_data_type(features_dict, target_column)
_verify_target_data_type(features_dict, target_columns)

col_slice_y = [int(features_dict[col_name]['index'])
for col_name in target_column]
for col_name in target_columns]

col_slice_x = [int(features_dict[col_name]['index'])
for col_name in data_columns]
Expand All @@ -615,10 +707,6 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
'columns. '.format(feat['name'], nr_missing))

# determine arff encoding to return
return_sparse = False
if data_description['format'].lower() == 'sparse_arff':
return_sparse = True

if not return_sparse:
data_qualities = _get_data_qualities(data_id, data_home)
shape = _get_data_shape(data_qualities)
Expand All @@ -631,46 +719,62 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,

# obtain the data
arff = _download_data_arff(data_description['file_id'], return_sparse,
data_home)

# nominal attributes is a dict mapping from the attribute name to the
# possible values. Includes also the target column (which will be popped
# off below, before it will be packed in the Bunch object)
nominal_attributes = {k: v for k, v in arff['attributes']
if isinstance(v, list) and
k in data_columns + target_column}

X, y = _convert_arff_data(arff['data'], col_slice_x, col_slice_y, shape)

is_classification = {col_name in nominal_attributes
for col_name in target_column}
if not is_classification:
# No target
pass
elif all(is_classification):
y = np.hstack([np.take(np.asarray(nominal_attributes.pop(col_name),
dtype='O'),
y[:, i:i+1].astype(int, copy=False))
for i, col_name in enumerate(target_column)])
elif any(is_classification):
raise ValueError('Mix of nominal and non-nominal targets is not '
'currently supported')
data_home, encode_nominal=not as_frame)

description = "{}\n\nDownloaded from openml.org.".format(
data_description.pop('description'))

# reshape y back to 1-D array, if there is only 1 target column; back
# to None if there are not target columns
if y.shape[1] == 1:
y = y.reshape((-1,))
elif y.shape[1] == 0:
y = None
nominal_attributes = None
frame = None
if as_frame:
columns = data_columns + target_columns
frame = _convert_arff_data_dataframe(arff, columns, features_dict)
X = frame[data_columns]
if len(target_columns) >= 2:
y = frame[target_columns]
elif len(target_columns) == 1:
y = frame[target_columns[0]]
else:
y = None
else:
# nominal attributes is a dict mapping from the attribute name to the
# possible values. Includes also the target column (which will be
# popped off below, before it will be packed in the Bunch object)
nominal_attributes = {k: v for k, v in arff['attributes']
if isinstance(v, list) and
k in data_columns + target_columns}

X, y = _convert_arff_data(arff['data'], col_slice_x,
col_slice_y, shape)

is_classification = {col_name in nominal_attributes
for col_name in target_columns}
if not is_classification:
# No target
pass
elif all(is_classification):
y = np.hstack([
np.take(
np.asarray(nominal_attributes.pop(col_name), dtype='O'),
y[:, i:i + 1].astype(int, copy=False))
for i, col_name in enumerate(target_columns)
])
elif any(is_classification):
raise ValueError('Mix of nominal and non-nominal targets is not '
'currently supported')

# reshape y back to 1-D array, if there is only 1 target column; back
# to None if there are not target columns
if y.shape[1] == 1:
y = y.reshape((-1,))
elif y.shape[1] == 0:
y = None

if return_X_y:
return X, y

bunch = Bunch(
data=X, target=y, feature_names=data_columns,
data=X, target=y, frame=frame, feature_names=data_columns,
DESCR=description, details=data_description,
categories=nominal_attributes,
url="https://www.openml.org/d/{}".format(data_id))
Expand Down
Binary file not shown.
Binary file not shown.

0 comments on commit cf3e303

Please sign in to comment.