In [90]:
import polars as pl
import tensorflow as tf
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
import numpy as np
import logging
import random

seed = 42
np.random.seed(seed)
random.seed(seed)
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.ERROR)
tf.random.set_seed(seed)
print(tf.__version__)

2.16.1


In [91]:
history = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/train/history.parquet')
behaviors = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/train/behaviors.parquet')
articles = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/articles.parquet')
history.head(2)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,list[datetime[μs]],list[f32],list[i32],list[f32]
13538,"[2023-04-27 10:17:43, 2023-04-27 10:18:01, … 2023-05-17 20:36:34]","[100.0, 35.0, … 100.0]","[9738663, 9738569, … 9769366]","[17.0, 12.0, … 16.0]"
14241,"[2023-04-27 09:40:18, 2023-04-27 09:40:33, … 2023-05-17 17:08:41]","[100.0, 46.0, … 100.0]","[9738557, 9738528, … 9767852]","[8.0, 9.0, … 12.0]"


In [92]:
topics = articles['topics'].explode().unique().drop_nans().drop_nulls().sort().to_frame().with_row_index()
category = articles['category'].unique().drop_nans().drop_nulls().sort().to_frame().with_row_index(offset=1)
subcategory = articles['subcategory'].explode().unique().drop_nans().drop_nulls().sort().to_frame().with_row_index()
sentiment_label = articles['sentiment_label'].explode().unique().drop_nans().drop_nulls().sort().to_frame().with_row_index(offset=1)
mask = 0

articles = articles.select(['article_id', 'category', 'subcategory', 'premium', 'topics', 'sentiment_label'])\
    .with_columns(
        pl.col('topics').fill_null(pl.lit([])),
        pl.col('subcategory').fill_null(pl.lit([]))
    )\
    .with_columns(
        pl.col('topics').list.eval(pl.element().replace(topics['topics'], topics['index'], default=None)).list.drop_nulls().cast(pl.List(pl.Int32)),
        pl.col('category').replace(category['category'], category['index'], default=None).fill_null(mask).cast(pl.Int32),
        pl.col('sentiment_label').replace(sentiment_label['sentiment_label'], sentiment_label['index'], default=None).fill_null(mask).cast(pl.Int32),
        pl.col('subcategory').list.eval(pl.element().replace(subcategory['subcategory'], subcategory['index'], default=None)).list.drop_nulls().cast(pl.List(pl.Int32)),
        pl.col('premium').cast(pl.Int8)
)
articles.head(2)

article_id,category,subcategory,premium,topics,sentiment_label
i32,i32,list[i32],i8,list[i32],i32
3001353,5,[],0,"[25, 47]",1
3003065,7,"[87, 88]",0,"[69, 13, 77]",3


In [93]:
dummies_topics = articles.select('article_id', 'topics').explode('topics').drop_nulls().to_dummies(columns=['topics'])\
    .group_by('article_id').agg(pl.all().sum())
dummies_subcategories = articles.select('article_id', 'subcategory').explode('subcategory').drop_nulls().to_dummies(columns=['subcategory'])\
    .group_by('article_id').agg(pl.all().sum())
    
dummies_subcategories.head(2)

article_id,subcategory_0,subcategory_1,subcategory_10,subcategory_100,subcategory_101,subcategory_102,subcategory_103,subcategory_104,subcategory_105,subcategory_106,subcategory_107,subcategory_108,subcategory_109,subcategory_11,subcategory_110,subcategory_111,subcategory_112,subcategory_113,subcategory_114,subcategory_115,subcategory_116,subcategory_117,subcategory_118,subcategory_119,subcategory_12,subcategory_120,subcategory_121,subcategory_122,subcategory_123,subcategory_124,subcategory_125,subcategory_126,subcategory_127,subcategory_128,subcategory_129,subcategory_13,…,subcategory_66,subcategory_67,subcategory_68,subcategory_69,subcategory_7,subcategory_70,subcategory_71,subcategory_72,subcategory_73,subcategory_74,subcategory_75,subcategory_76,subcategory_77,subcategory_78,subcategory_79,subcategory_8,subcategory_80,subcategory_81,subcategory_82,subcategory_83,subcategory_84,subcategory_85,subcategory_86,subcategory_87,subcategory_88,subcategory_89,subcategory_9,subcategory_90,subcategory_91,subcategory_92,subcategory_93,subcategory_94,subcategory_95,subcategory_96,subcategory_97,subcategory_98,subcategory_99
i32,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,…,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
9791280,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3236498,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [94]:
from polimi.utils._polars import reduce_polars_df_memory_size

articles = articles.join(dummies_topics, on='article_id', how='left')\
    .join(dummies_subcategories, on='article_id', how='left')\
    .drop('topics', 'subcategory')
articles = reduce_polars_df_memory_size(articles)

