In [1]:
# %pip install -U lightgbm

In [1]:
import pandas as pd
import numpy as np

import pickle

from tqdm import tqdm
from pathlib import Path
import gc

In [2]:
import warnings
import sys
from IPython.core.interactiveshell import InteractiveShell

warnings.filterwarnings("ignore")
sys.path.append("../src/")
InteractiveShell.ast_node_interactivity = "all"

In [3]:
!pwd

/home/tarique/Misc/HnM/Tarique/notebooks


In [4]:
from data import DataHelper
from data.metrics import map_at_k, hr_at_k, recall_at_k

from retrieval.rules import (
    OrderHistory,
    OrderHistoryDecay,
    ItemPair,

    UserGroupTimeHistory,
    UserGroupSaleTrend,

    TimeHistory,
    TimeHistoryDecay,
    SaleTrend,

    OutOfStock,
)
from retrieval.collector import RuleCollector


In [5]:
!ls ../src/

__init__.py  data      models	  utils.py
__pycache__  features  retrieval  visualization


In [6]:
data_dir = Path("../src/data/")
dh = DataHelper(data_dir)

In [7]:
# data = dh.preprocess_data(save=True) # run only once

In [8]:
data = dh.load_data(name="encoded_full")

In [9]:
data.keys()

dict_keys(['user', 'item', 'inter'])

In [10]:
listBin = [-1, 24, 32, 40, 50, 65, 75, 119]
# listBin = [-1, 19, 29, 39, 49, 59, 69, 119]
data['user']['age_bins'] = pd.cut(data['user']['age'], listBin)

In [11]:
trans = data["inter"].merge(data['item'][['article_id','product_code']], on='article_id', how='left')

## Retrieval

In [12]:
pd.to_datetime('2020-09-16') - 5*pd.Timedelta(days=7)

Timestamp('2020-08-12 00:00:00')

Timestamp('2020-09-13 00:00:00')

In [13]:
trans = data["inter"]
train, valid = dh.split_data(trans, "2020-09-16", "2020-09-23")
customer_list = valid["customer_id"].values

last_week = train.loc[train.t_dat >= "2020-09-09"]
last_3days = train.loc[train.t_dat >= "2020-09-13"]
last_2week = train.loc[train.t_dat >= "2020-09-02"]
last_5week = train.loc[train.t_dat >= "2020-08-12"]

In [35]:
trans.t_dat.max()

'2020-09-22'

In [34]:
last_week.t_dat.max()

'2020-09-15'

In [14]:
train = train.merge(data['user'][['customer_id','age_bins']], on='customer_id', how='left')
last_week = last_week.merge(data['user'][['customer_id','age_bins']], on='customer_id', how='left')
last_3days = last_3days.merge(data['user'][['customer_id','age_bins']], on='customer_id', how='left')
last_2week = last_2week.merge(data['user'][['customer_id','age_bins']], on='customer_id', how='left')
last_5week = last_5week.merge(data['user'][['customer_id','age_bins']], on='customer_id', how='left')

In [15]:
# last_week = last_week.merge(data['item'][['article_id','perceived_colour_master_id','product_group_name']], on='article_id', how='left')

