In [144]:
from sklearn.base import BaseEstimator, TransformerMixin,OneToOneFeatureMixin
from sklearn.datasets import load_diabetes
import pandas as pd
from typing import List
from sklearn.utils import check_X_y
from sklearn import utils
import numpy as np

# TODO: compatibilizar con numpy y no solo con pandas.
# x_, y_ = check_X_y(X, y)
# utils._get_column_indices(X,'bmi')
# col_idx = utils._get_column_indices(X,'bmi')
# utils._safe_indexing(x_, indices=col_idx, axis=1)

class QuantileFeatureEncoder(BaseEstimator, TransformerMixin, OneToOneFeatureMixin):
    def __init__(self, col:List[str], qs=[0.25,0.5,0.75]):
        self.col = col
        self.qs = qs

    def fit(self, X:pd.DataFrame, y:pd.Series):
        
        if not isinstance(X, pd.DataFrame):
            return Exception("X must be of type pd.DataFrame")
        if not isinstance(y, pd.Series):
            return Exception("y must be of type pd.Series")
        
        category_means_ = pd.concat([X[self.col], y], axis=1).groupby(self.col)[y.name].mean()

        qs_ = [category_means_.quantile(q) for q in self.qs]

        def encode_qs(x, qs):
            for i,q in enumerate(qs):
                if x < q:
                    return i
            return len(qs)+1
            
        self.category_encodings_ = category_means_.apply(lambda x: encode_qs(x, qs_)).to_dict()

        return self
    
    def transform(self, X):
        X = X.copy()
        X[self.col] = X[self.col].map(self.category_encodings_)
        return X
    

class ThresholdFeatureEncoder(BaseEstimator, TransformerMixin,OneToOneFeatureMixin):
    methods = ['mean','median']

    def __init__(self, col, method='mean'):
        self.col = col
        if method not in self.methods:
            raise ValueError(f"method must be one of {self.methods}")
        self.method = method

    def fit(self, X, y):
        
        if not isinstance(X, pd.DataFrame):
            return Exception("X must be of type pd.DataFrame")
        if not isinstance(y, pd.Series):
            return Exception("y must be of type pd.Series")
        
        if self.method == 'mean':
            category_measure_ = pd.concat([X[self.col], y], axis=1).groupby(self.col)[y.name].mean()
            target_measure_ = np.mean(y)
        elif self.method == 'median':
            category_measure_ = pd.concat([X[self.col], y], axis=1).groupby(self.col)[y.name].median()
            target_measure_ = np.median(y)
        
        self.category_encodings_ = category_measure_.apply(lambda x: 0 if x < target_measure_ else 1).to_dict()

        return self 
    
    def transform(self, X):
        X = X.copy()
        X[self.col] = X[self.col].map(self.category_encodings_)
        return X

In [136]:
X, y = load_diabetes(as_frame=True, return_X_y=True)

# creamos una categoria para probar el transformer
X['bmi'] = pd.cut(X['bmi'], bins=15, labels=[f"cat_{i}" for i in range(15)])


In [138]:
qfe = QuantileFeatureEncoder(col='bmi')
qfe

In [139]:
qfe.fit(X, y)

In [140]:
X_trans = qfe.transform(X)

In [141]:
X_trans.bmi.value_counts()

0    168
1    153
2     98
4     23
Name: bmi, dtype: int64

In [142]:
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression

m = make_pipeline(QuantileFeatureEncoder('bmi'),LinearRegression())
m.fit(X, y)

In [145]:
fe = ThresholdFeatureEncoder(col='bmi')
fe

In [146]:
fe.fit(X, y)

In [147]:
fe.transform(X)

Unnamed: 0,age,sex,bmi,bp,s1,s2,s3,s4,s5,s6
0,0.038076,0.050680,1,0.021872,-0.044223,-0.034821,-0.043401,-0.002592,0.019907,-0.017646
1,-0.001882,-0.044642,0,-0.026328,-0.008449,-0.019163,0.074412,-0.039493,-0.068332,-0.092204
2,0.085299,0.050680,1,-0.005670,-0.045599,-0.034194,-0.032356,-0.002592,0.002861,-0.025930
3,-0.089063,-0.044642,0,-0.036656,0.012191,0.024991,-0.036038,0.034309,0.022688,-0.009362
4,0.005383,-0.044642,0,0.021872,0.003935,0.015596,0.008142,-0.002592,-0.031988,-0.046641
...,...,...,...,...,...,...,...,...,...,...
437,0.041708,0.050680,1,0.059744,-0.005697,-0.002566,-0.028674,-0.002592,0.031193,0.007207
438,-0.005515,0.050680,0,-0.067642,0.049341,0.079165,-0.028674,0.034309,-0.018114,0.044485
439,0.041708,0.050680,0,0.017293,-0.037344,-0.013840,-0.024993,-0.011080,-0.046883,0.015491
440,-0.045472,-0.044642,1,0.001215,0.016318,0.015283,-0.028674,0.026560,0.044529,-0.025930
