In [40]:
import numpy as np
import polars as pl
from tqdm import tqdm
from polimi.utils._polars import reduce_polars_df_memory_size

In [41]:
from pathlib import Path


dpath = Path('../../dataset')
emb_dir = dpath
dtype = 'small'
articles = pl.read_parquet(f'{dpath}/ebnerd_{dtype}/articles.parquet')

behaviors_train = pl.read_parquet(f'{dpath}/ebnerd_{dtype}/train/behaviors.parquet')
history_train = pl.read_parquet(f'{dpath}/ebnerd_{dtype}/train/history.parquet')

# Test

In [42]:
embeddings = pl.read_parquet(emb_dir / 'Ekstra_Bladet_image_embeddings' / 'image_embeddings.parquet').sort('article_id')
embeddings.columns = ['article_id', 'embedding']
emb_size = len(embeddings['embedding'][0])
missing_articles_in_embedding = list(set(articles['article_id'].to_numpy()) - set(embeddings['article_id'].to_numpy()))
null_vector = np.zeros(emb_size, dtype=np.float32)
embeddings = embeddings.vstack(pl.DataFrame({'article_id': missing_articles_in_embedding, 'embedding': [null_vector] * len(missing_articles_in_embedding)}))
embeddings = embeddings.with_row_index()
embeddings.head(2)

index,article_id,embedding
u32,i32,list[f32]
0,3000022,"[-0.033208, -0.013787, … -0.036042]"
1,3000063,"[-0.047797, -0.025657, … 0.018883]"


In [43]:
all_zero_embeddings = embeddings.with_columns(pl.col('embedding').list.eval(pl.element() == 0.0).list.all().alias('check'))
are_all_zero_embeddings_present = len(all_zero_embeddings.filter(pl.col('check') == True)) > 0
are_all_zero_embeddings_present

True

In [44]:
m_non_norm = np.array([np.array(x) for x in embeddings['embedding'].to_numpy()])
row_norms = np.linalg.norm(m_non_norm, axis=1, keepdims=True)
m = m_non_norm / (row_norms + 1e-6)
m.shape

(106346, 1024)

In [45]:
article_emb_mapping = embeddings.select('index', 'article_id')
article_emb_mapping.head(1)

index,article_id
u32,i32
0,3000022


In [46]:
history_m = history_train.select('user_id', pl.col('article_id_fixed').list.eval(pl.element().replace(article_emb_mapping['article_id'], article_emb_mapping['index'], default=None).drop_nulls())).with_row_index('user_index')
user_history_map = history_m.select('user_id', 'user_index')
history_m = history_m['article_id_fixed'].to_numpy()
history_m.shape

(15143,)

In [47]:
df = behaviors_train.select('impression_id', 'user_id', pl.col('article_ids_inview').alias('article'))\
    .join(user_history_map, on='user_id')\
    .with_columns(
        pl.col('article').list.eval(pl.element().replace(article_emb_mapping['article_id'], article_emb_mapping['index'], default=None)).name.suffix('_index'),
    ).drop('impression_time_fixed', 'scroll_percentage_fixed', 'read_time_fixed')

df = reduce_polars_df_memory_size(df)
df.head(2)

Memory usage of dataframe is 25.95 MB
Memory usage after optimization is: 25.95 MB
Decreased by 0.0%


impression_id,user_id,article,user_index,article_index
u32,u32,list[i32],u32,list[u32]
149474,139836,"[9778623, 9778682, … 9778728]",11894,"[100868, 100874, … 100879]"
150528,143471,"[9778718, 9778728, … 9778682]",7016,"[105241, 100879, … 100874]"


In [48]:
scores_df = pl.concat([
    slice.explode(['article_index', 'article']).with_columns(scores = np.dot(
        m[slice['article_index'].explode().to_numpy()], 
        m[history_m[key[0]]].T))\
    .group_by(['impression_id', 'user_id', 'user_index'])\
    .agg(pl.all())
    for key, slice in tqdm(df[:1000].partition_by(by=['user_index'], as_dict=True).items(), total=df['user_index'].n_unique())
]).drop('article_index')
scores_df

  3%|▎         | 466/15143 [00:01<00:48, 304.87it/s]


impression_id,user_id,user_index,article,scores
u32,u32,u32,list[i32],list[list[f32]]
149474,139836,11894,"[9778623, 9778682, … 9778728]","[[0.546279, 0.456443, … 0.068054], [0.108021, -0.134935, … 0.308029], … [0.789861, 0.399746, … -0.11546]]"
150528,143471,7016,"[9778718, 9778728, … 9778682]","[[0.0, 0.0, … 0.0], [0.766936, 0.0, … 0.0], … [0.028365, 0.0, … 0.0]]"
153070,151570,7074,"[9020783, 9778444, … 9778628]","[[0.552781, 0.0, … 0.519307], [0.0, 0.0, … 0.0], … [-0.077273, 0.0, … 0.123942]]"
153071,151570,7074,"[9777492, 9774568, … 9775990]","[[-0.130659, 0.0, … -0.122535], [0.438304, 0.0, … 0.340977], … [0.6444, 0.0, … 0.253928]]"
153078,151570,7074,"[9778021, 9778627, … 7213923]","[[0.427206, 0.0, … 0.154384], [0.239121, 0.0, … 0.156372], … [-0.068958, 0.0, … -0.38356]]"
…,…,…,…,…
2433256,1606050,14460,"[9483850, 9779648, … 9779777]","[[0.337382, -0.065383, … 0.307274], [0.0, 0.0, … 0.0], … [0.0, 0.0, … 0.0]]"
2433248,1606050,14460,"[9552181, 9779263, … 9547869]","[[0.0, 0.0, … 0.0], [0.37884, 0.271457, … -0.0275], … [0.62122, 0.070636, … 0.213934]]"
2435848,1692081,10750,"[9779263, 9779205, … 9779577]","[[-0.135781, 0.201273, … 0.452139], [0.197592, 0.566438, … 0.604189], … [0.629099, 0.220339, … 0.12199]]"
2435885,1695195,10254,"[9658252, 9569934, … 9775885]","[[0.59537, 0.322892, … 0.576517], [-0.082606, -0.137399, … -0.313506], … [0.004153, 0.27604, … 0.32076]]"


In [49]:
simple_agg_df = scores_df.with_columns(
    pl.col('scores').list.eval(pl.element().list.mean()).name.suffix('_mean'),
    pl.col('scores').list.eval(pl.element().list.max()).name.suffix('_max'),
    pl.col('scores').list.eval(pl.element().list.max()).name.suffix('_min'),
    pl.col('scores').list.eval(pl.element().list.std()).name.suffix('_std'),
)
simple_agg_df.head(2)

impression_id,user_id,user_index,article,scores,scores_mean,scores_max,scores_min,scores_std
u32,u32,u32,list[i32],list[list[f32]],list[f32],list[f32],list[f32],list[f32]
149474,139836,11894,"[9778623, 9778682, … 9778728]","[[0.546279, 0.456443, … 0.068054], [0.108021, -0.134935, … 0.308029], … [0.789861, 0.399746, … -0.11546]]","[0.108393, 0.125424, … 0.097202]","[0.619155, 0.729108, … 0.789861]","[0.619155, 0.729108, … 0.789861]","[0.241325, 0.239725, … 0.221128]"
150528,143471,7016,"[9778718, 9778728, … 9778682]","[[0.0, 0.0, … 0.0], [0.766936, 0.0, … 0.0], … [0.028365, 0.0, … 0.0]]","[0.0, 0.170874, … 0.119192]","[0.0, 0.880629, … 0.856975]","[0.0, 0.880629, … 0.856975]","[0.0, 0.240824, … 0.22315]"


