In [1]:
import pandas as pd
import polars as pl

import feature_engine

import duckdb


x = pd.DataFrame({
    'id': [1,2,3,4,5,6],
    'a': [str(x) for x in [1,2,3,1,2,3]],
    'b': [str(x) for x in [1,2,1,2,1,2]],
    'c': ['o','p','q','r','s','t'],
})
#https://sekuel.com/sql-courses/duckdb-cookbook/correlation-matrix-duckdb/

rows = 1_000_000
kv = {
    chr(i+65): list([str(x%i) for x in range(rows)]) 
    #i: list([str(x%i) for x in range(rows)]),
    #i: list([chr(x%26+65) for x in range(rows)]),
    for i in range(2,23)
}
kv['id'] = list(range(rows))
x_pd = pd.DataFrame(kv)
x_pl = pl.DataFrame(kv)

"""
# https://stackoverflow.com/questions/15274305/is-it-possible-to-have-multiple-pivots-using-the-same-pivot-column-using-sql-ser
# https://duckdb.org/docs/sql/statements/pivot.html
# https://www.reddit.com/r/SQL/comments/ul4hth/one_hot_encoding_through_sql/

"""

def pdf():
    print(
        duckdb.sql("""
            pivot x
            on a
            using count(*)
            group by id
            order by id
        """)
    )
def bprint(*args, **kwargs):
    print("=================")
    print(*args, **kwargs)
    print("=================")

from functools import wraps
import time 
def time_it(func):
    @wraps(func)
    def with_timing(*args, **kwargs):
        s = time.time()
        res = func(*args, **kwargs)
        e = time.time()
        bprint(e-s)
        return res
    return with_timing

@time_it
def ohe(x):
    columns = ','.join([c for c in x.columns if c != 'id'])
    print(columns)
    db = duckdb.connect()
    
    sql = f"""
        --create table tmp_data_store as
        with long_form as (
            unpivot x
            on {columns}
            into
                NAME categorical_col_name
                VALUE value
        )
        , transformed_kv as (
            select * exclude(categorical_col_name, value), categorical_col_name || '__' ||cast(value as varchar) as kv 
            from long_form
        )
        pivot transformed_kv
        on kv 
        using bool_or(id>0)
        group by id
        order by id;

        --select * from tmp_data_store
    """
    
    if False: # Enable to run sql statement-by-statement to workaround https://github.com/duckdb/duckdb/issues/13863
        print(sql)
        stmts = db.extract_statements(sql)
        print(stmts)
        for stmt in stmts[:-1]:
            print(stmt)
            print(db.sql(stmt.query))
        sql = stmts[-1].query
    memo_sql = "SELECT value AS memlimit FROM duckdb_settings() WHERE name = 'memory_limit';"
    change_memo_sql = "SET memory_limit = '20GiB';"
    change_memo_back_sql = "SET memory_limit = '12.1GiB';"
    db.execute(change_memo_sql)
    
    print(db.execute(memo_sql).df())
    
    #res = db.sql(sql).show(max_width=250)#.df()
    res = db.sql(sql).df()
    #print(sorted(res.columns))
    with pd.option_context('display.max_rows', None, 'display.max_columns', None):
        print(res)
    
    return db.execute(change_memo_back_sql)

    #with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    #    print(res)

#ohe()

In [11]:
# res = ohe(x_pd)
# longer than 1.5 min
# crash the kernel

C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T,U,V,W
   memlimit
0  20.0 GiB


: 

In [2]:
x_pd

Unnamed: 0,C,D,E,F,G,H,I,J,K,L,...,O,P,Q,R,S,T,U,V,W,id
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
2,0,2,2,2,2,2,2,2,2,2,...,2,2,2,2,2,2,2,2,2,2
3,1,0,3,3,3,3,3,3,3,3,...,3,3,3,3,3,3,3,3,3,3
4,0,1,0,4,4,4,4,4,4,4,...,4,4,4,4,4,4,4,4,4,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
999995,1,2,3,0,5,3,3,5,5,7,...,3,5,11,4,5,6,15,17,7,999995
999996,0,0,0,1,0,4,4,6,6,8,...,4,6,12,5,6,7,16,18,8,999996
999997,1,1,1,2,1,5,5,7,7,9,...,5,7,13,6,7,8,17,19,9,999997
999998,0,2,2,3,2,6,6,8,8,10,...,6,8,14,7,8,9,18,20,10,999998


