In [1]:
import polars as pl
from fastauc.fastauc.fast_auc import CppAuc
import numpy as np

In [2]:
cpp_auc = CppAuc()
def get_score(df):
    df = df.select(['impression_id','article','target','prediction']).group_by('impression_id').agg(pl.col('target'), pl.col('prediction'))
    
    result = np.mean(
            [cpp_auc.roc_auc_score(np.array(y_t).astype(bool), np.array(y_s).astype(np.float32)) 
                for y_t, y_s in zip(df['target'].to_list(), 
                                    df['prediction'].to_list())]
        )
    print(result)

In [3]:
predictions = pl.read_parquet('/home/ubuntu/experiments/analyze_prediction/scores_validation_stacking.parquet')
print(predictions)
get_score(predictions)


shape: (2_928_942, 390)
┌────────────┬─────────┬────────┬────────────┬───┬────────────┬────────────┬───────────┬───────────┐
│ impression ┆ article ┆ target ┆ prediction ┆ … ┆ roberta_em ┆ w_2_vec_em ┆ emotions_ ┆ constrast │
│ _id        ┆ ---     ┆ ---    ┆ ---        ┆   ┆ b_icm_minu ┆ b_icm_minu ┆ emb_icm_m ┆ ive_emb_i │
│ ---        ┆ i32     ┆ i8     ┆ f64        ┆   ┆ s_median_a ┆ s_median_a ┆ inus_medi ┆ cm_minus_ │
│ u32        ┆         ┆        ┆            ┆   ┆ …          ┆ …          ┆ an_…      ┆ med…      │
│            ┆         ┆        ┆            ┆   ┆ ---        ┆ ---        ┆ ---       ┆ ---       │
│            ┆         ┆        ┆            ┆   ┆ f32        ┆ f32        ┆ f32       ┆ f32       │
╞════════════╪═════════╪════════╪════════════╪═══╪════════════╪════════════╪═══════════╪═══════════╡
│ 96791      ┆ 9783865 ┆ 0      ┆ 0.075048   ┆ … ┆ -1.846875  ┆ -0.003178  ┆ -0.00133  ┆ -0.001478 │
│ 96791      ┆ 9784591 ┆ 0      ┆ 0.19992    ┆ … ┆ -0.00018   ┆ -0.

In [4]:
for col in predictions.columns:
    print(col)

impression_id
article
target
prediction
prediction_catboost_ranker
prediction_catboost_classifier
prediction_dcn
prediction_GANDALF
prediction_mlp
prediction_wd
normalized_prediction_catboost_ranker
normalized_prediction_catboost_classifier
normalized_prediction_dcn
normalized_prediction_GANDALF
normalized_prediction_mlp
normalized_prediction_wd
art_norm_prediction_catboost_ranker
art_norm_prediction_catboost_classifier
art_norm_prediction_dcn
art_norm_prediction_GANDALF
art_norm_prediction_mlp
art_norm_prediction_wd
prediction_hybrid
mean_prediction_catboost_ranker
mean_prediction_catboost_classifier
mean_prediction_dcn
mean_prediction_GANDALF
mean_prediction_mlp
mean_prediction_wd
mean_prediction_hybrid
skew_prediction_catboost_ranker
skew_prediction_catboost_classifier
skew_prediction_dcn
skew_prediction_GANDALF
skew_prediction_mlp
skew_prediction_wd
skew_prediction_hybrid
std_prediction_catboost_ranker
std_prediction_catboost_classifier
std_prediction_dcn
std_prediction_GANDALF
std

In [5]:
query = 'delay'

for col in predictions.columns:
    if query in col:
        print(col)

article_delay_days
article_delay_hours
mean_topics_mean_delay_days
mean_topics_mean_delay_hours
user_mean_delay_days
user_mean_delay_hours
article_delay_hours_l_inf_impression
article_delay_hours_minus_median_impression
std_impression_article_delay_hours
skew_impression_article_delay_hours
entropy_impression_article_delay_hours
article_delay_hours_rank_impression
mean_topics_mean_delay_hours_rank_impression
article_delay_hours_l_inf_article