In [50]:
explode_cols = ['article'] + [col for col in simple_agg_df.columns if col.startswith('scores_')]
res = simple_agg_df.drop('user_index', 'scores')\
    .explode(explode_cols)\
    .sort('user_id', 'impression_id', 'article')
res.head(2)

impression_id,user_id,article,scores_mean,scores_max,scores_min,scores_std
u32,u32,i32,f32,f32,f32,f32
2097252,63123,9761926,0.194606,0.911962,0.911962,0.247309
2097252,63123,9769370,0.0,0.0,0.0,0.0


# Weightening

### Scroll Percentage weight

In [51]:
history_w = history_train.select('user_id', 'scroll_percentage_fixed').with_columns(
    pl.col('scroll_percentage_fixed').list.eval(pl.element().fill_null(0.0))\
        .list.eval(pl.element().sqrt()).alias('scroll_percentage_fixed_norm'),
    pl.col('scroll_percentage_fixed').list.eval(pl.element().fill_null(0.0))\
        .list.eval((pl.element() - pl.element().min()).truediv(pl.element().max() - pl.element().min())).alias('scroll_percentage_fixed_mmnorm')
    ).with_columns(
        pl.col('scroll_percentage_fixed_norm').list.eval(pl.element().truediv(pl.element().sum())).alias('scroll_percentage_fixed_norm_l1_w'),
        pl.col('scroll_percentage_fixed_mmnorm').list.eval(pl.element().truediv(pl.element().sum())).alias('scroll_percentage_fixed_mmnorm_l1_w'),
    )
history_w.head(2)

user_id,scroll_percentage_fixed,scroll_percentage_fixed_norm,scroll_percentage_fixed_mmnorm,scroll_percentage_fixed_norm_l1_w,scroll_percentage_fixed_mmnorm_l1_w
u32,list[f32],list[f32],list[f32],list[f32],list[f32]
13538,"[100.0, 35.0, … 100.0]","[10.0, 5.91608, … 10.0]","[1.0, 0.35, … 1.0]","[0.003138, 0.001856, … 0.003138]","[0.004735, 0.001657, … 0.004735]"
14241,"[100.0, 46.0, … 100.0]","[10.0, 6.78233, … 10.0]","[1.0, 0.46, … 1.0]","[0.007106, 0.00482, … 0.007106]","[0.007959, 0.003661, … 0.007959]"


### Read time weight

In [52]:
history_w_articles = history_train.explode(pl.all().exclude('user_id')).join(
    articles.select('article_id', 
        (pl.col('body') + pl.col('title') + pl.col('subtitle')).str.len_chars().alias('article_id_fixed_article_len'),
        'last_modified_time', 'published_time'), left_on='article_id_fixed', right_on='article_id'
    )\
    .with_columns(
        (pl.col('impression_time_fixed') - pl.col('published_time')).alias('time_to_impression'),
    ).group_by('user_id').agg(pl.all())
history_w_articles.head(2)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed,article_id_fixed_article_len,last_modified_time,published_time,time_to_impression
u32,list[datetime[μs]],list[f32],list[i32],list[f32],list[u32],list[datetime[μs]],list[datetime[μs]],list[duration[μs]]
204117,"[2023-04-27 09:57:17, 2023-04-27 09:58:14, … 2023-05-16 11:20:20]","[100.0, 100.0, … 75.0]","[9738569, 9738557, … 9767649]","[45.0, 36.0, … 101.0]","[1608, 1441, … 3026]","[2023-06-29 06:48:22, 2023-06-29 06:48:22, … 2023-06-29 06:48:51]","[2023-04-27 09:33:16, 2023-04-27 09:07:42, … 2023-05-16 11:04:39]","[24m 1s, 50m 32s, … 15m 41s]"
2013813,"[2023-05-01 16:22:22, 2023-05-01 16:22:59, … 2023-05-18 05:21:20]","[null, null, … 100.0]","[9744745, 9743893, … 9770604]","[4.0, 2.0, … 45.0]","[1846, 1854, … 1536]","[2023-06-29 06:48:29, 2023-06-29 06:48:28, … 2023-06-29 06:48:53]","[2023-05-01 16:05:52, 2023-05-01 14:28:06, … 2023-05-17 17:57:43]","[16m 30s, 1h 54m 53s, … 11h 23m 37s]"


In [53]:
history_w = history_w_articles.select('user_id', 'read_time_fixed', 'article_id_fixed_article_len')\
    .explode(pl.all().exclude('user_id'))\
    .with_columns(
        pl.col('read_time_fixed').truediv('article_id_fixed_article_len').fill_nan(0.0).alias('read_time_fixed_article_len_ratio'),
    ).with_columns(
        pl.when(pl.col('read_time_fixed_article_len_ratio').is_infinite()).then(0.0).otherwise(pl.col('read_time_fixed_article_len_ratio')).alias('read_time_fixed_article_len_ratio')
    ).group_by('user_id').agg(pl.all())\
    .with_columns(
        pl.col('read_time_fixed_article_len_ratio').list.eval(pl.element().truediv(pl.element().sum())).alias('read_time_fixed_article_len_ratio_l1_w'),
    )
history_w.head(2)

user_id,read_time_fixed,article_id_fixed_article_len,read_time_fixed_article_len_ratio,read_time_fixed_article_len_ratio_l1_w
u32,list[f32],list[u32],list[f64],list[f64]
1628323,"[0.0, 1.0, … 12.0]","[1527, 1527, … 2634]","[0.0, 0.000655, … 0.004556]","[0.0, 0.013231, … 0.092041]"
1359712,"[22.0, 10.0, … 31.0]","[1424, 1283, … 190]","[0.015449, 0.007794, … 0.163158]","[0.0029, 0.001463, … 0.030631]"


### Impression time

In [54]:
offset = 1
history_w = history_w_articles.select('user_id', 'time_to_impression')\
    .explode(pl.all().exclude('user_id'))\
    .with_columns(
        pl.lit(1).truediv(pl.col('time_to_impression').dt.total_minutes().sqrt() + offset).alias('time_to_impression_inverse_sqrt'),
        pl.col('time_to_impression').dt.total_minutes().sqrt().alias('time_to_impression_sqrt')
    ).group_by('user_id').agg(pl.all())\
    .with_columns(
        pl.col('time_to_impression_inverse_sqrt').list.eval(pl.element().truediv(pl.element().sum())).alias('time_to_impression_l1_w'),
        pl.col('time_to_impression_sqrt').list.eval(pl.element().truediv(pl.element().sum())).alias('time_to_impression_sqrt_l1_w')
    )
history_w.head(2)

user_id,time_to_impression,time_to_impression_inverse_sqrt,time_to_impression_sqrt,time_to_impression_l1_w,time_to_impression_sqrt_l1_w
u32,list[duration[μs]],list[f64],list[f64],list[f64],list[f64]
602844,"[50m 1s, 1h 17m 4s, … 12m 27s]","[0.123899, 0.102302, … 0.224009]","[7.071068, 8.774964, … 3.464102]","[0.004114, 0.003397, … 0.007438]","[0.000796, 0.000988, … 0.00039]"
1647571,"[19h 15m 24s, 9h 19m 3s, … 466d 9h 36m 18s]","[0.028583, 0.040579, … 0.001219]","[33.985291, 23.643181, … 819.521812]","[0.103393, 0.146785, … 0.004408]","[0.026745, 0.018606, … 0.644937]"


### Last k

In [55]:
history_len = history_w_articles['read_time_fixed'].list.len().to_list()
history_w = history_w_articles.select('user_id').with_columns(
    *[pl.Series([[1] * min(k, l) + [0] * max(0, l - k) for l in history_len], dtype=pl.List(pl.Int8)).alias(f'mask_w_{k}') for k in [5, 10, 15]]
)
history_w.head(2)

