In [3]:
import argparse
import gc
import glob
import os
import pickle

import lightgbm as lgb
import pandas as pd
from sklearn.model_selection import GroupKFold
from wandb.lightgbm import log_summary, wandb_callback

import wandb


class CFG:
    wandb = True
    num_iterations = 200
    n_folds = 5
    dtypes = {
        "session": "int32",
        "aid": "int32",
        "session_clicks_cnt": "int16",
        "session_carts_cnt": "int16",
        "session_orders_cnt": "int16",
        "session_aid_clicks_cnt": "int16",
        "session_aid_carts_cnt": "int16",
        "session_aid_orders_cnt": "int16",
        "clicks_rank": "int32",
        "carts_rank": "int32",
        "orders_rank": "int32",
        "session_clicks_unique_aid": "int16",
        "session_carts_unique_aid": "int16",
        "session_orders_unique_aid": "int16",
        "clicks_uu_rank": "int32",
        "carts_uu_rank": "int32",
        "orders_uu_rank": "int32",
    }
    float_cols = [
        "avg_action_num_reverse_chrono",
        "min_action_num_reverse_chrono",
        "max_action_num_reverse_chrono",
        "avg_sec_since_session_start",
        "min_sec_since_session_start",
        "max_sec_since_session_start",
        "avg_sec_to_session_end",
        "min_sec_to_session_end",
        "max_sec_to_session_end",
        "avg_log_recency_score",
        "min_log_recency_score",
        "max_log_recency_score",
        "avg_type_weighted_log_recency_score",
        "min_type_weighted_log_recency_score",
        "max_type_weighted_log_recency_score",
        "covisit_clicks_candidate_num",
        "covisit_carts_candidate_num",
        "covisit_orders_candidate_num",
        "w2v_candidate_num",
        "gru4rec_candidate_num",
        "narm_candidate_num",
        "sasrec_candidate_num",
        "session_clicks_carts_ratio",
        "session_carts_orders_ratio",
        "session_clicks_orders_ratio",
        "avg_sec_clicks_carts",
        "min_sec_clicks_carts",
        "max_sec_clicks_carts",
        "avg_sec_carts_orders",
        "min_sec_carts_orders",
        "max_sec_carts_orders",
        "avg_clicks_cnt",
        "avg_carts_cnt",
        "avg_orders_cnt",
        "clicks_carts_ratio",
        "carts_orders_ratio",
        "clicks_orders_ratio",
        "avg_sec_clicks_carts",
        "min_sec_clicks_carts",
        "max_sec_clicks_carts",
        "avg_sec_carts_orders",
        "min_sec_carts_orders",
        "max_sec_carts_orders",
        "avg_sec_session_clicks_carts",
        "min_sec_session_clicks_carts",
        "max_sec_session_clicks_carts",
        "avg_sec_session_carts_orders",
        "min_sec_session_carts_orders",
        "max_sec_session_carts_orders",
    ]


def read_files(path):
    dfs = []

    for file in glob.glob(path):
        df = pd.read_parquet(file)
        for col, dtype in CFG.dtypes.items():
            df[col] = df[col].astype(dtype)
        for col in CFG.float_cols:
            df[col] = df[col].astype("float16")
        dfs.append(df)
    return pd.concat(dfs).reset_index(drop=True)


def read_train_labels():
    train_labels = pd.read_parquet("./input/otto-validation/test_labels.parquet")
    train_labels = train_labels.explode("ground_truth")
    train_labels["aid"] = train_labels["ground_truth"]
    train_labels = train_labels[["session", "type", "aid"]]
    train_labels["aid"] = train_labels["aid"].astype("int32")
    train_labels["session"] = train_labels["session"].astype("int32")
    return train_labels


def dump_pickle(path, o):
    with open(path, "wb") as f:
        pickle.dump(o, f)

def cast_cols(df):
    for col, dtype in CFG.dtypes.items():
        df[col] = df[col].astype(dtype)
    for col in CFG.float_cols:
        df[col] = df[col].astype("float16")
    return df


def split_list(l, n):
    for idx in range(0, len(l), n):
        yield l[idx : idx + n]

In [5]:
path = "./input/lgbm_dataset_test/*"
files = glob.glob(path)
preds = []
files_list = split_list(files, 50)
files = next(iter(files_list))

In [6]:
dfs = []
for file in files:
    df = pd.read_parquet(file)
    df = cast_cols(df)
    dfs.append(df)
    break
test = pd.concat(dfs)
del dfs
gc.collect()
feature_cols = test.drop(columns=["session"]).columns.tolist()

In [7]:
type = "clicks"

In [8]:
pred_folds = []
fold = 0

In [9]:
output_dir = "output/lgbm/woven-elevator-414.bak"
ranker = pickle.load(open(os.path.join(output_dir, f"ranker_{type}_fold{fold}.pkl"), "rb"))

In [10]:
ranker

<lightgbm.basic.Booster at 0x7f61c996cd60>

In [16]:
feature_cols[:10]

['aid',
 'session_clicks_cnt',
 'session_carts_cnt',
 'session_orders_cnt',
 'session_clicks_unique_aid',
 'session_carts_unique_aid',
 'session_orders_unique_aid',
 'session_clicks_carts_ratio',
 'session_carts_orders_ratio',
 'session_clicks_orders_ratio']

In [18]:
hoge = ranker.predict(test[feature_cols[:10]])

[LightGBM] [Fatal] The number of features in data (10) is not the same as it was in training data (59).
You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.


LightGBMError: The number of features in data (10) is not the same as it was in training data (59).
You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.

In [14]:
test[feature_cols]

Unnamed: 0,aid,session_clicks_cnt,session_carts_cnt,session_orders_cnt,session_clicks_unique_aid,session_carts_unique_aid,session_orders_unique_aid,session_clicks_carts_ratio,session_carts_orders_ratio,session_clicks_orders_ratio,...,avg_orders_cnt,clicks_carts_ratio,carts_orders_ratio,clicks_orders_ratio,avg_sec_clicks_carts,min_sec_clicks_carts,max_sec_clicks_carts,avg_sec_carts_orders,min_sec_carts_orders,max_sec_carts_orders
0,519205,1,0,0,1,0,0,0.000,,0.0,...,0.000000,,,,,,,,,
1,359954,3,0,0,1,0,0,0.000,,0.0,...,1.000000,0.571289,0.250000,0.142822,64704.0,47200.0,inf,inf,inf,inf
2,1613713,2,0,0,2,0,0,0.000,,0.0,...,1.000000,0.600098,0.333252,0.199951,18064.0,3604.0,inf,inf,inf,inf
3,332813,3,0,0,3,0,0,0.000,,0.0,...,0.000000,0.000000,,0.000000,,,,,,
4,832192,2,0,0,2,0,0,0.000,,0.0,...,1.085938,0.119141,0.136719,0.016281,inf,2828.0,inf,inf,57920.0,inf
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1255327,1191966,1,0,0,1,0,0,0.000,,0.0,...,0.000000,0.030304,0.000000,0.000000,26128.0,5272.0,47008.0,,,
1255328,86039,10,1,0,8,1,0,0.125,0.0,0.0,...,1.000000,0.138916,0.300049,0.041656,41216.0,7976.0,inf,inf,inf,inf
1255329,592866,2,0,0,2,0,0,0.000,,0.0,...,0.000000,0.500000,0.000000,0.000000,18384.0,18384.0,18384.0,,,
1255330,832159,5,0,0,4,0,0,0.000,,0.0,...,0.000000,0.000000,,0.000000,,,,,,