In [72]:
get_score(predictions.with_columns(
        pl.when(pl.col('article_delay_days') > 3)\
            .then(pl.col('prediction') * 0.5)\
            .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('article_delay_hours') > pl.col('user_mean_delay_hours')*3)\
            .then(pl.col('prediction') * 1.2)\
            .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('normalized_endorsement_10h') == pl.col('normalized_endorsement_10h').max().over('impression_id'))
            .then(pl.col('prediction') * 0.95)\
            .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('category') == 414)
                .then(pl.col('prediction') * 0.80)\
                .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('category') == 512)
                .then(pl.col('prediction') * 0.90)\
                .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('category') == 561)
                .then(pl.col('prediction') * 1.3)\
                .otherwise(pl.col('prediction'))
    ).with_columns(
        pl.when(pl.col('category') == 2077)
                .then(pl.col('prediction') * 1.1)\
                .otherwise(pl.col('prediction'))
    )
)


0.8255250041907629


In [7]:
behaviors_train = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/validation/behaviors.parquet')
history_train = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/validation/history.parquet')
articles = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/articles.parquet')

In [8]:
hf = history_train.select(['user_id', 'article_id_fixed']).explode('article_id_fixed').rename({'article_id_fixed': 'article_id'})\
    .join(articles.select(['article_id','total_inviews', 'total_pageviews']).fill_null(0), on='article_id')\
    .filter(pl.col('total_inviews') > 0)\
    .with_columns(
        (pl.col('total_pageviews') / pl.col('total_inviews')).alias('total_pageviews/inviews')
    )\
    .group_by('user_id').agg(
        pl.col('total_pageviews/inviews').mean().alias('history_total_pageviews/inviews_avg'),
            pl.col('total_pageviews/inviews').std().alias('history_total_pageviews/inviews_std'),
            pl.col('total_pageviews/inviews').median().alias('history_total_pageviews/inviews_median'),
    )
predictions = predictions.join(hf, on='user_id', how = 'left')

In [14]:
predictions.unique('impression_id')