user_id,mask_w_5,mask_w_10,mask_w_15
u32,list[i8],list[i8],list[i8]
204117,"[1, 1, … 0]","[1, 1, … 1]","[1, 1, … 1]"
2013813,"[1, 1, … 0]","[1, 1, … 0]","[1, 1, … 0]"


### Last k hours

In [56]:
behaviors_w = behaviors_train.select('impression_id', 'user_id', 'impression_time')\
    .join(history_w_articles.select('user_id', 'impression_time_fixed'), on='user_id')\
    .explode(pl.all().exclude('impression_id', 'user_id', 'impression_time'))\
    .with_columns(
        *[(pl.col('impression_time').sub(pl.col('impression_time_fixed')).dt.total_hours() <= k).cast(pl.Int8).alias(f'impression_time_last_{k}_hours_mask') for k in [24, 24*2, 24*3, 24*7, 24*14]]
    ).group_by('impression_id', 'user_id', 'impression_time').agg(pl.all())
behaviors_w.head(2)

impression_id,user_id,impression_time,impression_time_fixed,impression_time_last_24_hours_mask,impression_time_last_48_hours_mask,impression_time_last_72_hours_mask,impression_time_last_168_hours_mask,impression_time_last_336_hours_mask
u32,u32,datetime[μs],list[datetime[μs]],list[i8],list[i8],list[i8],list[i8],list[i8]
329297347,1299258,2023-05-18 14:49:37,"[2023-04-27 13:52:01, 2023-04-27 13:52:30, … 2023-05-17 20:28:53]","[0, 0, … 1]","[0, 0, … 1]","[0, 0, … 1]","[0, 0, … 1]","[0, 0, … 1]"
12668769,1117389,2023-05-22 08:21:14,"[2023-05-02 12:55:14, 2023-05-02 12:55:33, … 2023-05-17 12:11:37]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 0]","[0, 0, … 1]","[0, 0, … 1]"


In [57]:
history_w = history_w_articles.select('user_id', 'time_to_impression')\
    .explode(pl.all().exclude('user_id'))\
    .with_columns(
        pl.col('time_to_impression').dt.total_minutes().sqrt().alias('time_to_impression_sqrt'),
    ).group_by('user_id').agg(pl.all())\
    .with_columns(
        pl.col('time_to_impression_sqrt').list.eval(pl.element().truediv(pl.element().sum())).alias('time_to_impression_l1_w')
    )
history_w.head(2)

user_id,time_to_impression,time_to_impression_sqrt,time_to_impression_l1_w
u32,list[duration[μs]],list[f64],list[f64]
527239,"[14h 37m 53s, 4h 29m 40s, … 1h 49m 3s]","[29.614186, 16.401219, … 10.440307]","[0.03576, 0.019805, … 0.012607]"
15226,"[1913d 6h 4m 22s, 30m 3s, … 4h 49m 31s]","[1659.844571, 5.477226, … 17.0]","[0.149842, 0.000494, … 0.001535]"


In [58]:
history_w_articles

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed,article_id_fixed_article_len,last_modified_time,published_time,time_to_impression
u32,list[datetime[μs]],list[f32],list[i32],list[f32],list[u32],list[datetime[μs]],list[datetime[μs]],list[duration[μs]]
204117,"[2023-04-27 09:57:17, 2023-04-27 09:58:14, … 2023-05-16 11:20:20]","[100.0, 100.0, … 75.0]","[9738569, 9738557, … 9767649]","[45.0, 36.0, … 101.0]","[1608, 1441, … 3026]","[2023-06-29 06:48:22, 2023-06-29 06:48:22, … 2023-06-29 06:48:51]","[2023-04-27 09:33:16, 2023-04-27 09:07:42, … 2023-05-16 11:04:39]","[24m 1s, 50m 32s, … 15m 41s]"
2013813,"[2023-05-01 16:22:22, 2023-05-01 16:22:59, … 2023-05-18 05:21:20]","[null, null, … 100.0]","[9744745, 9743893, … 9770604]","[4.0, 2.0, … 45.0]","[1846, 1854, … 1536]","[2023-06-29 06:48:29, 2023-06-29 06:48:28, … 2023-06-29 06:48:53]","[2023-05-01 16:05:52, 2023-05-01 14:28:06, … 2023-05-17 17:57:43]","[16m 30s, 1h 54m 53s, … 11h 23m 37s]"
1490412,"[2023-04-27 07:24:15, 2023-04-27 07:37:52, … 2023-05-17 20:36:15]","[79.0, null, … null]","[9738216, 9738334, … 9770390]","[12.0, 31.0, … 0.0]","[959, 1251, … 1924]","[2023-06-29 06:48:22, 2023-06-29 06:48:22, … 2023-06-29 06:48:53]","[2023-04-27 07:12:45, 2023-04-27 07:28:13, … 2023-05-17 19:54:55]","[11m 30s, 9m 39s, … 41m 20s]"
232425,"[2023-04-27 09:31:32, 2023-04-27 09:32:13, … 2023-05-17 11:59:26]","[100.0, null, … null]","[9738528, 9738364, … 9769893]","[31.0, 8.0, … 0.0]","[922, 204, … 219]","[2023-06-29 06:48:22, 2023-06-29 06:48:22, … 2023-06-29 06:48:52]","[2023-04-27 08:59:04, 2023-04-27 09:17:20, … 2023-05-17 10:47:43]","[32m 28s, 14m 53s, … 1h 11m 43s]"
1746250,"[2023-04-27 08:49:48, 2023-04-27 08:50:56, … 2023-05-17 20:54:19]","[100.0, 17.0, … 100.0]","[9738452, 9737521, … 9770483]","[63.0, 5.0, … 135.0]","[1931, 2572, … 1972]","[2023-06-29 06:48:22, 2023-06-29 06:48:21, … 2023-06-29 06:48:53]","[2023-04-27 08:28:48, 2023-04-27 08:06:58, … 2023-05-17 16:55:21]","[21m, 43m 58s, … 3h 58m 58s]"
…,…,…,…,…,…,…,…,…
1567318,"[2023-04-27 08:24:11, 2023-04-27 08:55:09, … 2023-05-18 06:22:45]","[null, 46.0, … 15.0]","[9737501, 9738452, … 9770989]","[0.0, 8.0, … 3.0]","[2649, 1931, … 1056]","[2023-06-29 06:48:21, 2023-06-29 06:48:22, … 2023-06-29 06:48:54]","[2023-04-27 05:23:38, 2023-04-27 08:28:48, … 2023-05-18 05:36:48]","[3h 33s, 26m 21s, … 45m 57s]"
1027265,"[2023-05-05 09:07:02, 2023-05-05 14:02:26, … 2023-05-16 18:12:45]","[100.0, null, … 37.0]","[9750133, 9746105, … 9763559]","[88.0, 0.0, … 3.0]","[1801, 199, … 3463]","[2023-06-29 06:48:34, 2023-06-29 06:48:30, … 2023-06-29 06:48:47]","[2023-05-04 17:04:42, 2023-05-03 07:08:08, … 2023-05-14 09:42:44]","[16h 2m 20s, 2d 6h 54m 18s, … 2d 8h 30m 1s]"
1521009,"[2023-04-27 08:53:44, 2023-04-27 15:55:29, … 2023-05-16 20:50:11]","[30.0, 65.0, … 18.0]","[9737719, 9739065, … 9767765]","[12.0, 42.0, … 2.0]","[1769, 1512, … 2705]","[2023-06-29 06:48:21, 2023-06-29 06:48:22, … 2023-06-29 06:48:51]","[2023-04-26 19:51:35, 2023-04-27 14:25:57, … 2023-05-16 18:30:53]","[13h 2m 9s, 1h 29m 32s, … 2h 19m 18s]"
2211930,"[2023-04-27 12:09:22, 2023-04-27 12:09:38, … 2023-05-17 21:02:22]","[42.0, 90.0, … 35.0]","[9738533, 9738663, … 9770592]","[9.0, 17.0, … 5.0]","[1672, 2091, … 1066]","[2023-06-29 06:48:22, 2023-06-29 06:48:22, … 2023-06-29 06:48:53]","[2023-04-27 11:33:33, 2023-04-27 10:08:17, … 2023-05-17 17:50:17]","[35m 49s, 2h 1m 21s, … 3h 12m 5s]"


