In [1]:
import copy

from hypex.dataset import Dataset, InfoRole, TreatmentRole, TargetRole
from hypex.transformers.filters import NanFilter, CorrFilter, ConstFilter, CVFilter, OutliersFilter
from hypex.transformers.category_agg import CategoryAggregator


# Test dataset creation 

In [2]:
data = Dataset(
    roles={
        "user_id": InfoRole(),
        "treat": TreatmentRole(),
        "post_spends": TargetRole(),
    }, data="data.csv",
)
data

      user_id  signup_month  treat  pre_spends  post_spends   age gender  \
0           0             0      0       488.0   414.444444   NaN      M   
1           1             8      1       512.5   462.222222  26.0    NaN   
2           2             7      1       483.0   479.444444  25.0      M   
3           3             0      0       501.5   424.333333  39.0      M   
4           4             1      1       543.0   514.555556  18.0      F   
...       ...           ...    ...         ...          ...   ...    ...   
9995     9995            10      1       538.5   450.444444  42.0      M   
9996     9996             0      0       500.5   430.888889  26.0      F   
9997     9997             3      1       473.0   534.111111  22.0      F   
9998     9998             2      1       495.0   523.222222  67.0      F   
9999     9999             7      1       508.0   475.888889  38.0      F   

        industry  
0     E-commerce  
1     E-commerce  
2      Logistics  
3     E-com

## Category aggregator

In [3]:
ca_data = copy.deepcopy(data)
ca_data = CategoryAggregator.calc(ca_data, target_cols=["signup_month"], threshold=450, new_group_name="99")
ca_data.data


Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,99,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics
9997,9997,99,1,473.0,534.111111,22.0,F,E-commerce
9998,9998,2,1,495.0,523.222222,67.0,F,E-commerce


## Nan filter

In [4]:
nan_data = copy.deepcopy(data)
nan_data = NanFilter.calc(nan_data, target_cols=["age", "gender"], threshold=0.05)
nan_data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Info(None),
 'gender': Info(None),
 'industry': Feature(<class 'str'>)}

## Outliers filter

In [5]:
outliers_data = copy.deepcopy(data)
outliers_data = OutliersFilter.calc(outliers_data, target_cols=["post_spends"], lower_percentile=0.05, upper_percentile=0.95)
outliers_data.data

Unnamed: 0,user_id,signup_month,treat,pre_spends,post_spends,age,gender,industry
0,0,0,0,488.0,414.444444,,M,E-commerce
1,1,8,1,512.5,462.222222,26.0,,E-commerce
2,2,7,1,483.0,479.444444,25.0,M,Logistics
3,3,0,0,501.5,424.333333,39.0,M,E-commerce
4,4,1,1,543.0,514.555556,18.0,F,E-commerce
...,...,...,...,...,...,...,...,...
9993,9993,5,1,462.0,509.888889,65.0,F,E-commerce
9994,9994,0,0,486.0,423.777778,69.0,F,Logistics
9995,9995,10,1,538.5,450.444444,42.0,M,Logistics
9996,9996,0,0,500.5,430.888889,26.0,F,Logistics


## Correlation filter

In [6]:
corr_data = copy.deepcopy(data)
corr_data = CorrFilter.calc(corr_data, target_cols=["age", "gender"], corr_space_cols=["pre_spends", "post_spends", "age", "gender"])
corr_data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>)}

## Constant filter

In [7]:
const_data = copy.deepcopy(data)
const_data = ConstFilter.calc(const_data, target_cols=["gender"], threshold=0.4)
const_data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Target(<class 'float'>),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Info(None),
 'industry': Feature(<class 'str'>)}

## CV filter

In [8]:
cv_data = copy.deepcopy(data)
cv_data = CVFilter.calc(cv_data, target_cols=["post_spends"], upper_bound=0.05)
cv_data.roles

{'user_id': Info(<class 'int'>),
 'treat': Treatment(<class 'int'>),
 'post_spends': Info(None),
 'signup_month': Feature(<class 'int'>),
 'pre_spends': Feature(<class 'float'>),
 'age': Feature(<class 'float'>),
 'gender': Feature(<class 'str'>),
 'industry': Feature(<class 'str'>)}