# Packages 

In [1]:
%load_ext autoreload
%autoreload 2
import sys
import logging
sys.path.append('../')
import os
import warnings
warnings.simplefilter('ignore')

import pickle
import gc
import re
import polars as pl
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
from tqdm.auto import tqdm
import polars as pl
from utils import *
from src.eval import get_recall_at_k, pd_get_recall_at_k

# Config 

In [2]:
debug = False


if debug:
    read_nrows = 100000
else:
    read_nrows = None

# Data 

In [3]:
w2v_eval_pl = pl.scan_parquet(f'../data/eval_data/w2v_train_eval_result_300k.parquet', n_rows=read_nrows)
nic_eval_pl = pl.scan_parquet(
    '../data/eval_data/next_item_counter_train_eval_300k.parquet',
                              n_rows=read_nrows)

In [4]:
# ! cat '../data/eval_data/next_item_counter_train_eval_result.parquet' | wc -l

In [5]:
w2v_eval_pl.schema

{'prev_items': Utf8,
 'next_item': Utf8,
 'locale': Utf8,
 'next_item_prediction': List(Utf8),
 'len': Int64,
 'recall@20': Boolean,
 'recall@100': Boolean,
 '__index_level_0__': Int64}

In [6]:
nic_eval_pl.schema

{'prev_items': Utf8,
 'next_item': Utf8,
 'locale': Utf8,
 'next_item_prediction': List(Utf8),
 'len': Int64,
 'recall@20': Boolean,
 'recall@100': Boolean,
 'last_item': Utf8,
 '__index_level_0__': Int64}

In [7]:
w2v_eval_pl.collect().shape

(300000, 8)

In [8]:
nic_eval_pl.collect().shape

(300000, 9)

In [9]:
nic_eval_pl = nic_eval_pl.select(
    pl.col('next_item_prediction').alias('nic_next_item_prediction')
    # , 'prev_items'
    # , 'locale'
)

w2v_eval_pl = w2v_eval_pl.select(
    'prev_items'
    , 'locale'
    , 'next_item_prediction'
    , 'next_item'
)


In [10]:
# w2v_eval_pl.collect()

In [11]:
# joined_pl = w2v_eval_pl.join(nic_eval_pl
#                                                           , how='left'
#                                                           , on=['prev_items', 'locale']
#                                                                       )#.collect()

In [12]:
joined_pl = pl.concat([w2v_eval_pl.collect(), nic_eval_pl.collect()], how='horizontal',)
# .collect().shape

In [13]:
joined_pl.schema

{'prev_items': Utf8,
 'locale': Utf8,
 'next_item_prediction': List(Utf8),
 'next_item': Utf8,
 'nic_next_item_prediction': List(Utf8)}

In [14]:
# joined_pl = joined_pl.collect()

In [15]:
assert joined_pl.shape[0] == w2v_eval_pl.collect().shape[0]

In [16]:
# joined_pl.sample(20)

# Eval 

In [17]:
joined_pl = (
    joined_pl.lazy().with_columns(
        pl.concat_list([pl.col('next_item_prediction'), pl.col('nic_next_item_prediction')]).alias('combined_prediction')
    )
)

In [18]:
joined_pl.schema

{'prev_items': Utf8,
 'locale': Utf8,
 'next_item_prediction': List(Utf8),
 'next_item': Utf8,
 'nic_next_item_prediction': List(Utf8),
 'combined_prediction': List(Utf8)}

## Recall@200 

In [19]:
joined_pl.head(3).collect()

prev_items,locale,next_item_prediction,next_item,nic_next_item_prediction,combined_prediction
str,str,list[str],str,list[str],list[str]
"""['B09CTD61C6' …","""DE""","[""B08HS4DWHK"", ""B09CTD61C6"", … ""B09SM5PMMQ""]","""B08HS5CWD4""","[""B08HS5CWD4"", ""B075JM8R1Q"", … ""B009SK9CTS""]","[""B08HS4DWHK"", ""B09CTD61C6"", … ""B009SK9CTS""]"
"""['B0BHZLZ5V4' …","""UK""","[""B08KFF1XG9"", ""B09YVFM9CB"", … ""B09YRLT3VW""]","""B089DJM2BG""","[""B089DJM2BG"", ""B01MU5YPGI"", … ""B09TDTPT4M""]","[""B08KFF1XG9"", ""B09YVFM9CB"", … ""B09TDTPT4M""]"
"""['B08CHLHTFB' …","""UK""","[""B0748H6Q7L"", ""B0748HGSG8"", … ""B0765WZV1Z""]","""B0B7RC68CT""","[""B0748HGSG8"", ""B0748J3RQ5"", … ""B0BFXGM3GY""]","[""B0748H6Q7L"", ""B0748HGSG8"", … ""B0BFXGM3GY""]"


In [20]:
joined_pl.select(
    pl.col('combined_prediction').apply(len)
).collect().describe()

describe,combined_prediction
str,f64
"""count""",300000.0
"""null_count""",0.0
"""mean""",200.0
"""std""",0.0
"""min""",200.0
"""max""",200.0
"""median""",200.0


In [29]:
joined_pl.select(
    pl.col('combined_prediction').arr.contains(pl.col('next_item')).mean().alias('recall@200')
).collect()
# .schema

recall@200
f64
0.46833


In [34]:
joined_pl.select(
    # pl.col('next_item_prediction').arr.contains(pl.col('next_item')).mean().alias('recall@20'),
    pl.col('next_item_prediction').arr.contains(pl.col('next_item')).mean().alias('recall@200')
).collect()

recall@200
f64
0.291


In [35]:
joined_pl.select(
    pl.col('nic_next_item_prediction').arr.contains(pl.col('next_item')).mean().alias('recall@200')
).collect()

recall@200
f64
0.371717


In [37]:
# joined_pl.select(
#     # pl.col('next_item_prediction').
#     # .contains(pl.col('next_item')).mean().alias('recall@200')
# ).head(2).collect()