In [282]:
import polars as pl
from pathlib import Path
import numpy as np
import datetime

In [263]:
dpath = Path('../dataset')

dtype = 'small'
articles = pl.read_parquet(f'{dpath}/ebnerd_{dtype}/articles.parquet')

behaviors_train = pl.read_parquet(f'{dpath}/ebnerd_{dtype}/train/behaviors.parquet')
history_train = pl.read_parquet(f'{dpath}/ebnerd_{dtype}/train/history.parquet')

behaviors_val = pl.read_parquet(f'{dpath}/ebnerd_{dtype}/validation/behaviors.parquet')
history_val = pl.read_parquet(f'{dpath}/ebnerd_{dtype}/validation/history.parquet')

In [264]:
print('History train: ', history_train['impression_time_fixed'].explode().min(), history_train['impression_time_fixed'].explode().max(), (history_train['impression_time_fixed'].explode().max() - history_train['impression_time_fixed'].explode().min()))
print('Behaviors train: ', behaviors_train['impression_time'].explode().min(), behaviors_train['impression_time'].explode().max(), behaviors_train['impression_time'].explode().max() - behaviors_train['impression_time'].explode().min())

History train:  2023-04-27 07:00:00 2023-05-18 06:59:59 20 days, 23:59:59
Behaviors train:  2023-05-18 07:00:01 2023-05-25 06:59:58 6 days, 23:59:57


In [265]:
print('History val: ', history_val['impression_time_fixed'].explode().min(), history_val['impression_time_fixed'].explode().max(), (history_val['impression_time_fixed'].explode().max() - history_val['impression_time_fixed'].explode().min()))
print('Behaviors val: ', behaviors_val['impression_time'].explode().min(), behaviors_val['impression_time'].explode().max(), behaviors_val['impression_time'].explode().max() - behaviors_val['impression_time'].explode().min())

History val:  2023-05-04 07:00:00 2023-05-25 06:59:59 20 days, 23:59:59
Behaviors val:  2023-05-25 07:00:02 2023-06-01 06:59:59 6 days, 23:59:57


In [266]:
behaviors_val['impression_time'].explode().max() - behaviors_train['impression_time'].explode().min()

datetime.timedelta(days=13, seconds=86398)

In [320]:
def behaviors_to_history(behaviors: pl.DataFrame) -> pl.DataFrame:
        return behaviors.sort('impression_time').select('user_id', 'impression_time', 'article_ids_clicked', 'next_scroll_percentage', 'next_read_time')\
                .rename({'impression_time': 'impression_time_fixed', 
                        'article_ids_clicked': 'article_id_fixed', 
                        'next_read_time': 'read_time_fixed', 
                        'next_scroll_percentage': 'scroll_percentage_fixed'})\
                .explode('article_id_fixed').group_by('user_id').agg(pl.all())
        
behaviors_to_history(behaviors_train).head(2)

user_id,impression_time_fixed,article_id_fixed,scroll_percentage_fixed,read_time_fixed
u32,list[datetime[μs]],list[i32],list[f32],list[f32]
1629856,"[2023-05-19 03:52:13, 2023-05-19 03:52:23, … 2023-05-20 19:18:05]","[9771187, 9771896, … 9773987]","[62.0, 60.0, … null]","[3.0, 7.0, … 5.0]"
518721,"[2023-05-19 04:54:36, 2023-05-22 04:30:17]","[9771938, 9775485]","[30.0, 32.0]","[12.0, 14.0]"


In [319]:
behaviors_train

impression_id,article_id,impression_time,read_time,scroll_percentage,device_type,article_ids_inview,article_ids_clicked,user_id,is_sso_user,gender,postcode,age,is_subscriber,session_id,next_read_time,next_scroll_percentage
u32,i32,datetime[μs],f32,f32,i8,list[i32],list[i32],u32,bool,i8,i8,i8,bool,u32,f32,f32
149474,,2023-05-24 07:47:53,13.0,,2,"[9778623, 9778682, … 9778728]",[9778657],139836,false,,,,false,759,7.0,22.0
150528,,2023-05-24 07:33:25,25.0,,2,"[9778718, 9778728, … 9778682]",[9778623],143471,false,,,,false,1240,287.0,100.0
153068,9778682,2023-05-24 07:09:04,78.0,100.0,1,"[9778657, 9778669, … 9778682]",[9778669],151570,false,,,,false,1976,45.0,100.0
153070,9777492,2023-05-24 07:13:14,26.0,100.0,1,"[9020783, 9778444, … 9778628]",[9778628],151570,false,,,,false,1976,4.0,18.0
153071,9778623,2023-05-24 07:11:08,125.0,100.0,1,"[9777492, 9774568, … 9775990]",[9777492],151570,false,,,,false,1976,26.0,100.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
580099643,9769306,2023-05-18 10:01:05,121.0,100.0,3,"[9233208, 9771242, … 9521144]",[9770886],2106715,false,,,,false,1416293,121.0,
580099644,9770882,2023-05-18 10:05:07,176.0,100.0,3,"[9771065, 9767697, … 9769762]",[9769306],2106715,false,,,,false,1416293,148.0,100.0
580099645,9769306,2023-05-18 10:11:03,24.0,100.0,3,"[9771042, 9440508, … 9767697]",[9771042],2106715,false,,,,false,1416293,4.0,
580100695,9771242,2023-05-18 10:00:08,5.0,100.0,1,"[9440508, 9142581, … 8422665]",[9767697],2110744,false,,,,false,747086,75.0,100.0


