In [1]:
import os
import sys
from pathlib import Path
sys.path.append(os.path.join(Path().resolve(), "src/"))

from src.config import ModelConfig, TrainerConfig
from src.dataset import load_dataset_manager
from src.trainers import PyTorchTrainer
from src.analyst import Analyst

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# コンフィグの設定
model_config = ModelConfig()
trainer_config = TrainerConfig()

In [7]:
# データセットのロード
dataset_manager = load_dataset_manager(
    dataset_name=trainer_config.dataset_name,
    dataset_dir=trainer_config.dataset_dir,
    load_dataset=trainer_config.load_dataset,
    save_dataset=trainer_config.save_dataset,
    window_size=model_config.window_size,
)

load cached dataset_manager from: cache/dataset/toydata-small.pickle


In [8]:
# Trainerの作成
trainer = PyTorchTrainer(
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config,
)

load_state_dict from: cache/model/toydata-small/attentive.pt


In [10]:
# モデルの学習
trainer.fit()

train start


100%|██████████| 71/71 [00:00<00:00, 132.31it/s]
100%|██████████| 71/71 [00:00<00:00, 422.47it/s]


Epoch: 1, loss: 0.2926398196690519, test_loss: {'test': 0.28961342153414876}
saved best model to cache/model/toydata-small/best-attentive.pt


100%|██████████| 71/71 [00:00<00:00, 134.08it/s]
100%|██████████| 71/71 [00:00<00:00, 431.62it/s]


Epoch: 2, loss: 0.28176973907040875, test_loss: {'test': 0.2570957060850842}
saved best model to cache/model/toydata-small/best-attentive.pt


100%|██████████| 71/71 [00:00<00:00, 137.15it/s]
100%|██████████| 71/71 [00:00<00:00, 431.07it/s]

Epoch: 3, loss: 0.25607466277941854, test_loss: {'test': 0.25452997545960926}
saved best model to cache/model/toydata-small/best-attentive.pt
train end
saved model to cache/model/toydata-small/attentive.pt





{'train': [0.2926398196690519, 0.28176973907040875, 0.25607466277941854],
 'test': [0.28961342153414876, 0.2570957060850842, 0.25452997545960926]}

In [11]:
# Analyst（モデルの出力を分析するクラス）インスタンスの作成
analyst = Analyst(trainer.model, dataset_manager)

In [17]:
# 系列と要素の関連性（デフォルトだと内積）の出力
analyst.similarity_between_seq_and_item(seq_index=0, num_recent_items=10)

Unnamed: 0,similarity,item
0,0.625977,v_1_M_2000
1,0.358295,v_1_E_2000
2,0.155408,v_2_E_2000
3,0.155408,v_2_E_2000
4,0.145529,v_2_F_1990
5,0.013705,v_2_M_1960
6,-0.137968,v_1_F_2000
7,-0.142515,v_1_M_1980
8,-0.281191,v_2_F_2000
9,-0.281191,v_2_F_2000


In [18]:
# 系列と要素の補助情報の関連性（デフォルトだと内積）の出力
analyst.similarity_between_seq_and_item_meta(seq_index=0, item_meta_name="genre")

Unnamed: 0,similarity,item_meta
0,0.094838,genre:F
1,0.094838,genre:F
2,0.094838,genre:F
3,0.094838,genre:F
4,0.094838,genre:F
5,0.094838,genre:F
6,0.094838,genre:F
7,0.094838,genre:F
8,0.094838,genre:F
9,0.094838,genre:F


In [14]:
# 系列の補助情報と要素の補助情報の関連性（デフォルトだと内積）の出力
analyst.similarity_between_seq_meta_and_item_meta("gender", "M", "genre")

Unnamed: 0,similarity,item_meta
0,0.291772,genre:E
1,0.291772,genre:E
2,0.291772,genre:E
3,0.291772,genre:E
4,0.291772,genre:E
5,0.291772,genre:E
6,0.291772,genre:E
7,0.291772,genre:E
8,0.291772,genre:E
9,0.291772,genre:E


In [19]:
# 系列と要素の固有の特徴と補助情報の関連性（デフォルトだと内積）の出力
analyst.analyze_seq(seq_index=0)

Unnamed: 0,similarity,seq,item
0,1.676092,gender:F,genre:F
1,1.357353,age:20,genre:M
2,0.895531,gender:F,genre:E
3,0.794512,gender:F,year:1990
4,0.777306,age:20,year:2000
5,0.671562,gender:F,year:1980
6,0.551117,gender:F,year:1960
7,0.450346,gender:F,year:2000
8,0.412878,u_0_F_20_20_F1,year:1980
9,0.311633,age:20,genre:F
