In [1]:
import polars as pl
import tensorflow as tf
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
import numpy as np
import logging
import random

seed = 42
np.random.seed(seed)
random.seed(seed)
tf.autograph.set_verbosity(0)
tf.get_logger().setLevel(logging.ERROR)
tf.random.set_seed(seed)
print(tf.__version__)

2024-06-16 08:55:29.797158: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-16 08:55:29.862380: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


2.16.1


In [2]:
history = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/train/history.parquet')
behaviors = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/train/behaviors.parquet')
articles = pl.read_parquet('/home/ubuntu/dataset/ebnerd_small/articles.parquet')
history.head(2)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,list[datetime[μs]],list[f32],list[i32],list[f32]
13538,"[2023-04-27 10:17:43, 2023-04-27 10:18:01, … 2023-05-17 20:36:34]","[100.0, 35.0, … 100.0]","[9738663, 9738569, … 9769366]","[17.0, 12.0, … 16.0]"
14241,"[2023-04-27 09:40:18, 2023-04-27 09:40:33, … 2023-05-17 17:08:41]","[100.0, 46.0, … 100.0]","[9738557, 9738528, … 9767852]","[8.0, 9.0, … 12.0]"


In [3]:
topics = articles['topics'].explode().unique().drop_nans().drop_nulls().sort().to_frame().with_row_index()
category = articles['category'].unique().drop_nans().drop_nulls().sort().to_frame().with_row_index(offset=1)
subcategory = articles['subcategory'].explode().unique().drop_nans().drop_nulls().sort().to_frame().with_row_index()
sentiment_label = articles['sentiment_label'].explode().unique().drop_nans().drop_nulls().sort().to_frame().with_row_index(offset=1)
mask = 0

articles = articles.select(['article_id', 'category', 'subcategory', 'premium', 'topics', 'sentiment_label'])\
    .with_columns(
        pl.col('topics').fill_null(pl.lit([])),
        pl.col('subcategory').fill_null(pl.lit([]))
    )\
    .with_columns(
        pl.col('topics').list.eval(pl.element().replace(topics['topics'], topics['index'], default=None)).list.drop_nulls().cast(pl.List(pl.Int32)),
        pl.col('category').replace(category['category'], category['index'], default=None).fill_null(mask).cast(pl.Int32),
        pl.col('sentiment_label').replace(sentiment_label['sentiment_label'], sentiment_label['index'], default=None).fill_null(mask).cast(pl.Int32),
        pl.col('subcategory').list.eval(pl.element().replace(subcategory['subcategory'], subcategory['index'], default=None)).list.drop_nulls().cast(pl.List(pl.Int32)),
        pl.col('premium').cast(pl.Int8)
)
articles.head(2)

article_id,category,subcategory,premium,topics,sentiment_label
i32,i32,list[i32],i8,list[i32],i32
3001353,5,[],0,"[25, 47]",1
3003065,7,"[87, 88]",0,"[69, 13, 77]",3


In [4]:
dummies_topics = articles.select('article_id', 'topics').explode('topics').drop_nulls().to_dummies(columns=['topics'])\
    .group_by('article_id').agg(pl.all().sum())
dummies_subcategories = articles.select('article_id', 'subcategory').explode('subcategory').drop_nulls().to_dummies(columns=['subcategory'])\
    .group_by('article_id').agg(pl.all().sum())
    
dummies_subcategories.head(2)

article_id,subcategory_0,subcategory_1,subcategory_10,subcategory_100,subcategory_101,subcategory_102,subcategory_103,subcategory_104,subcategory_105,subcategory_106,subcategory_107,subcategory_108,subcategory_109,subcategory_11,subcategory_110,subcategory_111,subcategory_112,subcategory_113,subcategory_114,subcategory_115,subcategory_116,subcategory_117,subcategory_118,subcategory_119,subcategory_12,subcategory_120,subcategory_121,subcategory_122,subcategory_123,subcategory_124,subcategory_125,subcategory_126,subcategory_127,subcategory_128,subcategory_129,subcategory_13,…,subcategory_66,subcategory_67,subcategory_68,subcategory_69,subcategory_7,subcategory_70,subcategory_71,subcategory_72,subcategory_73,subcategory_74,subcategory_75,subcategory_76,subcategory_77,subcategory_78,subcategory_79,subcategory_8,subcategory_80,subcategory_81,subcategory_82,subcategory_83,subcategory_84,subcategory_85,subcategory_86,subcategory_87,subcategory_88,subcategory_89,subcategory_9,subcategory_90,subcategory_91,subcategory_92,subcategory_93,subcategory_94,subcategory_95,subcategory_96,subcategory_97,subcategory_98,subcategory_99
i32,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,…,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
9311529,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0
9713814,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [5]:
from polimi.utils._polars import reduce_polars_df_memory_size

