In [10]:
from pathlib import Path
import torch
import tqdm

from au2v.config import ModelConfig, TrainerConfig
from au2v.dataset_manager import load_dataset_manager
from au2v.trainer import PyTorchTrainer, calc_accuracy
from au2v.model import load_model, PyTorchModel

In [None]:
def calc_accuracy(
    model: PyTorchModel,
    num_item: int,
    test_dataset: list[tuple[int, list[int], list[int]]],
    top_k: list[int],
) -> dict[str, float]:
    hit_counts = {k: 0 for k in top_k}
    for seq_index, context_items, target_indices in tqdm.tqdm(test_dataset):
        rec_list = model.output_rec_lists(
            seq_index=torch.LongTensor([seq_index]),
            item_indices=torch.LongTensor([context_items]),
            cand_item_indices=torch.arange(num_item),
            k=max(top_k),
        )
        for k in top_k:
            hit_counts[k] += len(set(target_indices) & set(rec_list[0][:k]))

    total_rec = len(test_dataset)
    results = {
        f"Accuracy@{k}": hit_count / total_rec / k
        for k, hit_count in hit_counts.items()
    }
    return results

In [12]:
models = ["old-attentive", "doc2vec", "attentive"]

for model_name in models:
    model_config = ModelConfig(
        weight_decay=1e-8,
        max_embedding_norm=1.0,
        d_model=64,
        lr=1e-4 if model_name == "doc2vec" else 5e-5
    )
    trainer_config = TrainerConfig(
        dataset_name="movielens",
        model_name=model_name,
        load_dataset=False,
        save_dataset=False,
        load_model=False,
        ignore_saved_model=True,
        epochs=5,
    )

    print(model_config)
    print(trainer_config)

    dataset_manager = load_dataset_manager(
        dataset_name=trainer_config.dataset_name,
        dataset_dir=trainer_config.dataset_dir,
        data_dir="../data/",
        load_dataset=trainer_config.load_dataset,
        save_dataset=trainer_config.save_dataset,
        window_size=model_config.window_size,
    )
    model = load_model(
        dataset_manager=dataset_manager,
        trainer_config=trainer_config,
        model_config=model_config,
    )
    trainer = PyTorchTrainer(
        model=model,
        dataset_manager=dataset_manager,
        trainer_config=trainer_config,
        model_config=model_config,
    )
    trainer.fit()

ModelConfig(d_model=64, init_embedding_std=0.2, max_embedding_norm=1.0, window_size=5, negative_sample_size=5, lr=5e-05, weight_decay=1e-08)
TrainerConfig(model_name='old-attentive', dataset_name='movielens', epochs=5, batch_size=64, verbose=False, ignore_saved_model=True, load_model=False, save_model=True, load_dataset=False, save_dataset=False, model_dir='cache/model/', dataset_dir='cache/dataset/', device='cpu')
dataset_manager does not exist at: cache/dataset/movielens.pickle, create dataset
num_seq: 6040, num_item: 3706, num_item_meta: 28, num_seq_meta: 30, num_item_meta_types: 3, num_seq_meta_types: 3
to_sequential_data start
to_sequential_data end


100%|██████████| 9190/9190 [02:02<00:00, 75.06it/s]


train 0.006227718123920886


100%|██████████| 2298/2298 [00:09<00:00, 244.99it/s]


valid 0.005795715705987552
saved model to cache/model/movielens/old-attentive.pt


100%|██████████| 1591/1591 [00:22<00:00, 70.49it/s]


{'Accuracy@10': 0.15003142677561282, 'Accuracy@30': 0.13779593547035407, 'Accuracy@50': 0.1310747957259585}


100%|██████████| 9190/9190 [02:07<00:00, 71.94it/s]


train 0.005340048728330351


100%|██████████| 2298/2298 [00:09<00:00, 246.36it/s]


valid 0.0049767685672591705
saved model to cache/model/movielens/old-attentive.pt


100%|██████████| 1591/1591 [00:22<00:00, 71.31it/s]


{'Accuracy@10': 0.21125078566939032, 'Accuracy@30': 0.1906138696836371, 'Accuracy@50': 0.1779635449402891}


100%|██████████| 9190/9190 [02:15<00:00, 67.96it/s]


train 0.004740277529360728


100%|██████████| 2298/2298 [00:08<00:00, 265.48it/s]


