In [1]:
from tqdm import tqdm
import pandas as pd
import random
import os
import numpy as np
import joblib
import logging
import coloredlogs
import time
import gc
import sys
import argparse
import itertools
from collections import Counter

In [2]:
# 載入資料和模型
start_dt = 12
官方指認欄位 = ['2','6','10','12','13','15','18','19','21','22','25','26','36','37','39','48']
nbrs = joblib.load('../model/nbrs.pkl')
X_pca = joblib.load('../model/X_pca_for_knn.pkl')
df_groupby_chid_preprocessed = pd.read_feather('../data/df_groupby_chid_preprocessed.feather')
df = pd.read_feather('../data/2021玉山人工智慧公開挑戰賽冬季賽訓練資料集.feather')
df = df.loc[df.dt >= start_dt] # 取近期資料(太久的資料可能參考價值不高)
test_data = pd.read_feather('../data/需預測的顧客名單及提交檔案範例.feather')

In [7]:
# 一些函數都放在這裡
def chid2answer(chid,method='median'):
    if method in ['sum','mean','median']:
        a = df.loc[df.chid==chid,['shop_tag','txn_amt']].groupby('shop_tag').agg(method).sort_values(by='txn_amt',ascending=False)
    elif method in 'value_counts':
        a = df.loc[df.chid==chid,'shop_tag'].value_counts().to_frame()
    else:
        raise 'error'
    a['在指認欄位'] = False
    a.loc[list(set(a.index)&set(官方指認欄位)),'在指認欄位'] = True #有交集的部份做個記號
    answer = a[a['在指認欄位']==True].head(3)
    if len(answer) != 0:
        return answer.index.tolist()
    else:
        return []

In [8]:
def predict_function(chid): # 預測函數
    answer = chid2answer(chid) # 根據這個chid做預測
    if len(answer) == 3: # 如果成功找到三個直接return
        assert type(answer) == type([]) #記得確認是list型別
        assert len(np.unique(answer)) == 3 #確認三個shop_tag不重複
        return answer
    else:
        print(chid,answer)
        remain = 3-len(answer) # 否則計算離三個答案還缺多少
        idx = df_groupby_chid_preprocessed.loc[df_groupby_chid_preprocessed.chid==chid].index[0] # 根據chid找到該筆樣本的"idx"
        distances, indices = nbrs.kneighbors(X_pca[[idx]]) # 根據該樣本的"idx"找到該筆樣本的"PCA特徵"進而取得"鄰居的indices"(其中距離近的indices自動排前面)
        chid_list = df_groupby_chid_preprocessed.loc[indices[0][-(nbrs.n_neighbors-1):]]['chid'].values.tolist() # 根據"鄰居的indices"取得"chid_list(鄰居們)"
        
        # 從鄰居的shop_tag有根據的猜
        for nb_chid in tqdm(chid_list):
            nb_answer = chid2answer(nb_chid) #鄰居的答案
            can_add_answer = list(filter(lambda a: a not in answer, nb_answer)) # 可以加入的答案
            answer.extend(can_add_answer) # 加入答案
            if len(answer) >= 3: # 如果超過三個直接return
                answer = answer[:3]
                assert type(answer) == type([]) #記得確認是list型別
                assert len(np.unique(answer)) == 3 #確認三個shop_tag不重複
                print(chid,answer)
                return answer
        
        # 如果上面迴圈跑完都還沒補滿三個則隨機猜
        remain = 3-len(answer)
        for _ in range(remain):
            answer_list = 官方指認欄位 # 既然answer_list等於0估解將官方指認欄位當作answer_list
            shop_tag = np.random.choice(list(set(answer_list)-set(answer))) # 隨機選但是answer裡面已經有的就不要選,官方規定的
            answer.append(shop_tag) # 加入shop_tag至answer
            answer_list = list(filter(lambda a: a not in answer, answer_list)) # 記得把answer有的shop_tag從answer_list做刪除
        assert type(answer) == type([]) #確認是list型別
        assert len(np.unique(answer)) == 3 #確認三個shop_tag不重複
        print(chid,answer)
        return answer # 返回答案(類型list)

In [9]:
for _ in tqdm(range(100)):
    chid = np.random.choice(df['chid'].values)
    predict_function(chid)

 27%|██▋       | 27/100 [00:00<00:01, 50.11it/s]

10384312.0 ['37']


100%|██████████| 4/4 [00:00<00:00, 41.35it/s]
 32%|███▏      | 32/100 [00:00<00:02, 27.97it/s]

10384312.0 ['37', '36', '21']
10417034.0 ['2']


  0%|          | 0/4 [00:00<?, ?it/s]
 42%|████▏     | 42/100 [00:01<00:02, 26.24it/s]

10417034.0 ['2', '25', '15']
10377741.0 ['37', '48']


  0%|          | 0/4 [00:00<?, ?it/s]
 54%|█████▍    | 54/100 [00:01<00:01, 27.03it/s]

10377741.0 ['37', '48', '36']
10082528.0 []


  0%|          | 0/4 [00:00<?, ?it/s]
 63%|██████▎   | 63/100 [00:02<00:01, 25.83it/s]

10082528.0 ['10', '6', '2']


  0%|          | 0/4 [00:00<?, ?it/s]
 80%|████████  | 80/100 [00:02<00:00, 28.57it/s]

10378989.0 ['6', '19']
10378989.0 ['6', '19', '2']


100%|██████████| 100/100 [00:03<00:00, 33.28it/s]