articles = articles.join(dummies_topics, on='article_id', how='left')\
    .join(dummies_subcategories, on='article_id', how='left')\
    .drop('topics', 'subcategory')
articles = reduce_polars_df_memory_size(articles)

one_hot_cols = [col for col in articles.columns if col.startswith('topics_') or col.startswith('subcategory_')]
articles = articles.with_columns(
    pl.col(one_hot_cols).fill_null(0)
)

articles.head(2)

Memory usage of dataframe is 40.75 MB
Memory usage after optimization is: 5.75 MB
Decreased by 85.9%


article_id,category,premium,sentiment_label,topics_0,topics_1,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_2,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,topics_28,topics_29,topics_3,topics_30,topics_31,topics_32,topics_33,topics_34,topics_35,topics_36,topics_37,topics_38,…,subcategory_66,subcategory_67,subcategory_68,subcategory_69,subcategory_7,subcategory_70,subcategory_71,subcategory_72,subcategory_73,subcategory_74,subcategory_75,subcategory_76,subcategory_77,subcategory_78,subcategory_79,subcategory_8,subcategory_80,subcategory_81,subcategory_82,subcategory_83,subcategory_84,subcategory_85,subcategory_86,subcategory_87,subcategory_88,subcategory_89,subcategory_9,subcategory_90,subcategory_91,subcategory_92,subcategory_93,subcategory_94,subcategory_95,subcategory_96,subcategory_97,subcategory_98,subcategory_99
i32,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,…,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8
3001353,5,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3003065,7,0,3,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,…,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0


In [6]:
df = pl.concat([
    slice.explode(pl.all().exclude('user_id'))\
        .with_columns(
            pl.col('scroll_percentage_fixed').fill_null(0.),
            pl.col('read_time_fixed').fill_null(0.),
        )\
        .with_columns(
            (pl.col('impression_time_fixed').dt.hour() // 4).alias('hour_group'),
            pl.col('impression_time_fixed').dt.weekday().alias('weekday'),
        )\
        .rename({'scroll_percentage_fixed': 'scroll_percentage', 'read_time_fixed': 'read_time'})
        .join(articles, left_on='article_id_fixed', right_on='article_id', how='left').drop('article_id_fixed')\
        .group_by('user_id').agg(pl.all())
    for slice in history.iter_slices(10000)
])

df.head(3)

user_id,impression_time_fixed,scroll_percentage,read_time,hour_group,weekday,category,premium,sentiment_label,topics_0,topics_1,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_2,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,topics_28,topics_29,topics_3,topics_30,topics_31,topics_32,topics_33,…,subcategory_66,subcategory_67,subcategory_68,subcategory_69,subcategory_7,subcategory_70,subcategory_71,subcategory_72,subcategory_73,subcategory_74,subcategory_75,subcategory_76,subcategory_77,subcategory_78,subcategory_79,subcategory_8,subcategory_80,subcategory_81,subcategory_82,subcategory_83,subcategory_84,subcategory_85,subcategory_86,subcategory_87,subcategory_88,subcategory_89,subcategory_9,subcategory_90,subcategory_91,subcategory_92,subcategory_93,subcategory_94,subcategory_95,subcategory_96,subcategory_97,subcategory_98,subcategory_99
u32,list[datetime[μs]],list[f32],list[f32],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],…,list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8]
247481,"[2023-05-15 12:29:51, 2023-05-15 12:30:27, … 2023-05-17 21:32:51]","[100.0, 96.0, … 100.0]","[36.0, 4.0, … 14.0]","[3, 3, … 5]","[1, 1, … 3]","[4, 4, … 6]","[0, 0, … 0]","[1, 1, … 2]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"
2273188,"[2023-04-27 07:21:03, 2023-04-27 11:46:42, … 2023-05-17 06:08:07]","[28.0, 0.0, … 0.0]","[8.0, 434.0, … 0.0]","[1, 2, … 1]","[4, 4, … 3]","[7, 7, … 4]","[0, 0, … 0]","[2, 2, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"
1121919,"[2023-04-27 09:33:20, 2023-04-27 09:33:46, … 2023-05-17 19:54:02]","[100.0, 100.0, … 27.0]","[11.0, 163.0, … 50.0]","[2, 2, … 4]","[4, 4, … 3]","[4, 4, … 4]","[0, 0, … 0]","[2, 2, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"