# Add all

In [59]:
history_w_articles = history_train.explode(pl.all().exclude('user_id')).join(
    articles.select('article_id', 
        (pl.col('body') + pl.col('title') + pl.col('subtitle')).str.len_chars().alias('article_id_fixed_article_len'),
        'last_modified_time', 'published_time'), left_on='article_id_fixed', right_on='article_id'
    )\
    .with_columns(
        (pl.col('impression_time_fixed') - pl.col('published_time')).alias('time_to_impression'),
    ).group_by('user_id').agg(pl.all())

In [60]:
history_w_articles.head(1)

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed,article_id_fixed_article_len,last_modified_time,published_time,time_to_impression
u32,list[datetime[μs]],list[f32],list[i32],list[f32],list[u32],list[datetime[μs]],list[datetime[μs]],list[duration[μs]]
527742,"[2023-04-27 14:19:16, 2023-04-27 20:14:19, … 2023-05-16 17:42:50]","[48.0, null, … 32.0]","[9738095, 9738095, … 9767722]","[0.0, 2.0, … 4.0]","[1700, 1700, … 2074]","[2023-06-29 06:48:22, 2023-06-29 06:48:22, … 2023-06-29 06:48:51]","[2023-04-27 04:25:55, 2023-04-27 04:25:55, … 2023-05-16 16:43:20]","[9h 53m 21s, 15h 48m 24s, … 59m 30s]"


In [61]:
history_all_w = history_w_articles.select('user_id', 'time_to_impression', 'impression_time_fixed', 'scroll_percentage_fixed', 'read_time_fixed', 'article_id_fixed_article_len')\
    .explode(pl.all().exclude('user_id'))\
    .with_columns(pl.col('scroll_percentage_fixed').fill_null(0.0))\
    .with_columns(
        pl.col('read_time_fixed').truediv('article_id_fixed_article_len').fill_nan(0.0).alias('read_time_fixed_article_len_ratio'),
        # scroll_percentage
        (pl.col('scroll_percentage_fixed') - pl.col('scroll_percentage_fixed').min()).truediv(pl.col('scroll_percentage_fixed').max() - pl.col('scroll_percentage_fixed').min()).over('user_id').alias('scroll_percentage_fixed_mmnorm'),
        # time_to_impression
        pl.col('time_to_impression').dt.total_minutes().sqrt().alias('time_to_impression_minutes_sqrt'),
        pl.lit(1).truediv(pl.col('time_to_impression').dt.total_minutes().sqrt() + 1).alias('time_to_impression_inverse_sqrt'),
    ).with_columns(
        pl.when(pl.col('read_time_fixed_article_len_ratio').is_infinite()).then(0.0).otherwise(pl.col('read_time_fixed_article_len_ratio')).alias('read_time_fixed_article_len_ratio')
    ).group_by('user_id').agg(pl.all())\
    .with_columns(
        pl.col('read_time_fixed_article_len_ratio').list.eval(pl.element().truediv(pl.element().sum())).alias('read_time_fixed_article_len_ratio_l1_w'),
        pl.col('scroll_percentage_fixed_mmnorm').list.eval(pl.element().truediv(pl.element().sum())).alias('scroll_percentage_fixed_mmnorm_l1_w'),
        pl.col('time_to_impression_minutes_sqrt').list.eval(pl.element().truediv(pl.element().sum())).alias('time_to_impression_minutes_sqrt_l1_w'),
        pl.col('time_to_impression_inverse_sqrt').list.eval(pl.element().truediv(pl.element().sum())).alias('time_to_impression_inverse_sqrt_l1_w'),
    )
l1_w_cols = [col for col in history_all_w.columns if col.endswith('_l1_w')]
history_all_w = history_all_w.select('user_id', *l1_w_cols)
history_all_w.head(1)

user_id,read_time_fixed_article_len_ratio_l1_w,scroll_percentage_fixed_mmnorm_l1_w,time_to_impression_minutes_sqrt_l1_w,time_to_impression_inverse_sqrt_l1_w
u32,list[f64],list[f32],list[f64],list[f64]
423165,"[0.069707, 0.037737, … 0.135944]","[0.101727, 0.074856, … 0.191939]","[0.012346, 0.012857, … 0.004516]","[0.112667, 0.108522, … 0.271558]"


# Multiple embeddings

In [62]:
emb_name_list = {'Ekstra_Bladet_contrastive_vector': 'contrastive_vector',
                 'FacebookAI_xlm_roberta_base': 'xlm_roberta_base',
                 'Ekstra_Bladet_image_embeddings': 'image_embeddings',
                 'google_bert_base_multilingual_cased': 'bert_base_multilingual_cased'}

In [97]:
def build_emb_scores(df: pl.DataFrame, history_m: np.ndarray, m_dict:dict[str, np.ndarray], last_k:list[int] = [1, 5]):
    df = reduce_polars_df_memory_size(df)
    print(f'Starting to build embeddings scores for {m_dict.keys()}...')
    df = pl.concat([
        slice.explode(['article_index', 'article']).with_columns(
            *[pl.lit(
                np.dot(
                    m[slice['article_index'].explode().to_numpy()], 
                    m[history_m[key[0]]].T)
                ).alias(f'{emb_name}_scores') for emb_name, m in m_dict.items()]
        ).group_by(['impression_id', 'user_id', 'user_index'])\
        .agg(pl.all())
        for key, slice in tqdm(df.partition_by(by=['user_index'], as_dict=True).items(), total=df['user_index'].n_unique()) # keep only 1000 for testing
    ]).drop('article_index', 'user_index')
    return df

def build_agg_scores(df: pl.DataFrame, agg_cols: list[str] = [], last_k: list[int] = []):
    df = df.with_columns(
        *[pl.col(col).list.eval(pl.element().list.mean()).name.suffix('_mean') for col in agg_cols],
        *[pl.col(col).list.eval(pl.element().list.max()).name.suffix('_max') for col in agg_cols],
        *[pl.col(col).list.eval(pl.element().list.min()).name.suffix('_min') for col in agg_cols],
        *[pl.col(col).list.eval(pl.element().list.std()).name.suffix('_std') for col in agg_cols],
        *[pl.col(col).list.eval(pl.element().list.median()).name.suffix('_median') for col in agg_cols],
    )
    
    if len(last_k) > 0:
        df = df.with_columns(
            *[pl.col(col).list.eval(pl.element().list.tail(k).list.mean()).name.suffix(f'_mean_tail_{k}') for col in agg_cols for k in last_k],
            *[pl.col(col).list.eval(pl.element().list.tail(k).list.max()).name.suffix(f'_max_tail_{k}') for col in agg_cols for k in last_k],
            *[pl.col(col).list.eval(pl.element().list.tail(k).list.min()).name.suffix(f'_min_tail_{k}') for col in agg_cols for k in last_k],
            *[pl.col(col).list.eval(pl.element().list.tail(k).list.std()).name.suffix(f'_std_tail_{k}') for col in agg_cols for k in last_k],
            *[pl.col(col).list.eval(pl.element().list.tail(k).list.median()).name.suffix(f'_median_tail_{k}') for col in agg_cols for k in last_k],
        )
    return df


