In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr
import matplotlib.pyplot as plt

import sys
sys.path.append("..")

from analyst import Analyst
from config import ModelConfig, TrainerConfig
from dataset import load_dataset_manager

In [None]:
torch.manual_seed(0)
np.random.seed(0)

trainer_config = TrainerConfig(
    dataset_name="toydata-paper", epochs=3, ignore_saved_model=True, load_model=True, batch_size=64,
    model_dir="../cache/model/", dataset_dir="../cache/dataset/"
)
model_config = ModelConfig(d_model=128, lr=0.0001, init_embedding_std=0.2, normalize_embedding_weight=True, window_size=5)

In [None]:
dataset_manager = load_dataset_manager(
    dataset_name=trainer_config.dataset_name,
    dataset_dir=trainer_config.dataset_dir,
    load_dataset=trainer_config.load_dataset,
    save_dataset=trainer_config.save_dataset,
    window_size=model_config.window_size,
    data_dir="../data/"
)

In [None]:
analyst = Analyst(
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config,
)

In [None]:
analyst.similarity_between_seq_meta_and_item_meta("gender", "F", "genre", method="inner-product")

In [None]:
analyst.similarity_between_seq_meta_and_item_meta("gender", "F", "genre", method="inner-product")

In [None]:
seq_meta = analyst.trainer.seq_meta_embedding
item_meta = analyst.trainer.item_meta_embedding

In [None]:
seq_meta_keys = ["gender:M", "gender:F", "age:20", "age:30", "age:40", "age:50", "age:60"]
item_meta_keys = ["genre:M", "genre:E", "genre:F", "year:1960", "year:1970", "year:1980", "year:1990", "year:2000"]

In [None]:
data = np.zeros((len(seq_meta_keys), len(item_meta_keys)))
for i, seq_key in enumerate(seq_meta_keys):
    for j, item_key in enumerate(item_meta_keys):
        data[i][j] = np.dot(seq_meta[seq_key], item_meta[item_key])

In [None]:
import seaborn as sns
# plt.rcParams["font.family"] = "Osaka" # or "Hiragino Mincho ProN"
plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['font.weight'] = 'regular'
plt.rcParams['figure.subplot.bottom'] = 0.30 # なんか保存する時にx軸のラベルが見切れるので、下の余白を調整

In [None]:
display_seq_meta_keys = list(map(lambda s: s.replace("gender:", "性別:").replace("M", "男性").replace("F", "女性").replace("age:", "年齢:"), seq_meta_keys))
display_item_meta_keys = list(map(lambda s: s.replace("genre:", "ジャンル:").replace("year:", "発売年:"), item_meta_keys))
ax = sns.heatmap(data, linewidth=0.2, annot=True, fmt="5.2f",
    yticklabels=display_seq_meta_keys, xticklabels=display_item_meta_keys, cmap="OrRd", cbar=False
)
# ax.set(xlabel="商品の補助情報", ylabel="顧客の補助情報")
# plt.savefig("data/fig_heatmap.svg", format="svg")
# tight_layoutでいい?
plt.subplots_adjust(left=0.12, right=1, bottom=0.3, top=1)
plt.savefig("data/fig_heatmap.png", format="png", dpi=300)
plt.savefig("data/fig_heatmap.pdf", format="pdf", dpi=300)