In [7]:
from polimi.utils._polars import reduce_polars_df_memory_size
cols = df.columns
topics_cols = sorted([col for col in cols if col.startswith('topics_')], key=lambda x: int(x.split('_')[-1]))
subcategory_cols = sorted([col for col in cols if col.startswith('subcategory_')], key=lambda x: int(x.split('_')[-1]))
all_others = set(cols) - set(topics_cols) - set(subcategory_cols) - {'user_id'}
cols = ['user_id'] + list(all_others) + topics_cols + subcategory_cols
df = df.select(cols)
df.head(1)

user_id,premium,category,hour_group,weekday,read_time,impression_time_fixed,scroll_percentage,sentiment_label,topics_0,topics_1,topics_2,topics_3,topics_4,topics_5,topics_6,topics_7,topics_8,topics_9,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,…,subcategory_137,subcategory_138,subcategory_139,subcategory_140,subcategory_141,subcategory_142,subcategory_143,subcategory_144,subcategory_145,subcategory_146,subcategory_147,subcategory_148,subcategory_149,subcategory_150,subcategory_151,subcategory_152,subcategory_153,subcategory_154,subcategory_155,subcategory_156,subcategory_157,subcategory_158,subcategory_159,subcategory_160,subcategory_161,subcategory_162,subcategory_163,subcategory_164,subcategory_165,subcategory_166,subcategory_167,subcategory_168,subcategory_169,subcategory_170,subcategory_171,subcategory_172,subcategory_173
u32,list[i8],list[i8],list[i8],list[i8],list[f32],list[datetime[μs]],list[f32],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],…,list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8]
247481,"[0, 0, … 0]","[4, 4, … 6]","[3, 3, … 5]","[1, 1, … 3]","[36.0, 4.0, … 14.0]","[2023-05-15 12:29:51, 2023-05-15 12:30:27, … 2023-05-17 21:32:51]","[100.0, 96.0, … 100.0]","[1, 1, … 2]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"


In [22]:
impression_time_idx = df.drop('user_id').columns.index('impression_time_fixed')
for user_df in df[:1].partition_by('user_id'):
        x = user_df.drop('user_id').to_numpy()[0]
        x = np.array([np.array(x_i) for x_i in x])
d = x[impression_time_idx][1]
d

datetime.datetime(2023, 5, 15, 12, 30, 27)

In [32]:
d.month

5

In [35]:
curr_date = x[:, 10][impression_time_idx]
last_step = [i for i in range(10, 20) if x[:, i][impression_time_idx].month == curr_date.month and x[:, i][impression_time_idx].day == curr_date.day]
last_step

[10, 11, 12, 13, 14, 15, 16, 17]

In [52]:
from tqdm import tqdm

def build_sequences(df: pl.DataFrame, w: int, stride: int):
    all_features = df.drop('user_id').columns
    singular_cols = ['topics', 'subcategory', 'category', 'weekday', 'hour_group']
    name_idx_dict = {key: [i for i, col in enumerate(all_features) if col.startswith(key)] for key in singular_cols}
    numerical_cols = ['scroll_percentage', 'read_time', 'premium']
    name_idx_dict['numerical'] = [i for i, col in enumerate(all_features) if col in numerical_cols]
    impression_time_idx = all_features.index('impression_time_fixed')
        
    res = {key: ([], []) for key in name_idx_dict.keys()}

    for user_df in tqdm(df.partition_by('user_id')):
        x = user_df.drop('user_id').to_numpy()[0]
        x = np.array([np.array(x_i) for x_i in x])
                
        i = 0
        if i + w >= x.shape[1]:
            # in case history is shorter than the window then we pad it and select the last element as target
            pad_width = w - x[:, :-1].shape[1]
            pad_m = np.zeros((x.shape[0], pad_width))
            padded_x = np.concatenate((pad_m, x[:, :-1]), axis=1)
            y_i = x[:, -1]
            
            for key, idx in name_idx_dict.items():
                res[key][0].append(padded_x[idx, :].T)
                res[key][1].append(y_i[idx].T)
            
        else:
            while i + w < x.shape[1]:
                # in case history is larger than the window then we select the window and the target randomly between the next elements
                x_i = x[:, i:i+w]
                last_window_date = x_i[impression_time_idx][-1]
                max_telescope = max([t for t in range(i+w+1, x.shape[1]) if x[:, t][impression_time_idx].month == last_window_date.month and 
                                        x[:, t][impression_time_idx].day == last_window_date.day], default=i+w)
                print(i+w, max_telescope)
                target_random_id = np.random.randint(i+w, max_telescope+1)
                y_i = x[:, target_random_id]
                
                for key, idx in name_idx_dict.items():
                    res[key][0].append(x_i[idx, :].T)
                    res[key][1].append(y_i[idx].T)
                
                i+=stride
                         
            #TODO: add padding for the last sequence, if we want to keep it
                

    for key in res.keys():
        res[key] = (np.array(res[key][0]), np.array(res[key][1]))
    
    return res