In [16]:
candidates = RuleCollector().collect(
    # data=data,
    valid = valid,
    customer_list=customer_list,
    rules=[
        # OrderHistory(train, 7),
        # ItemPair(OrderHistory(train, 7).retrieve()),
        # UserGroupTimeHistory(data, customer_list, last_week, ['age_bins'], 24),
        # OrderHistoryDecay(train, 7),
        OrderHistory(train, 3),
        OrderHistory(train, 7),
        OrderHistory(train, 14),
        OrderHistoryDecay(train, 3, n=50),
        OrderHistoryDecay(train, 7, n=50),
        OrderHistoryDecay(train, 14, n=50),
        ItemPair(OrderHistory(train, 3).retrieve(), name='1'),
        ItemPair(OrderHistory(train, 7).retrieve(), name='2'),
        ItemPair(OrderHistory(train, 14).retrieve(), name='3'),
        ItemPair(OrderHistoryDecay(train, 3, n=50).retrieve(), name='4'),
        ItemPair(OrderHistoryDecay(train, 7, n=50).retrieve(), name='5'),
        ItemPair(OrderHistoryDecay(train, 14, n=50).retrieve(), name='6'),
        UserGroupTimeHistory(data, customer_list, last_week, ['age_bins'], n=50, name='1'),
        UserGroupTimeHistory(data, customer_list, last_3days, ['age_bins'], n=50, name='2'),
        UserGroupTimeHistory(data, customer_list, last_2week, ['age_bins'], n=50, name='3'),
        UserGroupSaleTrend(data, customer_list, train, ['age_bins'], 3, n=50),
        UserGroupSaleTrend(data, customer_list, train, ['age_bins'], 7, n=50),
        UserGroupSaleTrend(data, customer_list, train, ['age_bins'], 14, n=50),
        TimeHistory(customer_list, last_week, n=50, name='1'),
        TimeHistory(customer_list, last_3days, n=50, name='2'),
        TimeHistory(customer_list, last_2week, n=50, name='3'),
        TimeHistoryDecay(customer_list, train, 3, n=50),
        TimeHistoryDecay(customer_list, train, 7, n=50),
        TimeHistoryDecay(customer_list, train, 14, n=50),
        SaleTrend(customer_list, train, 3, n=50),
        SaleTrend(customer_list, train, 7, n=50),
        SaleTrend(customer_list, train, 14, n=50),
    ],
    filters=[OutOfStock(trans)],
    min_pos_rate=0.006,
    compress=False,
)

Retrieve items by rules:   4%|▎         | 1/27 [00:28<12:19, 28.43s/it]

Positive rate: 0.03038


Retrieve items by rules:   7%|▋         | 2/27 [00:57<11:55, 28.61s/it]

Positive rate: 0.02859
Positive rate: 0.02577


Retrieve items by rules:  11%|█         | 3/27 [01:27<11:46, 29.43s/it]

Positive rate: 0.01413


Retrieve items by rules:  15%|█▍        | 4/27 [02:32<16:39, 43.46s/it]

Positive rate: 0.01295


Retrieve items by rules:  19%|█▊        | 5/27 [03:40<19:08, 52.22s/it]

Positive rate: 0.01230


Retrieve items by rules:  22%|██▏       | 6/27 [04:48<20:13, 57.77s/it]

Positive rate: 0.01519


Retrieve items by rules:  26%|██▌       | 7/27 [05:13<15:37, 46.87s/it]

Positive rate: 0.01472


Retrieve items by rules:  30%|██▉       | 8/27 [05:40<12:52, 40.68s/it]

Positive rate: 0.01374


Retrieve items by rules:  33%|███▎      | 9/27 [06:12<11:21, 37.89s/it]

Positive rate: 0.00908


Retrieve items by rules:  37%|███▋      | 10/27 [06:52<10:55, 38.55s/it]

Positive rate: 0.00853


Retrieve items by rules:  41%|████      | 11/27 [07:43<11:18, 42.39s/it]

Positive rate: 0.00825


Retrieve items by rules:  44%|████▍     | 12/27 [08:41<11:45, 47.04s/it]

TOP14.0 Positive rate: 0.00607


Retrieve items by rules:  48%|████▊     | 13/27 [09:38<11:41, 50.09s/it]

TOP21.5 Positive rate: 0.00607


Retrieve items by rules:  52%|█████▏    | 14/27 [10:29<10:56, 50.47s/it]

TOP13.0 Positive rate: 0.00610


Retrieve items by rules:  59%|█████▉    | 16/27 [12:29<10:06, 55.18s/it]

skip


Retrieve items by rules:  63%|██████▎   | 17/27 [13:35<09:44, 58.49s/it]

skip


Retrieve items by rules:  67%|██████▋   | 18/27 [14:50<09:31, 63.54s/it]

skip
TOP9.0 Positive rate: 0.00631


Retrieve items by rules:  70%|███████   | 19/27 [15:38<07:50, 58.85s/it]

