In [1]:
import polars as pl
import tensorflow as tf
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
import numpy as np
import logging
import random

seed = 42
np.random.seed(seed)
random.seed(seed)
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.ERROR)
tf.random.set_seed(seed)
print(tf.__version__)

2024-06-18 20:42:06.843321: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-18 20:42:06.906157: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


2.16.1


In [2]:
history = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/train/history.parquet')
behaviors = pl.read_parquet('/home/ubuntu/dataset/ebnerd_large/train/behaviors.parquet')
b_sub = pl.read_parquet('/home/ubuntu/dset_complete/subsample/train_ds.parquet')
articles = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/articles.parquet')
history.head(2)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,list[datetime[μs]],list[f32],list[i32],list[f32]
13538,"[2023-04-27 10:17:43, 2023-04-27 10:18:01, … 2023-05-17 20:36:34]","[100.0, 35.0, … 100.0]","[9738663, 9738569, … 9769366]","[17.0, 12.0, … 16.0]"
14241,"[2023-04-27 09:40:18, 2023-04-27 09:40:33, … 2023-05-17 17:08:41]","[100.0, 46.0, … 100.0]","[9738557, 9738528, … 9767852]","[8.0, 9.0, … 12.0]"


In [3]:
history_sub = b_sub.filter(pl.col('target') == 1).select('user_id', 'impression_time', 'scroll_percentage', 'article', 'read_time').sort('impression_time')\
    .group_by('user_id', maintain_order=True).agg(pl.all()).rename({
        'impression_time': 'impression_time_fixed', 'scroll_percentage': 'scroll_percentage_fixed',
        'article': 'article_id_fixed', 'read_time': 'read_time_fixed'
    })
history_sub.head(1)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,list[datetime[μs]],list[f32],list[i32],list[f32]
1260010,"[2023-05-18 07:00:01, 2023-05-19 06:46:44, … 2023-05-25 06:26:48]","[null, null, … null]","[9767697, 9772088, … 9780195]","[38.0, 64.0, … 42.0]"


In [4]:
history_all = history.join(
    history_sub, on='user_id', suffix='_r'
).with_columns(
    *[pl.concat_list([key, f'{key}_r']).alias(key) for key in history.columns if key != 'user_id']
).drop([f'{key}_r' for key in history.columns if key != 'user_id'])
history_all.head(1)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,list[datetime[μs]],list[f32],list[i32],list[f32]
1260010,"[2023-04-28 06:01:24, 2023-04-28 06:01:55, … 2023-05-25 06:26:48]","[null, 87.0, … null]","[9739837, 9739888, … 9780195]","[18.0, 20.0, … 42.0]"


In [5]:
df_order=b_sub.select('impression_id', 'user_id', 'impression_time').unique(['impression_id', 'user_id'], keep='first').sort('impression_time').drop('impression_time')\
    .group_by('user_id', maintain_order=True).map_groups(
        lambda x: x.with_row_index()
    )
df_order.head(1)

index,impression_id,user_id
u32,u32,u32
0,41650737,1260010


In [6]:
b_sub = b_sub.join(
    df_order.rename({'index': 'history_pos'}), on=['user_id', 'impression_id'], how='left'
).join(
    history.select('user_id', pl.col('article_id_fixed').list.len().alias('history_l')), on='user_id', how='left'
).with_columns(
    (pl.col('history_pos') + pl.col('history_l')).alias('history_all_end_pos')
).drop('history_pos').drop('history_l')
b_sub.head(1)