In [53]:
res = build_sequences(df[:100], w=10, stride=5)

  5%|▌         | 5/100 [00:00<00:03, 31.15it/s]

10 17
15 17
20 51
25 51
30 51
35 51
40 51
45 51
50 51
55 87
60 87
65 87
70 87
75 87
80 87
85 87
10 10
15 17
20 21
25 25
10 14
15 15
20 24
25 25
30 33
35 63
40 63
45 63
50 63
55 63
60 63
65 74
70 74
75 75
80 88
85 88
90 93
95 106
100 106
105 106
110 118
115 118
120 130
125 130
130 130
135 138
140 140
145 149
150 150
155 162
160 162
165 173
170 173
175 186
180 186
185 186
10 10
15 24
20 24
25 25
30 30
35 35
40 49
45 49
50 50
55 55
60 65
65 65
70 78
75 78
80 92
85 92
90 92
95 104
100 104
105 105
110 118
115 118
120 124
125 125
130 138
135 138
140 154
145 154
150 154
155 155
160 162
165 178
170 178
175 178
180 196
185 196
190 196
195 196
200 200
205 212
210 212
215 218
10 47
15 47
20 47
25 47
30 47
35 47
40 47
45 47
50 94
55 94
60 94
65 94
70 94
75 94
80 94
85 94
90 94
95 95
100 158
105 158
110 158
115 158
120 158
125 158
130 158
135 158
140 158
145 158
150 158
155 158
160 196
165 196
170 196
175 196
180 196
185 196
190 196
195 196
200 238
205 238
210 238
215 238
220 238
225 238
230 238
23

  9%|▉         | 9/100 [00:00<00:02, 31.63it/s]

220 266
225 266
230 266
235 266
240 266
245 266
250 266
255 266
260 266
265 266
270 312
275 312
280 312
285 312
290 312
295 312
300 312
305 312
310 312
315 355
320 355
325 355
330 355
335 355
340 355
345 355
350 355
355 355
360 403
365 403
370 403
375 403
380 403
385 403
390 403
395 403
400 403
405 428
410 428
415 428
420 428
425 428
430 453
435 453
440 453
445 453
450 453
455 484
460 484
465 484
470 484
475 484
480 484
485 485
490 512
495 512
500 512
505 512
510 512
515 563
520 563
525 563
530 563
535 563
540 563
545 563
550 563
555 563
560 563
565 599
570 599
575 599
580 599
585 599
590 599
595 599
600 600
605 606
610 616
615 616
620 636
625 636
630 636
635 636
640 719
645 719
650 719
655 719
660 719
665 719
670 719
675 719
680 719
685 719
690 719
695 719
700 719
705 719
710 719
715 719
720 720
725 777
730 777
735 777
740 777
745 777
750 777
755 777
760 777
765 777
770 777
775 777
10 25
15 25
20 25
25 25
30 30
35 49
40 49
45 49
50 50
10 18
15 18
20 27
25 27
30 40
35 40
40 40
45 49
50

 21%|██        | 21/100 [00:00<00:02, 38.32it/s]

