# atmaCup16 with RECRUIT - Train - t4rec

## Preprocess - NVTabular workflow

参考: https://nvidia-merlin.github.io/Transformers4Rec/stable/examples/end-to-end-session-based/01-ETL-with-NVTabular.html

NVTabular: FeatureEngineeringと前処理を簡単かつ高速に行えるライブラリ

### Config

In [1]:
# _pad_across_processesがTrainerクラスのクラスメソッドとして定義されていない
# transformersをdowngradeしたら直った
# https://github.com/huggingface/transformers/issues/24589
!pip install transformers4rec[pytorch,nvtabular]
!pip install transformers==4.19.4

Collecting transformers4rec[nvtabular,pytorch]
  Downloading transformers4rec-23.8.0.tar.gz (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l- \ | / - done
[?25h  Getting requirements to build wheel ... [?25l- done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- done
Collecting merlin-models>=23.4.0 (from transformers4rec[nvtabular,pytorch])
  Downloading merlin-models-23.8.1.tar.gz (485 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.8/485.8 kB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l- \ | / done
[?25h  Getting requirements to build wheel ... [?25l- \ done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- \ done
Collecting nvtabular (from transformers4rec[nvtabular,pytorch])
  Obtaining dependency information for nvtabular from

In [2]:
!pip install cudf-cu11 dask-cudf-cu11==23.10.0 --extra-index-url=https://pypi.nvidia.com

Looking in indexes: https://pypi.org/simple, https://pypi.nvidia.com
Collecting cudf-cu11
  Downloading https://pypi.nvidia.com/cudf-cu11/cudf_cu11-23.12.1-cp310-cp310-manylinux_2_28_x86_64.whl (506.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m506.4/506.4 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dask-cudf-cu11==23.10.0
  Downloading https://pypi.nvidia.com/dask-cudf-cu11/dask_cudf_cu11-23.10.0-py3-none-any.whl (82 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.0/82.0 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cudf-cu11
  Downloading https://pypi.nvidia.com/cudf-cu11/cudf_cu11-23.10.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (502.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m502.6/502.6 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cupy-cuda11x>=12.0.0 (from dask-cudf-cu11==23.10.0)
  Obtaining dependency information for cupy-cuda

In [3]:
import os
import glob
import numpy as np
import pandas as pd
import gc
import calendar
import datetime

import cudf
import nvtabular as nvt
from merlin.dag import ColumnSelector
from merlin.schema import Schema, Tags
from merlin.io import Dataset

--------------------------------------------------------------------------------

  CuPy may not function correctly because multiple CuPy packages are installed
  in your environment:

    cupy, cupy-cuda11x

  Follow these steps to resolve this issue:

    1. For all packages listed above, run the following command to remove all
       existing CuPy installations:

         $ pip uninstall <package_name>

      If you previously installed CuPy via conda, also run the following:

         $ conda uninstall cupy

    2. Install the appropriate CuPy package.
       Refer to the Installation Guide for detailed instructions.

         https://docs.cupy.dev/en/stable/install.html

--------------------------------------------------------------------------------

  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


In [4]:
INPUT_FOLDER = "/kaggle/input/atmacup16-recruit"
OUTPUT_FOLDER = "/kaggle/working"

In [5]:
SESSIONS_MAX_LENGTH=100

### FeatureEngineering

In [6]:
# Categorical features
# @see: https://nvidia-merlin.github.io/core/v0.3.0/api/merlin.dag.html#merlin.dag.ColumnSelector
categorical_features = ColumnSelector([
    "yad_no",
    "yad_type",
    "wireless_lan_flg",
    "onsen_flg",
    "kd_stn_5min",
    "kd_bch_5min",
    "kd_slp_5min",
    "kd_conv_walk_5min",
    "wid_cd",
    "ken_cd",
    "lrg_cd",
    "sml_cd",
    "day_idx",
]) >> nvt.ops.Categorify()

# Continuous features
total_room_cnt = ColumnSelector([
    "total_room_cnt",
]) >> nvt.ops.LogOp() >> nvt.ops.Normalize(out_dtype=np.float32)

# text_vecs = ColumnSelector([
#     f"vec{v}" for v in range(TEXT_VECTOR_SIZE)
# ])

continuous_features = total_room_cnt

# 最終的な特徴量
features = (
    ColumnSelector(["session_id", "seq_no"]) +
    categorical_features +
    continuous_features
)

In [7]:
# Define groupby workflow
# Group interaction features by session
d = {
    "session_id": ["list"],
    "yad_no": ["list"],
    "yad_type": ["list"],
    "wireless_lan_flg": ["list"],
    "onsen_flg": ["list"],
    "kd_stn_5min": ["list"],
    "kd_bch_5min": ["list"],
    "kd_slp_5min": ["list"],
    "kd_conv_walk_5min": ["list"],
    "wid_cd": ["list"],
    "ken_cd": ["list"],
    "lrg_cd": ["list"],
    "sml_cd": ["list"],
    "total_room_cnt": ["list"],
    "day_idx": ["first"],
}
# for v in range(TEXT_VECTOR_SIZE):
#     d[f"vec{v}"] = ["list"]

groupby_features = features >> nvt.ops.Groupby(
    groupby_cols=["session_id"],
    sort_cols=["seq_no"],
    aggs=d,
    name_sep="-",
)

# 各種類の特徴量に対して、タグ付け
item_feature_list = (
    groupby_features["yad_no-list"]
    >> nvt.ops.TagAsItemID()
)
categorical_features_list = (
    groupby_features[
        "yad_type-list",
        "wireless_lan_flg-list",
        "onsen_flg-list",
        "kd_stn_5min-list",
        "kd_bch_5min-list",
        "kd_slp_5min-list",
        "kd_conv_walk_5min-list",
        "wid_cd-list",
        "ken_cd-list",
        "lrg_cd-list",
        "sml_cd-list",
    ]
    >> nvt.ops.TagAsItemFeatures() # nvt.op.Categorify()でMetadataは付与済み
)
continuous_features_list = (
    groupby_features["total_room_cnt-list"]
    >> nvt.ops.TagAsItemFeatures()
    >> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])
)

groupby_features_list = (
    item_feature_list +
    categorical_features_list +
    continuous_features_list
)

# ユーザ行動履歴を最新からSESSION_MAX_LENGTHの長さでtruncated
groupby_features_truncated = groupby_features_list >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH)

# INDEXの代わり
session_id = groupby_features["session_id"] >> nvt.ops.AddMetadata(tags=[Tags.CATEGORICAL])

day_idx = (
    groupby_features["day_idx-first"] >>
    nvt.ops.Rename(f=lambda col: "day_idx")
)

# 学習のための特徴量
selected_features = (
    session_id +
    groupby_features_truncated +
    day_idx
)

### 前処理(pandas)

- partition_colの追加
- yado.csvの欠損値埋め
- 各CSVのmerge

In [8]:
# 読み込み
train_logs_df = pd.read_csv(os.path.join(INPUT_FOLDER, "train_log.csv"))
train_labels_df = pd.read_csv(os.path.join(INPUT_FOLDER, "train_label.csv"))
yados_df = pd.read_csv(os.path.join(INPUT_FOLDER, "yado.csv"))

In [9]:
# logとlabelをconcat
seq_no_labels_df = pd.DataFrame(train_logs_df.groupby("session_id")["seq_no"].max()+1)
train_labels_df = pd.merge(train_labels_df, seq_no_labels_df, on="session_id")

df = pd.concat([train_logs_df, train_labels_df], axis=0).reset_index(drop=True)

df

Unnamed: 0,session_id,seq_no,yad_no
0,000007603d533d30453cc45d0f3d119f,0,2395
1,0000ca043ed437a1472c9d1d154eb49b,0,13535
2,0000d4835cf113316fe447e2f80ba1c8,0,123
3,0000fcda1ae1b2f431e55a7075d1f500,0,8475
4,000104bdffaaad1a1e0a9ebacf585f33,0,96
...,...,...,...
707963,ffff2262d38abdeb247ebd591835dcc9,1,2259
707964,ffff2360540745117193ecadcdc06538,1,963
707965,ffff7fb4617164b2604aaf51c40bf82d,1,13719
707966,ffffcd5bc19d62cad5a3815c87818d83,3,10619


In [10]:
# yadoの欠損値埋め
yados_df = yados_df.fillna(
    {
        "total_room_cnt": yados_df["total_room_cnt"].mean(),
        "wireless_lan_flg": 0,
        "onsen_flg": 0,
        "kd_stn_5min": 0,
        "kd_bch_5min": 0,
        "kd_slp_5min": 0,
        "kd_conv_walk_5min": 0,
    }
).astype(
    {
        "yad_no": np.int32,
        "yad_type": np.int8,
        "total_room_cnt": np.int32, # TODO: NaNの扱い（中央値とする？）
        "wireless_lan_flg": np.int8,
        "onsen_flg": np.int8,
        "kd_stn_5min": np.int8,
        "kd_slp_5min": np.int8,
        "kd_bch_5min": np.int8,
        "kd_conv_walk_5min": np.int8,        
    }
)

In [11]:
yados_df

Unnamed: 0,yad_no,yad_type,total_room_cnt,wireless_lan_flg,onsen_flg,kd_stn_5min,kd_bch_5min,kd_slp_5min,kd_conv_walk_5min,wid_cd,ken_cd,lrg_cd,sml_cd
0,1,0,129,1,0,1,0,0,1,f0112abf369fb03cdc5f5309300913da,072c85e1653e10c9c7dd065ad007125a,449c52ef581d5f9ef311189469a0520e,677a32689cd1ad74e867f1fbe43a3e1c
1,2,0,23,1,0,0,0,0,0,d86102dd9c232bade9a97dccad40df48,b4d2fb4e51ea7bca80eb1270aa474a54,5c9a8f48e9df0234da012747a02d4b29,4ee16ee838dd2703cc9a1d5a535f0ced
2,3,0,167,1,1,1,0,0,1,46e33861f921c3e38b81998fbf283f01,572d60f0f5212aacda515ebf81fb0a3a,8a623b960557e87bd1f4edf71b6255be,ab9480fd72a44d51690ab16c4ad4d49c
3,4,0,144,1,0,1,0,0,1,46e33861f921c3e38b81998fbf283f01,107c7305a74c8dcc4f143de208bf7ec2,52c9ea83f2cfe92be54cb6bc961edf21,1cc3e1838bb0fd0fde0396130b1f82b9
4,5,0,41,1,1,0,0,0,0,43875109d1dab93592812c50d18270a7,75617bb07a2785a948ab1958909211f1,9ea5a911019b66ccd42f556c42a2fe2f,be1b876af18afc4deeb3081591d2a910
...,...,...,...,...,...,...,...,...,...,...,...,...,...
13801,13802,0,10,1,1,0,0,0,0,c312e07b7a5d456d53a5b00910a336e1,558ac1909f0318b82c621ab250329d6d,80fb3c5ad0c89931d0923e9f80885218,5eb30820716082c720836733d73c605e
13802,13803,0,87,0,0,1,0,0,1,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,e5cfcc0a43c82072aca11628ff0add53,20ad8785a30f125bee5a8a325782ab06
13803,13804,0,80,1,1,0,1,0,1,d86102dd9c232bade9a97dccad40df48,7d76599bd27ff9e7823b2b1323ca763e,c5fe8848b6ab39b040cdb3668aea9433,b3eab50ccf6ffb51c37d36ee384abfbf
13804,13805,0,8,1,1,0,0,0,1,3300cf6f774b7c6a5807110f244cbc21,689cf8289e7ea0b2eef1b017dcdfe8de,8b712435430a6875839a6c3b5a40b008,2b4165444a777465576b25f65697d739


In [12]:
df = pd.merge(df, yados_df, on="yad_no")

df

Unnamed: 0,session_id,seq_no,yad_no,yad_type,total_room_cnt,wireless_lan_flg,onsen_flg,kd_stn_5min,kd_bch_5min,kd_slp_5min,kd_conv_walk_5min,wid_cd,ken_cd,lrg_cd,sml_cd
0,000007603d533d30453cc45d0f3d119f,0,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343
1,05d87a854b34e30b25f07ac7c5b1dc2e,0,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343
2,189be6eb839900bf2035481d0db7a7f9,1,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343
3,3801fd3f98a4a62e31aa94e3ce156619,0,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343
4,7254bb04284937d96ef8309ccd62b058,0,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
707963,fd9a83ca399d8ca4ec886a546bcf6a85,1,10700,0,160,1,0,0,0,0,1,89e181a40914767dfee00fa2b7c2dcb5,2b99151dba9558109a35c75a3c05c38b,84fea75411a084611637f301b8970178,38d428257f5ec0e45e91fb74a2d926c1
707964,fda07b762c2ee3bb0147ae3a515d204f,1,9572,0,10,1,1,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,6920865be128aa14814810654738b159,828bd0261886a914435f0434dbfc2264,2eac3ef54f291530cfeae907b8823eaf
707965,fe6bff9642657a8a47b1f4e8a5165f0b,1,415,0,37,1,0,0,0,0,1,f0112abf369fb03cdc5f5309300913da,ce3aaf25e7e38a0c42d373fb148efc86,972e29ad914b6393f0ae1d369a3a22fd,5c4d53b9fd2c6f9faaa0d2cb77541c16
707966,fe8640584e5a182da211b7fabcf96011,1,5020,1,87,1,1,0,0,1,0,b07b75d367ebece55a23ceecc939fff4,0a66f6ab9c0507059da6f22a0e1f1690,4713062d683b3be22a00131d9546c66d,975a4a51b4386eec81f3b698d05bc475


In [13]:
# partition_colを無理やり追加
df["day_idx"] = 1

df

Unnamed: 0,session_id,seq_no,yad_no,yad_type,total_room_cnt,wireless_lan_flg,onsen_flg,kd_stn_5min,kd_bch_5min,kd_slp_5min,kd_conv_walk_5min,wid_cd,ken_cd,lrg_cd,sml_cd,day_idx
0,000007603d533d30453cc45d0f3d119f,0,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343,1
1,05d87a854b34e30b25f07ac7c5b1dc2e,0,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343,1
2,189be6eb839900bf2035481d0db7a7f9,1,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343,1
3,3801fd3f98a4a62e31aa94e3ce156619,0,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343,1
4,7254bb04284937d96ef8309ccd62b058,0,2395,0,113,1,0,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,d78f53d0856617bc782f02c3280dfef2,4fd631b15116098340cdb099c86a5a40,4044dac1931ddaa5a967e09506d76343,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
707963,fd9a83ca399d8ca4ec886a546bcf6a85,1,10700,0,160,1,0,0,0,0,1,89e181a40914767dfee00fa2b7c2dcb5,2b99151dba9558109a35c75a3c05c38b,84fea75411a084611637f301b8970178,38d428257f5ec0e45e91fb74a2d926c1,1
707964,fda07b762c2ee3bb0147ae3a515d204f,1,9572,0,10,1,1,0,0,0,0,dc414a17890cfc17d011d5038b88ca93,6920865be128aa14814810654738b159,828bd0261886a914435f0434dbfc2264,2eac3ef54f291530cfeae907b8823eaf,1
707965,fe6bff9642657a8a47b1f4e8a5165f0b,1,415,0,37,1,0,0,0,0,1,f0112abf369fb03cdc5f5309300913da,ce3aaf25e7e38a0c42d373fb148efc86,972e29ad914b6393f0ae1d369a3a22fd,5c4d53b9fd2c6f9faaa0d2cb77541c16,1
707966,fe8640584e5a182da211b7fabcf96011,1,5020,1,87,1,1,0,0,1,0,b07b75d367ebece55a23ceecc939fff4,0a66f6ab9c0507059da6f22a0e1f1690,4713062d683b3be22a00131d9546c66d,975a4a51b4386eec81f3b698d05bc475,1


### NVTabular workflowの実行

In [14]:
dataset = nvt.Dataset(df)
workflow = nvt.Workflow(selected_features)

# 前処理のワークフローをparquet形式へ変換
workflow.fit_transform(dataset).to_parquet(os.path.join(OUTPUT_FOLDER, "processed_nvt"))

### Export pre-processed data by day

- 時間的な学習と評価を行うため、split
    - train.parquet
    - valid.parquet
    - test.parquet
        - testは要らないかも

In [15]:
sessions_gdf = cudf.read_parquet(os.path.join(OUTPUT_FOLDER, "processed_nvt", "part_0.parquet"))

In [16]:
print(sessions_gdf.head())

                         session_id        yad_no-list yad_type-list  \
0  000007603d533d30453cc45d0f3d119f       [6358, 7528]        [3, 3]   
1  0000ca043ed437a1472c9d1d154eb49b       [4217, 1646]        [3, 3]   
2  0000d4835cf113316fe447e2f80ba1c8      [8998, 10862]        [3, 3]   
3  0000fcda1ae1b2f431e55a7075d1f500       [2335, 1065]        [3, 3]   
4  000104bdffaaad1a1e0a9ebacf585f33  [3064, 738, 3064]     [3, 3, 3]   

  wireless_lan_flg-list onsen_flg-list kd_stn_5min-list kd_bch_5min-list  \
0                [3, 4]         [3, 3]           [4, 4]           [3, 3]   
1                [3, 3]         [3, 3]           [3, 3]           [3, 3]   
2                [3, 3]         [3, 3]           [4, 4]           [3, 3]   
3                [3, 3]         [3, 3]           [3, 3]           [3, 3]   
4             [3, 3, 3]      [3, 3, 3]        [4, 4, 4]        [3, 3, 3]   

  kd_slp_5min-list kd_conv_walk_5min-list wid_cd-list   ken_cd-list  \
0           [3, 3]                 [4, 

In [17]:
from transformers4rec.utils.data_utils import save_time_based_splits

save_time_based_splits(
    data=nvt.Dataset(sessions_gdf),
    output_dir=os.path.join(OUTPUT_FOLDER, "preproc_sessions_by_day"),
    partition_col="day_idx",
    timestamp_col="session_id",
    test_size=0.0,
    val_size=0.2,
)

Creating time-based splits: 100%|██████████| 1/1 [00:01<00:00,  1.42s/it]


## Train - Transformers4Rec

- Transformers4RecはTransformersをwrapしているライブラリ

参考: https://nvidia-merlin.github.io/Transformers4Rec/stable/examples/end-to-end-session-based/01-ETL-with-NVTabular.html

### Get the schema

In [18]:
train = Dataset(os.path.join(OUTPUT_FOLDER, "processed_nvt/part_0.parquet"))
schema = train.schema

In [19]:
schema = schema.select_by_name(
    [
        "yad_no-list",
        "yad_type-list",
        "wireless_lan_flg-list",
        "kd_stn_5min-list",
        "kd_bch_5min-list",
        "kd_slp_5min-list",
        "kd_conv_walk_5min-list",
        "wid_cd-list",
        "ken_cd-list",
        "lrg_cd-list",
        "sml_cd-list",
        "total_room_cnt-list",
    ]
)

schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.freq_threshold,properties.num_buckets,properties.cat_path,properties.max_size,properties.embedding_sizes.dimension,properties.embedding_sizes.cardinality,properties.domain.min,properties.domain.max,properties.domain.name,properties.value_count.min,properties.value_count.max
0,yad_no-list,"(Tags.LIST, Tags.ID, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.yad_no.parquet,0.0,333.0,13809.0,0.0,13808.0,yad_no,0,100
1,yad_type-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.yad_type.parquet,0.0,16.0,5.0,0.0,4.0,yad_type,0,100
2,wireless_lan_flg-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.wireless_lan_flg.parquet,0.0,16.0,5.0,0.0,4.0,wireless_lan_flg,0,100
3,kd_stn_5min-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.kd_stn_5min.parquet,0.0,16.0,5.0,0.0,4.0,kd_stn_5min,0,100
4,kd_bch_5min-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.kd_bch_5min.parquet,0.0,16.0,5.0,0.0,4.0,kd_bch_5min,0,100
5,kd_slp_5min-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.kd_slp_5min.parquet,0.0,16.0,5.0,0.0,4.0,kd_slp_5min,0,100
6,kd_conv_walk_5min-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.kd_conv_walk_5min.parquet,0.0,16.0,5.0,0.0,4.0,kd_conv_walk_5min,0,100
7,wid_cd-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.wid_cd.parquet,0.0,16.0,15.0,0.0,14.0,wid_cd,0,100
8,ken_cd-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.ken_cd.parquet,0.0,16.0,50.0,0.0,49.0,ken_cd,0,100
9,lrg_cd-list,"(Tags.LIST, Tags.ITEM, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.lrg_cd.parquet,0.0,39.0,302.0,0.0,301.0,lrg_cd,0,100


### Define the Transformer-based recommendation model

In [20]:
from transformers4rec import torch
from transformers4rec.torch.ranking_metric import NDCGAt, AvgPrecisionAt, RecallAt

max_sequence_length, d_model, top_k = 50, 64, 10

# Define input module to process tabular input-features and to prepare masked inupts
input_module = torch.TabularSequenceFeatures.from_schema(
    schema,
    max_sequence_length=max_sequence_length,
    continuous_projection=64,
    aggregation="concat",
    d_output=d_model,
    masking="mlm",
)

# Define Next item prediction task
prediction_task = torch.NextItemPredictionTask(
    # TyingEmbeddings　テクニック
    # 入力Embedding行列のweightを出力射影層と結びつける
    weight_tying=True,
    metrics=[
        NDCGAt(top_ks=[top_k], labels_onehot=True),
        AvgPrecisionAt(top_ks=[top_k], labels_onehot=True),
        RecallAt(top_ks=[top_k], labels_onehot=True),
    ],
        
)

# Define the config of the XLNet Transformer architecture
transformer_config = torch.XLNetConfig.build(
    d_model=d_model,
    n_head=8,
    n_layer=4,
    total_seq_length=max_sequence_length,
)

model = transformer_config.to_torch_model(input_module, prediction_task)

In [21]:
BATCH_SIZE_TRAIN = int(os.environ.get("BATCH_SIZE_TRAIN", "2048"))
BATCH_SIZE_VALID = int(os.environ.get("BATCH_SIZE_VALID", "2048"))

training_args = torch.trainer.T4RecTrainingArguments(
    output_dir=OUTPUT_FOLDER,
    max_sequence_length=SESSIONS_MAX_LENGTH,
    data_loader_engine="merlin",
    num_train_epochs=30,
    dataloader_drop_last=False,
    predict_top_k=top_k,
    per_device_train_batch_size=BATCH_SIZE_TRAIN,
    per_device_eval_batch_size=BATCH_SIZE_VALID,
    learning_rate=5e-3,
    # 半浮動点小数にすることでメモリ使用量が削減できる
    fp16=True,
    report_to=[],
)

In [22]:
recsys_trainer = torch.Trainer(
    model=model,
    args=training_args,
    schema=schema,
    compute_metrics=True,
)

Using amp half precision backend


In [23]:
from transformers4rec.torch.utils.examples_utils import wipe_memory

recsys_trainer.train_dataset_or_path = os.path.join(
    OUTPUT_FOLDER,
    "preproc_sessions_by_day",
    "3",
    "train.parquet"
)
recsys_trainer.reset_lr_scheduler()
recsys_trainer.train()

wipe_memory()

***** Running training *****
  Num examples = 231424
  Num Epochs = 30
  Instantaneous batch size per device = 2048
  Total train batch size (w. parallel, distributed & accumulation) = 2048
  Gradient Accumulation steps = 1
  Total optimization steps = 3390


Step,Training Loss
500,7.6431
1000,6.0175
1500,5.2469
2000,4.875
2500,4.6389
3000,4.4842


Saving model checkpoint to /kaggle/working/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to /kaggle/working/checkpoint-1000
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to /kaggle/working/checkpoint-1500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to /kaggle/working/checkpoint-2000
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to /kaggle/working/checkpoint-2500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
Saving model checkpoint to /kaggle/working/checkpoint-3000
Trainer.model is not a `PreTrainedModel`, only saving its state dict.


Training completed. Do not forget to share your model on huggingface.co/models =)




In [24]:
indexed_by_time_metrics = {}

recsys_trainer.eval_dataset_or_path = os.path.join(
    OUTPUT_FOLDER,
    "preproc_sessions_by_day",
    "3",
    "valid.parquet"
)

eval_metrics = recsys_trainer.evaluate(metric_key_prefix="valid")

for key in sorted(eval_metrics.keys()):
    if "at_" in key:
        print(" %s = %s" % (key.replace("_at_", "@"), str(eval_metrics[key])))
        if "indexed_by_time_" + key.replace("_at_", "@") in indexed_by_time_metrics:
            indexed_by_time_metrics["indexed_by_time_" + key.replace("_at_", "@")] += [
                eval_metrics[key]
            ]
        else:
            indexed_by_time_metrics["indexed_by_time_" + key.replace("_at_", "@")] = [
                eval_metrics[key]
            ]

 valid_/next-item/avg_precision@10 = 0.09930777549743652
 valid_/next-item/ndcg@10 = 0.14374642074108124
 valid_/next-item/recall@10 = 0.2835339605808258


In [25]:
print(indexed_by_time_metrics)

{'indexed_by_time_valid_/next-item/avg_precision@10': [0.09930777549743652], 'indexed_by_time_valid_/next-item/ndcg@10': [0.14374642074108124], 'indexed_by_time_valid_/next-item/recall@10': [0.2835339605808258]}


In [26]:
recsys_trainer.model.save(path=OUTPUT_FOLDER, model_name="t4rec")