In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr
import matplotlib.pyplot as plt

import sys
sys.path.append("..")

from analyst import Analyst
from config import ModelConfig, TrainerConfig
from dataset import load_dataset_manager
from trainers import PyTorchTrainer

In [None]:
torch.manual_seed(0)
np.random.seed(0)

trainer_config = TrainerConfig(
    dataset_name="movielens", epochs=0, load_model=True, batch_size=64,
    model_dir="../cache/model/", dataset_dir="../cache/dataset/"
)
model_config = ModelConfig(d_model=128, lr=0.0001, init_embedding_std=0.2, window_size=5)

In [None]:
dataset_manager = load_dataset_manager(
    dataset_name=trainer_config.dataset_name,
    dataset_dir=trainer_config.dataset_dir,
    load_dataset=trainer_config.load_dataset,
    save_dataset=trainer_config.save_dataset,
    window_size=model_config.window_size,
    data_dir="../data/"
)

In [None]:
trainer = PyTorchTrainer(
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config,
)
analyst = Analyst(
    dataset_manager=dataset_manager,
    model=trainer.model
)

In [None]:
fig, ax = analyst.visualize_similarity_heatmap(
    seq_meta_names=["age:Under 18", "age:18-24"],
    item_meta_names=["genre:Adventure", "genre:Comedy"],
    figsize=(12, 8)
)

In [None]:
sorted(list(dataset_manager.seq_meta_dict["age"]))