# Exercise: implementing a `TableVectorizer` from its components
Replicate the behavior of a `TableVectorizer` using `ApplyToCols`, the skrub 
selectors, and the given transformers. 

In [1]:
from skrub import Cleaner, ApplyToCols, StringEncoder, DatetimeEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import make_pipeline
import skrub.selectors as s

Notes on the implementation: 

- In the first step, the TableVectorizer cleans the data to parse datetimes and other
dtypes.
- Numeric features are left untouched, i.e., they use a Passthrough transformer. 
- String and categorical feature are split into high and low cardinality features. 
- For this exercise, set the the cardinality `threshold` to 4. 
- High cardinality features are transformed with a `StringEncoder`. In this exercise,
set `n_components` to 2. 
- Low cardinality features are transformed with a `OneHotEncoder`, and the first 
category in binary features is dropped (hint: check the docs of the `OneHotEncoder`
for the `drop` parameter). Set `sparse_output=True`.
- Remember  `cardinality_below` is one of the skrub selectors. 
- Datetimes are transformed by a default `DatetimeEncoder`. 
- Everything should be wrapped in a scikit-learn `Pipeline`. 


Use the following dataframe to test the result. 

In [2]:
import pandas as pd
import datetime

data = {
    "int": [15, 56, 63, 12, 44],
    "float": [5.2, 2.4, 6.2, 10.45, 9.0],
    "str1": ["public", "private", "private", "private", "public"],
    "str2": ["officer", "manager", "lawyer", "chef", "teacher"],
    "bool": [True, False, True, False, True],
    "datetime-col": [
            "2020-02-03T12:30:05",
            "2021-03-15T00:37:15",
            "2022-02-13T17:03:25",
            "2023-05-22T08:45:55",
    ]
    + [None],
}
df = pd.DataFrame(data)
df

Unnamed: 0,int,float,str1,str2,bool,datetime-col
0,15,5.2,public,officer,True,2020-02-03T12:30:05
1,56,2.4,private,manager,False,2021-03-15T00:37:15
2,63,6.2,private,lawyer,True,2022-02-13T17:03:25
3,12,10.45,private,chef,False,2023-05-22T08:45:55
4,44,9.0,public,teacher,True,


Use the following `PassThrough` transformer where needed. 

In [3]:
from skrub._apply_to_cols import SingleColumnTransformer
class PassThrough(SingleColumnTransformer):
    def fit_transform(self, column, y=None):
        return column

    def transform(self, column):
        return column

You can test the correctness of your solution by comparing it with the equivalent
`TableVectorizer`:

In [4]:
from skrub import TableVectorizer

tv = TableVectorizer(
    high_cardinality=StringEncoder(n_components=2), cardinality_threshold=4
)
tv.fit_transform(df)

Unnamed: 0,int,float,str1_public,str2_0,str2_1,bool,datetime-col_year,datetime-col_month,datetime-col_day,datetime-col_hour,datetime-col_total_seconds
0,15.0,5.2,1.0,0.820968,-0.926887,1.0,2020.0,2.0,3.0,12.0,1580733000.0
1,56.0,2.4,0.0,0.820966,-0.926906,0.0,2021.0,3.0,15.0,0.0,1615769000.0
2,63.0,6.2,0.0,0.862895,-0.936516,1.0,2022.0,2.0,13.0,17.0,1644772000.0
3,12.0,10.45,0.0,1.029683,1.353004,0.0,2023.0,5.0,22.0,8.0,1684745000.0
4,44.0,9.0,1.0,1.419118,0.660165,1.0,,,,,


In [5]:
# Write your code here
#
# 
# 
# 
# 
# 
# 
# 
# 
# 
# 
# 

In [6]:
# Solution
cleaner = ApplyToCols(Cleaner())
high_cardinality = ApplyToCols(
    StringEncoder(n_components=2), cols=~s.cardinality_below(4) & (s.string())
)
low_cardinality = ApplyToCols(
    OneHotEncoder(sparse_output=False, drop="if_binary"),
    cols=s.cardinality_below(4) & s.string(),
)
numeric = ApplyToCols(PassThrough(), cols=s.numeric())
datetime = ApplyToCols(DatetimeEncoder(), cols=s.any_date())

my_table_vectorizer = make_pipeline(
    cleaner, numeric, high_cardinality, low_cardinality, datetime
)

my_table_vectorizer.fit_transform(df)

Unnamed: 0,int,float,str1_public,str2_0,str2_1,bool,datetime-col_year,datetime-col_month,datetime-col_day,datetime-col_hour,datetime-col_total_seconds
0,15,5.2,1.0,0.820964,-0.926893,True,2020.0,2.0,3.0,12.0,1580733000.0
1,56,2.4,0.0,0.820965,-0.926902,False,2021.0,3.0,15.0,0.0,1615769000.0
2,63,6.2,0.0,0.862892,-0.936521,True,2022.0,2.0,13.0,17.0,1644772000.0
3,12,10.45,0.0,1.029686,1.353002,False,2023.0,5.0,22.0,8.0,1684745000.0
4,44,9.0,1.0,1.419119,0.660161,True,,,,,