one_hot_cols = [col for col in articles.columns if col.startswith('topics_') or col.startswith('subcategory_')]
articles = articles.with_columns(
    pl.col(one_hot_cols).fill_null(0)
)

articles.head(2)

Memory usage of dataframe is 40.75 MB
Memory usage after optimization is: 5.75 MB
Decreased by 85.9%


article_id,category,premium,sentiment_label,topics_0,topics_1,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_2,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,topics_28,topics_29,topics_3,topics_30,topics_31,topics_32,topics_33,topics_34,topics_35,topics_36,topics_37,topics_38,…,subcategory_66,subcategory_67,subcategory_68,subcategory_69,subcategory_7,subcategory_70,subcategory_71,subcategory_72,subcategory_73,subcategory_74,subcategory_75,subcategory_76,subcategory_77,subcategory_78,subcategory_79,subcategory_8,subcategory_80,subcategory_81,subcategory_82,subcategory_83,subcategory_84,subcategory_85,subcategory_86,subcategory_87,subcategory_88,subcategory_89,subcategory_9,subcategory_90,subcategory_91,subcategory_92,subcategory_93,subcategory_94,subcategory_95,subcategory_96,subcategory_97,subcategory_98,subcategory_99
i32,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,…,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8
3001353,5,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3003065,7,0,3,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0


In [95]:
df = pl.concat([
    slice.explode(pl.all().exclude('user_id'))\
        .with_columns(
            pl.col('scroll_percentage_fixed').fill_null(0.),
            pl.col('read_time_fixed').fill_null(0.),
        )\
        .with_columns(
            (pl.col('impression_time_fixed').dt.hour() // 4).alias('hour_group'),
            pl.col('impression_time_fixed').dt.weekday().alias('weekday'),
        ).drop('impression_time_fixed')\
        .rename({'scroll_percentage_fixed': 'scroll_percentage', 'read_time_fixed': 'read_time'})
        .join(articles, left_on='article_id_fixed', right_on='article_id', how='left').drop('article_id_fixed')\
        .group_by('user_id').agg(pl.all())
    for slice in history.iter_slices(10000)
])

df.head(3)

user_id,scroll_percentage,read_time,hour_group,weekday,category,premium,sentiment_label,topics_0,topics_1,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_2,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,topics_28,topics_29,topics_3,topics_30,topics_31,topics_32,topics_33,topics_34,…,subcategory_66,subcategory_67,subcategory_68,subcategory_69,subcategory_7,subcategory_70,subcategory_71,subcategory_72,subcategory_73,subcategory_74,subcategory_75,subcategory_76,subcategory_77,subcategory_78,subcategory_79,subcategory_8,subcategory_80,subcategory_81,subcategory_82,subcategory_83,subcategory_84,subcategory_85,subcategory_86,subcategory_87,subcategory_88,subcategory_89,subcategory_9,subcategory_90,subcategory_91,subcategory_92,subcategory_93,subcategory_94,subcategory_95,subcategory_96,subcategory_97,subcategory_98,subcategory_99
u32,list[f32],list[f32],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],…,list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8]
2053634,"[100.0, 0.0, … 0.0]","[16.0, 0.0, … 0.0]","[4, 4, … 5]","[4, 4, … 3]","[4, 5, … 4]","[0, 0, … 0]","[2, 1, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"
1899092,"[100.0, 69.0, … 60.0]","[71.0, 18.0, … 8.0]","[3, 3, … 1]","[4, 4, … 4]","[5, 4, … 4]","[0, 0, … 0]","[1, 2, … 3]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"
1586072,"[100.0, 0.0, … 39.0]","[90.0, 1.0, … 86.0]","[4, 5, … 1]","[4, 4, … 4]","[4, 4, … 7]","[0, 0, … 0]","[2, 1, … 2]","[0, 1, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"


In [96]:
from polimi.utils._polars import reduce_polars_df_memory_size
cols = df.columns
topics_cols = sorted([col for col in cols if col.startswith('topics_')], key=lambda x: int(x.split('_')[-1]))
subcategory_cols = sorted([col for col in cols if col.startswith('subcategory_')], key=lambda x: int(x.split('_')[-1]))
all_others = set(cols) - set(topics_cols) - set(subcategory_cols) - {'user_id'}
cols = ['user_id'] + list(all_others) + topics_cols + subcategory_cols
df = df.select(cols)
df.head(1)