impression_id,user_id,article,target,device_type,read_time,scroll_percentage,is_sso_user,gender,age,is_subscriber,postcode,trendiness_score_1d,trendiness_score_3d,trendiness_score_5d,trendiness_score_3d_leak,weekday,hour,trendiness_score_1d/3d,trendiness_score_1d/5d,normalized_trendiness_score_overall,premium,category,sentiment_score,sentiment_label,num_images,title_len,subtitle_len,body_len,num_topics,total_pageviews,total_inviews,total_read_time,total_pageviews/inviews,article_type,article_delay_days,article_delay_hours,…,std_article_kenneth_emb_icm,std_article_distilbert_emb_icm,std_article_bert_emb_icm,std_article_roberta_emb_icm,std_article_w_to_vec_emb_icm,std_article_emotions_emb_icm,std_article_constrastive_emb_icm,skew_article_kenneth_emb_icm,skew_article_distilbert_emb_icm,skew_article_bert_emb_icm,skew_article_roberta_emb_icm,skew_article_w_to_vec_emb_icm,skew_article_emotions_emb_icm,skew_article_constrastive_emb_icm,kurtosis_article_kenneth_emb_icm,kurtosis_article_distilbert_emb_icm,kurtosis_article_bert_emb_icm,kurtosis_article_roberta_emb_icm,kurtosis_article_w_to_vec_emb_icm,kurtosis_article_emotions_emb_icm,kurtosis_article_constrastive_emb_icm,entropy_article_kenneth_emb_icm,entropy_article_distilbert_emb_icm,entropy_article_bert_emb_icm,entropy_article_roberta_emb_icm,entropy_article_w_to_vec_emb_icm,entropy_article_emotions_emb_icm,entropy_article_constrastive_emb_icm,kenneth_emb_icm_minus_median_article,distilbert_emb_icm_minus_median_article,bert_emb_icm_minus_median_article,roberta_emb_icm_minus_median_article,w_to_vec_emb_icm_minus_median_article,emotions_emb_icm_minus_median_article,constrastive_emb_icm_minus_median_article,impression_time,history_all_end_pos
u32,u32,i32,i8,i8,f32,f32,bool,i8,i8,bool,i8,i16,i16,i16,i16,i8,i8,f32,f32,f32,bool,i16,f32,str,u32,u8,u8,u16,u32,i32,i32,f32,f32,str,i16,i32,…,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,datetime[μs],u32
149474,139836,9778728,0,2,13.0,,False,2,,False,5,150,521,836,419,3,7,0.287908,0.179426,0.880068,False,142,0.9654,"""Negative""",1,5,18,251,7,22415,220247,1004828.0,0.101772,"""article_default""",0,0,…,0.003313,0.010775,0.020001,6.258241,0.005474,4.929348,0.062682,1.300164,2.060066,1.152196,1.754251,2.088241,1.214502,1.406736,1.339344,5.427761,1.025453,3.529418,5.37705,1.054931,1.963193,,,,,,,,-0.002187,-0.004914,-0.014626,-3.694031,-0.002645,-3.802204,-0.039383,2023-05-24 07:47:53,26


In [7]:
history_cols = [col for col in history.columns if col != 'user_id']
window=20
history_f = b_sub.join(
    history_all, on='user_id', how='left'
).with_columns(
    pl.col(history_cols).list.slice((pl.col('history_all_end_pos') - window).clip(lower_bound=0), window).name.keep()
).select('impression_id', 'user_id', 'article', *history_cols).with_row_index()[:100]
history_f.head(2)

index,impression_id,user_id,article,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,u32,u32,i32,list[datetime[μs]],list[f32],list[i32],list[f32]
0,149474,139836,9778728,"[2023-05-05 10:09:28, 2023-05-05 10:09:37, … 2023-05-20 15:22:31]","[39.0, 29.0, … null]","[9750829, 9750793, … 9771113]","[4.0, 8.0, … 7.0]"
1,149474,139836,9778669,"[2023-05-05 10:09:28, 2023-05-05 10:09:37, … 2023-05-20 15:22:31]","[39.0, 29.0, … null]","[9750829, 9750793, … 9771113]","[4.0, 8.0, … 7.0]"


In [8]:
from polimi.utils.tf_models.utils.build_sequences import build_history_seq, build_sequences_seq_iterator, N_CATEGORY, N_SENTIMENT_LABEL, N_SUBCATEGORY, N_TOPICS, N_HOUR_GROUP, N_WEEKDAY
from polimi.utils.tf_models import TemporalHistorySequenceModel
import joblib
import tensorflow as tf

