In [1]:
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import gc, os
import logging
import time
import lightgbm as lgb
from gensim.models import Word2Vec
from sklearn.preprocessing import MinMaxScaler
import warnings
warnings.filterwarnings('ignore')

In [3]:
data_path = '../data/raw/'
save_path = "../temp/"

In [4]:
def trn_val_split(all_click_df, sample_user_nums):
    all_click = all_click_df
    all_user_ids = all_click.user_id.unique()
    
    # replace=True表示可以重复抽样，反之不可以
    sample_user_ids = np.random.choice(all_user_ids, size=sample_user_nums, replace=False) 
    
    click_val = all_click[all_click['user_id'].isin(sample_user_ids)]
    click_trn = all_click[~all_click['user_id'].isin(sample_user_ids)]
    
    # 将验证集中的最后一次点击给抽取出来作为答案
    click_val = click_val.sort_values(['user_id', 'click_timestamp'])
    val_ans = click_val.groupby('user_id').tail(1)
    
    click_val = click_val.groupby('user_id').apply(lambda x: x[:-1]).reset_index(drop=True)
    
    # 去除val_ans中某些用户只有一个点击数据的情况，如果该用户只有一个点击数据，又被分到ans中，
    # 那么训练集中就没有这个用户的点击数据，出现用户冷启动问题，给自己模型验证带来麻烦
    val_ans = val_ans[val_ans.user_id.isin(click_val.user_id.unique())] # 保证答案中出现的用户再验证集中还有
    click_val = click_val[click_val.user_id.isin(val_ans.user_id.unique())]
    
    return click_trn, click_val, val_ans

In [5]:
def get_hist_and_last_click(all_click):
    all_click = all_click.sort_values(by=['user_id', 'click_timestamp'])
    click_last_df = all_click.groupby('user_id').tail(1)

    # 如果用户只有一个点击，hist为空了，会导致训练的时候这个用户不可见，此时默认泄露一下
    def hist_func(user_df):
        if len(user_df) == 1:
            return user_df
        else:
            return user_df[:-1]

    click_hist_df = all_click.groupby('user_id').apply(hist_func).reset_index(drop=True)

    return click_hist_df, click_last_df

In [2]:
emb = pd.read_csv('../data/raw/articles_emb.csv')
emb.head

<bound method NDFrame.head of         article_id     emb_0     emb_1     emb_2     emb_3     emb_4  \
0                0 -0.161183 -0.957233 -0.137944  0.050855  0.830055   
1                1 -0.523216 -0.974058  0.738608  0.155234  0.626294   
2                2 -0.619619 -0.972960 -0.207360 -0.128861  0.044748   
3                3 -0.740843 -0.975749  0.391698  0.641738 -0.268645   
4                4 -0.279052 -0.972315  0.685374  0.113056  0.238315   
...            ...       ...       ...       ...       ...       ...   
364042      364042 -0.055038 -0.962136  0.869436 -0.071523 -0.725294   
364043      364043 -0.136932 -0.995471  0.991298  0.031871 -0.915621   
364044      364044 -0.251390 -0.976243  0.586097  0.643631 -0.663359   
364045      364045  0.224342 -0.923288 -0.381742  0.687890 -0.773911   
364046      364046 -0.257134 -0.994631  0.983792 -0.190975 -0.953720   

           emb_5     emb_6     emb_7     emb_8  ...   emb_240   emb_241  \
0       0.901365 -0.335148 -0.

In [3]:
emb_cols = [f'emb_{i}' for i in range(250)]  # 根据实际列数调整

# 直接使用numpy数组构建字典(最快)
article_ids = emb['article_id'].values
embedding_matrix = emb[emb_cols].values

item_emb_dict = {
    article_id: embedding_matrix[idx] 
    for idx, article_id in enumerate(article_ids)
}

In [4]:
item_emb_dict[2]