impression_id,article,target,prediction,prediction_catboost_ranker,prediction_catboost_classifier,prediction_dcn,prediction_GANDALF,prediction_mlp,prediction_wd,normalized_prediction_catboost_ranker,normalized_prediction_catboost_classifier,normalized_prediction_dcn,normalized_prediction_GANDALF,normalized_prediction_mlp,normalized_prediction_wd,art_norm_prediction_catboost_ranker,art_norm_prediction_catboost_classifier,art_norm_prediction_dcn,art_norm_prediction_GANDALF,art_norm_prediction_mlp,art_norm_prediction_wd,prediction_hybrid,mean_prediction_catboost_ranker,mean_prediction_catboost_classifier,mean_prediction_dcn,mean_prediction_GANDALF,mean_prediction_mlp,mean_prediction_wd,mean_prediction_hybrid,skew_prediction_catboost_ranker,skew_prediction_catboost_classifier,skew_prediction_dcn,skew_prediction_GANDALF,skew_prediction_mlp,skew_prediction_wd,skew_prediction_hybrid,…,std_article_distilbert_emb_icm,std_article_bert_emb_icm,std_article_roberta_emb_icm,std_article_w_2_vec_emb_icm,std_article_emotions_emb_icm,std_article_constrastive_emb_icm,skew_article_kenneth_emb_icm,skew_article_distilbert_emb_icm,skew_article_bert_emb_icm,skew_article_roberta_emb_icm,skew_article_w_2_vec_emb_icm,skew_article_emotions_emb_icm,skew_article_constrastive_emb_icm,kurtosis_article_kenneth_emb_icm,kurtosis_article_distilbert_emb_icm,kurtosis_article_bert_emb_icm,kurtosis_article_roberta_emb_icm,kurtosis_article_w_2_vec_emb_icm,kurtosis_article_emotions_emb_icm,kurtosis_article_constrastive_emb_icm,entropy_article_kenneth_emb_icm,entropy_article_distilbert_emb_icm,entropy_article_bert_emb_icm,entropy_article_roberta_emb_icm,entropy_article_w_2_vec_emb_icm,entropy_article_emotions_emb_icm,entropy_article_constrastive_emb_icm,kenneth_emb_icm_minus_median_article,distilbert_emb_icm_minus_median_article,bert_emb_icm_minus_median_article,roberta_emb_icm_minus_median_article,w_2_vec_emb_icm_minus_median_article,emotions_emb_icm_minus_median_article,constrastive_emb_icm_minus_median_article,history_total_pageviews/inviews_avg,history_total_pageviews/inviews_std,history_total_pageviews/inviews_median
u32,i32,i8,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,…,f32,f32,f32,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64
466616468,9788576,1,0.200186,0.410536,0.358013,0.285043,0.108904,0.154077,0.073095,0.502709,0.496884,0.40397,0.0,0.0,0.0,0.35309,0.359364,0.291747,0.108765,0.152183,0.072822,0.334548,0.210867,0.281212,0.263552,0.369086,0.332147,0.236565,0.397691,0.662059,0.867065,0.94269,1.254149,0.672242,0.918029,1.029912,…,9.428566,0.043782,4.034845,0.005602,0.008537,0.002076,1.347945,1.338961,1.225559,1.469521,1.263587,1.310116,1.569947,1.706475,1.734821,1.184969,2.24453,1.28404,1.508709,2.485618,,,,,,,,0.005165,9.932929,0.069678,12.004709,0.019945,0.007979,0.006523,0.226252,0.060943,0.221927
181135546,7184889,0,0.003716,-4.374781,0.00237,0.000196,0.000021,0.000038,0.001082,0.024296,0.0,0.00012,0.0,0.000029,0.0011,0.11619,0.001023,0.000199,0.000022,0.000035,0.001179,0.015831,-1.737497,0.23082,0.247421,0.220219,0.196941,0.23532,0.337074,0.765169,0.964014,0.968717,1.083183,1.085879,0.962499,0.894323,…,8.515834,0.055129,0.949957,0.002568,0.012781,0.000963,1.654848,0.655124,0.737623,2.769909,3.858868,0.79574,2.26755,3.458012,-0.295578,-0.199014,13.924756,23.803068,0.03194,10.87664,,,,,,,,0.000943,1.925414,-0.015148,0.923537,0.00166,-0.002659,0.0,0.21412,0.065238,0.214103
434408561,9777950,0,0.150125,-0.789998,0.195015,0.043196,0.054314,0.113592,0.098158,0.317028,0.28534,0.047383,0.040152,0.132422,0.119598,0.333494,0.195227,0.042895,0.053755,0.113438,0.099996,0.24678,-0.495165,0.194334,0.115816,0.182435,0.192526,0.174167,0.372989,0.490994,1.247005,1.430332,1.346423,1.695936,1.429039,0.833782,…,8.149134,0.049538,4.089873,0.008084,0.018078,0.002521,1.346631,1.073569,1.036617,1.20194,1.232099,1.165708,1.450549,1.745771,0.657796,0.533671,1.237353,1.24167,1.117331,2.403681,,,,,,,,0.001306,0.060627,0.033328,5.540361,0.016311,0.007978,0.002253,0.215749,0.064758,0.219972
455607329,6741781,0,0.005318,-1.580378,0.028869,0.00751,0.000469,0.009125,0.009218,0.215536,0.033673,0.007886,0.000096,0.00937,0.010039,0.504808,0.036559,0.009309,0.00061,0.012503,0.011108,0.143169,0.32635,0.177274,0.198701,0.164337,0.17368,0.176648,0.363934,0.328293,1.429937,1.216671,1.401829,1.346754,1.367407,0.814268,…,6.68057,0.028271,2.233563,0.003277,0.014566,0.001419,1.625104,1.057565,1.149429,2.850819,3.168119,1.125724,4.061207,2.89059,0.495455,0.883546,11.440269,15.620951,0.686762,27.084297,,,,,,,,0.004688,-0.908599,0.00909,0.923592,0.008594,-0.007977,0.001463,0.186736,0.070253,0.200735
542029241,9203696,0,0.024434,-1.544941,0.077473,0.013194,0.008828,0.030792,0.015861,0.0,0.0,0.0,0.0,0.0,0.0,0.416693,0.092309,0.016421,0.010407,0.042032,0.019545,0.0,0.283708,0.279142,0.273703,0.241631,0.314878,0.271652,0.378648,0.725799,1.438303,1.319823,1.517037,0.994715,1.199145,0.94255,…,8.726703,0.04658,2.910646,0.005924,0.01456,0.002754,1.654433,0.861155,0.984398,1.059821,1.440068,0.907932,1.428251,3.625127,-0.154971,0.486249,0.418673,1.828675,0.252385,1.768458,,,,,,,,0.0,-4.521471,-0.024238,-0.923441,-0.000008,-0.009308,-0.000091,0.20975,0.065162,0.216495
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
219090908,9769155,0,0.199436,-0.356438,0.150871,0.339135,0.260861,0.105849,0.195931,0.0,0.0,0.543481,0.341069,0.049129,0.300948,0.326117,0.150002,0.347064,0.262747,0.108091,0.200677,0.040818,0.55917,0.276487,0.303046,0.261586,0.244233,0.268766,0.407667,0.454701,0.614104,0.249865,0.629715,0.69547,0.549454,0.583092,…,9.177736,0.044453,5.211512,0.004433,0.012189,0.003271,1.16413,1.132542,1.107671,1.238481,1.146935,1.144706,1.297057,1.25167,1.02671,0.820553,1.408498,1.019593,0.934026,1.425347,,,,,,,,0.00226,4.774179,0.048476,5.540766,0.000037,0.013297,0.002113,0.211621,0.0695,0.213246
275156541,9777858,0,0.171856,0.289145,0.221737,0.140449,0.124301,0.234315,0.211557,0.645782,0.389328,0.329108,0.253287,0.423265,0.32054,0.396965,0.224412,0.149175,0.1314,0.25749,0.221333,0.555736,0.315597,0.268607,0.152717,0.222952,0.200108,0.241824,0.558484,-0.856843,0.367949,1.076201,0.106896,1.039403,0.965498,-0.796228,…,6.62546,0.040341,2.97747,0.006171,0.008635,0.003008,1.204465,1.277506,1.161931,1.523974,1.378817,1.395276,1.844985,1.138423,1.655059,0.890875,2.261726,1.727611,1.766794,3.763301,,,,,,,,-0.00083,-3.81962,-0.024237,-1.846935,-0.003224,-0.003989,-0.001313,0.217811,0.076293,0.222455
482214189,9366571,0,0.07841,-0.402126,0.054986,0.089655,0.124331,0.046937,0.062027,0.0,0.0,0.0,0.0,0.0,0.0,0.57819,0.066642,0.110551,0.178136,0.071709,0.079293,0.0,1.153503,0.366695,0.333912,0.403939,0.427915,0.346386,0.711964,-1.232293,-0.697271,-0.740019,-1.052811,-1.374655,-0.737767,-1.399009,…,8.395556,0.056226,13.644414,0.022021,0.014398,0.016876,0.657441,0.83283,0.730011,0.677215,0.651346,0.780793,0.676229,-0.367898,0.047627,-0.155785,-0.380409,-0.413521,-0.07283,-0.244502,,,,,,,,0.001998,-0.015533,0.039386,3.694004,0.009987,0.009308,0.000229,0.211346,0.064938,0.211914
438218782,9766007,0,0.039822,-1.561564,0.052188,0.030028,0.029174,0.031817,0.031283,0.065119,0.0,0.019304,0.016057,0.0,0.011708,0.366445,0.193122,0.0452,0.050352,0.055327,0.04773,0.043655,-0.561423,0.344528,0.302639,0.288788,0.331217,0.324667,0.41642,0.43552,0.428817,0.448285,0.398113,0.246602,0.403542,0.423346,…,7.162863,0.053167,8.473217,0.013954,0.017422,0.006111,1.289302,1.167315,1.028914,1.193776,1.12911,1.010507,1.400839,1.233523,0.464261,0.154289,0.733091,0.439969,-0.052226,1.591715,,,,,,,,0.00309,5.882044,0.018178,5.079128,0.008223,0.003989,0.00558,0.2281,0.0525,0.234614


