In [59]:
import polars as pl
from fastauc.fastauc.fast_auc import CppAuc
import numpy as np

In [60]:
cpp_auc = CppAuc()
def get_score(df):
    df = df.select(['impression_id','article','target','prediction']).group_by('impression_id').agg(pl.col('target'), pl.col('prediction'))
    
    result = np.mean(
            [cpp_auc.roc_auc_score(np.array(y_t).astype(bool), np.array(y_s).astype(np.float32)) 
                for y_t, y_s in zip(df['target'].to_list(), 
                                    df['prediction'].to_list())]
        )
    print(result)

In [61]:
predictions = pl.read_parquet('/home/ubuntu/experiments/analyze_prediction/scores_validation_stacking.parquet')
print(predictions)
get_score(predictions)


shape: (2_928_942, 390)
┌────────────┬─────────┬────────┬────────────┬───┬────────────┬────────────┬───────────┬───────────┐
│ impression ┆ article ┆ target ┆ prediction ┆ … ┆ roberta_em ┆ w_2_vec_em ┆ emotions_ ┆ constrast │
│ _id        ┆ ---     ┆ ---    ┆ ---        ┆   ┆ b_icm_minu ┆ b_icm_minu ┆ emb_icm_m ┆ ive_emb_i │
│ ---        ┆ i32     ┆ i8     ┆ f64        ┆   ┆ s_median_a ┆ s_median_a ┆ inus_medi ┆ cm_minus_ │
│ u32        ┆         ┆        ┆            ┆   ┆ …          ┆ …          ┆ an_…      ┆ med…      │
│            ┆         ┆        ┆            ┆   ┆ ---        ┆ ---        ┆ ---       ┆ ---       │
│            ┆         ┆        ┆            ┆   ┆ f32        ┆ f32        ┆ f32       ┆ f32       │
╞════════════╪═════════╪════════╪════════════╪═══╪════════════╪════════════╪═══════════╪═══════════╡
│ 96791      ┆ 9783865 ┆ 0      ┆ 0.075048   ┆ … ┆ -1.846875  ┆ -0.003178  ┆ -0.00133  ┆ -0.001478 │
│ 96791      ┆ 9784591 ┆ 0      ┆ 0.19992    ┆ … ┆ -0.00018   ┆ -0.

In [62]:
for col in predictions.columns:
    print(col)

impression_id
article
target
prediction
prediction_catboost_ranker
prediction_catboost_classifier
prediction_dcn
prediction_GANDALF
prediction_mlp
prediction_wd
normalized_prediction_catboost_ranker
normalized_prediction_catboost_classifier
normalized_prediction_dcn
normalized_prediction_GANDALF
normalized_prediction_mlp
normalized_prediction_wd
art_norm_prediction_catboost_ranker
art_norm_prediction_catboost_classifier
art_norm_prediction_dcn
art_norm_prediction_GANDALF
art_norm_prediction_mlp
art_norm_prediction_wd
prediction_hybrid
mean_prediction_catboost_ranker
mean_prediction_catboost_classifier
mean_prediction_dcn
mean_prediction_GANDALF
mean_prediction_mlp
mean_prediction_wd
mean_prediction_hybrid
skew_prediction_catboost_ranker
skew_prediction_catboost_classifier
skew_prediction_dcn
skew_prediction_GANDALF
skew_prediction_mlp
skew_prediction_wd
skew_prediction_hybrid
std_prediction_catboost_ranker
std_prediction_catboost_classifier
std_prediction_dcn
std_prediction_GANDALF
std

In [75]:
query = 'delay'

for col in predictions.columns:
    if query in col:
        print(col)

article_delay_days
article_delay_hours
mean_topics_mean_delay_days
mean_topics_mean_delay_hours
user_mean_delay_days
user_mean_delay_hours
article_delay_hours_l_inf_impression
article_delay_hours_minus_median_impression
std_impression_article_delay_hours
skew_impression_article_delay_hours
entropy_impression_article_delay_hours
article_delay_hours_rank_impression
mean_topics_mean_delay_hours_rank_impression
article_delay_hours_l_inf_article