365 365
370 429
375 429
380 429
385 429
390 429
395 429
400 429
405 429
410 429
415 429
420 429
425 429
430 430
435 469
440 469
445 469
450 469
455 469
460 469
465 469
470 470
475 539
480 539
485 539
490 539
495 539
500 539
505 539
510 539
515 539
520 539
525 539
530 539
535 539
540 540
545 598
550 598
555 598
560 598
565 598
570 598
575 598
580 598
585 598
590 598
595 598
600 667
605 667
610 667
615 667
620 667
625 667
630 667
635 667
640 667
645 667
650 667
655 667
660 667
665 667
670 706
675 706
680 706
685 706
690 706
695 706
700 706
705 706
710 734
715 734
720 734
725 734
730 734
735 735
740 768
745 768
750 768
755 768
760 768
765 768
770 775
775 775
780 833
785 833
790 833
795 833
800 833
805 833
810 833
815 833
820 833
825 833
830 833
835 856
840 856
845 856
850 856
855 856
860 904
865 904
870 904
875 904
880 904
885 904
890 904
895 904
900 904
905 905
910 946
915 946
920 946
925 946
930 946
935 946
940 946
945 946
10 24
15 24
20 24
25 25
30 49
35 49
40 49
45 49
50 50
55 74
60 7

 26%|██▌       | 26/100 [00:00<00:02, 25.96it/s]

260 285
265 285
270 285
275 285
280 285
285 285
290 317
295 317
300 317
305 317
310 317
315 317
320 328
325 328
330 354
335 354
340 354
345 354
350 354
355 355
360 377
365 377
370 377
375 377
380 418
385 418
390 418
395 418
400 418
405 418
410 418
415 418
420 462
425 462
430 462
435 462
440 462
445 462
450 462
455 462
460 462
465 511
470 511
475 511
480 511
485 511
490 511
495 511
500 511
505 511
510 511
515 571
520 571
525 571
530 571
535 571
540 571
545 571
550 571
555 571
560 571
565 571
570 571
575 622
580 622
585 622
590 622
595 622
600 622
605 622
610 622
615 622
620 622
625 654
630 654
635 654
640 654
645 654
650 654
655 655
660 720
665 720
670 720
675 720
680 720
685 720
690 720
695 720
700 720
705 720
710 720
715 720
720 720
725 780
730 780
735 780
740 780
745 780
750 780
755 780
760 780
765 780
770 780
775 780
780 780
785 818
790 818
795 818
800 818
805 818
810 818
815 818
820 848
825 848
830 848
835 848
840 848
845 848
10 15
15 15
20 24
25 25
30 30
10 12
15 31
20 31
25 31
30

 35%|███▌      | 35/100 [00:01<00:02, 30.94it/s]

150 164
155 164
160 164
165 165
170 177
175 177
180 208
185 208
190 208
195 208
200 208
205 208
210 240
215 240
220 240
225 240
230 240
235 240
240 240
245 254
250 254
255 255
260 275
265 275
270 275
275 275
280 293
285 293
290 293
295 306
300 306
305 306
310 316
315 316
320 326
325 326
330 343
335 343
340 343
345 358
350 358
355 358
360 377
365 377
370 377
375 377
10 15
15 15
20 24
25 25
30 32
35 47
40 47
45 47
50 50
10 10
15 15
20 23
25 27
30 31
10 22
15 22
20 22
25 32
30 32
35 47
40 47
45 47
50 50
55 78
60 78
65 78
70 78
75 78
80 93
85 93
90 93
95 117
100 117
105 117
110 117
115 117
120 124
125 125
130 135
135 135
140 156
145 156
150 156
155 156
160 175
165 175
170 175
175 175
180 192
185 192
190 192
195 210
200 210
205 210
210 210
215 217
220 228
225 228
230 230
235 239
240 240
245 249
10 18
15 18
20 40
25 40
30 40
35 40
40 40
45 55
50 55
55 55
60 72
65 72
70 72
75 89
80 89
85 89
90 90
95 114
100 114
105 114
110 114
115 115
120 137
125 137
130 137
135 137
140 156
145 156
150 156
15

 56%|█████▌    | 56/100 [00:01<00:00, 55.87it/s]

10 10
15 46
20 46
25 46
30 46
35 46
40 46
45 46
50 60
55 60
60 60
65 70
70 70
75 75
80 118
85 118
90 118
95 118
100 118
105 118
110 118
115 118
120 150
125 150
130 150
135 150
140 150
145 150
150 150
155 166
160 166
165 166
170 202
175 202
180 202
185 202
190 202
195 202
200 202
205 207
210 228
215 228
220 228
225 228
230 232
235 250
240 250
245 250
250 250
255 271
260 271
265 271
270 271
275 286
280 286
285 286
290 310
295 310
300 310
305 310
310 310
315 321
320 321
325 338
330 338
335 338
340 344
345 345
350 371
355 371
360 371
365 371
370 371
375 380
380 380
10 17
15 17
20 20
25 26
30 34
35 35
40 41
45 45
50 51
55 60
60 60
65 65
70 77
75 77
80 83
85 98
90 98
95 98
100 100
105 109
110 110
115 123
120 123
125 127
130 130
10 12
10 16
15 16
20 31
25 31
30 31
35 46
40 46
45 46
50 60
55 60
60 60
65 76
70 76
75 76
80 86
85 86
90 99
95 99
100 100
105 124
110 124
115 124
120 124
125 125
130 138
135 138
140 151
145 151
150 151
155 168
160 168
165 168
170 173
175 184
180 184
185 185
190 191
19

 63%|██████▎   | 63/100 [00:01<00:00, 45.21it/s]