TOP16.0 Positive rate: 0.00601


Retrieve items by rules:  78%|███████▊  | 21/27 [17:12<05:15, 52.52s/it]

skip
TOP12.0 Positive rate: 0.00607


Retrieve items by rules:  81%|████████▏ | 22/27 [18:38<05:13, 62.72s/it]

TOP8.0 Positive rate: 0.00616


Retrieve items by rules:  85%|████████▌ | 23/27 [20:05<04:39, 69.87s/it]

TOP11.0 Positive rate: 0.00604


Retrieve items by rules:  93%|█████████▎| 25/27 [22:21<02:15, 67.55s/it]

skip
TOP2.0 Positive rate: 0.00758


Retrieve items by rules: 100%|██████████| 27/27 [24:07<00:00, 53.62s/it]

skip





In [17]:
candidates.groupby('method').article_id.nunique()

method
ItemPairRetrieve1             31864
ItemPairRetrieve2             31755
ItemPairRetrieve3             32788
ItemPairRetrieve4              3741
ItemPairRetrieve5              4986
ItemPairRetrieve6              6017
OrderHistoryDecay_14_top50    19232
OrderHistoryDecay_3_top50     11089
OrderHistoryDecay_7_top50     15404
OrderHistory_14               76479
OrderHistory_3                75492
OrderHistory_7                75980
SaleTrend_7_top50                 2
TimeHistoryDecay_14_top50        11
TimeHistoryDecay_3_top50         12
TimeHistoryDecay_7_top50          8
TimeHistory_50_1                  9
TimeHistory_50_2                 16
UGTimeHistory_age_bins_501       44
UGTimeHistory_age_bins_502       65
UGTimeHistory_age_bins_503       54
Name: article_id, dtype: int64

In [24]:
# low_candidates_grp = ['SaleTrend_7_top50','TimeHistoryDecay_14_top50','TimeHistoryDecay_3_top50',
#                      'TimeHistory_50_1', 'TimeHistory_50_2', 'TimeHistory_50_3', 'UGTimeHistory_age_bins_501',
#                      'UGTimeHistory_age_bins_502', 'UGTimeHistory_age_bins_503']

# lookback_mapping = {'SaleTrend_7_top50':7,'TimeHistoryDecay_14_top50':14,'TimeHistoryDecay_3_top50':3,
#                      'TimeHistory_50_1':7, 'TimeHistory_50_2':3, 'TimeHistory_50_3':14, 
#                     'UGTimeHistory_age_bins_501':7,
#                      'UGTimeHistory_age_bins_502':3, 'UGTimeHistory_age_bins_503':14}

# rare_candidates = candidates[candidates['method'].isin(low_candidates_grp)]
# # rare_candidates.to_csv('rare_candidates.csv', index=False)

In [None]:
rare_candidates = pd.read_csv('rare_candidates.csv')
rare_candidates['lookback'] = rare_candidates['method'].map(lookback_mapping)
rare_group = rare_candidates[['article_id','method','lookback']].drop_duplicates()

In [32]:
trans = trans.merge(data['item'][['article_id','product_code','colour_group_code','graphical_appearance_no']], 
                 on='article_id')

In [52]:
## From same channel find alternate item from same prod group with highest number of sale