In [68]:
get_score(predictions.with_columns(
        pl.when(pl.col('article_delay_days') > 3)\
            .then(pl.col('prediction') * 0.5)\
            .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('article_delay_hours') > pl.col('user_mean_delay_hours')*3)\
            .then(pl.col('prediction') * 1.2)\
            .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('normalized_endorsement_10h') == pl.col('normalized_endorsement_10h').max().over('impression_id'))
            .then(pl.col('prediction') * 0.95)\
            .otherwise(pl.col('prediction'))
    )
)

0.8250008049549922


In [65]:
behaviors_train = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/validation/behaviors.parquet')
history_train = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/validation/history.parquet')
articles = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/articles.parquet')

In [67]:
hf = history_train.select(['user_id', 'article_id_fixed']).explode('article_id_fixed').rename({'article_id_fixed': 'article_id'})\
    .join(articles.select(['article_id','total_inviews', 'total_pageviews']).fill_null(0), on='article_id')\
    .filter(pl.col('total_inviews') > 0)\
    .with_columns(
        (pl.col('total_pageviews') / pl.col('total_inviews')).alias('total_pageviews/inviews')
    )\
    .group_by('user_id').agg(
        pl.col('total_pageviews/inviews').mean().alias('history_total_pageviews/inviews_avg'),
            pl.col('total_pageviews/inviews').std().alias('history_total_pageviews/inviews_std'),
            pl.col('total_pageviews/inviews').median().alias('history_total_pageviews/inviews_median'),
    )
predictions = predictions.join(hf, on='user_id', how = 'left')

In [69]:
predictions

