In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import os
import sys
from pathlib import Path

sys.path.append(os.path.join(Path().resolve(), "../src/"))

from analyst import Analyst
from config import ModelConfig, TrainerConfig
from dataset_manager import load_dataset_manager
from trainers import PyTorchTrainer
from model import load_model

In [None]:
torch.manual_seed(0)
np.random.seed(0)

trainer_config = TrainerConfig(
    dataset_name="toydata-paper", epochs=0, ignore_saved_model=True, load_model=True, batch_size=64,
    model_dir="../cache/model/", dataset_dir="../cache/dataset/"
)
model_config = ModelConfig(d_model=128, lr=0.0001, init_embedding_std=0.2, window_size=5)

In [None]:
dataset_manager = load_dataset_manager(
    dataset_name=trainer_config.dataset_name,
    dataset_dir=trainer_config.dataset_dir,
    load_dataset=trainer_config.load_dataset,
    save_dataset=trainer_config.save_dataset,
    window_size=model_config.window_size,
    data_dir="../data/"
)

In [None]:
model = load_model(
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config
)
trainer = PyTorchTrainer(
    model=model,
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config,
)
analyst = Analyst(
    dataset_manager=dataset_manager,
    model=trainer.model
)

In [None]:
seq = analyst.seq_embedding
seq_meta = analyst.seq_meta_embedding
item_meta = analyst.item_meta_embedding

In [None]:
seq_keys = ["u_4_M_20_50_M5", "u_1_M_20_50_M2", "u_38_M_20_50_M4"]
seq_meta_keys = ["gender:M", "gender:F", "age:20", "age:30", "age:40", "age:50", "age:60"]
item_meta_keys = ["genre:M", "genre:E", "genre:F", "year:1960", "year:1970", "year:1980", "year:1990", "year:2000"]

In [None]:
# plt.rcParams["font.family"] = "Osaka" # or "Hiragino Mincho ProN"
# plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['font.weight'] = 'regular'
plt.rcParams['figure.subplot.bottom'] = 0.30 # なんか保存する時にx軸のラベルが見切れるので、下の余白を調整
# cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["cyan", "white", "hotpink"])
cmap = "OrRd"
# plt.rcParams["font.size"] = 16

In [None]:
data = np.zeros((len(seq_meta_keys), len(item_meta_keys)))
for i, seq_key in enumerate(seq_meta_keys):
    for j, item_key in enumerate(item_meta_keys):
        data[i][j] = torch.dot(seq_meta[seq_key], item_meta[item_key])

# display_seq_meta_keys = list(map(lambda s: s.replace("gender:", "性別:").replace("M", "男性").replace("F", "女性").replace("age:", "年代:"), seq_meta_keys))
# display_item_meta_keys = list(map(lambda s: s.replace("genre:", "ジャンル:").replace("year:", "発売年:"), item_meta_keys))
display_seq_meta_keys = seq_meta_keys
display_item_meta_keys = item_meta_keys

plt.figure(figsize = (6,7/2))
ax = sns.heatmap(data, linewidth=0.2, annot=True, fmt="5.2f",
    yticklabels=display_seq_meta_keys, xticklabels=display_item_meta_keys, cmap=cmap, cbar=False
)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
# plt.subplots_adjust(left=0.12, right=1, bottom=0.33, top=1)
plt.savefig("../data/fig_heatmap.pdf", format="pdf", dpi=300)
# plt.savefig("../data/fig_heatmap.svg", format="svg", dpi=300)

In [None]:
data = np.zeros((len(seq_keys), len(item_meta_keys)))
for i, seq_key in enumerate(seq_keys):
    for j, item_key in enumerate(item_meta_keys):
        data[i][j] = np.dot(seq[seq_key], item_meta[item_key])

display_item_meta_keys = list(map(lambda s: s.replace("genre:", "ジャンル:").replace("year:", "発売年:"), item_meta_keys))
display_seq_meta_keys = ["顧客A", "顧客B", "顧客C"]
display_item_meta_keys = item_meta_keys
display_seq_meta_keys = ["Customer A", "Customer B", "Customer C"]

plt.figure(figsize = (6,2))
ax = sns.heatmap(data, linewidth=0.2, annot=True, fmt="5.2f",
    yticklabels=display_seq_meta_keys, xticklabels=display_item_meta_keys, cmap=cmap, cbar=False
)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
plt.subplots_adjust(left=0.18, right=1, bottom=0.5, top=1)
plt.savefig("../data/fig_heatmap_user.pdf", format="pdf")
# plt.savefig("../data/fig_heatmap_user.svg", format="svg", dpi=300)