array([-0.61961854, -0.9729604 , -0.20736018, -0.12886102,  0.04474759,
       -0.387535  , -0.73047674, -0.06612612, -0.75489885, -0.24200428,
        0.670484  , -0.2803883 , -0.557285  , -0.08414505,  0.02778196,
        0.29407424,  0.36269727, -0.3685494 ,  0.14796   , -0.01175088,
        0.03020873,  0.10631693,  0.6280128 ,  0.388849  ,  0.6159109 ,
       -0.44511306,  0.10602808,  0.13710949, -0.09553552,  0.3425321 ,
        0.5926465 , -0.26179096,  0.34212252,  0.7045392 , -0.43306684,
        0.1041543 ,  0.7859709 ,  0.5886402 , -0.62768734, -0.14329416,
        0.39983153, -0.70823455, -0.73296404, -0.95824176, -0.629325  ,
       -0.28223997,  0.0551875 , -0.70930463,  0.5806534 , -0.5183282 ,
        0.0590419 ,  0.66433567,  0.37024036, -0.22426963, -0.22767073,
        0.6944705 ,  0.16796917,  0.10058454,  0.9468768 , -0.47480643,
        0.91217107, -0.43829462, -0.04617592,  0.80739474, -0.2778143 ,
       -0.6002078 , -0.5066402 , -0.00820139, -0.8228875 ,  0.20

In [5]:
click = pd.read_csv('../data/raw/train_click_log.csv')
click.head()

Unnamed: 0,user_id,click_article_id,click_timestamp,click_environment,click_deviceGroup,click_os,click_country,click_region,click_referrer_type
0,199999,160417,1507029570190,4,1,17,1,13,1
1,199999,5408,1507029571478,4,1,17,1,13,1
2,199999,50823,1507029601478,4,1,17,1,13,1
3,199998,157770,1507029532200,4,1,17,1,25,5
4,199998,96613,1507029671831,4,1,17,1,25,5


In [6]:
articles = pd.read_csv('../data/raw/articles.csv')
articles.head()

Unnamed: 0,article_id,category_id,created_at_ts,words_count
0,0,0,1513144419000,168
1,1,1,1405341936000,189
2,2,1,1408667706000,250
3,3,1,1408468313000,230
4,4,1,1407071171000,162


In [3]:
import pickle
import os

In [6]:
path = "../temp/all_recall_results.pkl"
with open(path, 'rb') as f:
    all_recall_results = pickle.load(f)

In [12]:
all_recall_results[131078]

[(124352, 0.9627451735014284),
 (124228, 0.9252055640072008),
 (124177, 0.9221347840781757),
 (123289, 0.920215561294154),
 (124350, 0.9160100563142084),
 (124194, 0.914803620585804),
 (123909, 0.9125198739466032),
 (158046, 0.911674797610559),
 (20249, 0.9116732013313792),
 (140646, 0.9116731574053989),
 (30064, 0.9116702021379017),
 (109812, 0.9116677594911227),
 (57771, 0.9116624471645811),
 (84020, 0.9116552238314753),
 (76002, 0.9116539948097181),
 (216448, 0.911653145725048),
 (158023, 0.9116511880754326),
 (136599, 0.9116497244798838),
 (160565, 0.9116487092010409),
 (158047, 0.911646830799328),
 (129520, 0.911646629736078),
 (140627, 0.9116449840967761),
 (13635, 0.9116435934092977),
 (58265, 0.9116427447774726),
 (158850, 0.9116414754524964),
 (201797, 0.9116395757670611),
 (106313, 0.9116386551328561),
 (123368, 0.9116195179612896),
 (140676, 0.9115379769795942),
 (124176, 0.9106024230639465)]

In [13]:
all_recall_results[163862]

[(30064, 0.911680424664485),
 (140646, 0.9116803259442406),
 (20249, 0.911680301037757),
 (158046, 0.9116664032198762),
 (158023, 0.9116624349377619),
 (140627, 0.9116582107981336),
 (57771, 0.9116569387560864),
 (84020, 0.9116539028821512),
 (76002, 0.9116536719311209),
 (109812, 0.9116535954002892),
 (160565, 0.91165017370228),
 (129520, 0.9116456769498669),
 (136599, 0.9116447748823133),
 (158850, 0.9116444914012448),
 (87212, 0.9116418526725125),
 (13635, 0.911641201028331),
 (106313, 0.9116378024254247),
 (216448, 0.9116343146120224),
 (174691, 0.9116321654089053),
 (187642, 0.9116311125439144),
 (214800, 0.910226777439227),
 (111210, 0.908954388983774),
 (71076, 0.9088510757017482),
 (160132, 0.9082284493551545),
 (233478, 0.9081793381521457),
 (293301, 0.9080102784012332),
 (40969, 0.9077640172247813),
 (257291, 0.9075642500453636),
 (297639, 0.9074363038684117),
 (159762, 0.9071887493317825)]