<a href="https://colab.research.google.com/github/rurusasu/RecommendSystem/blob/main/RecBoleTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Recboleを使って様々なレコメンドシステムをテストする

参考
* [RecBole を用いてクックパッドマートのデータに対する50以上のレコメンドモデルの実験をしてみた](https://techlife.cookpad.com/entry/2021/11/04/090000)
* [Atomic Files](https://recbole.io/docs/user_guide/data/atomic_files.html)
* [新しいデータセットの実行](https://recbole.io/docs/user_guide/usage/running_new_dataset.html#prepare-atomic-files)
* [RecBoleを使ってみよう3 Atomicファイルについて](https://zenn.dev/kentoo1/articles/d5aef1c67901a0)


In [None]:
# Googleドライブのマウント
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
base_dir = "/content/drive/MyDrive/ColabNotebooks/RecBole"

In [None]:
!pip install  --upgrade -q recbole ray kmeans_pytorch

## ライブラリ読み込み

In [None]:
import argparse
import click
import os
import tempfile

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.decomposition import NMF
from sklearn.preprocessing import LabelEncoder, StandardScaler
from recbole.quick_start import run_recbole, run

plt.style.use("seaborn-whitegrid")

  plt.style.use("seaborn-whitegrid")


## データの読み込みとAtomic file の作成

### データの読み込み

In [None]:
# エクセルファイルからデータを読み込む
data = pd.read_excel(f"{base_dir}/data/sample_merged_full.xlsx")
data.head(2)

Unnamed: 0,user_id,target_id,rating,rating_conv,user_name_target,nickname_target,gender_target,location_target,age_range_target,height_range_target,...,body_type_user,personality_user,appearance_user,job_user,blood_type_user,car_user,interests_user,salary_user,plan_user,account_creation_timestamp_user
0,1,8627.0,0.0,1,原田遥,アオイ,女性,埼玉県伊奈町,45-49,150-154,...,スリム,元気,セクシー系,会社員,O型,有り,技術・プログラミング,8160000,option2,2024-01-14 00:11:34
1,1,18213.0,0.0,1,井上萌,ユイ,女性,福島県玉川村,30-34,150-154,...,スリム,元気,セクシー系,会社員,O型,有り,技術・プログラミング,8160000,option2,2024-01-14 00:11:34


### Atomic file 作成

In [None]:
dataset_dir = f"{base_dir}/dataset/profile"

# NaNを含むレコードを削除
data = data.dropna()

# ratingを除く数値データをint型に変換
numeric_columns = data.select_dtypes(include=['number']).columns.tolist()
numeric_columns.remove('rating')
data[numeric_columns] = data[numeric_columns].astype(int)

# カラム名とデータ型の確認
numeric_columns = data.select_dtypes(include=['number']).columns.tolist()
categorical_columns = data.select_dtypes(exclude=['number']).columns.tolist()

# userファイルの作成
user_columns = [col for col in data.columns if '_user' in col] + ['user_id']
user_df = data[user_columns]
user_df.columns = [
    f"{col}:token" if col in numeric_columns else f"{col}:token_seq" for col in user_df.columns
]
user_df.to_csv(f'{dataset_dir}/profile.user', index=False, sep='\t')

# itemファイルの作成
item_columns = [col for col in data.columns if '_target' in col] + ['target_id']
item_df = data[item_columns]
item_df.columns = [
    f"{col}:token" if col in numeric_columns else f"{col}:token_seq" for col in item_df.columns
]
item_df.to_csv(f'{dataset_dir}/profile.item', index=False, sep='\t')

# interファイルの作成
inter_df = data[['user_id', 'target_id', 'rating']]
inter_df.columns = [
    'user_id:token', 'target_id:token', 'rating:float'
]
inter_df.to_csv(f'{dataset_dir}/profile.inter', index=False, sep='\t')

print("Atomic files have been created successfully.")

Atomic files have been created successfully.


# モデル設定
* [Model list](https://recbole.io/docs/user_guide/model_intro.html#context-aware-recommendation)

In [None]:
model_list = [
    # General Recommendation
    'LDiffRec',
    'DiffRec',
    'Random',
    'NCL',
    'SimpleX',
    'NCEPLRec',
    'ADMMSLIM',
    'SGL',
    'SLIMElastic',
    'EASE',
    'RecVAE',
    'RaCT',
    'NNCF',
    'ENMF',
    'CDAE',
    'MacridVAE',
    'MultiDAE',
    'MultiVAE',
    'LINE',
    'DGCF',
    'LightGCN',
    'NGCF',
    'GCMC',
    'SpectralCF',
    'NAIS',
    'FISM',
    'DMF',
    'ConvNCF',
    # Context-aware Recommendation
    'EulerNet',
    'FiGNN',
    'KD_DAGFM',
    'AutoInt',
    'DCNV2',
    'DCN',
    'DIEN',
    'DIN',
    'WideDeep',
    'DSSM',
    'PNN',
    'FNN',
    'FwFM',
    'FFM',
    'AFM',
    'xDeepFM',
    'DeepFM',
    # Sequential Recommendation
    'FEARec',
    'CORE',
    'SINE',
    'LightSANs',
    'NPE',
    'HRM',
    'HGN',
    'RepeatNet',
    'SHAN',
    'FOSSIL',
    'KSR',
    'GRU4RecKG',
    'S3Rec',
    'FDSA',
    'SASRecF',
    'GRU4RecF',
    'GCSAN',
    'SRGNN',
    'BERT4Rec',
    'SASRec',
    'TransRec',
    'NextItNet',
    'Caser'
]

# 実行

In [None]:
if __name__ == "__main__":
    dataset = 'profile'

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", "-m", type=str, default="BPR", help="name of models")
    parser.add_argument(
        "--dataset", "-d", type=str, default=dataset, help="name of datasets"
    )
    parser.add_argument("--config_files", type=str, default=f"{base_dir}/config/profile.yml", help="config files")
    parser.add_argument(
        "--nproc", type=int, default=1, help="the number of process in this group"
    )
    parser.add_argument(
        "--ip", type=str, default="localhost", help="the ip of master node"
    )
    parser.add_argument(
        "--port", type=str, default="5678", help="the port of master node"
    )
    parser.add_argument(
        "--world_size", type=int, default=-1, help="total number of jobs"
    )
    parser.add_argument(
        "--group_offset",
        type=int,
        default=0,
        help="the global rank offset of this group",
    )

    args, _ = parser.parse_known_args()

    config_file_list = (
        args.config_files.strip().split(" ") if args.config_files else None
    )

    run(
        args.model,
        args.dataset,
        config_file_list=config_file_list,
        nproc=args.nproc,
        world_size=args.world_size,
        ip=args.ip,
        port=args.port,
        group_offset=args.group_offset,
    )

Train     0: 100%|████████████████████████████████████████████████████| 1/1 [00:00<00:00, 41.19it/s]
Evaluate   : 100%|█████████████████████████████████████████████████| 21/21 [00:00<00:00, 705.61it/s]
Train     1: 100%|████████████████████████████████████████████████████| 1/1 [00:00<00:00, 32.19it/s]
Evaluate   : 100%|██████████████████████████████████████████████████| 21/21 [00:00<00:00, 35.18it/s]
Train     2: 100%|████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.67it/s]
Evaluate   : 100%|██████████████████████████████████████████████████| 21/21 [00:00<00:00, 89.82it/s]
Train     3: 100%|████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.81it/s]
Evaluate   : 100%|█████████████████████████████████████████████████| 21/21 [00:00<00:00, 239.48it/s]
Train     4: 100%|████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15.20it/s]
Evaluate   : 100%|█████████████████████████████████████████████████| 21/21 [00:00<00:00, 28

KeyboardInterrupt: 