# try to store distinct kv in duckdb



In [10]:
columns = ", ".join([c for c in x_pd.columns if c != 'id'])
duckdb.sql(f"""
        create or replace table distinct_kv (col_name varchar, value varchar);
        
        insert into distinct_kv 
        with stg as (
            unpivot x_pd
            on {columns}
            into
                NAME col_name
                VALUE value
        )
        select col_name, value
        from stg;
""")

In [11]:
@time_it
def ohe_ddb(df):
    @time_it
    def construct_query():
        columns = [c for c in df.columns if c != 'id']
        insert_statements = "\n".join([
            f"""insert into distinct_kv select distinct '{c}' as col_name, "{c}" from x_pd;""" 
            for c in columns 
        ])
        print(insert_statements)
        constructed_sql = "select id \n" + duckdb.sql(f"""
                create or replace table distinct_kv (col_name varchar, value varchar);
                
                {insert_statements}

                with case_strings as (
                    select  'case when "' || col_name || '" = ' || value || ' then 1 else 0 end as ' || col_name::varchar || '__' || value::varchar as stm 
                    from distinct_kv
                )
                select ', ' || string_agg(stm, '\n, ') as stm
                from case_strings

        """).df().loc[0,'stm'] + "\n from df"
        return constructed_sql
    #duckdb.sql("pragma enable_profile; " + constructed_sql).explain('analyze')
    sql = construct_query()
    print(sql)
    @time_it
    def res_wrap():
        res = duckdb.sql("pragma disable_profile; " + sql).df()#.explain('analyze')
    res_wrap()

ohe_ddb(x_pd)


insert into distinct_kv select distinct 'C' as col_name, "C" from x_pd;
insert into distinct_kv select distinct 'D' as col_name, "D" from x_pd;
insert into distinct_kv select distinct 'E' as col_name, "E" from x_pd;
insert into distinct_kv select distinct 'F' as col_name, "F" from x_pd;
insert into distinct_kv select distinct 'G' as col_name, "G" from x_pd;
insert into distinct_kv select distinct 'H' as col_name, "H" from x_pd;
insert into distinct_kv select distinct 'I' as col_name, "I" from x_pd;
insert into distinct_kv select distinct 'J' as col_name, "J" from x_pd;
insert into distinct_kv select distinct 'K' as col_name, "K" from x_pd;
insert into distinct_kv select distinct 'L' as col_name, "L" from x_pd;
insert into distinct_kv select distinct 'M' as col_name, "M" from x_pd;
insert into distinct_kv select distinct 'N' as col_name, "N" from x_pd;
insert into distinct_kv select distinct 'O' as col_name, "O" from x_pd;
insert into distinct_kv select distinct 'P' as col_name, "P" fro

In [105]:


columns = ", ".join([c for c in x_pd.columns if c != 'id'])
print(duckdb.sql(f"""
    pragma disable_profile;
    create or replace table distinct_kv (col_name varchar, value varchar);
                 
    insert into distinct_kv 
        with stg as (
            unpivot x_pd
            on {columns}
            into
                NAME col_name
                VALUE value
        )
        select col_name, value
        from stg
    ;
                 
    with case_strings as (
        select  'case when "' || col_name || '" = ' || value || ' then 1 end as ' || col_name::varchar || '__' || value::varchar as stm 
        from distinct_kv
    )
    select ', ' || string_agg(stm, '\n, ') as stm
    from case_strings
""").df().loc[0,'stm'])

In [None]:
# compare performance of encoder 

# join pivots

In [46]:

@time_it
def ohe_db():
    duckdb.sql("""
        PRAGMA enable_profiling;
        --PRAGMA disable_profiling;
        create or replace temp table __internal_calc__ as select * from x_pd;
        CREATE UNIQUE INDEX id ON __internal_calc__ (id);
        with p_c as (
            pivot __internal_calc__
            on C 
            using count(*)
            group by  
            id
        )
        , p_d as (
            pivot __internal_calc__
            on D 
            using count(*)
            group by  
            id
        )
        , p_e as (
            pivot __internal_calc__
            on D 
            using count(*)
            group by  
            id
        )
        , p_f as (
            pivot __internal_calc__
            on f
            using count(*)
            group by  
            id
        )
        , p_g as (
            pivot __internal_calc__
            on g
            using count(*)
            group by  
            id
        )
        , p_h as (
            pivot __internal_calc__
            on h
            using count(*)
            group by  
            id
        )
        select * 
        from p_c
        inner join p_d on
            p_c.id = p_d.id 
        inner join p_e on
            p_c.id = p_e.id 
        inner join p_f on
            p_c.id = p_f.id 
        inner join p_g on
            p_c.id = p_g.id 
        inner join p_h on
            p_c.id = p_h.id 
    """)#.explain('analyze')
ohe_db()

┌─────────────────────────────────────┐
│┌───────────────────────────────────┐│
││    Query Profiling Information    ││
│└───────────────────────────────────┘│
└─────────────────────────────────────┘
         PRAGMA enable_profiling;         --PRAGMA disable_profiling;         create or replace temp table __internal_calc__ as select * from x_pd;         CREATE UNIQUE INDEX id ON __internal_calc__ (id);         with p_c as (             pivot __internal_calc__             on C              using count(*)             group by               id         )         , p_d as (             pivot __internal_calc__             on D              using count(*)             group by               id         )         , p_e as (             pivot __internal_calc__             on D              using count(*)             group by               id         )         , p_f as (             pivot __internal_calc__             on f             using count(*)             group by               id         ) 

1.2579054832458496


┌─────────────────────────────────────┐
│┌───────────────────────────────────┐│
││    Query Profiling Information    ││
│└───────────────────────────────────┘│
└─────────────────────────────────────┘

┌─────────────────────────────────────┐
│┌───────────────────────────────────┐│
││         Total Time: 0.821s        ││
│└───────────────────────────────────┘│
└────────────────────────��────────────┘
┌───────────────────────────┐                                                                                                                                                 
│      RESULT_COLLECTOR     │                                                                                                                                                 
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                                                                                                                                                 
│             0             │                                            

