In [39]:
import torch
from src.model.dataloader import PPRecDataLoader
from src.model.model_config import hparams_pprec
from model.modules import PPRec
from transformers import AutoTokenizer, AutoModel
from ebrec.evaluation import MetricEvaluator, AucScore, NdcgScore, MrrScore
import polars as pl
from ebrec.utils._constants import (
    DEFAULT_HISTORY_ARTICLE_ID_COL,
    DEFAULT_CLICKED_ARTICLES_COL,
    DEFAULT_INVIEW_ARTICLES_COL,
    DEFAULT_IMPRESSION_ID_COL,
    DEFAULT_SUBTITLE_COL,
    DEFAULT_LABELS_COL,
    DEFAULT_TITLE_COL,
    DEFAULT_USER_COL,
    DEFAULT_ARTICLE_MODIFIED_TIMESTAMP_COL,
    DEFAULT_IMPRESSION_TIMESTAMP_COL,
    DEFAULT_HISTORY_IMPRESSION_TIMESTAMP_COL
)
from ebrec.utils._nlp import get_transformers_word_embeddings
import numpy as np

In [8]:
TRANSFORMER_MODEL_NAME = "FacebookAI/xlm-roberta-base"

transformer_model = AutoModel.from_pretrained(TRANSFORMER_MODEL_NAME)
transformer_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)

word2vec_embedding = get_transformers_word_embeddings(transformer_model)

saved_model = PPRec(hparams_pprec,word2vec_embedding)

In [9]:
PATH = 'model_20240618_011912_0'
saved_model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [13]:
import pickle
article_mapping_title, article_mapping_entity, articles_ctr, popularity_mapping = {},{},{},{}
with open('article_mapping_title.pkl', 'rb') as handle:
    article_mapping_title = pickle.load(handle)
with open('article_mapping_entity.pkl', 'rb') as handle:
    article_mapping_entity = pickle.load(handle)
with open('articles_ctr.pkl', 'rb') as handle:
    articles_ctr = pickle.load(handle)
with open('popularity_mapping.pkl', 'rb') as handle:
    popularity_mapping = pickle.load(handle)

COLUMNS = [
   'user_id',
   'article_id_fixed',
   'article_ids_inview',
   'article_ids_clicked',
   'impression_id',
   'labels',
   'recency_inview',
   'recency_hist'  
]


df_validation =  pl.scan_parquet("small_demo_val_all_features_with_sampling.parquet").select(COLUMNS).collect()

val_dataloader = PPRecDataLoader(
    behaviors=df_validation,
    article_dict=article_mapping_title,
    entity_mapping=article_mapping_entity,
    ctr_mapping=articles_ctr,
    popularity_mapping = popularity_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    history_recency = 'recency_hist',
    inview_recency = 'recency_inview',
    eval_mode=True,
    batch_size=4,
)

In [33]:
saved_model.eval()

predictions = np.empty(shape=(4,5))
with torch.no_grad():
    for i, vdata in enumerate(val_dataloader):
            vinputs, vlabels = vdata
            vtitle = vinputs[5]
            ventities = vinputs[6]
            vctr = vinputs[7]
            vrecency = vinputs[8]
            vhist_title = vinputs[0]
            vhist_popularity = vinputs[2]

            vtitle = torch.from_numpy(vtitle)
            ventities = torch.from_numpy(ventities)
            vctr = torch.from_numpy(vctr)
            vrecency = torch.from_numpy(vrecency)
            vhist_title = torch.from_numpy(vhist_title)
            vhist_popularity = torch.from_numpy(vhist_popularity)
            vlabels = torch.from_numpy(vlabels)
        
            # vtitle = vtitle.to(device)
            # ventities = ventities.to(device)
            # vctr = vctr.to(device)
            # vrecency = vrecency.to(device)
            # vhist_title = vhist_title.to(device)
            # vhist_popularity = vhist_popularity.to(device)
            # vlabels = vlabels.to(device)


            outputs = saved_model(vtitle, ventities, vctr, vrecency , vhist_title, vhist_popularity).cpu().detach().numpy()
            predictions = np.concatenate([predictions,outputs],axis=0)
            
           
            

Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here
Reached here

In [34]:
predictions = predictions[4:]
predictions.shape

(25, 5)

In [35]:
df_validation.shape

(25, 8)

In [36]:
df_validation = df_validation.with_columns(pl.Series(name="predicted_scores", values=predictions)) 

In [38]:
df_validation.head()

user_id,article_id_fixed,article_ids_inview,article_ids_clicked,impression_id,labels,recency_inview,recency_hist,predicted_scores
u32,list[i32],list[i64],list[i64],u32,list[i8],list[i64],list[i64],list[f64]
2052088,"[9778952, 9777636, … 9779577]","[9777705, 9786209, … 9786230]",[9786172],69798187,"[0, 0, … 0]","[3, 0, … 1]","[1, 0, … 1]","[0.402318, 0.095339, … 0.352166]"
1622906,"[9779285, 9779181, … 9779748]","[9789745, 9789702, … 9789676]",[9789702],571539786,"[0, 1, … 0]","[0, 1, … 1]","[0, 1, … 1]","[0.398743, 0.097287, … 0.351678]"
667805,"[9779427, 9780195, … 9780195]","[9783213, 9783213, … 9726237]",[9782092],396061188,"[0, 0, … 0]","[1, 1, … 0]","[8, 0, … 0]","[0.397379, 0.092174, … 0.350979]"
1887792,"[9778369, 9778381, … 9778971]","[9782879, 9780968, … 9782695]",[9782695],369793739,"[0, 0, … 1]","[3, 3, … 4]","[4, 10, … 7]","[0.407408, 0.088169, … 0.358567]"
1216284,"[9778769, 9778745, … 9778827]","[7213923, 9052240, … 9780651]",[9780651],142351970,"[0, 0, … 1]","[42295, 12288, … 0]","[0, 0, … 1]","[0.40595, 0.09089, … 0.347896]"


In [41]:
metrics = MetricEvaluator(
    labels=df_validation["labels"].to_list(),
    predictions=df_validation["predicted_scores"].to_list(),
    metric_functions=[AucScore(), MrrScore(), NdcgScore(k=5), NdcgScore(k=10)],
)
metrics.evaluate()

<MetricEvaluator class>: 
 {
    "auc": 0.57,
    "mrr": 0.5053333333333333,
    "ndcg@5": 0.6274650294492777,
    "ndcg@10": 0.6274650294492777
}