-
Notifications
You must be signed in to change notification settings - Fork 2
/
cat.py
79 lines (73 loc) · 3.11 KB
/
cat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# -*- coding: utf-8 -*-
"""
@file
@brief Jeux de données reliés aux catégories.
"""
import os
from io import StringIO, BytesIO
import pandas
from pyquickhelper.filehelper import read_content_ufs, ungzip_files
from .data_helper import get_data_folder
def load_adult_dataset(download=True, small=False, url='uci'):
"""
Retourne le jeu de données
`Adult Data Set <https://archive.ics.uci.edu/ml/datasets/adult>`_.
Les variables sont principalement catégorielles.
Notebooks associés à ce jeu de données :
.. runpython::
:rst:
from papierstat.datasets.documentation import list_notebooks_rst_links
links = list_notebooks_rst_links('lectures', 'adult')
links = [' * %s' % s for s in links]
print('\\n'.join(links))
@param download télécharge le jeu de données ou considères une copie en local.
@param small récupère une version allégée en local
@param url source
@return :epkg:`pandas:DataFrame` (train, test)
"""
columns = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status',
'occupation', 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss',
'hours_per_week', 'native_country', '<=50K']
if small:
fold = get_data_folder()
data_train = os.path.join(fold, 'adult.data.gz')
data_test = os.path.join(fold, 'adult.test.gz')
train = pandas.read_csv(data_train, header=None)
test = pandas.read_csv(data_test, header=None)
train.columns = columns
test.columns = columns
elif download:
if url == 'uci':
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/"
train = pandas.read_csv(url + "adult.data", header=None)
test = pandas.read_csv(url + "adult.test", header=None, skiprows=1)
else:
url = "http://www.xavierdupre.fr/enseignement/complements/"
tr = read_content_ufs(url + "adult.data.gz",
asbytes=True, encoding=None,
min_size=400000)
by = BytesIO(tr)
tx = ungzip_files(by, unzip=False)
st = StringIO(tx.decode('ascii'))
train = pandas.read_csv(st, header=None)
te = read_content_ufs(url + "adult.test.gz",
asbytes=True, encoding=None,
min_size=200000)
by = BytesIO(te)
tx = ungzip_files(by, unzip=False)
st = StringIO(tx.decode('ascii'))
test = pandas.read_csv(st, header=None, skiprows=1)
train.columns = columns
test.columns = columns
else:
raise NotImplementedError( # pragma: no cover
"No local copy")
label = '<=50K'
train[label] = train[label].str.strip(' .')
test[label] = test[label].str.strip(' .')
cols = train.select_dtypes(object).columns
for c in cols:
train[c] = train[c].str.strip() # pylint: disable=E1136,E1137
for c in cols:
test[c] = test[c].str.strip()
return train, test