def find_alternate_items_all(base_itm, bs_lookback):

    item_1 = base_itm
    prod_cd = data['item'][data['item'].article_id==item_1]['product_code'].iloc[0]
    tdt = '2020-09-16'
    n = bs_lookback
    
    item_2 = ''
    item_3 = ''
    item_4 = ''
    item_5 = ''
    
    trans.t_dat = pd.to_datetime(trans.t_dat)

    color_1 = data['item'][data['item'].article_id==item_1]['colour_group_code'].iloc[0]
    g_1 = data['item'][data['item'].article_id==item_1]['graphical_appearance_no'].iloc[0]

    p1 = trans[(trans.article_id==item_1) & 
                 (trans.product_code==prod_cd) &
                 (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n))]['price'].mean()

    filter_df = trans[(trans.article_id!=item_1) & 
                 (trans.product_code==prod_cd) & 
                 (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n))]

    if len(filter_df)>0:

        item_2 = filter_df.groupby('t_dat')['article_id'].value_counts().head(1).index.values[0][1]
        color_2 = data['item'][data['item'].article_id==item_2]['colour_group_code'].iloc[0]
        g_2 = data['item'][data['item'].article_id==item_2]['graphical_appearance_no'].iloc[0]

        p2 = trans[(trans.article_id==item_2) & 
                 (trans.product_code==prod_cd) & 
                 (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n))]['price'].mean()

        p_low = min(p1, p2)
        filter_df_2 = trans[~(trans.article_id.isin([item_1, item_2])) & 
                 (trans.product_code==prod_cd) & 
                 (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n)) & 
                 (trans.price < p_low)]

        if len(filter_df_2)>0:

            item_3 = filter_df_2.groupby('t_dat')['article_id'].value_counts().head(1).index.values[0][1]
            color_3 = data['item'][data['item'].article_id==item_3]['colour_group_code'].iloc[0]
            g_3 = data['item'][data['item'].article_id==item_3]['graphical_appearance_no'].iloc[0]

            filter_df_3 = trans[~(trans.article_id.isin([item_1, item_2, item_3])) & 
                     (trans.product_code==prod_cd) & 
                     (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n)) & 
                     ~ (trans.colour_group_code.isin([color_1, color_2, color_3]))]

            if len(filter_df_3)>0:
                item_4 = filter_df_3.groupby('t_dat')['article_id'].value_counts().head(1).index.values[0][1]
                color_4 = data['item'][data['item'].article_id==item_4]['colour_group_code'].iloc[0]
                g_4 = data['item'][data['item'].article_id==item_4]['graphical_appearance_no'].iloc[0]

                filter_df_4 = trans[~(trans.article_id.isin([item_1, item_2, item_3, item_4])) & 
                     (trans.product_code==prod_cd) & 
                     (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n)) & 
                     ~ (trans.graphical_appearance_no.isin([g_1, g_2, g_3, g_4]))]

                if len(filter_df_4)>0:
                    item_5 = filter_df_4.groupby('t_dat')['article_id'].value_counts().head(1).index.values[0][1]            

            else:
                filter_df_3_1 = trans[~(trans.article_id.isin([item_1, item_2, item_3])) & 
                 (trans.product_code==prod_cd) & 
                 (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n)) & 
                 ~ (trans.graphical_appearance_no.isin([g_1, g_2, g_3]))]

                if len(filter_df_3_1)>0:
                    item_4 = filter_df_3_1.groupby('t_dat')['article_id'].value_counts().head(1).index.values[0][1]

        else:
            filter_df_2_1 = trans[~(trans.article_id.isin([item_1, item_2])) & 
                 (trans.product_code==prod_cd) & 
                 (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n)) & 
                 ~ (trans.colour_group_code.isin([color_1, color_2]))]

            if len(filter_df_2_1)>0:
                item_3 = filter_df_2_1.groupby('t_dat')['article_id'].value_counts().head(1).index.values[0][1]
                color_3 = data['item'][data['item'].article_id==item_3]['colour_group_code'].iloc[0]
                g_3 = data['item'][data['item'].article_id==item_3]['graphical_appearance_no'].iloc[0]

            else:
                filter_df_2_2 = trans[~(trans.article_id.isin([item_1, item_2])) & 
                     (trans.product_code==prod_cd) & 
                     (trans.t_dat<pd.to_datetime(tdt)) & (trans.t_dat >= pd.to_datetime(tdt) - pd.Timedelta(days=n)) & 
                     ~ (trans.graphical_appearance_no.isin([g_1, g_2]))]

                if len(filter_df_2_2)>0:
                    item_3 = filter_df_2_2.groupby('t_dat')['article_id'].value_counts().head(1).index.values[0][1]

    return item_2, item_3, item_4, item_5