125 132
130 132
135 147
140 147
145 147
150 158
155 158
160 160
165 176
170 176
175 176
180 186
185 186
190 206
195 206
200 206
205 206
210 225
215 225
220 225
225 225
230 243
235 243
240 243
245 264
250 264
255 264
260 264
265 265
270 272
275 291
280 291
285 291
290 291
295 324
300 324
305 324
310 324
315 324
320 324
325 325
330 340
335 340
340 340
345 371
350 371
355 371
360 371
365 371
370 371
375 378
10 10
15 15
10 17
15 17
20 47
25 47
30 47
35 47
40 47
45 47
50 91
55 91
60 91
65 91
70 91
75 91
80 91
85 91
90 91
95 105
100 105
105 105
110 175
115 175
120 175
125 175
130 175
135 175
140 175
145 175
150 175
155 175
160 175
165 175
170 175
175 175
180 206
185 206
190 206
195 206
200 206
205 206
210 234
215 234
220 234
225 234
230 234
235 235
240 258
245 258
250 258
255 258
260 301
265 301
270 301
275 301
280 301
285 301
290 301
295 301
300 301
305 312
310 312
315 337
320 337
325 337
330 337
335 337
340 357
345 357
350 357
355 357
360 370
365 370
370 370
375 402
380 402
385 402
390 402

 69%|██████▉   | 69/100 [00:01<00:00, 44.87it/s]

90 95
95 95
100 100
105 109
110 110
115 118
120 120
125 128
130 135
135 135
140 148
145 148
150 153
10 10
15 15
20 25
25 25
30 32
35 35
40 41
45 45
50 51
55 58
60 60
65 65
70 70
75 75
10 18
15 18
20 23
25 33
30 33
35 43
40 43
45 45
50 51
10 14
15 15
20 20
25 25
30 32
35 35
40 40
45 50
50 50
55 55
60 60
65 68
70 71
10 25
15 25
20 25
25 25
30 57
35 57
40 57
45 57
50 57
55 57
60 69
65 69
70 70
75 98
80 98
85 98
90 98
95 98
100 135
105 135
110 135
115 135
120 135
125 135
130 135
135 135
140 157
145 157
150 157
155 157
160 172
165 172
170 172
175 214
180 214
185 214
190 214
195 214
200 214
205 214
210 214
215 215
220 234
225 234
230 234
235 235
240 245
245 245
250 265
255 265
260 265
265 265
270 293
275 293
280 293
285 293
290 293
295 329
300 329
305 329
310 329
315 329
320 329
325 329
330 330
335 371
340 371
345 371
350 371
355 371
360 371
365 371
370 371
375 414
380 414
385 414
390 414
395 414
400 414
405 414
410 414
415 415
420 445
425 445
430 445
435 445
440 445
445 445
450 450
455 459


 88%|████████▊ | 88/100 [00:01<00:00, 64.51it/s]

175 183
180 183
185 214
190 214
195 214
200 214
205 214
210 214
215 215
220 235
225 235
230 235
235 235
240 250
245 250
250 250
255 277
260 277
265 277
270 277
275 277
280 293
285 293
290 293
295 299
300 300
305 309
310 310
315 328
320 328
325 328
330 341
335 341
340 341
345 385
350 385
355 385
360 385
365 385
370 385
375 385
380 385
385 385
390 394
10 28
15 28
20 28
25 28
30 38
35 38
40 59
45 59
50 59
55 59
60 60
65 75
70 75
75 75
80 84
85 85
90 100
95 100
100 100
105 107
110 110
115 126
120 126
125 126
130 133
135 139
140 140
145 154
150 154
155 155
160 167
165 167
170 176
175 176
180 185
185 185
190 198
195 198
200 216
205 216
210 216
215 216
220 233
225 233
230 233
10 10
15 15
20 28
25 28
30 35
35 35
40 48
45 48
50 64
55 64
60 64
65 65
70 75
75 75
80 85
85 85
90 109
95 109
100 109
105 109
110 110
115 121
120 121
125 136
130 136
135 136
140 155
145 155
150 155
155 155
160 171
165 171
170 171
175 175
180 180
185 201
190 201
195 201
200 201
205 206
10 15
15 15
20 21
10 10
15 17
20 26


