In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr
import matplotlib.pyplot as plt
import random

import sys
sys.path.append("..")

from analyst import Analyst
from config import ModelConfig, TrainerConfig
from dataset import load_dataset_manager
from trainers import PyTorchTrainer

In [None]:
torch.manual_seed(0)
np.random.seed(0)

trainer_config = TrainerConfig(
    dataset_name="movielens", epochs=0, load_model=True, batch_size=64,
    model_dir="../cache/model/", dataset_dir="../cache/dataset/"
)
model_config = ModelConfig(d_model=128, lr=0.0001, init_embedding_std=0.2, window_size=5)

In [None]:
dataset_manager = load_dataset_manager(
    dataset_name=trainer_config.dataset_name,
    dataset_dir=trainer_config.dataset_dir,
    load_dataset=trainer_config.load_dataset,
    save_dataset=trainer_config.save_dataset,
    window_size=model_config.window_size,
    data_dir="../data/"
)

In [None]:
trainer = PyTorchTrainer(
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config,
)
analyst = Analyst(
    dataset_manager=dataset_manager,
    model=trainer.model
)

In [None]:
from util import to_full_meta_value

seq_meta_name = "occupation"
item_meta_name = "genre"
seq_meta_names = sorted([
    to_full_meta_value(seq_meta_name, e)
    for e in dataset_manager.seq_meta_dict[seq_meta_name]
])
item_meta_names = sorted([
    to_full_meta_value(item_meta_name, e)
    for e in dataset_manager.item_meta_dict[item_meta_name]
])

fig, ax = analyst.visualize_similarity_heatmap(
    seq_meta_names=seq_meta_names,
    item_meta_names=item_meta_names,
)

fig.tight_layout()
fig.savefig(f"../data/fig_heatmap_{seq_meta_name}_{item_meta_name}.pdf", format="pdf")

In [None]:
["age:Under 18", "age:18-24", "age:25-34", "age:35-44", "age:45-49", "age:50-55", "age:56+"]

In [None]:
rnd = random.Random(0)

target_ids = []
for id, d in dataset_manager.seq_metadata.items():
    if d["gender"] == "M" and d["age"] == "25-34" and d["occupation"] == "college/grad student":
        target_ids.append(int(id))

random.shuffle(target_ids, rnd.random)

target_ids = sorted(target_ids[:5])
item_keys = dataset_manager.item_meta_le.classes_
seq_keys = list(map(lambda s: "顧客" + str(s), target_ids))

data = np.zeros((len(seq_keys), len(item_keys)))
for i, id in enumerate(target_ids):
    seq_name = dataset_manager.seq_le.classes_[id]
    df = analyst.analyze_seq(id)
    df = df[df.seq == seq_name]
    for j, item_key in enumerate(item_keys):
        data[i][j] = df[df.item == item_key].similarity.values[0]

import util
plt.rcParams['font.family'] = 'Hiragino Sans'
plt.rcParams['font.weight'] = 'regular'
fig, ax = util.visualize_heatmap(data, seq_keys, item_keys)
fig.tight_layout()
fig.savefig("data/fig_heatmap_user_itemmeta.pdf")