In [64]:
test = predictions.select(['impression_id','article_delay_days','category','target','prediction'])\
    .with_columns(
        pl.when(pl.col('prediction') == pl.col('prediction').max().over('impression_id'))\
            .then(1)
            .otherwise(0)
        )\
    .group_by('article_delay_days').agg(
        pl.col('target').sum(),
        pl.col('literal').sum()
    ).filter(pl.col('target') > 20).filter((pl.col('target') * 1.1 < pl.col('literal')).or_(pl.col('target')* 0.9 > pl.col('literal')))\
        .with_columns(
            (pl.col('literal') /pl.col('target')).alias('ratio')
        ).sort(by='ratio', descending= False)

print(test[:10])
print(test[52:])

shape: (10, 4)
┌────────────────────┬────────┬─────────┬──────────┐
│ article_delay_days ┆ target ┆ literal ┆ ratio    │
│ ---                ┆ ---    ┆ ---     ┆ ---      │
│ i16                ┆ i64    ┆ i32     ┆ f64      │
╞════════════════════╪════════╪═════════╪══════════╡
│ 353                ┆ 80     ┆ 0       ┆ 0.0      │
│ 352                ┆ 81     ┆ 1       ┆ 0.012346 │
│ 618                ┆ 52     ┆ 1       ┆ 0.019231 │
│ 39                 ┆ 100    ┆ 3       ┆ 0.03     │
│ 474                ┆ 65     ┆ 2       ┆ 0.030769 │
│ 38                 ┆ 23     ┆ 1       ┆ 0.043478 │
│ 296                ┆ 243    ┆ 11      ┆ 0.045267 │
│ 295                ┆ 228    ┆ 11      ┆ 0.048246 │
│ 879                ┆ 164    ┆ 9       ┆ 0.054878 │
│ 40                 ┆ 68     ┆ 4       ┆ 0.058824 │
└────────────────────┴────────┴─────────┴──────────┘
shape: (10, 4)
┌────────────────────┬────────┬─────────┬──────────┐
│ article_delay_days ┆ target ┆ literal ┆ ratio    │
│ ---           

In [37]:
print(test[10:])

shape: (9, 3)
┌──────────┬────────┬─────────┐
│ category ┆ target ┆ literal │
│ ---      ┆ ---    ┆ ---     │
│ i16      ┆ i64    ┆ i32     │
╞══════════╪════════╪═════════╡
│ 572      ┆ 224    ┆ 198     │
│ 2975     ┆ 9991   ┆ 8638    │
│ 2        ┆ 0      ┆ 0       │
│ 2737     ┆ 1      ┆ 0       │
│ 2889     ┆ 1      ┆ 0       │
│ 512      ┆ 10297  ┆ 12429   │
│ 565      ┆ 1301   ┆ 1165    │
│ 498      ┆ 10901  ┆ 9745    │
│ 529      ┆ 68     ┆ 52      │
└──────────┴────────┴─────────┘