In [43]:
rare_group = rare_group.merge(candidates[candidates.article_id.isin(rare_group.article_id.unique())] \
                              .groupby(['article_id','method']).agg({'score':'mean', 'hit_rate':'mean'}), 
                              on=['article_id','method'])

In [53]:
alternate_best_sells = pd.DataFrame(columns = rare_group.columns)

lb_list = []
art_list = []
score_list = []
hr_list = []
method_list = []

for i in range(0, len(rare_group)):
    
    bs_itm = rare_group['article_id'][i]
    bs_lb = rare_group['lookback'][i]
    bs_score = rare_group['score'][i]
    bs_hr = rare_group['hit_rate'][i]
    bs_method = rare_group['method'][i]
    
    a,b,c,d = find_alternate_items_all(bs_itm, bs_lb)

    art_list.extend([a,b,c,d])
    method_list.extend([bs_method]*4)
    lb_list.extend([bs_lb]*4)
    score_list.extend([bs_score]*4)
    hr_list.extend([bs_hr]*4)

In [54]:
alternate_best_sells['article_id'] = art_list
alternate_best_sells['method'] = method_list
alternate_best_sells['lookback'] = lb_list
alternate_best_sells['score'] = score_list
alternate_best_sells['hit_rate'] = hr_list

alternate_best_sells = alternate_best_sells[alternate_best_sells.article_id!='']

In [60]:
additional_candidates = alternate_best_sells.drop('lookback', axis=1).merge(
    candidates[['article_id','method','customer_id']], on=['article_id','method'])

additional_candidates = additional_candidates[candidates.columns]

Unnamed: 0,customer_id,article_id,score,method,hit_rate
0,1371767,60325,5.199338,OrderHistory_3,1.0
1,66507,42631,5.199338,OrderHistory_3,1.0
2,77078,63326,5.199338,OrderHistory_3,1.0
3,73718,70920,5.199338,OrderHistory_3,1.0
4,73123,80609,5.199338,OrderHistory_3,1.0
...,...,...,...,...,...
49617858,191894,104555,1.887753,SaleTrend_7_top50,0.0
49617859,1198812,104555,1.887753,SaleTrend_7_top50,0.0
49617860,425976,104555,1.887753,SaleTrend_7_top50,0.0
49617861,361074,104555,1.887753,SaleTrend_7_top50,0.0


In [73]:
del alternate_best_sells, rare_group, art_list,method_list,lb_list,score_list,hr_list
del rare_candidates

del trans['product_code']
del trans['colour_group_code']
del trans['graphical_appearance_no']

In [70]:
# additional_candidates.to_parquet('additional_candidates.parquet')
additional_candidates = pd.read_parquet('additional_candidates.parquet')

In [69]:
candidates = pd.concat([candidates, additional_candidates], axis=0)

In [74]:
candidates.shape

(51441976, 5)

In [75]:
candidates = (
    pd.pivot_table(
        candidates,
        values="score",
        index=["customer_id", "article_id"],
        columns=["method"],
        aggfunc=np.sum,
    )
    .reset_index()
)

In [76]:
candidates.shape

# (15630864, 24)

(15709261, 23)

In [77]:
label = valid[["customer_id", "article_id"]]
label.columns = ["customer_id", "label_item"]
tmp_items = candidates.merge(label, on=["customer_id"], how="left")
tmp_items = tmp_items[tmp_items["label_item"].notnull()]
tmp_items["label"] = tmp_items.apply(lambda x: 1 if x["article_id"] in x["label_item"] else 0, axis=1)
pos_rate = tmp_items["label"].mean()
pos_rate
# 0.007686350632672472
# 0.0056985812619375735
# 0.0062356084250075995

# 0.0062356084250075995

0.006208012290656241

In [78]:
candidates = candidates.drop_duplicates(['customer_id','article_id'])

In [79]:
candidates = candidates.groupby('customer_id')['article_id'].apply(list).reset_index()

In [80]:
candidates.rename(columns={'article_id': 'prediction'}, inplace=True)
valid2 = pd.merge(valid, candidates, on="customer_id", how="left")

In [184]:
from typing import Iterable
import numpy as np