model = TemporalHistorySequenceModel(
    seq_embedding_dims={
        # adding one dim more to cover missings, where needed
        'input_topics': (N_TOPICS + 1, 10, True),
        'input_subcategory': (N_SUBCATEGORY + 1, 10, True),
        'input_category': (N_CATEGORY + 1, 10, False),
        'input_weekday': (N_WEEKDAY, 3, False),
        'input_hour_group': (N_HOUR_GROUP, 3, False),
        'input_sentiment_label': (N_SENTIMENT_LABEL + 1, 2, False)
    },
    seq_numerical_features=['scroll_percentage', 'read_time', 'premium'],
    n_recurrent_layers=1,
    recurrent_embedding_dim=64,
    l1_lambda=1e-4,
    l2_lambda=1e-4,
)

model._build()
model.model.load_weights('/home/ubuntu/experiments/rnn_seq_2024-06-18_18-22-14/checkpoints/checkpoint.weights.h5')

In [9]:
concatenate_layer = model.model.get_layer('concatenate').output

gru_layer = model.model.get_layer('gru')
gru_layer.return_state = True

output = gru_layer(concatenate_layer)

embedding_model = tf.keras.Model(inputs=model.model.inputs, outputs=output)
embedding_model.summary()

In [10]:
history_seq_final = history_f.select(
    pl.col('index').alias('user_id'),
    *history_cols
)
history_seq_final.head(2)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,list[datetime[μs]],list[f32],list[i32],list[f32]
0,"[2023-05-05 10:09:28, 2023-05-05 10:09:37, … 2023-05-20 15:22:31]","[39.0, 29.0, … null]","[9750829, 9750793, … 9771113]","[4.0, 8.0, … 7.0]"
1,"[2023-05-05 10:09:28, 2023-05-05 10:09:37, … 2023-05-20 15:22:31]","[39.0, 29.0, … null]","[9750829, 9750793, … 9771113]","[4.0, 8.0, … 7.0]"


In [11]:
history_seq = build_history_seq(history_seq_final, articles)
history_seq.head(1)

user_id,category,hour_group,impression_time_fixed,premium,read_time,scroll_percentage,sentiment_label,weekday,topics_0,topics_1,topics_2,topics_3,topics_4,topics_5,topics_6,topics_7,topics_8,topics_9,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,…,subcategory_226,subcategory_227,subcategory_228,subcategory_229,subcategory_230,subcategory_231,subcategory_232,subcategory_233,subcategory_234,subcategory_235,subcategory_236,subcategory_237,subcategory_238,subcategory_239,subcategory_240,subcategory_241,subcategory_242,subcategory_243,subcategory_244,subcategory_245,subcategory_246,subcategory_247,subcategory_248,subcategory_249,subcategory_250,subcategory_251,subcategory_252,subcategory_253,subcategory_254,subcategory_255,subcategory_256,subcategory_257,subcategory_258,subcategory_259,subcategory_260,subcategory_261,subcategory_262
u32,list[i8],list[i8],list[datetime[μs]],list[i8],list[f32],list[f32],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],…,list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8]
0,"[8, 4, … 5]","[2, 2, … 3]","[2023-05-05 10:09:28, 2023-05-05 10:09:37, … 2023-05-20 15:22:31]","[0, 0, … 1]","[4.0, 8.0, … 7.0]","[39.0, 29.0, … 0.0]","[1, 2, … 1]","[4, 4, … 5]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"


In [13]:
history_seq_trucated = history_seq.with_columns(
    pl.all().exclude('user_id').list.reverse().list.eval(pl.element().extend_constant(0, window)).list.reverse().list.tail(window).name.keep()
)
history_seq_trucated.head(2)

