see paper: https://paperswithcode.com/paper/neural-news-recommendation-with-long-and

# LSTUR: Neural News Recommendation with Long- and Short-term User Representations

In [1]:
from pathlib import Path
import tensorflow as tf
import polars as pl
import numpy as np
import os

from ebrec.utils._constants import *
from ebrec.utils._behaviors import create_binary_labels_column, sampling_strategy_wu2019, add_known_user_column, add_prediction_scores, truncate_history
from ebrec.utils._articles import convert_text2encoding_with_transformers, create_article_id_to_value_mapping
from ebrec.utils._polars import concat_str_columns, slice_join_dataframes
from ebrec.utils._nlp import get_transformers_word_embeddings

from transformers import AutoTokenizer, AutoModel


"""
load data
"""

data_base = Path(os.getcwd()).parent / "data-merged" / "merged"
# train_val_base = data_base / "1-ebnerd_demo_(20MB)"
train_val_base = data_base / "2-ebnerd_small_(80MB)"
# train_val_base = data_base / "3-ebnerd_large_(3.0GB)"
test_base = data_base / "5-ebnerd_testset_(1.5GB)"
assert train_val_base.exists() and test_base.exists()

train_behaviors = pl.scan_parquet(train_val_base / "train" / "behaviors.parquet")
train_history = pl.scan_parquet(train_val_base / "train" / "history.parquet")

val_behavior = pl.scan_parquet(train_val_base / "validation" / "behaviors.parquet")
val_history = pl.scan_parquet(train_val_base / "validation" / "history.parquet")

test_behavior = pl.scan_parquet(test_base / "test" / "behaviors.parquet")
test_history = pl.scan_parquet(test_base / "test" / "history.parquet")

train_articles: pl.LazyFrame = pl.scan_parquet(train_val_base / "articles.parquet")
val_articles: pl.LazyFrame = train_articles
test_articles: pl.LazyFrame = pl.scan_parquet(test_base / "articles.parquet")

articles_word2vec: pl.LazyFrame = pl.scan_parquet(data_base / "7-Ekstra-Bladet-word2vec_(133MB)" / "document_vector.parquet")
articles_image_embeddings: pl.LazyFrame = pl.scan_parquet(data_base / "8-Ekstra_Bladet_image_embeddings_(372MB)" / "image_embeddings.parquet")
articles_contrastive_vector: pl.LazyFrame = pl.scan_parquet(data_base / "9-Ekstra-Bladet-contrastive_vector_(341MB)" / "contrastive_vector.parquet")
articles_bert_base_multilingual_cased: pl.LazyFrame = pl.scan_parquet(data_base / "10-google-bert-base-multilingual-cased_(344MB)" / "bert_base_multilingual_cased.parquet")
articles_xlm_roberta_base: pl.LazyFrame = pl.scan_parquet(data_base / "11-FacebookAI-xlm-roberta-base_(341MB)" / "xlm_roberta_base.parquet")


"""
preprocessing: truncate user history, select subset of columns, join behavior and history, sample based on Wu2019, add binary labels column
"""


def ebnerd_from_path(history: pl.LazyFrame, behaviors: pl.LazyFrame, history_size: int = 30) -> pl.DataFrame:
    df_history = history.select(DEFAULT_USER_COL, DEFAULT_HISTORY_ARTICLE_ID_COL).pipe(truncate_history, column=DEFAULT_HISTORY_ARTICLE_ID_COL, history_size=history_size, padding_value=0)
    df_behaviors = behaviors.collect().pipe(slice_join_dataframes, df2=df_history.collect(), on=DEFAULT_USER_COL, how="left")
    return df_behaviors


COLUMNS = [
    DEFAULT_USER_COL,
    DEFAULT_HISTORY_ARTICLE_ID_COL,
    DEFAULT_INVIEW_ARTICLES_COL,
    DEFAULT_CLICKED_ARTICLES_COL,
]
HISTORY_SIZE = 30
N_SAMPLES = 100
df_train = (
    ebnerd_from_path(history=train_history, behaviors=train_behaviors, history_size=HISTORY_SIZE)
    .select(COLUMNS)
    .pipe(sampling_strategy_wu2019, npratio=4, shuffle=True, with_replacement=True, seed=123)
    .pipe(create_binary_labels_column)
    .sample(n=N_SAMPLES)
)
df_validation = ebnerd_from_path(history=val_history, behaviors=val_behavior, history_size=HISTORY_SIZE).select(COLUMNS).pipe(create_binary_labels_column).sample(n=N_SAMPLES)
df_test = (
    ebnerd_from_path(history=test_history, behaviors=val_behavior, history_size=HISTORY_SIZE)
    .with_columns(pl.Series(DEFAULT_CLICKED_ARTICLES_COL, [[]]))
    .select(COLUMNS)
    .pipe(create_binary_labels_column)
    .sample(n=N_SAMPLES)
)