user_id,sentiment_label,category,premium,read_time,scroll_percentage,hour_group,weekday,topics_0,topics_1,topics_2,topics_3,topics_4,topics_5,topics_6,topics_7,topics_8,topics_9,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,topics_28,…,subcategory_137,subcategory_138,subcategory_139,subcategory_140,subcategory_141,subcategory_142,subcategory_143,subcategory_144,subcategory_145,subcategory_146,subcategory_147,subcategory_148,subcategory_149,subcategory_150,subcategory_151,subcategory_152,subcategory_153,subcategory_154,subcategory_155,subcategory_156,subcategory_157,subcategory_158,subcategory_159,subcategory_160,subcategory_161,subcategory_162,subcategory_163,subcategory_164,subcategory_165,subcategory_166,subcategory_167,subcategory_168,subcategory_169,subcategory_170,subcategory_171,subcategory_172,subcategory_173
u32,list[i8],list[i8],list[i8],list[f32],list[f32],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],…,list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8]
2053634,"[2, 1, … 1]","[4, 5, … 4]","[0, 0, … 0]","[16.0, 0.0, … 0.0]","[100.0, 0.0, … 0.0]","[4, 4, … 5]","[4, 4, … 3]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"


In [97]:
from tqdm import tqdm

def build_sequences(df: pl.DataFrame, w: int, stride: int):
    all_features = df.drop('user_id').columns
    singular_cols = ['topics', 'subcategory', 'category', 'weekday', 'hour_group']
    name_idx_dict = {key: [i for i, col in enumerate(all_features) if col.startswith(key)] for key in singular_cols}
    numerical_cols = ['scroll_percentage', 'read_time', 'premium']
    name_idx_dict['numerical'] = [i for i, col in enumerate(all_features) if col in numerical_cols]
        
    res = {key: ([], []) for key in name_idx_dict.keys()}

    for user_df in tqdm(df.partition_by('user_id')):
        x = user_df.drop('user_id').to_numpy()[0]
        x = np.array([np.array(x_i) for x_i in x])
                
        i = 0
        if i + w >= x.shape[1]:
            # in case history is shorter than the window then we pad it and select the last element as target
            pad_width = w - x[:, :-1].shape[1]
            pad_m = np.zeros((x.shape[0], pad_width))
            padded_x = np.concatenate((pad_m, x[:, :-1]), axis=1)
            y_i = x[:, -1]
            
            for key, idx in name_idx_dict.items():
                res[key][0].append(padded_x[idx, :].T)
                res[key][1].append(y_i[idx].T)
            
        else:
            while i + w < x.shape[1]:
                # in case history is larger than the window then we select the window and the target randomly between the next elements
                x_i = x[:, i:i+w]
                target_random_id = np.random.randint(i+w, x.shape[1])
                y_i = x[:, target_random_id]
                
                for key, idx in name_idx_dict.items():
                    res[key][0].append(x_i[idx, :].T)
                    res[key][1].append(y_i[idx].T)
                
                i+=stride
                         
            #TODO: add padding for the last sequence, if we want to keep it
                

    for key in res.keys():
        res[key] = (np.array(res[key][0]), np.array(res[key][1]))
    
    return res

In [98]:
res = build_sequences(df[:10], w=10, stride=5)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 183.11it/s]


In [99]:
from polimi.utils._polars import reduce_polars_df_memory_size

window = 20
mask = 0
df_trucated = df.with_columns(
    pl.all().exclude('user_id').list.reverse().list.eval(pl.element().extend_constant(mask, window)).list.reverse().list.tail(window).name.keep()
)

df_trucated.head(2)

user_id,sentiment_label,category,premium,read_time,scroll_percentage,hour_group,weekday,topics_0,topics_1,topics_2,topics_3,topics_4,topics_5,topics_6,topics_7,topics_8,topics_9,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,topics_28,…,subcategory_137,subcategory_138,subcategory_139,subcategory_140,subcategory_141,subcategory_142,subcategory_143,subcategory_144,subcategory_145,subcategory_146,subcategory_147,subcategory_148,subcategory_149,subcategory_150,subcategory_151,subcategory_152,subcategory_153,subcategory_154,subcategory_155,subcategory_156,subcategory_157,subcategory_158,subcategory_159,subcategory_160,subcategory_161,subcategory_162,subcategory_163,subcategory_164,subcategory_165,subcategory_166,subcategory_167,subcategory_168,subcategory_169,subcategory_170,subcategory_171,subcategory_172,subcategory_173
u32,list[i8],list[i8],list[i8],list[f32],list[f32],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],…,list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8]
2053634,"[1, 1, … 1]","[25, 4, … 4]","[0, 0, … 0]","[0.0, 116.0, … 0.0]","[0.0, 100.0, … 0.0]","[4, 0, … 5]","[1, 2, … 3]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"
1899092,"[3, 1, … 3]","[5, 25, … 4]","[0, 0, … 0]","[60.0, 17.0, … 8.0]","[100.0, 24.0, … 60.0]","[4, 4, … 1]","[3, 3, … 4]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"


In [100]:
behaviors_df = pl.read_parquet('/home/ubuntu/dset_complete/subsample/train_ds.parquet')
behaviors_df.head(2)

