In [None]:
import numpy as np
import torch
import torch.nn.functional as F

from analyst import Analyst
from config import ModelConfig, TrainerConfig
from dataset import load_dataset_manager

In [None]:
torch.manual_seed(24)
np.random.seed(24)

trainer_config = TrainerConfig(
    dataset_name="movielens", epochs=10, ignore_saved_model=True, load_model=False, batch_size=64
)
model_config = ModelConfig(d_model=128, lr=0.001, init_embedding_std=5)

In [None]:
trainer_config

In [None]:
model_config

In [None]:
dataset_manager = load_dataset_manager(
    dataset_name=trainer_config.dataset_name,
    dataset_dir=trainer_config.dataset_dir,
    load_dataset=trainer_config.load_dataset,
    save_dataset=trainer_config.save_dataset,
    window_size=model_config.window_size,
)

In [None]:
analyst = Analyst(
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config,
)

In [None]:
def on_epoch_start(epoch: int):
    analyst.similarity_between_seq_meta_and_item_meta(
        "gender", "M", "genre", method="inner-product", num_top_values=30
    )
    analyst.similarity_between_seq_meta_and_item_meta(
        "gender", "F", "genre", method="inner-product", num_top_values=30
    )
    analyst.visualize_meta_embedding("gender", "genre", method="pca")
    print(F.cosine_similarity(analyst.trainer.model.embedding_seq_meta.weight[6], analyst.trainer.model.embedding_seq_meta.weight[7], dim=0))

In [None]:
analyst.fit(show_fig=False, on_epoch_start=on_epoch_start)