"""
use huggingface transformers to convert article text to tokens, tokens to embeddings
"""

df_articles = train_articles.collect()
TRANSFORMER_MODEL_NAME = "bert-base-multilingual-cased"
TEXT_COLUMNS_TO_USE = [DEFAULT_SUBTITLE_COL, DEFAULT_TITLE_COL]
MAX_TITLE_LENGTH = 30

transformer_model = AutoModel.from_pretrained(TRANSFORMER_MODEL_NAME)
transformer_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)

word2vec_embedding = get_transformers_word_embeddings(transformer_model)
df_articles, cat_cal = concat_str_columns(df_articles, columns=TEXT_COLUMNS_TO_USE)
df_articles, token_col_title = convert_text2encoding_with_transformers(df_articles, transformer_tokenizer, cat_cal, max_length=MAX_TITLE_LENGTH)
article_mapping = create_article_id_to_value_mapping(df=df_articles, value_col=token_col_title)



In [2]:
"""
batch data
"""
from ebrec.models.newsrec.dataloader import LSTURDataLoader  # NPA and LSTUR share the same dataloader

user_id_mapping = {user_id: i for i, user_id in enumerate(df_train[DEFAULT_USER_COL].unique())}

train_dataloader = LSTURDataLoader(
    user_id_mapping=user_id_mapping,
    behaviors=df_train,
    article_dict=article_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    eval_mode=False,
    batch_size=64,
)
val_dataloader = LSTURDataLoader(
    user_id_mapping=user_id_mapping,
    behaviors=df_validation,
    article_dict=article_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    eval_mode=True,
    batch_size=32,
)
test_dataloader = LSTURDataLoader(
    user_id_mapping=user_id_mapping,
    behaviors=df_test,
    article_dict=article_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    eval_mode=True,
    batch_size=32,
)

In [3]:
"""
train model
"""
from ebrec.models.newsrec.model_config import hparams_npa
from ebrec.models.newsrec.npa import NPAModel

MODEL_NAME = "NPA"
LOG_DIR = f"./runs/{MODEL_NAME}"
MODEL_WEIGHTS = f"./runs/data/state_dict/{MODEL_NAME}/weights"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR, histogram_freq=1)
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=2)
modelcheckpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=MODEL_WEIGHTS, save_best_only=True, save_weights_only=True, verbose=1)

config = hparams_npa
model = NPAModel(hparams=config, word2vec_embedding=word2vec_embedding, seed=42)
model.model.summary()
hist = model.model.fit(
    train_dataloader,
    validation_data=val_dataloader,
    epochs=1,
    callbacks=[tensorboard_callback, early_stopping_callback, modelcheckpoint_callback],
)
model.model.load_weights(filepath=MODEL_WEIGHTS)


2024-05-11 20:42:41.102492: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2 Pro
2024-05-11 20:42:41.102515: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-05-11 20:42:41.102521: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-05-11 20:42:41.102553: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-05-11 20:42:41.102566: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, None, 30)]           0         []                            
                                                                                                  
 input_4 (InputLayer)        [(None, 1)]                  0         []                            
                                                                                                  
 tf.compat.v1.shape (TFOpLa  (3,)                         0         ['input_2[0][0]']             
 mbda)                                                                                            
                                                                                                  
 reshape_1 (Reshape)         (None, 1, 1)                 0         ['input_4[0][0]']       

2024-05-11 20:42:42.803401: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
2024-05-11 20:42:42.883565: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp_16.
2024-05-11 20:42:43.072837: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14097751946289750256
2024-05-11 20:42:43.072855: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 104762413353123931
2024-05-11 20:42:43.072866: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12793545091132215493
2024-05-11 20:42:43.072871: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15332597895229543239
2024-05-11 20:42:43.072876: I tensorflow/core/framewo

InvalidArgumentError: Graph execution error:

Detected at node model_2/user_encoder/concatenate_3/concat defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>

  File "/Users/sueszli/.local/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/Users/sueszli/.local/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/asyncio/base_events.py", line 608, in run_forever

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/asyncio/events.py", line 84, in _run

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 534, in process_one

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 362, in execute_request

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 778, in execute_request

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 449, in do_execute

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code

  File "/var/folders/g6/xn7nvvtn4ng5pf9bhxfs7p2m0000gn/T/ipykernel_22665/3824227253.py", line 17, in <module>

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1807, in fit

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1401, in train_function

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1384, in step_function

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1373, in run_step

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1150, in train_step

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/functional.py", line 515, in call

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/functional.py", line 672, in _run_internal_graph

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/functional.py", line 515, in call

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/functional.py", line 672, in _run_internal_graph

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/layers/merging/base_merge.py", line 196, in call

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/layers/merging/concatenate.py", line 134, in _merge_function

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/backend.py", line 3580, in concatenate

Detected at node model_2/user_encoder/concatenate_3/concat defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>

  File "/Users/sueszli/.local/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start

  File "/Users/sueszli/.local/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/asyncio/base_events.py", line 608, in run_forever

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/asyncio/events.py", line 84, in _run

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 534, in process_one

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 362, in execute_request

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 778, in execute_request

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 449, in do_execute

  File "/Users/sueszli/.local/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code

  File "/var/folders/g6/xn7nvvtn4ng5pf9bhxfs7p2m0000gn/T/ipykernel_22665/3824227253.py", line 17, in <module>

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1807, in fit

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1401, in train_function

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1384, in step_function

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1373, in run_step

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 1150, in train_step

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/functional.py", line 515, in call

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/functional.py", line 672, in _run_internal_graph

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/training.py", line 590, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/functional.py", line 515, in call

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/functional.py", line 672, in _run_internal_graph

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/layers/merging/base_merge.py", line 196, in call

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/layers/merging/concatenate.py", line 134, in _merge_function

  File "/Users/sueszli/.asdf/installs/python/3.11.9/lib/python3.11/site-packages/keras/src/backend.py", line 3580, in concatenate

2 root error(s) found.
  (0) INVALID_ARGUMENT:  ConcatOp : Dimension 1 in both shapes must be equal: shape[0] = [36,30,30] vs. shape[1] = [36,50,1]
	 [[{{node model_2/user_encoder/concatenate_3/concat}}]]
	 [[Adam/ReadVariableOp_22/_12]]
  (1) INVALID_ARGUMENT:  ConcatOp : Dimension 1 in both shapes must be equal: shape[0] = [36,30,30] vs. shape[1] = [36,50,1]
	 [[{{node model_2/user_encoder/concatenate_3/concat}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_4575]

In [None]:
"""
evaluate performance
"""

pred_validation = model.scorer.predict(val_dataloader)



In [None]:
df_validation = add_prediction_scores(df_validation, pred_validation.tolist()).pipe(
    add_known_user_column, known_users=df_train[DEFAULT_USER_COL]
)
df_validation.head(2)

user_id,article_id_fixed,article_ids_inview,article_ids_clicked,labels,scores,is_known_user
u32,list[i32],list[i32],list[i32],list[i8],list[f64],bool
2094919,"[9773351, 9774187, … 9779860]","[9780702, 9788188, … 9787499]",[9787499],"[0, 0, … 1]","[0.49845, 0.497381, … 0.497841]",False
2029201,"[9778326, 9777941, … 9779738]","[9783405, 9783852, … 8496358]",[9783852],"[0, 1, … 0]","[0.49801, 0.497147, … 0.498185]",False


In [None]:
from ebrec.evaluation import MetricEvaluator, AucScore, NdcgScore, MrrScore

metrics = MetricEvaluator(
    labels=df_validation["labels"].to_list(),
    predictions=df_validation["scores"].to_list(),
    metric_functions=[AucScore(), MrrScore(), NdcgScore(k=5), NdcgScore(k=10)],
)
metrics.evaluate()

<MetricEvaluator class>: 
 {
    "auc": 0.4832519055489523,
    "mrr": 0.30877717899621926,
    "ndcg@5": 0.32656298038911985,
    "ndcg@10": 0.4381692958258077
}

In [None]:
# pred_test = model.scorer.predict(test_dataloader) <--- breaks because of size mismatch
pred_test = model.scorer.predict(val_dataloader)

# store
submission_path = Path(os.getcwd()).parent / "submissions" / "ebnerd_lstur.txt"
with open(submission_path, "w") as f:
    for idx, row in enumerate(pred_test):
        f.write(f"{idx} {row}\n")
    f.close()