valid 0.004572295020346841
saved model to cache/model/movielens/old-attentive.pt


100%|██████████| 1591/1591 [00:21<00:00, 72.63it/s]


{'Accuracy@10': 0.21533626649905718, 'Accuracy@30': 0.19765346742090928, 'Accuracy@50': 0.1834569453174104}


100%|██████████| 9190/9190 [02:08<00:00, 71.26it/s]


train 0.004460149996238026


100%|██████████| 2298/2298 [00:08<00:00, 269.57it/s]


valid 0.004390288854835
saved model to cache/model/movielens/old-attentive.pt


100%|██████████| 1591/1591 [00:19<00:00, 80.00it/s]


{'Accuracy@10': 0.20747957259585167, 'Accuracy@30': 0.18897967735177038, 'Accuracy@50': 0.1768698931489629}


100%|██████████| 9190/9190 [02:01<00:00, 75.34it/s]


train 0.0043293312624182


100%|██████████| 2298/2298 [00:07<00:00, 290.41it/s]


valid 0.004298849968445711
saved model to cache/model/movielens/old-attentive.pt


100%|██████████| 1591/1591 [00:18<00:00, 87.72it/s]


{'Accuracy@10': 0.177561282212445, 'Accuracy@30': 0.16614288707311964, 'Accuracy@50': 0.15803896920175992}
ModelConfig(d_model=64, init_embedding_std=0.2, max_embedding_norm=1.0, window_size=5, negative_sample_size=5, lr=0.0001, weight_decay=1e-08)
TrainerConfig(model_name='doc2vec', dataset_name='movielens', epochs=5, batch_size=64, verbose=False, ignore_saved_model=True, load_model=False, save_model=True, load_dataset=False, save_dataset=False, model_dir='cache/model/', dataset_dir='cache/dataset/', device='cpu')
dataset_manager does not exist at: cache/dataset/movielens.pickle, create dataset
num_seq: 6040, num_item: 3706, num_item_meta: 28, num_seq_meta: 30, num_item_meta_types: 3, num_seq_meta_types: 3
to_sequential_data start
to_sequential_data end


100%|██████████| 9190/9190 [01:13<00:00, 125.84it/s]


train 0.006309416829998543


100%|██████████| 2298/2298 [00:03<00:00, 695.02it/s]


valid 0.005998013916685358
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 1591/1591 [00:08<00:00, 190.86it/s]


{'Accuracy@10': 0.06373350094280326, 'Accuracy@30': 0.07454431175361408, 'Accuracy@50': 0.07235700817096166}


100%|██████████| 9190/9190 [00:59<00:00, 154.39it/s]


train 0.0055397002546317365


100%|██████████| 2298/2298 [00:03<00:00, 738.54it/s]


valid 0.005162167519854522
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 1591/1591 [00:08<00:00, 195.69it/s]


{'Accuracy@10': 0.1636706473915776, 'Accuracy@30': 0.16660381311544104, 'Accuracy@50': 0.15866750471401633}


100%|██████████| 9190/9190 [01:00<00:00, 152.72it/s]


train 0.00470211990603263


100%|██████████| 2298/2298 [00:03<00:00, 747.98it/s]


valid 0.004446926794330406
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 1591/1591 [00:08<00:00, 193.23it/s]


{'Accuracy@10': 0.1884978001257071, 'Accuracy@30': 0.18420280745862141, 'Accuracy@50': 0.1747705845380264}


100%|██████████| 9190/9190 [01:11<00:00, 128.37it/s]


train 0.004113741008052045


100%|██████████| 2298/2298 [00:03<00:00, 703.22it/s]


valid 0.003991150644352295
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 1591/1591 [00:08<00:00, 194.75it/s]


{'Accuracy@10': 0.18064110622250157, 'Accuracy@30': 0.18087156924366227, 'Accuracy@50': 0.17121307353865492}


100%|██████████| 9190/9190 [01:14<00:00, 122.98it/s]


train 0.0037406814447781466


100%|██████████| 2298/2298 [00:03<00:00, 609.36it/s]


valid 0.0036860691496218554
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 1591/1591 [00:08<00:00, 193.87it/s]