100%|██████████| 100/100 [00:02<00:00, 47.48it/s]


115 143
120 143
125 143
130 143
135 143
140 143
145 173
150 173
155 173
160 173
165 173
170 173
175 192
180 192
185 192
190 192
195 214
200 214
205 214
210 214
215 215
220 231
225 231
230 231
235 247
240 247
245 247
250 284
255 284
260 284
265 284
270 284
275 284
280 284
285 285
290 300
295 300
300 300
305 326
310 326
315 326
320 326
325 326
330 353
335 353
340 353
345 353
350 353
355 372
360 372
365 372
370 372
375 391
380 391
385 391
390 391
395 407
400 407
405 407
410 421
415 421
420 421
425 435
430 435
435 435
440 441
10 19
15 19
20 20
25 31
30 31
35 35
40 46
45 46
10 22
15 22
20 22
25 31
30 31
35 45
40 45
45 45
50 54
55 55
60 60
65 70
70 70
75 75
80 99
85 99
90 99
95 99
100 100
105 128
110 128
115 128
120 128
125 128
130 143
135 143
140 143
145 149
150 150
155 161
160 161
165 174
170 174
175 175
180 221
185 221
190 221
195 221
200 221
205 221
210 221
215 221
220 221
225 238
230 238
235 238
240 240
245 260
250 260
255 260
260 260
10 11
10 12
15 20
20 20
25 25
30 36
35 36
40 44
45 4

In [99]:
from polimi.utils._polars import reduce_polars_df_memory_size

window = 20
mask = 0
df_trucated = df.with_columns(
    pl.all().exclude('user_id').list.reverse().list.eval(pl.element().extend_constant(mask, window)).list.reverse().list.tail(window).name.keep()
)

df_trucated.head(2)

user_id,sentiment_label,category,premium,read_time,scroll_percentage,hour_group,weekday,topics_0,topics_1,topics_2,topics_3,topics_4,topics_5,topics_6,topics_7,topics_8,topics_9,topics_10,topics_11,topics_12,topics_13,topics_14,topics_15,topics_16,topics_17,topics_18,topics_19,topics_20,topics_21,topics_22,topics_23,topics_24,topics_25,topics_26,topics_27,topics_28,…,subcategory_137,subcategory_138,subcategory_139,subcategory_140,subcategory_141,subcategory_142,subcategory_143,subcategory_144,subcategory_145,subcategory_146,subcategory_147,subcategory_148,subcategory_149,subcategory_150,subcategory_151,subcategory_152,subcategory_153,subcategory_154,subcategory_155,subcategory_156,subcategory_157,subcategory_158,subcategory_159,subcategory_160,subcategory_161,subcategory_162,subcategory_163,subcategory_164,subcategory_165,subcategory_166,subcategory_167,subcategory_168,subcategory_169,subcategory_170,subcategory_171,subcategory_172,subcategory_173
u32,list[i8],list[i8],list[i8],list[f32],list[f32],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],…,list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8],list[i8]
2053634,"[1, 1, … 1]","[25, 4, … 4]","[0, 0, … 0]","[0.0, 116.0, … 0.0]","[0.0, 100.0, … 0.0]","[4, 0, … 5]","[1, 2, … 3]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 1, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"
1899092,"[3, 1, … 3]","[5, 25, … 4]","[0, 0, … 0]","[60.0, 17.0, … 8.0]","[100.0, 24.0, … 60.0]","[4, 4, … 1]","[3, 3, … 4]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[1, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]",…,"[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]"


In [100]:
behaviors_df = pl.read_parquet('/home/ubuntu/dset_complete/subsample/train_ds.parquet')
behaviors_df.head(2)