user_id,category,hour_group,impression_time_fixed,premium,read_time,scroll_percentage,sentiment_label,weekday,topics_0,topics_1,topics_2,topics_3,topics_4,topics_5,topics_6,topics_7,topics_8,topics_9,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,…,subcategory_226,subcategory_227,subcategory_228,subcategory_229,subcategory_230,subcategory_231,subcategory_232,subcategory_233,subcategory_234,subcategory_235,subcategory_236,subcategory_237,subcategory_238,subcategory_239,subcategory_240,subcategory_241,subcategory_242,subcategory_243,subcategory_244,subcategory_245,subcategory_246,subcategory_247,subcategory_248,subcategory_249,subcategory_250,subcategory_251,subcategory_252,subcategory_253,subcategory_254,subcategory_255,subcategory_256,subcategory_257,subcategory_258,subcategory_259,subcategory_260,subcategory_261,subcategory_262
u32,list[i8],list[i8],list[datetime[μs]],list[i8],list[f32],list[f32],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],…,list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8]
0,"[8, 4, … 5]","[2, 2, … 3]","[2023-05-05 10:09:28, 2023-05-05 10:09:37, … 2023-05-20 15:22:31]","[0, 0, … 1]","[4.0, 8.0, … 7.0]","[39.0, 29.0, … 0.0]","[1, 2, … 1]","[4, 4, … 5]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"
1,"[8, 4, … 5]","[2, 2, … 3]","[2023-05-05 10:09:28, 2023-05-05 10:09:37, … 2023-05-20 15:22:31]","[0, 0, … 1]","[4.0, 8.0, … 7.0]","[39.0, 29.0, … 0.0]","[1, 2, … 1]","[4, 4, … 5]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"


In [20]:
import numpy as np

multi_one_hot_cols = ['topics', 'subcategory']
categorical_cols = ['category', 'weekday', 'hour_group', 'sentiment_label']
# caterical_cols_num_classes = {key: history_seq[key].explode().max() + 1 for key in categorical_cols}  #uncomment if you don't want to hardcode
caterical_cols_num_classes = {
    'category': N_CATEGORY + 1,#+1 to handle null values
    'weekday': N_WEEKDAY,
    'hour_group': N_HOUR_GROUP,
    'sentiment_label': N_SENTIMENT_LABEL + 1 #+1 to handle null
}
#it can be hardcoded if needed
all_features = history_seq.drop('user_id').columns
name_idx_dict = {key: [i for i, col in enumerate(all_features) if col.startswith(key)] for key in multi_one_hot_cols + categorical_cols}
numerical_cols = ['scroll_percentage', 'read_time', 'premium']
name_idx_dict['numerical'] = [i for i, col in enumerate(all_features) if col in numerical_cols]

def last_history_window_generator(history_seq_trucated):
    for user_history in history_seq_trucated.partition_by(['user_id'], maintain_order=True):
        x = user_history.drop('user_id').to_numpy()[0]
        x = np.array([np.array(x_i) for x_i in x])
        res_x = {}
        for key, idx in name_idx_dict.items():
            res_x[f'input_{key}'] = x[idx, :].T.astype(np.float32 if key in numerical_cols else np.int16)
            # print(key, res_x[f'input_{key}'].shape)
    
        yield res_x

In [21]:
mask = 0

inference_dataset = tf.data.Dataset.from_generator(
    lambda : last_history_window_generator(history_seq_trucated),
    output_signature={
        'input_topics': tf.TensorSpec(shape=(window,N_TOPICS+1), dtype=tf.int16), # history topics sequence
        'input_category': tf.TensorSpec(shape=(window, 1), dtype=tf.int16), # history category sequence
        'input_subcategory': tf.TensorSpec(shape=(window, N_SUBCATEGORY+1), dtype=tf.int16), # history subcategory sequence
        'input_weekday': tf.TensorSpec(shape=(window, 1), dtype=tf.int16), # history weekday sequence
        'input_hour_group': tf.TensorSpec(shape=(window, 1), dtype=tf.int16), # history hour_group sequence
        'input_sentiment_label': tf.TensorSpec(shape=(window, 1), dtype=tf.int16), # history sentiment_label sequence
        'input_numerical': tf.TensorSpec(shape=(window, 3), dtype=tf.float32), # history (premium, read_time, scroll_percentage) sequence
    }
).batch(512)

sequence_embeddings, state_embeddings = embedding_model.predict(inference_dataset)
sequence_embeddings.shape, state_embeddings.shape

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1s/step


2024-06-18 20:44:55.076197: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
  self.gen.throw(typ, value, traceback)


((100, 64), (100, 64))

In [28]:
user_embeddings_train = history_seq_trucated.select('user_id').hstack(pl.DataFrame(state_embeddings, schema=[f'user_embedding_{i}' for i in range(sequence_embeddings.shape[1])]))\
    .rename({'user_id': 'index'})
user_embeddings_train.head(3)