def _ap_at_kk(actual, predicted, k=10):
    if len(predicted) > k:
        predicted = predicted[:k]

    score = 0.0
    num_hits = 0.0

    for i, p in enumerate(predicted):
        if p in actual and p not in predicted[:i]:
            num_hits += 1.0
            score += num_hits / (i + 1.0)

    if actual is None:
        return 0.0

    return score / min(len(actual), k)

def map_at_kk(actual: Iterable, predicted: Iterable, k: int = 12) -> float:
    return [_ap_at_kk(a,p,k) for a, p in zip(actual, predicted) if a is not None], \
            [list(set(a).intersection(p)) for a,p in zip(actual,predicted) if a is not None]

map_list, matched_articles = map_at_kk(valid2["article_id"], valid2["prediction"], k=12)
valid_df = pd.DataFrame({"article_id":valid2["article_id"], 'prediction':valid2["prediction"],
                         'MAPall':map_list, 'MatchedArticles':matched_articles})

valid_df.to_parquet('validation.parquet')

In [185]:
valid_df[valid_df.MAPall==valid_df.MAPall.max()]

Unnamed: 0,article_id,prediction,MAPall,MatchedArticles
226,[3092],"[3092, 42627, 46383, 46385, 53893, 53894, 5669...",1.0,[3092]
363,[3092],"[3092, 3521, 13122, 13125, 17164, 20922, 42627...",1.0,[3092]
1668,[75],"[75, 78, 1714, 1731, 1733, 2196, 2197, 2200, 2...",1.0,[75]
2340,[2253],"[2253, 2254, 2261, 2864, 3092, 3521, 8106, 171...",1.0,[2253]
2682,[75],"[75, 76, 3092, 42627, 53893, 56695, 67523, 675...",1.0,[75]
2895,[75],"[75, 76, 507, 620, 1068, 1069, 1715, 3092, 127...",1.0,[75]
3752,"[76, 75]","[75, 76, 78, 1714, 3092, 25191, 46352, 53893, ...",1.0,"[75, 76]"
3871,[2235],"[2235, 3092, 42627, 53893, 56606, 56608, 56695...",1.0,[2235]
4614,[1715],"[1715, 3092, 42627, 46352, 46371, 46383, 46385...",1.0,[1715]
4822,[1714],"[1714, 1715, 3092, 13339, 15989, 15990, 16004,...",1.0,[1714]


In [229]:
valid_df[(valid_df.MAPall==0) & (valid_df.MatchedArticles.astype(str)!='[]')]

Unnamed: 0,article_id,prediction,MAPall,MatchedArticles
6,"[76696, 103576, 100335, 103576, 100335, 96767,...","[7, 9, 54, 74, 122, 123, 124, 1714, 1715, 2433...",0.0,[104150]
9,"[54340, 54340, 96071]","[1483, 1715, 3092, 42627, 46383, 46385, 53893,...",0.0,[96071]
17,"[53894, 97850, 90580, 83497]","[1714, 1715, 3092, 3281, 17670, 24838, 25800, ...",0.0,[53894]
19,"[102446, 103105]","[910, 975, 3092, 3521, 9774, 15989, 16004, 171...",0.0,[102446]
23,[104073],"[899, 1301, 1393, 1714, 2030, 2060, 2913, 2914...",0.0,[104073]
...,...,...,...,...
68958,"[102183, 102183, 81558]","[3092, 6145, 42627, 50670, 53893, 55589, 56695...",0.0,[102183]
68966,"[104554, 97323, 103304]","[68, 1714, 1715, 3092, 6459, 42627, 53893, 566...",0.0,"[103304, 104554]"
68968,"[76004, 83764, 97252, 103705, 103794, 104932, ...","[1714, 2209, 2220, 2235, 3089, 3092, 3941, 395...",0.0,"[104073, 103794]"
68975,[104554],"[3092, 42627, 53893, 56695, 67523, 67540, 6754...",0.0,[104554]