�────┬─────────────┘└─────────────┬─────────────┘└─────────────┬─────────────┘└─────────────┬─────────────┘└─────────────┬─────────────┘                             
                             ┌─────────────┴─────────────┐┌─────────────┴─────────────┐┌─────────────┴─────────────┐┌─────────────┴─────────────┐┌─────────────┴─────────────┐
                             │         SEQ_SCAN          ││         PROJECTION        ││         PROJECTION        ││       HASH_GROUP_BY       ││       HASH_GROUP_BY       │
                             │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
                             │     __internal_calc__     ││__internal_compress_integra││             id            ││             #0            ││             #0            │
                             │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││     l_uinteger(#0, 0)     ││(D IS NOT DISTINCT FROM '0'││count_star() FILT

In [4]:
@time_it
def ohe_d(df, cat_col):
    res = []
    for c in cat_col:
        res.append(duckdb.sql(f"""
            select distinct "{c}"
            from df;
            select distinct "{c}"
            from df;
        """))
    return res
ohe_d(x_pd,[c for c in x_pd.columns if c != 'id'])


0.39345836639404297


[┌─────────┐
 │    C    │
 │ varchar │
 ├─────────┤
 │ 1       │
 │ 0       │
 └─────────┘,
 ┌─────────┐
 │    D    │
 │ varchar │
 ├─────────┤
 │ 1       │
 │ 0       │
 │ 2       │
 └─────────┘,
 ┌─────────┐
 │    E    │
 │ varchar │
 ├─────────┤
 │ 0       │
 │ 2       │
 │ 3       │
 │ 1       │
 └─────────┘,
 ┌─────────┐
 │    F    │
 │ varchar │
 ├─────────┤
 │ 4       │
 │ 1       │
 │ 0       │
 │ 2       │
 │ 3       │
 └─────────┘,
 ┌─────────┐
 │    G    │
 │ varchar │
 ├─────────┤
 │ 4       │
 │ 3       │
 │ 5       │
 │ 1       │
 │ 2       │
 │ 0       │
 └─────────┘,
 ┌─────────┐
 │    H    │
 │ varchar │
 ├─────────┤
 │ 4       │
 │ 6       │
 │ 5       │
 │ 3       │
 │ 0       │
 │ 2       │
 │ 1       │
 └─────────┘,
 ┌─────────┐
 │    I    │
 │ varchar │
 ├─────────┤
 │ 4       │
 │ 6       │
 │ 1       │
 │ 3       │
 │ 5       │
 │ 0       │
 │ 2       │
 │ 7       │
 └─────────┘,
 ┌─────────┐
 │    J    │
 │ varchar │
 ├─────────┤
 │ 6       │
 │ 3       │
 │ 5 

# compare get dummies 

- polars is faster 5x 
- pandas is slower 

In [5]:
@time_it
def ohe_pd(x):
    res = pd.get_dummies(x.drop(columns = 'id'))
    #bprint(res.tail())
ohe_pd(x_pd)
#https://stackoverflow.com/questions/68429779/is-there-a-way-to-make-get-dummies-work-faster

import polars as pl
@time_it
def ohe_pl(x):
    res = x.to_dummies([c for c in x.columns if c != 'id'])
    #bprint(res.tail())
ohe_pl(x_pl)

1.0780911445617676
0.27526354789733887


In [9]:
x_pl

C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T,U,V,W,id
str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,str,i64
"""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""","""0""",0
"""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""","""1""",1
"""0""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""","""2""",2
"""1""","""0""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""","""3""",3
"""0""","""1""","""0""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""","""4""",4
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""1""","""2""","""3""","""0""","""5""","""3""","""3""","""5""","""5""","""7""","""11""","""9""","""3""","""5""","""11""","""4""","""5""","""6""","""15""","""17""","""7""",999995
"""0""","""0""","""0""","""1""","""0""","""4""","""4""","""6""","""6""","""8""","""0""","""10""","""4""","""6""","""12""","""5""","""6""","""7""","""16""","""18""","""8""",999996
"""1""","""1""","""1""","""2""","""1""","""5""","""5""","""7""","""7""","""9""","""1""","""11""","""5""","""7""","""13""","""6""","""7""","""8""","""17""","""19""","""9""",999997
"""0""","""2""","""2""","""3""","""2""","""6""","""6""","""8""","""8""","""10""","""2""","""12""","""6""","""8""","""14""","""7""","""8""","""9""","""18""","""20""","""10""",999998


In [13]:
x_pl.select(pl.corr("G", ['H','I'], method="spearman"))

InvalidOperationError: cannot cast List type (inner: 'String', to: 'String')

In [9]:
columns = ','.join([c for c in x.columns if c != 'id'])
duckdb.sql("""PRAGMA enable_profiling;
""")
duckdb.sql(f"""

        --create table tmp_data_store as
        with long_form as (
            unpivot x
            on {columns}
            into
                NAME categorical_col_name
                VALUE value
        )
        , transformed_kv as (
            select * exclude(categorical_col_name, value), categorical_col_name || '__' ||cast(value as varchar) as kv 
            from long_form
        )
        pivot transformed_kv
        on kv 
        using bool_or(id>0)
        group by id
        order by id;
""").explain("analyze")

┌─────────────────────────────────────┐
│┌───────────────────────────────────┐│
││    Query Profiling Information    ││
│└───────────────────────────────────┘│
└─────────────────────────────────────┘

┌─────────────────────────────────────┐
│┌───────────────────────────────────┐│
││         Total Time: 1.62s         ││
│└───────────────────────────────────┘│
└────────────────────────��────────────┘
┌───────────────────────────┐
│      RESULT_COLLECTOR     │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│             0             │
│          (0.00s)          │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│        CREATE_TYPE        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│             0             │
│          (0.00s)          │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│          ORDER_BY         │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          ORDERS:          │
│  CAST(kv AS VARCHAR) ASC  │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │


'┌─────────────────────────────────────┐\n│┌───────────────────────────────────┐│\n││    Query Profiling Information    ││\n│└───────────────────────────────────┘│\n└─────────────────────────────────────┘\n\n\n'