{'Accuracy@10': 0.17165304839723444, 'Accuracy@30': 0.16647810601298973, 'Accuracy@50': 0.15751099937146448}
ModelConfig(d_model=64, init_embedding_std=0.2, max_embedding_norm=1.0, window_size=5, negative_sample_size=5, lr=5e-05, weight_decay=1e-08)
TrainerConfig(model_name='attentive', dataset_name='movielens', epochs=5, batch_size=64, verbose=False, ignore_saved_model=True, load_model=False, save_model=True, load_dataset=False, save_dataset=False, model_dir='cache/model/', dataset_dir='cache/dataset/', device='cpu')
dataset_manager does not exist at: cache/dataset/movielens.pickle, create dataset
num_seq: 6040, num_item: 3706, num_item_meta: 28, num_seq_meta: 30, num_item_meta_types: 3, num_seq_meta_types: 3
to_sequential_data start
to_sequential_data end


100%|██████████| 9190/9190 [02:04<00:00, 74.00it/s]


train 0.005644030542993539


100%|██████████| 2298/2298 [00:10<00:00, 224.11it/s]


valid 0.005331109099238408
saved model to cache/model/movielens/attentive.pt


100%|██████████| 1591/1591 [00:17<00:00, 88.49it/s]


{'Accuracy@10': 0.21125078566939032, 'Accuracy@30': 0.20081709616593338, 'Accuracy@50': 0.19030798240100566}


100%|██████████| 9190/9190 [02:05<00:00, 72.95it/s]


train 0.005239035004001999


100%|██████████| 2298/2298 [00:11<00:00, 191.72it/s]


valid 0.005189901123389607
saved model to cache/model/movielens/attentive.pt


100%|██████████| 1591/1591 [00:18<00:00, 86.91it/s]


{'Accuracy@10': 0.2231301068510371, 'Accuracy@30': 0.20955373978629793, 'Accuracy@50': 0.1962036455059711}


100%|██████████| 9190/9190 [02:14<00:00, 68.44it/s]


train 0.005159919626070379


100%|██████████| 2298/2298 [00:11<00:00, 203.49it/s]


valid 0.005148413195745291
saved model to cache/model/movielens/attentive.pt


100%|██████████| 1591/1591 [00:18<00:00, 87.34it/s]


{'Accuracy@10': 0.22338152105593964, 'Accuracy@30': 0.21022417766603813, 'Accuracy@50': 0.19744814582023884}


100%|██████████| 9190/9190 [02:12<00:00, 69.20it/s]


train 0.005133872842659385


100%|██████████| 2298/2298 [00:10<00:00, 216.85it/s]


valid 0.005132436792235003
saved model to cache/model/movielens/attentive.pt


100%|██████████| 1591/1591 [00:19<00:00, 81.81it/s]


{'Accuracy@10': 0.20408548082966688, 'Accuracy@30': 0.20299601927508906, 'Accuracy@50': 0.19454431175361406}


100%|██████████| 9190/9190 [02:28<00:00, 62.02it/s]


train 0.005122879080087255


100%|██████████| 2298/2298 [00:12<00:00, 190.35it/s]


valid 0.005124912388533432
saved model to cache/model/movielens/attentive.pt


100%|██████████| 1591/1591 [00:18<00:00, 85.23it/s]

{'Accuracy@10': 0.1744186046511628, 'Accuracy@30': 0.18013827781269642, 'Accuracy@50': 0.18154619736015085}





In [13]:
models = ["attentive", "doc2vec", "old-attentive"]

for model_name in models:
    model = torch.load(f"cache/model/movielens/{model_name}.pt")

    print(
        model_name,
        calc_accuracy(
            model=model,
            num_item=dataset_manager.num_item,
            test_dataset=dataset_manager.test_datasets["test"],
            top_k=[10, 30, 50],
        ),
    )

100%|██████████| 1591/1591 [00:18<00:00, 86.71it/s]


attentive {'Accuracy@10': 0.1744186046511628, 'Accuracy@30': 0.18013827781269642, 'Accuracy@50': 0.18154619736015085}


100%|██████████| 1591/1591 [00:08<00:00, 191.86it/s]


doc2vec {'Accuracy@10': 0.17165304839723444, 'Accuracy@30': 0.16647810601298973, 'Accuracy@50': 0.15751099937146448}


100%|██████████| 1591/1591 [00:18<00:00, 84.60it/s]

old-attentive {'Accuracy@10': 0.177561282212445, 'Accuracy@30': 0.16614288707311964, 'Accuracy@50': 0.15803896920175992}