In [195]:
item_focus = valid_df[(valid_df.MAPall==0) & 
                      (valid_df.MatchedArticles.astype(str)!='[]')]['MatchedArticles'].tolist()

In [197]:
flat_list = [item for sublist in item_focus for item in sublist]
flat_list = list(set(flat_list))

In [204]:
item_index2id = pickle.load(open(data_dir/"index_id_map/item_index2id.pkl", "rb"))

In [215]:
with open('bring_to_top12.pkl', 'wb') as f:
    pickle.dump(flat_list, f)

In [212]:
flat_list_id = [item_index2id[i] for i in flat_list]

In [232]:
for p in valid2["prediction"][6:7]:
    k = len(p)
    print(p, k)

    print(list(set(flat_list).intersection(p)))
    
    print(p[:12])

[7, 9, 54, 74, 122, 123, 124, 1714, 1715, 2433, 3092, 3521, 12727, 12757, 14241, 14254, 17044, 17045, 17164, 24838, 25069, 25086, 25114, 25129, 40284, 40289, 42627, 46814, 53893, 53894, 55589, 56695, 57064, 57130, 57659, 57660, 60764, 62704, 62708, 62709, 64157, 67052, 67053, 67523, 67544, 70195, 70556, 71102, 71104, 71107, 71108, 72966, 72967, 77474, 77886, 79488, 79489, 80057, 80058, 80247, 81609, 82629, 82632, 83673, 85193, 85194, 85195, 86193, 86640, 87468, 87471, 88141, 88142, 88246, 88270, 89960, 89961, 90083, 90084, 91738, 92136, 93826, 93827, 94675, 94697, 95001, 95003, 95218, 95252, 95790, 99395, 99397, 99399, 99400, 99499, 100938, 101352, 101471, 101993, 101994, 102084, 102233, 102241, 102290, 102444, 103109, 103110, 103130, 103305, 103390, 103704, 103794, 103796, 103797, 103798, 104046, 104073, 104074, 104149, 104150, 104276, 104277, 104334, 104528, 104554, 104555, 104643, 104758, 104759, 104840, 104841, 104948, 105078, 105079] 134
[3092, 95252, 95790, 70195, 99395, 99397, 9

In [233]:
for p in valid2["prediction"][226:227]:
    k = len(p)
    print(p, k)

    print(list(set(flat_list).intersection(p)))
    
    print(p[:12])

[3092, 42627, 46383, 46385, 53893, 53894, 56695, 57064, 67523, 67544, 70696, 71108, 75439, 82629, 82632, 91738, 95218, 103109, 103794, 103797, 103798, 104046, 104073, 104074, 104149, 104158, 104528, 104554, 104555] 29
[42627, 53893, 53894, 104073, 104074, 3092, 70696, 46383, 75439, 46385, 67523, 71108, 82629, 103109, 82632, 104528, 104149, 67544, 91738, 104158, 57064, 104554, 104555, 104046, 95218, 103794, 103797, 103798, 56695]
[3092, 42627, 46383, 46385, 53893, 53894, 56695, 57064, 67523, 67544, 70696, 71108]


In [110]:
map_at_k(valid2["article_id"], valid2["prediction"], k=12)
hr_at_k(valid2["article_id"], valid2["prediction"], k=12)
recall_at_k(valid2["article_id"], valid2["prediction"], k=12)
# 0.025620866741013788

# 0.007122429128294375
# 0.06400034790676098
# 0.028341529076267406

0.007019627177185514

0.06349298388032007

0.027707398480994898

In [82]:
valid2['prediction'].apply(len).mean()
# 31.335150179751828 0.09236951948783895
# 56.68628957439406 0.1210664822292757

# 49.73995419227647 0.11656190014664647

50.87640612315899

In [111]:
recall_at_k(valid2["article_id"], valid2["prediction"], k=200)

# 0.11656190014664647

0.11855617258364239

In [112]:
recall_at_k(valid2["article_id"], valid2["prediction"], k=200) / valid2['prediction'].apply(len).mean()
# 0.002504861238505555
# 0.0029477924617551813

# 0.0023434259648905343

0.002330278052593725