In [72]:
norm_m_dict = {}
article_emb_mapping = articles.select('article_id').unique().with_row_index()
for dir, file_name in emb_name_list.items():
    print(f'Processing {file_name} embedding matrix...')
    emb_df = pl.read_parquet(emb_dir / dir / f'{file_name}.parquet')
    emb_df.columns = ['article_id', 'embedding']
    
    emb_size = len(emb_df['embedding'][0])
    missing_articles_in_embedding = list(set(articles['article_id'].to_numpy()) - set(emb_df['article_id'].to_numpy()))
    if len(missing_articles_in_embedding) > 0:
        print(f'[Warning... {len(missing_articles_in_embedding)} missing articles in embedding matrix]')
        null_vector = np.zeros(emb_size, dtype=np.float32)
        emb_df = emb_df.vstack(pl.DataFrame({'article_id': missing_articles_in_embedding, 'embedding': [null_vector] * len(missing_articles_in_embedding)}))
        
    emb_df = article_emb_mapping.join(emb_df, on='article_id', how='left')
    m = np.array([np.array(row) for row in emb_df['embedding'].to_numpy()])
    row_norms = np.linalg.norm(m, axis=1, keepdims=True)
    m = m / (row_norms + 1e-6)
    norm_m_dict[file_name] = m

Processing contrastive_vector embedding matrix...
Processing xlm_roberta_base embedding matrix...
Processing image_embeddings embedding matrix...
Processing bert_base_multilingual_cased embedding matrix...


In [78]:
history_m = history_train\
    .select('user_id', pl.col('article_id_fixed').list.eval(
                pl.element().replace(article_emb_mapping['article_id'], article_emb_mapping['index'], default=None)))\
    .with_row_index('user_index')

user_history_map = history_m.select('user_id', 'user_index')
history_m = history_m['article_id_fixed'].to_numpy()
train_ds = behaviors_train[:1000].select('impression_id', 'user_id', pl.col('article_ids_inview').alias('article'))\
    .join(user_history_map, on='user_id')\
    .with_columns(
        pl.col('article').list.eval(pl.element().replace(article_emb_mapping['article_id'], article_emb_mapping['index'], default=None)).name.suffix('_index'),
    ).drop('impression_time_fixed', 'scroll_percentage_fixed', 'read_time_fixed')

train_ds = build_emb_scores(train_ds, history_m, m_dict=norm_m_dict, last_k=[5, 10, 20])
emb_names = [col for col in train_ds.columns if 'scores' in col]
train_ds_agg = build_agg_scores(train_ds, emb_names)
# agg_scores_col = [col for col in train_ds.columns if '_scores_' in col]
# train_ds = train_ds.drop([f'{emb_name}_scores' for emb_name in list(norm_m_dict.keys())]).explode(['article'] + agg_scores_col)
train_ds.head()

Memory usage of dataframe is 0.11 MB
Memory usage after optimization is: 0.09 MB
Decreased by 20.6%
Starting to build embeddings scores for dict_keys(['contrastive_vector', 'xlm_roberta_base', 'image_embeddings', 'bert_base_multilingual_cased'])...


100%|██████████| 466/466 [00:04<00:00, 101.18it/s]


impression_id,user_id,article,contrastive_vector_scores,xlm_roberta_base_scores,image_embeddings_scores,bert_base_multilingual_cased_scores
u32,u32,list[i32],list[list[f32]],list[list[f32]],list[list[f32]],list[list[f32]]
2097255,63123,"[9771916, 9771938, … 9771855]","[[0.557273, 0.317587, … 0.145003], [0.182505, 0.066342, … 0.131874], … [0.116064, 0.052735, … 0.095286]]","[[0.999381, 0.998612, … 0.999212], [0.998721, 0.997995, … 0.998756], … [0.998922, 0.998184, … 0.998931]]","[[0.618652, 0.237785, … 0.0], [0.022949, -0.205564, … 0.0], … [0.395846, 0.083482, … 0.0]]","[[0.986466, 0.981622, … 0.985318], [0.969257, 0.968638, … 0.966765], … [0.977267, 0.980927, … 0.981054]]"
2097252,63123,"[9761926, 9771896, … 9769370]","[[0.107697, 0.240556, … 0.207324], [0.269079, 0.082241, … 0.149119], … [0.188112, 0.083346, … 0.402493]]","[[0.998851, 0.998298, … 0.998933], [0.999403, 0.999, … 0.999468], … [0.999225, 0.998372, … 0.999208]]","[[0.339977, 0.162691, … 0.0], [0.174351, 0.361142, … 0.0], … [0.0, 0.0, … 0.0]]","[[0.981906, 0.986131, … 0.984856], [0.985936, 0.985265, … 0.98395], … [0.985635, 0.984928, … 0.988126]]"
2099252,84383,"[9771187, 9771919, … 9769370]","[[0.335738, 0.274959, … 0.292364], [0.154554, 0.258899, … 0.204154], … [0.122367, 0.401185, … 0.173453]]","[[0.999364, 0.99929, … 0.999399], [0.999168, 0.999351, … 0.999357], … [0.999324, 0.999554, … 0.999526]]","[[0.439715, 0.233114, … 0.336275], [0.442196, -0.204846, … 0.343981], … [0.0, 0.0, … 0.0]]","[[0.951342, 0.942027, … 0.947812], [0.989138, 0.987074, … 0.988399], … [0.988274, 0.987833, … 0.985349]]"
2099253,84383,"[9771916, 9771187, … 9769348]","[[0.235649, 0.25459, … 0.14069], [0.335738, 0.274959, … 0.292364], … [0.133124, 0.523986, … 0.167435]]","[[0.999306, 0.999261, … 0.999192], [0.999364, 0.99929, … 0.999399], … [0.999006, 0.999353, … 0.999236]]","[[0.465422, 0.067244, … 0.059784], [0.439715, 0.233114, … 0.336275], … [0.350681, 0.199099, … 0.166773]]","[[0.98583, 0.978367, … 0.978577], [0.951342, 0.942027, … 0.947812], … [0.981203, 0.978487, … 0.983893]]"
2099250,84383,"[9686860, 9702964, … 9771919]","[[0.182729, 0.550608, … 0.136484], [0.168673, 0.700092, … 0.182869], … [0.154554, 0.258899, … 0.204154]]","[[0.999335, 0.999449, … 0.99938], [0.999257, 0.999564, … 0.999384], … [0.999168, 0.999351, … 0.999357]]","[[0.179768, 0.514603, … 0.060329], [0.092108, -0.099428, … 0.083508], … [0.442196, -0.204846, … 0.343981]]","[[0.988752, 0.983094, … 0.984562], [0.986144, 0.984657, … 0.983357], … [0.989138, 0.987074, … 0.988399]]"


## Apply weight

In [79]:
train_ds_w_base = train_ds.join(
    history_all_w, on='user_id', how='left'
)
train_ds_w_base.head(2)