index,user_embedding_0,user_embedding_1,user_embedding_2,user_embedding_3,user_embedding_4,user_embedding_5,user_embedding_6,user_embedding_7,user_embedding_8,user_embedding_9,user_embedding_10,user_embedding_11,user_embedding_12,user_embedding_13,user_embedding_14,user_embedding_15,user_embedding_16,user_embedding_17,user_embedding_18,user_embedding_19,user_embedding_20,user_embedding_21,user_embedding_22,user_embedding_23,user_embedding_24,user_embedding_25,user_embedding_26,user_embedding_27,user_embedding_28,user_embedding_29,user_embedding_30,user_embedding_31,user_embedding_32,user_embedding_33,user_embedding_34,user_embedding_35,user_embedding_36,user_embedding_37,user_embedding_38,user_embedding_39,user_embedding_40,user_embedding_41,user_embedding_42,user_embedding_43,user_embedding_44,user_embedding_45,user_embedding_46,user_embedding_47,user_embedding_48,user_embedding_49,user_embedding_50,user_embedding_51,user_embedding_52,user_embedding_53,user_embedding_54,user_embedding_55,user_embedding_56,user_embedding_57,user_embedding_58,user_embedding_59,user_embedding_60,user_embedding_61,user_embedding_62,user_embedding_63
u32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
0,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
1,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
2,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914


In [34]:
history_f = history_f.select('index', 'impression_id', 'user_id', 'article').join(user_embeddings_train, on='index', how='left').drop('index')

In [35]:
history_f

impression_id,user_id,article,user_embedding_0,user_embedding_1,user_embedding_2,user_embedding_3,user_embedding_4,user_embedding_5,user_embedding_6,user_embedding_7,user_embedding_8,user_embedding_9,user_embedding_10,user_embedding_11,user_embedding_12,user_embedding_13,user_embedding_14,user_embedding_15,user_embedding_16,user_embedding_17,user_embedding_18,user_embedding_19,user_embedding_20,user_embedding_21,user_embedding_22,user_embedding_23,user_embedding_24,user_embedding_25,user_embedding_26,user_embedding_27,user_embedding_28,user_embedding_29,user_embedding_30,user_embedding_31,user_embedding_32,user_embedding_33,user_embedding_34,user_embedding_35,user_embedding_36,user_embedding_37,user_embedding_38,user_embedding_39,user_embedding_40,user_embedding_41,user_embedding_42,user_embedding_43,user_embedding_44,user_embedding_45,user_embedding_46,user_embedding_47,user_embedding_48,user_embedding_49,user_embedding_50,user_embedding_51,user_embedding_52,user_embedding_53,user_embedding_54,user_embedding_55,user_embedding_56,user_embedding_57,user_embedding_58,user_embedding_59,user_embedding_60,user_embedding_61,user_embedding_62,user_embedding_63
u32,u32,i32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
149474,139836,9778728,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
149474,139836,9778669,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
149474,139836,9778657,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
150528,143471,9778682,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
150528,143471,9778669,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
163881,545277,9776047,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
164059,545856,9773392,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
164059,545856,9776442,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
164059,545856,9776190,1.0,0.999873,0.999965,-0.925864,-0.832986,0.97016,-0.935581,0.999996,-0.999901,0.999968,-0.441256,0.999989,-0.789612,0.999999,0.999977,0.999941,-0.999994,-0.999951,-0.999313,0.602215,-0.999608,-0.999913,-0.99695,1.0,-0.999959,0.920291,-0.9999,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914


In [36]:
from pathlib import Path


train_ds = pl.read_parquet('/home/ubuntu/dset_complete/subsample/train_ds.parquet')
train_ds.join(history_f, on=['user_id', 'impression_id', 'article'], how='left')

