In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr
import matplotlib.pyplot as plt

from analyst import Analyst
from config import ModelConfig, TrainerConfig
from dataset import load_dataset_manager

In [None]:
torch.manual_seed(0)
np.random.seed(0)

trainer_config = TrainerConfig(
    dataset_name="toydata-hard", epochs=5, ignore_saved_model=True, load_model=False, batch_size=64
)
model_config = ModelConfig(d_model=128, lr=0.0001, init_embedding_std=0.2, normalize_embedding_weight=True)

In [None]:
trainer_config

In [None]:
model_config

In [None]:
dataset_manager = load_dataset_manager(
    dataset_name=trainer_config.dataset_name,
    dataset_dir=trainer_config.dataset_dir,
    load_dataset=trainer_config.load_dataset,
    save_dataset=trainer_config.save_dataset,
    window_size=model_config.window_size,
)

In [None]:
analyst = Analyst(
    dataset_manager=dataset_manager,
    trainer_config=trainer_config,
    model_config=model_config,
)

In [None]:
def calc_spearmanr(v, c):
    p_c = []
    for e in c:
        for i, e2 in enumerate(v):
            if e == e2[1]:
                p_c.append(i)
    p_v = list(range(len(v)))
    return spearmanr(p_c, p_v)

m_s = []
f_s = []

def on_epoch_start(epoch: int):
    m_v = analyst.similarity_between_seq_meta_and_item_meta(
        "gender", "M", "genre", method="inner-product", num_top_values=30, verbose=True
    )
    f_v = analyst.similarity_between_seq_meta_and_item_meta(
        "gender", "F", "genre", method="inner-product", num_top_values=30, verbose=True
    )
    # f_c = [
    #     "Romance",
    #     "Musical",
    #     "Children's",
    #     "Animation",
    #     "Drama",
    #     "Comedy",
    #     "Mystery",
    #     "Documentary",
    #     "Fantasy",
    #     "Film-Noir",
    #     "Thriller",
    #     "Crime",
    #     "War",
    #     "Adventure",
    #     "Horror",
    #     "Action",
    #     "Sci-Fi",
    #     "Western"
    # ]
    # m_c = reversed(f_c)
    # m_spearman = calc_spearmanr(m_v, m_c)
    # f_spearman = calc_spearmanr(f_v, f_c)
    # print(f"M_spearman_result: {m_spearman}")
    # print(f"F_spearman_result: {f_spearman}")
    # m_s.append(m_spearman.correlation)
    # f_s.append(f_spearman.correlation)

    # print(analyst.trainer.seq_meta_embedding)

    # analyst.similarity_between_seq_meta_and_item_meta(
    #     "age", "10", "year", method="inner-product", num_top_values=30, verbose=True
    # )
    # analyst.similarity_between_seq_meta_and_item_meta(
    #     "age", "30", "year", method="inner-product", num_top_values=30, verbose=True
    # )
    # analyst.similarity_between_seq_meta_and_item_meta(
    #     "age", "50", "year", method="inner-product", num_top_values=30, verbose=True
    # )

    print(
        "seq:",
        analyst.trainer.model.embedding_seq.embedding_element.weight.data.mean(),
        analyst.trainer.model.embedding_seq.embedding_element.weight.data.std(),
        analyst.trainer.model.embedding_seq.embedding_meta.weight.data.mean(),
        analyst.trainer.model.embedding_seq.embedding_meta.weight.data.std(),
    )
    print(
        "item:",
        analyst.trainer.model.embedding_item.embedding_element.weight.data.mean(),
        analyst.trainer.model.embedding_item.embedding_element.weight.data.std(),
        analyst.trainer.model.embedding_item.embedding_meta.weight.data.mean(),
        analyst.trainer.model.embedding_item.embedding_meta.weight.data.std(),
    )
    print(F.cosine_similarity(analyst.trainer.model.embedding_seq.embedding_element.weight[6], analyst.trainer.model.embedding_seq.embedding_element.weight[7], dim=0))
    analyst.visualize_meta_embedding("age", "year", method="pca")

    # if epoch == 5:
    #     analyst.trainer.model.set_train_mode("seq")

In [None]:
analyst.fit(show_fig=False, on_epoch_start=on_epoch_start)

In [None]:
s = []
for m, f in zip(m_s, f_s):
    s.append(m+f)
fig, ax = plt.subplots()
ax.plot(m_s)
ax.plot(f_s)
ax.plot(s)

In [None]:
from sklearn.cluster import KMeans
from util import visualize_cluster

kmeans = KMeans(n_clusters=10)
h_seq = list(analyst.trainer.seq_embedding.values())
kmeans.fit(h_seq)

In [None]:
cluster_labels = kmeans.labels_

In [None]:
from collections import Counter
for k, l in zip(analyst.trainer.seq_embedding.keys(), cluster_labels):
    print(k, l)