impression_id,user_id,article,contrastive_vector_scores,xlm_roberta_base_scores,image_embeddings_scores,bert_base_multilingual_cased_scores,read_time_fixed_article_len_ratio_l1_w,scroll_percentage_fixed_mmnorm_l1_w,time_to_impression_minutes_sqrt_l1_w,time_to_impression_inverse_sqrt_l1_w
u32,u32,list[i32],list[list[f32]],list[list[f32]],list[list[f32]],list[list[f32]],list[f64],list[f32],list[f64],list[f64]
2097255,63123,"[9771916, 9771938, … 9771855]","[[0.557273, 0.317587, … 0.145003], [0.182505, 0.066342, … 0.131874], … [0.116064, 0.052735, … 0.095286]]","[[0.999381, 0.998612, … 0.999212], [0.998721, 0.997995, … 0.998756], … [0.998922, 0.998184, … 0.998931]]","[[0.618652, 0.237785, … 0.0], [0.022949, -0.205564, … 0.0], … [0.395846, 0.083482, … 0.0]]","[[0.986466, 0.981622, … 0.985318], [0.969257, 0.968638, … 0.966765], … [0.977267, 0.980927, … 0.981054]]","[0.00012, 0.000011, … 0.0]","[0.000822, 0.00069, … 0.0]","[0.000552, 0.000264, … 0.000329]","[0.001571, 0.002886, … 0.002432]"
2097252,63123,"[9761926, 9771896, … 9769370]","[[0.107697, 0.240556, … 0.207324], [0.269079, 0.082241, … 0.149119], … [0.188112, 0.083346, … 0.402493]]","[[0.998851, 0.998298, … 0.998933], [0.999403, 0.999, … 0.999468], … [0.999225, 0.998372, … 0.999208]]","[[0.339977, 0.162691, … 0.0], [0.174351, 0.361142, … 0.0], … [0.0, 0.0, … 0.0]]","[[0.981906, 0.986131, … 0.984856], [0.985936, 0.985265, … 0.98395], … [0.985635, 0.984928, … 0.988126]]","[0.00012, 0.000011, … 0.0]","[0.000822, 0.00069, … 0.0]","[0.000552, 0.000264, … 0.000329]","[0.001571, 0.002886, … 0.002432]"


In [81]:
l1_w_cols = [col for col in train_ds_w_base.columns if col.endswith('_l1_w')]
scores_cols = [col for col in train_ds_w_base.columns if col.endswith('_scores')]
train_ds_w = pl.concat([
    slice.explode(['article'] + scores_cols).with_columns(
        *[pl.lit(
            np.array([np.array(i) for i in slice[col_score].explode().to_numpy()]) * slice[col_w][0].to_numpy()
        ).alias(f'{col_score}_weighted_{col_w}')
        for col_w in l1_w_cols for col_score in scores_cols]
    ).drop(l1_w_cols).group_by('impression_id', 'user_id').agg(pl.all())
    for key, slice in tqdm(train_ds_w_base.partition_by(by=['user_id'], as_dict=True).items(), total=train_ds_w_base['user_id'].n_unique())    
])
train_ds_w.head(2)


100%|██████████| 466/466 [00:01<00:00, 448.61it/s]


impression_id,user_id,article,contrastive_vector_scores,xlm_roberta_base_scores,image_embeddings_scores,bert_base_multilingual_cased_scores,contrastive_vector_scores_weighted_read_time_fixed_article_len_ratio_l1_w,xlm_roberta_base_scores_weighted_read_time_fixed_article_len_ratio_l1_w,image_embeddings_scores_weighted_read_time_fixed_article_len_ratio_l1_w,bert_base_multilingual_cased_scores_weighted_read_time_fixed_article_len_ratio_l1_w,contrastive_vector_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w,xlm_roberta_base_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w,image_embeddings_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w,bert_base_multilingual_cased_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w,contrastive_vector_scores_weighted_time_to_impression_minutes_sqrt_l1_w,xlm_roberta_base_scores_weighted_time_to_impression_minutes_sqrt_l1_w,image_embeddings_scores_weighted_time_to_impression_minutes_sqrt_l1_w,bert_base_multilingual_cased_scores_weighted_time_to_impression_minutes_sqrt_l1_w,contrastive_vector_scores_weighted_time_to_impression_inverse_sqrt_l1_w,xlm_roberta_base_scores_weighted_time_to_impression_inverse_sqrt_l1_w,image_embeddings_scores_weighted_time_to_impression_inverse_sqrt_l1_w,bert_base_multilingual_cased_scores_weighted_time_to_impression_inverse_sqrt_l1_w
u32,u32,list[i32],list[list[f32]],list[list[f32]],list[list[f32]],list[list[f32]],list[list[f64]],list[list[f64]],list[list[f64]],list[list[f64]],list[list[f32]],list[list[f32]],list[list[f32]],list[list[f32]],list[list[f64]],list[list[f64]],list[list[f64]],list[list[f64]],list[list[f64]],list[list[f64]],list[list[f64]],list[list[f64]]
2097255,63123,"[9771916, 9771938, … 9771855]","[[0.557273, 0.317587, … 0.145003], [0.182505, 0.066342, … 0.131874], … [0.116064, 0.052735, … 0.095286]]","[[0.999381, 0.998612, … 0.999212], [0.998721, 0.997995, … 0.998756], … [0.998922, 0.998184, … 0.998931]]","[[0.618652, 0.237785, … 0.0], [0.022949, -0.205564, … 0.0], … [0.395846, 0.083482, … 0.0]]","[[0.986466, 0.981622, … 0.985318], [0.969257, 0.968638, … 0.966765], … [0.977267, 0.980927, … 0.981054]]","[[0.000067, 0.000003, … 0.0], [0.000022, 7.1257e-7, … 0.0], … [0.000014, 5.6641e-7, … 0.0]]","[[0.00012, 0.000011, … 0.0], [0.00012, 0.000011, … 0.0], … [0.00012, 0.000011, … 0.0]]","[[0.000074, 0.000003, … 0.0], [0.000003, -0.000002, … 0.0], … [0.000048, 8.9666e-7, … 0.0]]","[[0.000118, 0.000011, … 0.0], [0.000116, 0.00001, … 0.0], … [0.000117, 0.000011, … 0.0]]","[[0.000458, 0.000219, … 0.0], [0.00015, 0.000046, … 0.0], … [0.000095, 0.000036, … 0.0]]","[[0.000821, 0.000689, … 0.0], [0.000821, 0.000689, … 0.0], … [0.000821, 0.000689, … 0.0]]","[[0.000508, 0.000164, … 0.0], [0.000019, -0.000142, … 0.0], … [0.000325, 0.000058, … 0.0]]","[[0.000811, 0.000678, … 0.0], [0.000797, 0.000669, … 0.0], … [0.000803, 0.000677, … 0.0]]","[[0.000308, 0.000084, … 0.000048], [0.000101, 0.000018, … 0.000043], … [0.000064, 0.000014, … 0.000031]]","[[0.000552, 0.000264, … 0.000328], [0.000552, 0.000264, … 0.000328], … [0.000552, 0.000264, … 0.000328]]","[[0.000342, 0.000063, … 0.0], [0.000013, -0.000054, … 0.0], … [0.000219, 0.000022, … 0.0]]","[[0.000545, 0.00026, … 0.000324], [0.000535, 0.000256, … 0.000318], … [0.00054, 0.000259, … 0.000323]]","[[0.000876, 0.000917, … 0.000353], [0.000287, 0.000191, … 0.000321], … [0.000182, 0.000152, … 0.000232]]","[[0.00157, 0.002882, … 0.00243], [0.001569, 0.00288, … 0.002429], … [0.00157, 0.002881, … 0.002429]]","[[0.000972, 0.000686, … 0.0], [0.000036, -0.000593, … 0.0], … [0.000622, 0.000241, … 0.0]]","[[0.00155, 0.002833, … 0.002396], [0.001523, 0.002796, … 0.002351], … [0.001536, 0.002831, … 0.002386]]"
2097252,63123,"[9761926, 9771896, … 9769370]","[[0.107697, 0.240556, … 0.207324], [0.269079, 0.082241, … 0.149119], … [0.188112, 0.083346, … 0.402493]]","[[0.998851, 0.998298, … 0.998933], [0.999403, 0.999, … 0.999468], … [0.999225, 0.998372, … 0.999208]]","[[0.339977, 0.162691, … 0.0], [0.174351, 0.361142, … 0.0], … [0.0, 0.0, … 0.0]]","[[0.981906, 0.986131, … 0.984856], [0.985936, 0.985265, … 0.98395], … [0.985635, 0.984928, … 0.988126]]","[[0.000013, 0.000003, … 0.0], [0.000032, 8.8333e-7, … 0.0], … [0.000023, 8.9520e-7, … 0.0]]","[[0.00012, 0.000011, … 0.0], [0.00012, 0.000011, … 0.0], … [0.00012, 0.000011, … 0.0]]","[[0.000041, 0.000002, … 0.0], [0.000021, 0.000004, … 0.0], … [0.0, 0.0, … 0.0]]","[[0.000118, 0.000011, … 0.0], [0.000118, 0.000011, … 0.0], … [0.000118, 0.000011, … 0.0]]","[[0.000089, 0.000166, … 0.0], [0.000221, 0.000057, … 0.0], … [0.000155, 0.000058, … 0.0]]","[[0.000821, 0.000689, … 0.0], [0.000821, 0.00069, … 0.0], … [0.000821, 0.000689, … 0.0]]","[[0.000279, 0.000112, … 0.0], [0.000143, 0.000249, … 0.0], … [0.0, 0.0, … 0.0]]","[[0.000807, 0.000681, … 0.0], [0.00081, 0.00068, … 0.0], … [0.00081, 0.00068, … 0.0]]","[[0.000059, 0.000064, … 0.000068], [0.000149, 0.000022, … 0.000049], … [0.000104, 0.000022, … 0.000132]]","[[0.000552, 0.000264, … 0.000328], [0.000552, 0.000264, … 0.000329], … [0.000552, 0.000264, … 0.000328]]","[[0.000188, 0.000043, … 0.0], [0.000096, 0.000095, … 0.0], … [0.0, 0.0, … 0.0]]","[[0.000542, 0.000261, … 0.000324], [0.000545, 0.000261, … 0.000323], … [0.000544, 0.00026, … 0.000325]]","[[0.000169, 0.000694, … 0.000504], [0.000423, 0.000237, … 0.000363], … [0.000296, 0.000241, … 0.000979]]","[[0.00157, 0.002881, … 0.002429], [0.00157, 0.002883, … 0.00243], … [0.00157, 0.002881, … 0.00243]]","[[0.000534, 0.00047, … 0.0], [0.000274, 0.001042, … 0.0], … [0.0, 0.0, … 0.0]]","[[0.001543, 0.002846, … 0.002395], [0.001549, 0.002844, … 0.002393], … [0.001549, 0.002843, … 0.002403]]"