impression_id,user_id,article,target,device_type,read_time,scroll_percentage,is_sso_user,gender,age,is_subscriber,postcode,trendiness_score_1d,trendiness_score_3d,trendiness_score_5d,trendiness_score_3d_leak,weekday,hour,trendiness_score_1d/3d,trendiness_score_1d/5d,normalized_trendiness_score_overall,premium,category,sentiment_score,sentiment_label,num_images,title_len,subtitle_len,body_len,num_topics,total_pageviews,total_inviews,total_read_time,total_pageviews/inviews,article_type,article_delay_days,article_delay_hours,…,user_embedding_27,user_embedding_28,user_embedding_29,user_embedding_30,user_embedding_31,user_embedding_32,user_embedding_33,user_embedding_34,user_embedding_35,user_embedding_36,user_embedding_37,user_embedding_38,user_embedding_39,user_embedding_40,user_embedding_41,user_embedding_42,user_embedding_43,user_embedding_44,user_embedding_45,user_embedding_46,user_embedding_47,user_embedding_48,user_embedding_49,user_embedding_50,user_embedding_51,user_embedding_52,user_embedding_53,user_embedding_54,user_embedding_55,user_embedding_56,user_embedding_57,user_embedding_58,user_embedding_59,user_embedding_60,user_embedding_61,user_embedding_62,user_embedding_63
u32,u32,i32,i8,i8,f32,f32,bool,i8,i8,bool,i8,i16,i16,i16,i16,i8,i8,f32,f32,f32,bool,i16,f32,str,u32,u8,u8,u16,u32,i32,i32,f32,f32,str,i16,i32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
149474,139836,9778728,0,2,13.0,,false,2,,false,5,150,521,836,419,3,7,0.287908,0.179426,0.880068,false,142,0.9654,"""Negative""",1,5,18,251,7,22415,220247,1.004828e6,0.101772,"""article_default""",0,0,…,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
149474,139836,9778669,0,2,13.0,,false,2,,false,5,85,199,313,266,3,7,0.427136,0.271565,0.336149,false,118,0.9481,"""Negative""",1,5,11,150,4,74491,373488,4.365609e6,0.199447,"""article_default""",0,1,…,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
149474,139836,9778657,1,2,13.0,,false,2,,false,5,45,117,183,138,3,7,0.384615,0.245902,0.197635,false,118,0.8347,"""Neutral""",2,6,31,336,3,108389,478098,7.606737e6,0.226709,"""article_default""",0,1,…,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
150528,143471,9778682,0,2,25.0,,false,2,,false,5,69,206,334,201,3,7,0.334951,0.206587,0.347973,false,498,0.9546,"""Negative""",1,5,20,267,3,143520,455723,9.298546e6,0.314928,"""article_default""",0,1,…,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
150528,143471,9778669,0,2,25.0,,false,2,,false,5,85,199,313,266,3,7,0.427136,0.271565,0.336149,false,118,0.9481,"""Negative""",1,5,11,150,4,74491,373488,4.365609e6,0.199447,"""article_default""",0,1,…,0.999965,0.999977,0.999954,-0.999992,-0.999728,0.999771,-0.99997,-0.973959,-0.99998,-0.999961,-0.999997,-0.999987,0.999825,-0.999978,-0.999271,0.999996,0.53861,1.0,0.999927,0.825415,0.084491,0.999998,0.508373,0.999949,-0.8903,-0.999646,0.644044,-0.9973,0.999645,-0.999996,0.999428,-0.952411,-0.999984,0.999956,0.940429,0.999929,-0.999914
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
580100695,2110744,9769917,0,1,5.0,100.0,false,2,,false,5,46,105,152,76,4,10,0.438095,0.302632,0.203883,true,140,0.989,"""Negative""",4,5,32,826,2,203222,2163455,1.2661448e7,0.093934,"""article_default""",0,17,…,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
580100695,2110744,9767697,1,1,5.0,100.0,false,2,,false,5,50,187,238,116,4,10,0.26738,0.210084,0.363107,false,118,0.9613,"""Negative""",5,7,2,982,3,199205,954408,2.595362e7,0.208721,"""article_default""",0,3,…,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
580100697,2110744,9770997,0,1,14.0,100.0,false,2,,false,5,32,78,136,54,4,10,0.410256,0.235294,0.151456,false,414,0.845,"""Positive""",1,5,18,164,4,110632,485698,5.034287e6,0.227779,"""article_default""",0,3,…,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
580100697,2110744,9514481,0,1,14.0,100.0,false,2,,false,5,8,37,49,34,4,10,0.216216,0.163265,0.071845,true,414,0.9501,"""Neutral""",7,9,30,371,3,,,,,"""article_standard_feature""",182,4390,…,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