impression_id,user_id,article,target,device_type,read_time,scroll_percentage,is_sso_user,gender,age,is_subscriber,postcode,trendiness_score_1d,trendiness_score_3d,trendiness_score_5d,trendiness_score_3d_leak,weekday,hour,trendiness_score_1d/3d,trendiness_score_1d/5d,normalized_trendiness_score_overall,premium,category,sentiment_score,sentiment_label,num_images,title_len,subtitle_len,body_len,num_topics,total_pageviews,total_inviews,total_read_time,total_pageviews/inviews,article_type,article_delay_days,article_delay_hours,…,constrastive_emb_icm_l_inf_article,std_article_kenneth_emb_icm,std_article_distilbert_emb_icm,std_article_bert_emb_icm,std_article_roberta_emb_icm,std_article_w_to_vec_emb_icm,std_article_emotions_emb_icm,std_article_constrastive_emb_icm,skew_article_kenneth_emb_icm,skew_article_distilbert_emb_icm,skew_article_bert_emb_icm,skew_article_roberta_emb_icm,skew_article_w_to_vec_emb_icm,skew_article_emotions_emb_icm,skew_article_constrastive_emb_icm,kurtosis_article_kenneth_emb_icm,kurtosis_article_distilbert_emb_icm,kurtosis_article_bert_emb_icm,kurtosis_article_roberta_emb_icm,kurtosis_article_w_to_vec_emb_icm,kurtosis_article_emotions_emb_icm,kurtosis_article_constrastive_emb_icm,entropy_article_kenneth_emb_icm,entropy_article_distilbert_emb_icm,entropy_article_bert_emb_icm,entropy_article_roberta_emb_icm,entropy_article_w_to_vec_emb_icm,entropy_article_emotions_emb_icm,entropy_article_constrastive_emb_icm,kenneth_emb_icm_minus_median_article,distilbert_emb_icm_minus_median_article,bert_emb_icm_minus_median_article,roberta_emb_icm_minus_median_article,w_to_vec_emb_icm_minus_median_article,emotions_emb_icm_minus_median_article,constrastive_emb_icm_minus_median_article,impression_time
u32,u32,i32,i8,i8,f32,f32,bool,i8,i8,bool,i8,i16,i16,i16,i16,i8,i8,f32,f32,f32,bool,i16,f32,str,u32,u8,u8,u16,u32,i32,i32,f32,f32,str,i16,i32,…,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,datetime[μs]
149474,139836,9778728,0,2,13.0,,False,2,,False,5,150,521,836,419,3,7,0.287908,0.179426,0.880068,False,142,0.9654,"""Negative""",1,5,18,251,7,22415,220247,1004828.0,0.101772,"""article_default""",0,0,…,0.017241,0.003313,0.010775,0.020001,6.258241,0.005474,4.929348,0.062682,1.300164,2.060066,1.152196,1.754251,2.088241,1.214502,1.406736,1.339344,5.427761,1.025453,3.529418,5.37705,1.054931,1.963193,,,,,,,,-0.002187,-0.004914,-0.014626,-3.694031,-0.002645,-3.802204,-0.039383,2023-05-24 07:47:53
149474,139836,9778669,0,2,13.0,,False,2,,False,5,85,199,313,266,3,7,0.427136,0.271565,0.336149,False,118,0.9481,"""Negative""",1,5,11,150,4,74491,373488,4365609.0,0.199447,"""article_default""",0,1,…,0.017094,0.003318,0.011131,0.019302,5.737217,0.003193,8.05171,0.06661,1.08501,1.063945,0.947301,1.095085,1.058148,1.020884,1.072975,0.822594,0.947987,0.355356,0.968488,1.051799,0.600099,0.791561,,,,,,,,-0.002933,-0.010537,-0.017286,-4.617104,-0.002901,-7.678095,-0.057551,2023-05-24 07:47:53


In [101]:
def build_sequences_cls_iterator(history_seq: pl.DataFrame, behaviors: pl.DataFrame, window:int):
    mask = 0
    history_seq_trucated = history_seq.with_columns(
        pl.all().exclude('user_id').list.reverse().list.eval(pl.element().extend_constant(mask, window)).list.reverse().list.tail(window).name.keep()
    )
    
    for user_id, user_history in tqdm(history_seq_trucated.partition_by(['user_id'], as_dict=True, maintain_order=False).items()): #order not maintained
        for b in behaviors.filter(pl.col('user_id') == user_id[0]).iter_slices(1):
            yield (b.drop('target'), user_history, b['target'].item())
        

In [103]:
for i, (features, user_history_trunc, target) in enumerate(build_sequences_cls_iterator(df, behaviors_df, window=20)):
    print(target)
    break

  0%|          | 0/15143 [00:00<?, ?it/s]


0