impression_id,user_id,article,target,device_type,read_time,scroll_percentage,is_sso_user,gender,age,is_subscriber,postcode,trendiness_score_1d,trendiness_score_3d,trendiness_score_5d,trendiness_score_3d_leak,weekday,hour,trendiness_score_1d/3d,trendiness_score_1d/5d,normalized_trendiness_score_overall,premium,category,sentiment_score,sentiment_label,num_images,title_len,subtitle_len,body_len,num_topics,total_pageviews,total_inviews,total_read_time,total_pageviews/inviews,article_type,article_delay_days,article_delay_hours,…,constrastive_emb_icm_l_inf_article,std_article_kenneth_emb_icm,std_article_distilbert_emb_icm,std_article_bert_emb_icm,std_article_roberta_emb_icm,std_article_w_to_vec_emb_icm,std_article_emotions_emb_icm,std_article_constrastive_emb_icm,skew_article_kenneth_emb_icm,skew_article_distilbert_emb_icm,skew_article_bert_emb_icm,skew_article_roberta_emb_icm,skew_article_w_to_vec_emb_icm,skew_article_emotions_emb_icm,skew_article_constrastive_emb_icm,kurtosis_article_kenneth_emb_icm,kurtosis_article_distilbert_emb_icm,kurtosis_article_bert_emb_icm,kurtosis_article_roberta_emb_icm,kurtosis_article_w_to_vec_emb_icm,kurtosis_article_emotions_emb_icm,kurtosis_article_constrastive_emb_icm,entropy_article_kenneth_emb_icm,entropy_article_distilbert_emb_icm,entropy_article_bert_emb_icm,entropy_article_roberta_emb_icm,entropy_article_w_to_vec_emb_icm,entropy_article_emotions_emb_icm,entropy_article_constrastive_emb_icm,kenneth_emb_icm_minus_median_article,distilbert_emb_icm_minus_median_article,bert_emb_icm_minus_median_article,roberta_emb_icm_minus_median_article,w_to_vec_emb_icm_minus_median_article,emotions_emb_icm_minus_median_article,constrastive_emb_icm_minus_median_article,impression_time
u32,u32,i32,i8,i8,f32,f32,bool,i8,i8,bool,i8,i16,i16,i16,i16,i8,i8,f32,f32,f32,bool,i16,f32,str,u32,u8,u8,u16,u32,i32,i32,f32,f32,str,i16,i32,…,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,datetime[μs]
149474,139836,9778728,0,2,13.0,,False,2,,False,5,150,521,836,419,3,7,0.287908,0.179426,0.880068,False,142,0.9654,"""Negative""",1,5,18,251,7,22415,220247,1004828.0,0.101772,"""article_default""",0,0,…,0.017241,0.003313,0.010775,0.020001,6.258241,0.005474,4.929348,0.062682,1.300164,2.060066,1.152196,1.754251,2.088241,1.214502,1.406736,1.339344,5.427761,1.025453,3.529418,5.37705,1.054931,1.963193,,,,,,,,-0.002187,-0.004914,-0.014626,-3.694031,-0.002645,-3.802204,-0.039383,2023-05-24 07:47:53
149474,139836,9778669,0,2,13.0,,False,2,,False,5,85,199,313,266,3,7,0.427136,0.271565,0.336149,False,118,0.9481,"""Negative""",1,5,11,150,4,74491,373488,4365609.0,0.199447,"""article_default""",0,1,…,0.017094,0.003318,0.011131,0.019302,5.737217,0.003193,8.05171,0.06661,1.08501,1.063945,0.947301,1.095085,1.058148,1.020884,1.072975,0.822594,0.947987,0.355356,0.968488,1.051799,0.600099,0.791561,,,,,,,,-0.002933,-0.010537,-0.017286,-4.617104,-0.002901,-7.678095,-0.057551,2023-05-24 07:47:53


In [101]:
def build_sequences_cls_iterator(history_seq: pl.DataFrame, behaviors: pl.DataFrame, window:int):
    mask = 0
    history_seq_trucated = history_seq.with_columns(
        pl.all().exclude('user_id').list.reverse().list.eval(pl.element().extend_constant(mask, window)).list.reverse().list.tail(window).name.keep()
    )
    
    for user_id, user_history in tqdm(history_seq_trucated.partition_by(['user_id'], as_dict=True, maintain_order=False).items()): #order not maintained
        for b in behaviors.filter(pl.col('user_id') == user_id[0]).iter_slices(1):
            yield (b.drop('target'), user_history, b['target'].item())
        

In [103]:
for i, (features, user_history_trunc, target) in enumerate(build_sequences_cls_iterator(df, behaviors_df, window=20)):
    print(target)
    break

  0%|          | 0/15143 [00:00<?, ?it/s]


0