impression_id,article,target,prediction,prediction_catboost_ranker,prediction_catboost_classifier,prediction_dcn,prediction_GANDALF,prediction_mlp,prediction_wd,normalized_prediction_catboost_ranker,normalized_prediction_catboost_classifier,normalized_prediction_dcn,normalized_prediction_GANDALF,normalized_prediction_mlp,normalized_prediction_wd,art_norm_prediction_catboost_ranker,art_norm_prediction_catboost_classifier,art_norm_prediction_dcn,art_norm_prediction_GANDALF,art_norm_prediction_mlp,art_norm_prediction_wd,prediction_hybrid,mean_prediction_catboost_ranker,mean_prediction_catboost_classifier,mean_prediction_dcn,mean_prediction_GANDALF,mean_prediction_mlp,mean_prediction_wd,mean_prediction_hybrid,skew_prediction_catboost_ranker,skew_prediction_catboost_classifier,skew_prediction_dcn,skew_prediction_GANDALF,skew_prediction_mlp,skew_prediction_wd,skew_prediction_hybrid,…,std_article_distilbert_emb_icm,std_article_bert_emb_icm,std_article_roberta_emb_icm,std_article_w_2_vec_emb_icm,std_article_emotions_emb_icm,std_article_constrastive_emb_icm,skew_article_kenneth_emb_icm,skew_article_distilbert_emb_icm,skew_article_bert_emb_icm,skew_article_roberta_emb_icm,skew_article_w_2_vec_emb_icm,skew_article_emotions_emb_icm,skew_article_constrastive_emb_icm,kurtosis_article_kenneth_emb_icm,kurtosis_article_distilbert_emb_icm,kurtosis_article_bert_emb_icm,kurtosis_article_roberta_emb_icm,kurtosis_article_w_2_vec_emb_icm,kurtosis_article_emotions_emb_icm,kurtosis_article_constrastive_emb_icm,entropy_article_kenneth_emb_icm,entropy_article_distilbert_emb_icm,entropy_article_bert_emb_icm,entropy_article_roberta_emb_icm,entropy_article_w_2_vec_emb_icm,entropy_article_emotions_emb_icm,entropy_article_constrastive_emb_icm,kenneth_emb_icm_minus_median_article,distilbert_emb_icm_minus_median_article,bert_emb_icm_minus_median_article,roberta_emb_icm_minus_median_article,w_2_vec_emb_icm_minus_median_article,emotions_emb_icm_minus_median_article,constrastive_emb_icm_minus_median_article,history_total_pageviews/inviews_avg,history_total_pageviews/inviews_std,history_total_pageviews/inviews_median
u32,i32,i8,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,…,f32,f32,f32,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64
96791,9783865,0,0.075048,-0.666797,0.082062,0.043175,0.14195,0.060346,0.057008,0.0,0.0,0.0,0.0,0.0,0.0,0.324926,0.088514,0.045872,0.153701,0.067696,0.062322,0.0,0.220082,0.253237,0.350458,0.343151,0.367336,0.337175,0.543405,-0.262577,0.296461,-0.821649,0.419907,-0.58474,-1.18355,-0.353738,…,6.779265,0.030801,3.140027,0.006229,0.017181,0.002783,1.272064,1.046977,1.088959,1.297803,1.29522,1.078533,1.461609,1.638777,0.643103,0.767744,1.485673,1.405132,0.826377,2.443438,,,,,,,,-0.000422,-2.730756,-0.024237,-1.846875,-0.003178,-0.00133,-0.001478,0.205572,0.053941,0.204381
96791,9784591,0,0.19992,-0.201826,0.24496,0.496269,0.278617,0.492531,0.376784,0.27914,0.44268,0.957946,0.303862,0.834344,0.744446,0.400545,0.24594,0.499679,0.278779,0.498223,0.378448,0.435855,0.220082,0.253237,0.350458,0.343151,0.367336,0.337175,0.543405,-0.262577,0.296461,-0.821649,0.419907,-0.58474,-1.18355,-0.353738,…,7.509755,0.047205,6.185093,0.010945,0.016172,0.004107,1.588415,1.063038,1.128102,1.475852,2.023667,1.084642,2.014237,2.528819,0.664436,0.903211,2.306099,5.010594,0.659787,4.696737,,,,,,,,-0.002441,0.977748,-0.021208,-0.00018,-0.001672,-0.003989,-0.001263,0.205572,0.053941,0.204381
96791,9784679,0,0.306644,0.490838,0.283759,0.408114,0.406924,0.437615,0.381464,0.694972,0.548115,0.771566,0.589136,0.728329,0.75534,0.451658,0.285161,0.414432,0.407297,0.448214,0.385264,0.696486,0.220082,0.253237,0.350458,0.343151,0.367336,0.337175,0.543405,-0.262577,0.296461,-0.821649,0.419907,-0.58474,-1.18355,-0.353738,…,8.403144,0.040637,5.682543,0.010008,0.012445,0.005754,1.578662,1.08893,1.100653,1.98333,2.297778,1.144221,2.065043,2.484456,0.967195,0.767086,5.043804,6.952426,1.172618,5.555449,,,,,,,,0.000079,-2.890614,-0.015149,-1.847095,0.001612,-0.005319,-0.002653,0.205572,0.053941,0.204381
96791,9784696,1,0.543125,0.998933,0.450044,0.51616,0.591718,0.57834,0.486557,1.0,1.0,1.0,1.0,1.0,1.0,0.428148,0.477463,0.545293,0.61163,0.646852,0.526188,1.0,0.220082,0.253237,0.350458,0.343151,0.367336,0.337175,0.543405,-0.262577,0.296461,-0.821649,0.419907,-0.58474,-1.18355,-0.353738,…,8.681229,0.049315,5.904523,0.009892,0.004723,0.004614,1.622447,1.100682,1.608831,1.952653,2.036736,1.65544,1.904608,2.769404,0.926011,2.755922,4.322889,4.383129,4.038434,3.958584,,,,,,,,-0.000657,-0.003477,4.0233e-7,-0.000004,-0.001678,-0.00266,0.001201,0.205572,0.053941,0.204381
96791,9784710,0,0.246063,0.479261,0.205358,0.288571,0.296547,0.267848,0.384064,0.688021,0.335059,0.518825,0.343727,0.400589,0.761393,0.363708,0.218965,0.312793,0.304385,0.304941,0.409638,0.584685,0.220082,0.253237,0.350458,0.343151,0.367336,0.337175,0.543405,-0.262577,0.296461,-0.821649,0.419907,-0.58474,-1.18355,-0.353738,…,6.637402,0.068946,5.388244,0.011082,0.009704,0.003886,1.787499,1.042659,1.018408,1.570231,1.99189,1.262127,1.944808,3.66696,0.719079,0.474365,2.751674,4.55359,1.531854,5.028709,,,,,,,,-0.001664,-3.841198,-0.006057,-0.923683,0.001605,-0.00133,-0.001465,0.205572,0.053941,0.204381
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
579552453,9782656,1,0.55835,1.011055,0.320655,0.550256,0.513169,0.397782,0.506223,1.0,0.612771,0.82151,0.71654,0.610015,0.700638,0.508244,0.342105,0.58319,0.52827,0.46131,0.563561,0.873054,0.096367,0.266925,0.283593,0.318598,0.289854,0.290079,0.528563,-0.371188,-0.075086,0.405572,0.125059,0.102622,0.431123,-0.29871,…,6.368006,0.048577,2.739619,0.007318,0.006686,0.002871,1.345641,1.084181,1.104169,1.143345,1.250609,1.10688,1.498162,1.787051,0.895667,0.68701,0.890969,1.386189,0.766253,2.352604,,,,,,,,0.002114,1.959111,0.00606,-0.923343,-0.001631,0.002659,0.001786,0.192703,0.064912,0.187862
579552453,9784559,0,0.026516,-0.913502,0.106565,0.020037,0.011794,0.028081,0.014113,0.113908,0.109009,0.0,0.0,0.004567,0.0,0.322986,0.105262,0.019287,0.011418,0.027072,0.013217,0.076153,0.096367,0.266925,0.283593,0.318598,0.289854,0.290079,0.528563,-0.371188,-0.075086,0.405572,0.125059,0.102622,0.431123,-0.29871,…,7.467955,0.06488,6.27112,0.013966,0.011188,0.006094,1.096043,1.12984,1.101092,1.19727,1.378717,1.063071,1.592374,0.850862,0.910113,0.757573,1.167805,1.68162,0.751886,2.515387,,,,,,,,-0.000071,4.567139,0.030286,-1.846838,-0.001583,0.006648,0.00285,0.192703,0.064912,0.187862
579552453,9784575,0,0.358186,0.615032,0.361946,0.134701,0.33531,0.3622,0.191993,0.817666,0.709933,0.177657,0.462353,0.551743,0.253256,0.482383,0.363077,0.136377,0.335953,0.366307,0.195157,0.716575,0.096367,0.266925,0.283593,0.318598,0.289854,0.290079,0.528563,-0.371188,-0.075086,0.405572,0.125059,0.102622,0.431123,-0.29871,…,9.200513,0.045309,3.450484,0.00751,0.018695,0.004103,1.220666,1.07949,1.079275,1.218456,1.287567,1.139891,1.296978,1.19197,0.717156,0.684186,1.024205,1.382959,0.80121,1.395573,,,,,,,,0.00077,2.743022,0.009089,2.770691,0.003361,-0.002659,-0.001299,0.192703,0.064912,0.187862
579552453,9784642,0,0.563919,0.930159,0.485219,0.665457,0.711512,0.635916,0.716487,0.962754,1.0,1.0,1.0,1.0,1.0,0.44422,0.495563,0.679435,0.716259,0.649014,0.732073,0.975745,0.096367,0.266925,0.283593,0.318598,0.289854,0.290079,0.528563,-0.371188,-0.075086,0.405572,0.125059,0.102622,0.431123,-0.29871,…,6.825923,0.054073,4.459253,0.007828,0.010464,0.003565,1.222041,1.159264,1.096896,1.14425,1.250952,1.027117,1.38026,1.364627,1.146529,0.751426,0.85459,1.293319,0.665794,1.712454,,,,,,,,0.002409,0.960852,0.012115,0.923409,0.004757,0.005319,0.003241,0.192703,0.064912,0.187862


In [74]:
get_score(predictions.with_columns(
        pl.when(pl.col('article_delay_days') > 3)\
            .then(pl.col('prediction') * 0.5)\
            .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('article_delay_hours') > pl.col('user_mean_delay_hours')*3)\
            .then(pl.col('prediction') * 1.2)\
            .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('normalized_endorsement_10h') == pl.col('normalized_endorsement_10h').max().over('impression_id'))
            .then(pl.col('prediction') * 0.95)\
            .otherwise(pl.col('prediction'))
    ))

0.8249911278934282