## Moving window

In [313]:
def moving_window_split_iterator(history: pl.DataFrame, behaviors: pl.DataFrame, window:int=4, window_val:int=2, stride:int=2):
    assert behaviors['impression_time'].is_sorted()
    
    behaviors_dates_mapping = behaviors.select(pl.col('impression_time').dt.date().alias('impression_date')).sort('impression_date').unique().with_row_index()
    dates = behaviors_dates_mapping['impression_date'].to_list()
    idx = behaviors_dates_mapping['index'].to_list()
    
    history_cols = [col for col in history.columns if col != 'user_id']
    history_window_train_start_date = history['impression_time_fixed'].explode().min().date()
    history_window_val_start_date = history_window_train_start_date + datetime.timedelta(days=window)
    
    i = 0
    while i + window + window_val - 1 <= len(idx):
        start_window_train_date = dates[i]
        end_window_train_date = dates[i + window - 1]
        end_window_val_date = dates[i + window + window_val - 2]
        print(idx[i: i + window], idx[i + window - 1: i + window + window_val - 1], '-> ', start_window_train_date, end_window_train_date, end_window_val_date)
        
        behaviors_k_train = behaviors.filter(
            pl.col('impression_time') >= datetime.datetime.combine(start_window_train_date, datetime.time(7, 0, 0)),
            pl.col('impression_time') < datetime.datetime.combine(end_window_train_date, datetime.time(7, 0, 0)),
        )
        
        history_k_train = history.explode(pl.all().exclude('user_id')).filter(
            pl.col('impression_time_fixed') >= datetime.datetime.combine(history_window_train_start_date, datetime.time(7, 0, 0))
        ).group_by('user_id').agg(pl.all())

        behaviors_k_val = behaviors.filter(
            pl.col('impression_time') >= datetime.datetime.combine(end_window_train_date, datetime.time(7, 0, 0)),
            pl.col('impression_time') < datetime.datetime.combine(end_window_val_date, datetime.time(7, 0, 0)),
        )

        behaviors_k_prev = behaviors.filter(
            pl.col('impression_time') < datetime.datetime.combine(end_window_train_date, datetime.time(7, 0, 0)),
        )
        
        history_k_val = history.explode(pl.all().exclude('user_id')).filter(
            pl.col('impression_time_fixed') >= datetime.datetime.combine(history_window_val_start_date, datetime.time(7, 0, 0))
        ).group_by('user_id').agg(pl.all()).join(
            behaviors_to_history(behaviors_k_prev) , on='user_id', suffix='_next', how='outer_coalesce'
        )\
        .with_columns(pl.all().exclude('user_id').fill_null(pl.lit([])))\
        .with_columns(
            *[pl.col(col).list.concat(f'{col}_next').alias(col) for col in history_cols],
        ).drop([f'{col}_next' for col in history_cols])
        
        i += stride
        history_window_train_start_date += datetime.timedelta(days=stride)
        history_window_val_start_date += datetime.timedelta(days=stride)
        
        yield history_k_train, behaviors_k_train, history_k_val, behaviors_k_val

In [327]:
history = history_train
behaviors = behaviors_train.vstack(behaviors_val).sort('impression_time').set_sorted('impression_time')
for i, (history_k_train, behaviors_k_train, history_k_val, behaviors_k_val) in enumerate(moving_window_split_iterator(history, behaviors, window=4, window_val=2, stride=2)):
    print(f'Fold {i}:  ',history_k_train.shape, behaviors_k_train.shape, history_k_val.shape, behaviors_k_val.shape)    

[0, 1, 2, 3] [3, 4] ->  2023-05-18 2023-05-21 2023-05-22
Fold 0:   (15143, 5) (101487, 17) (15129, 5) (32515, 17)
[2, 3, 4, 5] [5, 6] ->  2023-05-20 2023-05-23 2023-05-24
Fold 1:   (15134, 5) (97274, 17) (15124, 5) (31806, 17)
[4, 5, 6, 7] [7, 8] ->  2023-05-22 2023-05-25 2023-05-26
Fold 2:   (15108, 5) (98885, 17) (15143, 5) (32556, 17)
[6, 7, 8, 9] [9, 10] ->  2023-05-24 2023-05-27 2023-05-28
Fold 3:   (15085, 5) (98476, 17) (17413, 5) (31235, 17)
[8, 9, 10, 11] [11, 12] ->  2023-05-26 2023-05-29 2023-05-30
Fold 4:   (15040, 5) (99640, 17) (18050, 5) (36901, 17)
[10, 11, 12, 13] [13, 14] ->  2023-05-28 2023-05-31 2023-06-01
Fold 5:   (14984, 5) (110237, 17) (18603, 5) (37258, 17)