In [98]:
agg_cols = [col for col in train_ds_w.columns if '_scores' in col]
train_ds_w_aggs = build_agg_scores(train_ds_w, agg_cols=agg_cols, last_k=[5, 10, 15]).drop(agg_cols)
train_ds_w_aggs.head(2)

impression_id,user_id,article,contrastive_vector_scores_mean,xlm_roberta_base_scores_mean,image_embeddings_scores_mean,bert_base_multilingual_cased_scores_mean,contrastive_vector_scores_weighted_read_time_fixed_article_len_ratio_l1_w_mean,xlm_roberta_base_scores_weighted_read_time_fixed_article_len_ratio_l1_w_mean,image_embeddings_scores_weighted_read_time_fixed_article_len_ratio_l1_w_mean,bert_base_multilingual_cased_scores_weighted_read_time_fixed_article_len_ratio_l1_w_mean,contrastive_vector_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_mean,xlm_roberta_base_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_mean,image_embeddings_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_mean,bert_base_multilingual_cased_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_mean,contrastive_vector_scores_weighted_time_to_impression_minutes_sqrt_l1_w_mean,xlm_roberta_base_scores_weighted_time_to_impression_minutes_sqrt_l1_w_mean,image_embeddings_scores_weighted_time_to_impression_minutes_sqrt_l1_w_mean,bert_base_multilingual_cased_scores_weighted_time_to_impression_minutes_sqrt_l1_w_mean,contrastive_vector_scores_weighted_time_to_impression_inverse_sqrt_l1_w_mean,xlm_roberta_base_scores_weighted_time_to_impression_inverse_sqrt_l1_w_mean,image_embeddings_scores_weighted_time_to_impression_inverse_sqrt_l1_w_mean,bert_base_multilingual_cased_scores_weighted_time_to_impression_inverse_sqrt_l1_w_mean,contrastive_vector_scores_max,xlm_roberta_base_scores_max,image_embeddings_scores_max,bert_base_multilingual_cased_scores_max,contrastive_vector_scores_weighted_read_time_fixed_article_len_ratio_l1_w_max,xlm_roberta_base_scores_weighted_read_time_fixed_article_len_ratio_l1_w_max,image_embeddings_scores_weighted_read_time_fixed_article_len_ratio_l1_w_max,bert_base_multilingual_cased_scores_weighted_read_time_fixed_article_len_ratio_l1_w_max,contrastive_vector_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_max,xlm_roberta_base_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_max,image_embeddings_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_max,bert_base_multilingual_cased_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_max,contrastive_vector_scores_weighted_time_to_impression_minutes_sqrt_l1_w_max,xlm_roberta_base_scores_weighted_time_to_impression_minutes_sqrt_l1_w_max,…,bert_base_multilingual_cased_scores_weighted_read_time_fixed_article_len_ratio_l1_w_median_tail_15,contrastive_vector_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_5,contrastive_vector_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_10,contrastive_vector_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_15,xlm_roberta_base_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_5,xlm_roberta_base_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_10,xlm_roberta_base_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_15,image_embeddings_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_5,image_embeddings_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_10,image_embeddings_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_15,bert_base_multilingual_cased_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_5,bert_base_multilingual_cased_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_10,bert_base_multilingual_cased_scores_weighted_scroll_percentage_fixed_mmnorm_l1_w_median_tail_15,contrastive_vector_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_5,contrastive_vector_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_10,contrastive_vector_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_15,xlm_roberta_base_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_5,xlm_roberta_base_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_10,xlm_roberta_base_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_15,image_embeddings_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_5,image_embeddings_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_10,image_embeddings_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_15,bert_base_multilingual_cased_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_5,bert_base_multilingual_cased_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_10,bert_base_multilingual_cased_scores_weighted_time_to_impression_minutes_sqrt_l1_w_median_tail_15,contrastive_vector_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_5,contrastive_vector_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_10,contrastive_vector_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_15,xlm_roberta_base_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_5,xlm_roberta_base_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_10,xlm_roberta_base_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_15,image_embeddings_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_5,image_embeddings_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_10,image_embeddings_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_15,bert_base_multilingual_cased_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_5,bert_base_multilingual_cased_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_10,bert_base_multilingual_cased_scores_weighted_time_to_impression_inverse_sqrt_l1_w_median_tail_15
u32,u32,list[i32],list[f32],list[f32],list[f32],list[f32],list[f64],list[f64],list[f64],list[f64],list[f32],list[f32],list[f32],list[f32],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f32],list[f32],list[f32],list[f32],list[f64],list[f64],list[f64],list[f64],list[f32],list[f32],list[f32],list[f32],list[f64],list[f64],…,list[f64],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64],list[f64]
2097255,63123,"[9771916, 9771938, … 9771855]","[0.200907, 0.154063, … 0.251488]","[0.999051, 0.998986, … 0.999161]","[0.198911, 0.193671, … 0.162336]","[0.961375, 0.951313, … 0.965217]","[0.000126, 0.000107, … 0.000175]","[0.001376, 0.001376, … 0.001376]","[0.000226, 0.00029, … 0.000337]","[0.001262, 0.001249, … 0.001273]","[0.000277, 0.000225, … 0.000368]","[0.001376, 0.001376, … 0.001376]","[0.000282, 0.000275, … 0.00024]","[0.00133, 0.001316, … 0.001335]","[0.000268, 0.000226, … 0.000474]","[0.001376, 0.001376, … 0.001376]","[0.000345, 0.000259, … 0.000233]","[0.001327, 0.001313, … 0.001333]","[0.000279, 0.000202, … 0.000327]","[0.001376, 0.001376, … 0.001376]","[0.000258, 0.000259, … 0.000217]","[0.001327, 0.001313, … 0.001332]","[0.909731, 0.94831, … 0.87319]","[0.999661, 0.999873, … 0.999793]","[0.791619, 0.880052, … 0.867757]","[0.990057, 0.993238, … 0.99174]","[0.008613, 0.00695, … 0.012875]","[0.174697, 0.174734, … 0.174727]","[0.019199, 0.066946, … 0.088744]","[0.162056, 0.163121, … 0.16432]","[0.002857, 0.002695, … 0.00283]","[0.003286, 0.003286, … 0.003287]","[0.002522, 0.002729, … 0.002853]","[0.003238, 0.003249, … 0.003249]","[0.007098, 0.007963, … 0.021301]","[0.059936, 0.059946, … 0.05997]",…,"[0.000125, 0.000124, … 0.000126]","[0.000121, 0.000032, … 0.000101]","[0.000125, 0.000075, … 0.000171]","[0.000116, 0.000032, … 0.000101]","[0.000887, 0.000887, … 0.000887]","[0.000985, 0.000986, … 0.000985]","[0.000755, 0.000755, … 0.000756]","[0.000026, -0.0, … 0.0]","[0.000079, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[0.000861, 0.000877, … 0.00087]","[0.000948, 0.000952, … 0.00096]","[0.000743, 0.000737, … 0.000743]","[0.00023, 0.000054, … 0.000409]","[0.00023, 0.000069, … 0.000315]","[0.000054, 0.000078, … 0.000144]","[0.001933, 0.001935, … 0.001934]","[0.001884, 0.001883, … 0.001884]","[0.000647, 0.000647, … 0.000647]","[0.000044, -0.000004, … 0.0]","[0.000063, -0.000002, … 0.0]","[0.000082, 0.0, … 0.0]","[0.001871, 0.001828, … 0.001898]","[0.001772, 0.001749, … 0.001789]","[0.000618, 0.000613, … 0.000626]","[0.000067, 0.000031, … 0.000232]","[0.000139, 0.000073, … 0.000236]","[0.000181, 0.000082, … 0.000232]","[0.000492, 0.000493, … 0.000493]","[0.000505, 0.000505, … 0.000505]","[0.001364, 0.001364, … 0.001364]","[0.00001, -8.7387e-7, … 0.0]","[0.000052, -4.3694e-7, … 0.0]","[0.000094, 0.0, … 0.0]","[0.000454, 0.000447, … 0.000459]","[0.000495, 0.000489, … 0.000496]","[0.001304, 0.001292, … 0.00132]"
2097252,63123,"[9761926, 9771896, … 9769370]","[0.175282, 0.2935, … 0.229902]","[0.999008, 0.999115, … 0.99925]","[0.194606, 0.14915, … 0.0]","[0.966587, 0.965852, … 0.9677]","[0.00015, 0.000195, … 0.000106]","[0.001376, 0.001376, … 0.001376]","[0.00033, 0.000317, … 0.0]","[0.00127, 0.001265, … 0.001271]","[0.000255, 0.000428, … 0.000342]","[0.001376, 0.001376, … 0.001376]","[0.000278, 0.000226, … 0.0]","[0.001337, 0.001336, … 0.001338]","[0.000312, 0.000492, … 0.000326]","[0.001376, 0.001376, … 0.001376]","[0.000259, 0.00022, … 0.0]","[0.001336, 0.001335, … 0.001336]","[0.000237, 0.000393, … 0.000321]","[0.001376, 0.001376, … 0.001376]","[0.000275, 0.000205, … 0.0]","[0.001334, 0.001333, … 0.001336]","[0.797017, 0.771882, … 0.716651]","[0.999568, 0.999626, … 0.999731]","[0.911962, 0.80103, … 0.0]","[0.99214, 0.993042, … 0.993572]","[0.006662, 0.014161, … 0.005872]","[0.1747, 0.174706, … 0.174748]","[0.090131, 0.11578, … 0.0]","[0.163513, 0.161841, … 0.163951]","[0.002532, 0.002505, … 0.002356]","[0.003286, 0.003286, … 0.003286]","[0.002998, 0.002599, … 0.0]","[0.003258, 0.003264, … 0.003264]","[0.011884, 0.01939, … 0.00858]","[0.059949, 0.059941, … 0.059956]",…,"[0.000125, 0.000125, … 0.000126]","[0.000113, 0.000101, … 0.00015]","[0.000141, 0.000152, … 0.000207]","[0.000113, 0.000101, … 0.000197]","[0.000887, 0.000887, … 0.000887]","[0.000985, 0.000985, … 0.000986]","[0.000755, 0.000755, … 0.000756]","[-0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[0.00087, 0.000866, … 0.00087]","[0.000955, 0.000953, … 0.000959]","[0.000744, 0.000745, … 0.00075]","[0.000269, 0.000454, … 0.00036]","[0.000257, 0.000337, … 0.00036]","[0.000108, 0.000122, … 0.000284]","[0.001933, 0.001933, … 0.001934]","[0.001883, 0.001884, … 0.001884]","[0.000647, 0.000647, … 0.000647]","[-0.000037, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[0.001889, 0.001889, … 0.001875]","[0.001784, 0.001785, … 0.001779]","[0.000621, 0.000621, … 0.000625]","[0.000063, 0.000313, … 0.000084]","[0.000117, 0.000304, … 0.000187]","[0.000133, 0.000256, … 0.00025]","[0.000493, 0.000492, … 0.000493]","[0.000505, 0.000505, … 0.000505]","[0.001364, 0.001364, … 0.001365]","[-0.000009, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[0.0, 0.0, … 0.0]","[0.000457, 0.000455, … 0.000455]","[0.000497, 0.000498, … 0.000497]","[0.001308, 0.001308, … 0.001317]"


In [103]:
train_ds_w.select('user_id', 'impression_id', 'article', 'contrastive_vector_scores').with_columns(
    pl.col('contrastive_vector_scores').list.eval(
        pl.element().list.tail(2).list.mean()
        ).name.suffix('_mean_tail_2'),
).head(2)

user_id,impression_id,article,contrastive_vector_scores,contrastive_vector_scores_mean_tail_2
u32,u32,list[i32],list[list[f32]],list[f32]
63123,2097255,"[9771916, 9771938, … 9771855]","[[0.557273, 0.317587, … 0.145003], [0.182505, 0.066342, … 0.131874], … [0.116064, 0.052735, … 0.095286]]","[0.170951, 0.251488]"
63123,2097252,"[9761926, 9771896, … 9769370]","[[0.107697, 0.240556, … 0.207324], [0.269079, 0.082241, … 0.149119], … [0.188112, 0.083346, … 0.402493]]","[0.251488, 0.229902]"
