-
Notifications
You must be signed in to change notification settings - Fork 91
/
tabular.py
165 lines (149 loc) · 7.11 KB
/
tabular.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#
# Copyright (c) 2023 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
"""
The class for re-sampling training data.
"""
import pandas as pd
from ..data.tabular import Tabular
class Sampler:
"""
The class for re-sampling training data. It includes sub-sampling, under-sampling
and over-sampling.
"""
@staticmethod
def _get_categorical_values(df, categorical_columns):
"""
Gets the categorical feature values.
:param df: The input dataframe.
:param categorical_columns: A list of categorical feature names.
:return: A dict whose keys are feature names and values are feature values.
:rtype: Dict
"""
if categorical_columns is None or len(categorical_columns) == 0:
return {}
categorical_values = {}
for col in categorical_columns:
categorical_values[col] = set(df[col].values)
return categorical_values
@staticmethod
def _find_extra_samples(df, feature_name, feature_value, n=1):
"""
Returns a sub-dataframe whose column ``feature_name`` contains value ``feature_value``.
:param df: The input dataframe.
:param feature_name: The feature name.
:param feature_value: The feature value.
:param n: The number of rows to select.
:return: The selected rows with ``feature_name = feature_value``.
:rtype: pd.DataFrame
"""
x = df[df[feature_name] == feature_value]
return x.head(n)
@staticmethod
def _add_extra_samples(original_df, sampled_df, categorical_columns):
"""
Checks if all the categorical feature values in ``original_df`` are included in ``sampled_df``.
If there are some values that are not included, some additional examples extracted from ``original_df``
will be added into ``sampled_df``. These examples contains the missing feature values.
:param original_df: The original dataframe.
:param sampled_df: The sampled dataframe (via subsampling, undersampling, etc.).
:param categorical_columns: A list of categorical feature names.
:return: A new dataframe containing all the categorical feature values in ``original_df``.
:rtype: pd.DataFrame
"""
dfs = [sampled_df]
cate_a = Sampler._get_categorical_values(original_df, categorical_columns)
cate_b = Sampler._get_categorical_values(sampled_df, categorical_columns)
for col in cate_a.keys():
a, b = cate_a[col], cate_b[col]
for value in a.difference(b):
dfs.append(Sampler._find_extra_samples(original_df, col, value))
return pd.concat(dfs)
@staticmethod
def subsample(tabular_data: Tabular, fraction: float, random_state=None) -> Tabular:
"""
Samples a subset of the input dataset. It guarantees that all the categorical values
are included in the sampled dataframe, i.e., there will be no missing categorical values.
:param tabular_data: The input tabular data.
:param fraction: The fraction of the sampled instance.
:param random_state: The random seed.
:return: A subset extracted from ``tabular_data``.
:rtype: Tabular
"""
df = tabular_data.to_pd(copy=False)
if tabular_data.target_column is None:
samples = df.sample(frac=fraction, random_state=random_state)
else:
dfs = []
for label in df[tabular_data.target_column].unique():
split = df[df[tabular_data.target_column] == label]
dfs.append(split.sample(frac=fraction, random_state=random_state))
samples = pd.concat(dfs)
# Add additional samples to make sure no categorical values are missing
new_df = Sampler._add_extra_samples(
original_df=df, sampled_df=samples, categorical_columns=tabular_data.categorical_columns
)
return Tabular(
data=new_df.sample(frac=1, random_state=random_state),
categorical_columns=tabular_data.categorical_columns,
target_column=tabular_data.target_column,
)
@staticmethod
def undersample(tabular_data: Tabular, random_state=None) -> Tabular:
"""
Undersamples a class-imbalance dataset to make it more balance, i.e.,
keeping all of the data in the minority class and decreasing the size of the majority class.
It guarantees that all the categorical values are included in the sampled dataframe, i.e.,
there will be no missing categorical values.
:param tabular_data: The input tabular data.
:param random_state: The random seed.
:return: A subset extracted from ``tabular_data``.
:rtype: Tabular
"""
assert tabular_data.target_column is not None, "`tabular_data` doesn't have a target column."
df = tabular_data.to_pd(copy=False)
splits = {
label: df[df[tabular_data.target_column] == label] for label in df[tabular_data.target_column].unique()
}
min_count = min([len(split) for split in splits.values()])
samples = pd.concat(
[split.sample(n=min(min_count, len(split)), random_state=random_state) for label, split in splits.items()]
)
# Add additional samples to make sure no categorical values are missing
new_df = Sampler._add_extra_samples(
original_df=df, sampled_df=samples, categorical_columns=tabular_data.categorical_columns
)
return Tabular(
data=new_df.sample(frac=1, random_state=random_state),
categorical_columns=tabular_data.categorical_columns,
target_column=tabular_data.target_column,
)
@staticmethod
def oversample(tabular_data: Tabular, random_state=None) -> Tabular:
"""
Oversamples a class-imbalance dataset to make it more balance, i.e.,
keeping all of the data in the majority class and increasing the size of the minority class.
It guarantees that all the categorical values are included in the sampled dataframe, i.e.,
there will be no missing categorical values.
:param tabular_data: The input tabular data.
:param random_state: The random seed.
:return: An oversampled dataset.
:rtype: Tabular
"""
assert tabular_data.target_column is not None, "`tabular_data` doesn't have a target column."
df = tabular_data.to_pd(copy=False)
splits = {
label: df[df[tabular_data.target_column] == label] for label in df[tabular_data.target_column].unique()
}
max_count = max([len(split) for split in splits.values()])
samples = pd.concat(
[split.sample(n=max_count, random_state=random_state, replace=True) for label, split in splits.items()]
)
return Tabular(
data=samples.sample(frac=1, random_state=random_state),
categorical_columns=tabular_data.categorical_columns,
target_column=tabular_data.target_column,
)