In [1]:
from pathlib import Path
import torch
import json
import pandas as pd
import tqdm

from au2v.config import ModelConfig, TrainerConfig
from au2v.dataset_manager import load_dataset_manager
from au2v.trainer import PyTorchTrainer
from au2v.model import load_model, PyTorchModel

In [2]:
def calc_accuracy(
    model: PyTorchModel,
    num_item: int,
    test_dataset: list[tuple[int, list[int], list[int]]],
    top_k: list[int],
) -> dict[str, float]:
    hit_counts = {k: 0 for k in top_k}
    for seq_index, context_items, target_indices in tqdm.tqdm(test_dataset):
        rec_list = model.output_rec_lists(
            seq_index=torch.LongTensor([seq_index]),
            item_indices=torch.LongTensor([context_items]),
            cand_item_indices=torch.arange(num_item),
            k=max(top_k),
        )
        for k in top_k:
            hit_counts[k] += len(set(target_indices) & set(rec_list[0][:k]))

    total_rec = len(test_dataset)
    results = {
        f"Accuracy@{k}": hit_count / total_rec / k
        for k, hit_count in hit_counts.items()
    }
    return results

In [3]:
run_configs = [
    {
        "name": "User2Vec (d=16)",
        "model_name": "doc2vec",
        "d_model": 16,
        "epochs": 10,
        "use_weight_tying": True,
        "use_attention": True,
        "use_meta": True,
    },
    {
        "name": "User2Vec (d=32)",
        "model_name": "doc2vec",
        "d_model": 32,
        "epochs": 10,
        "use_weight_tying": True,
        "use_attention": True,
        "use_meta": True,
    },
    {
        "name": "User2Vec (d=64)",
        "model_name": "doc2vec",
        "d_model": 64,
        "epochs": 10,
        "use_weight_tying": True,
        "use_attention": True,
        "use_meta": True,
    },
    # {
    #     "name": "AU2V (d=16)",
    #     "model_name": "attentive",
    #     "d_model": 16,
    #     "epochs": 10,
    # },
    # {
    #     "name": "AU2V (d=32)",
    #     "model_name": "attentive",
    #     "d_model": 32,
    #     "epochs": 10,
    # },
    # {
    #     "name": "AU2V (d=64)",
    #     "model_name": "attentive",
    #     "d_model": 64,
    #     "epochs": 10,
    # },
    # {
    #     "name": "AU2V (wo weight-tying)",
    #     "model_name": "attentive",
    #     "d_model": 64,
    #     "epochs": 10,
    #     "use_weight_tying": False,
    #     "use_attention": True,
    #     "use_meta": True,
    # },
    # {
    #     "name": "AU2V (wo attention)",
    #     "model_name": "attentive",
    #     "d_model": 64,
    #     "epochs": 10,
    #     "use_weight_tying": True,
    #     "use_attention": False,
    #     "use_meta": True,
    # },
    # {
    #     "name": "AU2V (wo meta)",
    #     "model_name": "attentive",
    #     "d_model": 64,
    #     "epochs": 10,
    #     "use_weight_tying": True,
    #     "use_attention": True,
    #     "use_meta": False,
    # },
]

In [4]:
results = {}

for run_config in run_configs:
    results[run_config["name"]] = {"accuracy": []}
    model_config = ModelConfig(
        weight_decay=1e-8,
        max_embedding_norm=1,
        d_model=run_config["d_model"],
        lr=5e-5,
        use_attention=run_config["use_attention"],
        use_meta=run_config["use_meta"],
        use_weight_tying=run_config["use_weight_tying"],
    )
    trainer_config = TrainerConfig(
        dataset_name="movielens",
        model_name=run_config["model_name"],
        load_dataset=False,
        save_dataset=False,
        load_model=False,
        ignore_saved_model=True,
        epochs=run_config["epochs"],
    )

    print(model_config)
    print(trainer_config)

    dataset_manager = load_dataset_manager(
        dataset_name=trainer_config.dataset_name,
        dataset_dir=trainer_config.dataset_dir,
        data_dir="../data/",
        load_dataset=trainer_config.load_dataset,
        save_dataset=trainer_config.save_dataset,
        window_size=model_config.window_size,
    )
    print(
        "train:",
        len(dataset_manager.train_dataset),
        "valid:",
        len(dataset_manager.valid_dataset),
    )
    model = load_model(
        dataset_manager=dataset_manager,
        trainer_config=trainer_config,
        model_config=model_config,
    )
    trainer = PyTorchTrainer(
        model=model,
        dataset_manager=dataset_manager,
        trainer_config=trainer_config,
        model_config=model_config,
    )

    def on_epoch_end(epoch: int):
        result = calc_accuracy(
            model=model,
            num_item=dataset_manager.num_item,
            test_dataset=dataset_manager.test_datasets["test"],
            top_k=[10, 20, 30, 40, 50],
        )
        print(epoch, result)
        results[run_config["name"]]["accuracy"].append(result)
        torch.save(model, f"cache/model/movielens-2/{run_config['name']}-{epoch}.pt")

    losses = trainer.fit(on_epoch_end=on_epoch_end)
    results[run_config["name"]]["loss"] = losses

    with open("result.json", "w") as f:
        json.dump(results, f)

ModelConfig(d_model=16, init_embedding_std=0.2, max_embedding_norm=1, window_size=5, negative_sample_size=5, lr=5e-05, weight_decay=1e-08, use_weight_tying=True, use_meta=True, use_attention=True)
TrainerConfig(model_name='doc2vec', dataset_name='movielens', epochs=10, batch_size=64, verbose=False, ignore_saved_model=True, load_model=False, save_model=True, load_dataset=False, save_dataset=False, model_dir='cache/model/', dataset_dir='cache/dataset/', device='cpu')
dataset_manager does not exist at: cache/dataset/movielens.pickle, create dataset
num_seq: 6040, num_item: 3706, num_item_meta: 28, num_seq_meta: 30, num_item_meta_types: 3, num_seq_meta_types: 3
to_sequential_data start
to_sequential_data end
train: 589192 valid: 110547


100%|██████████| 9207/9207 [00:34<00:00, 270.75it/s]


train 0.00649393937346458


100%|██████████| 1728/1728 [00:01<00:00, 1057.15it/s]


valid 0.006496575769230932


100%|██████████| 1591/1591 [00:07<00:00, 200.34it/s]


0 {'Accuracy@10': 0.05223130106851037, 'Accuracy@20': 0.05012570710245129, 'Accuracy@30': 0.04869055101613241, 'Accuracy@40': 0.0484443746071653, 'Accuracy@50': 0.04790697674418604}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:34<00:00, 266.71it/s]


train 0.006440958644123753


100%|██████████| 1728/1728 [00:01<00:00, 1032.41it/s]


valid 0.0064436110503948135


100%|██████████| 1591/1591 [00:07<00:00, 208.00it/s]


1 {'Accuracy@10': 0.09912005028284097, 'Accuracy@20': 0.09421747328724073, 'Accuracy@30': 0.08967106641525247, 'Accuracy@40': 0.08687932118164676, 'Accuracy@50': 0.08414833438089252}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:34<00:00, 268.34it/s]


train 0.0062689030609933485


100%|██████████| 1728/1728 [00:01<00:00, 1055.85it/s]


valid 0.006284045014497646


100%|██████████| 1591/1591 [00:08<00:00, 198.77it/s]


2 {'Accuracy@10': 0.15480829666876178, 'Accuracy@20': 0.15619107479572597, 'Accuracy@30': 0.15384454221663527, 'Accuracy@40': 0.15001571338780642, 'Accuracy@50': 0.14525455688246386}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:35<00:00, 259.51it/s]


train 0.005936785582163713


100%|██████████| 1728/1728 [00:01<00:00, 1055.93it/s]


valid 0.00601709129005283


100%|██████████| 1591/1591 [00:07<00:00, 205.19it/s]


3 {'Accuracy@10': 0.17127592708988057, 'Accuracy@20': 0.1697360150848523, 'Accuracy@30': 0.16777707940498637, 'Accuracy@40': 0.16588623507228156, 'Accuracy@50': 0.16316781898177246}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:35<00:00, 259.02it/s]


train 0.005575470878833977


100%|██████████| 1728/1728 [00:01<00:00, 1059.25it/s]


valid 0.005741339889813712


100%|██████████| 1591/1591 [00:07<00:00, 202.08it/s]


4 {'Accuracy@10': 0.1785669390320553, 'Accuracy@20': 0.18029541169076052, 'Accuracy@30': 0.17613660171799708, 'Accuracy@40': 0.17388434946574483, 'Accuracy@50': 0.17098680075424263}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:35<00:00, 256.39it/s]


train 0.005295273199634173


100%|██████████| 1728/1728 [00:01<00:00, 1054.07it/s]


valid 0.005515128145475105


100%|██████████| 1591/1591 [00:08<00:00, 197.49it/s]


5 {'Accuracy@10': 0.18629792583280955, 'Accuracy@20': 0.1843808925204274, 'Accuracy@30': 0.1802639849151477, 'Accuracy@40': 0.17815839094908864, 'Accuracy@50': 0.1753362664990572}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:35<00:00, 256.15it/s]


train 0.0051019563908842784


100%|██████████| 1728/1728 [00:01<00:00, 1053.57it/s]


valid 0.005342472890312782


100%|██████████| 1591/1591 [00:07<00:00, 203.71it/s]


6 {'Accuracy@10': 0.18246385920804525, 'Accuracy@20': 0.17922690131992458, 'Accuracy@30': 0.1779384035197989, 'Accuracy@40': 0.17704274041483342, 'Accuracy@50': 0.17509742300439973}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:36<00:00, 251.75it/s]


train 0.00497048651629419


100%|██████████| 1728/1728 [00:01<00:00, 1051.08it/s]


valid 0.005212828747774094


100%|██████████| 1591/1591 [00:07<00:00, 201.45it/s]


7 {'Accuracy@10': 0.17360150848522943, 'Accuracy@20': 0.17146448774355752, 'Accuracy@30': 0.1714854389272994, 'Accuracy@40': 0.17121307353865495, 'Accuracy@50': 0.17038340666247642}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:36<00:00, 254.60it/s]


train 0.004880592201238487


100%|██████████| 1728/1728 [00:01<00:00, 1057.90it/s]


valid 0.005114680140466985


100%|██████████| 1591/1591 [00:07<00:00, 201.24it/s]


8 {'Accuracy@10': 0.16398491514770586, 'Accuracy@20': 0.1656191074795726, 'Accuracy@30': 0.1651791326209931, 'Accuracy@40': 0.16425204274041483, 'Accuracy@50': 0.1619610307982401}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:36<00:00, 252.92it/s]


train 0.004816293888392039


100%|██████████| 1728/1728 [00:01<00:00, 1057.40it/s]


valid 0.005038207789327158


100%|██████████| 1591/1591 [00:07<00:00, 204.46it/s]


9 {'Accuracy@10': 0.1596480201131364, 'Accuracy@20': 0.15713387806411064, 'Accuracy@30': 0.1580557301487534, 'Accuracy@40': 0.15648962916404777, 'Accuracy@50': 0.1551602765556254}
saved model to cache/model/movielens/doc2vec.pt
ModelConfig(d_model=32, init_embedding_std=0.2, max_embedding_norm=1, window_size=5, negative_sample_size=5, lr=5e-05, weight_decay=1e-08, use_weight_tying=True, use_meta=True, use_attention=True)
TrainerConfig(model_name='doc2vec', dataset_name='movielens', epochs=10, batch_size=64, verbose=False, ignore_saved_model=True, load_model=False, save_model=True, load_dataset=False, save_dataset=False, model_dir='cache/model/', dataset_dir='cache/dataset/', device='cpu')
dataset_manager does not exist at: cache/dataset/movielens.pickle, create dataset
num_seq: 6040, num_item: 3706, num_item_meta: 28, num_seq_meta: 30, num_item_meta_types: 3, num_seq_meta_types: 3
to_sequential_data start
to_sequential_data end
train: 589192 valid: 110547


100%|██████████| 9207/9207 [00:42<00:00, 218.73it/s]


train 0.006483769257259272


100%|██████████| 1728/1728 [00:01<00:00, 1047.27it/s]


valid 0.006480335563660916


100%|██████████| 1591/1591 [00:07<00:00, 199.65it/s]


0 {'Accuracy@10': 0.056568196103079824, 'Accuracy@20': 0.05515399120050283, 'Accuracy@30': 0.053760737481667716, 'Accuracy@40': 0.05248271527341295, 'Accuracy@50': 0.05220615964802011}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:43<00:00, 213.76it/s]


train 0.0063689489210328135


100%|██████████| 1728/1728 [00:01<00:00, 1042.44it/s]


valid 0.006363773392584231


100%|██████████| 1591/1591 [00:07<00:00, 202.84it/s]


1 {'Accuracy@10': 0.1247014456316782, 'Accuracy@20': 0.12620993086109364, 'Accuracy@30': 0.12478525036664571, 'Accuracy@40': 0.12295725958516657, 'Accuracy@50': 0.11968573224387179}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:44<00:00, 207.70it/s]


train 0.0060613081873250605


100%|██████████| 1728/1728 [00:01<00:00, 1043.47it/s]


valid 0.006087788981950836


100%|██████████| 1591/1591 [00:07<00:00, 200.55it/s]


2 {'Accuracy@10': 0.18057825267127592, 'Accuracy@20': 0.18001257071024512, 'Accuracy@30': 0.17552901738948248, 'Accuracy@40': 0.1714487743557511, 'Accuracy@50': 0.16828409805153993}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:44<00:00, 204.72it/s]


train 0.0056395424929265095


100%|██████████| 1728/1728 [00:01<00:00, 1041.00it/s]


valid 0.005751325259562323


100%|██████████| 1591/1591 [00:07<00:00, 202.02it/s]


3 {'Accuracy@10': 0.20490257699560024, 'Accuracy@20': 0.20025141420490256, 'Accuracy@30': 0.1951393253718835, 'Accuracy@40': 0.1890634820867379, 'Accuracy@50': 0.18390949088623507}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:44<00:00, 206.50it/s]


train 0.005291729367934827


100%|██████████| 1728/1728 [00:01<00:00, 1040.05it/s]


valid 0.005471595327546454


100%|██████████| 1591/1591 [00:07<00:00, 202.53it/s]


4 {'Accuracy@10': 0.20980515399120053, 'Accuracy@20': 0.20421118793211815, 'Accuracy@30': 0.19861722187303585, 'Accuracy@40': 0.19248900062853552, 'Accuracy@50': 0.18740414833438088}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:45<00:00, 202.10it/s]


train 0.005062885929050376


100%|██████████| 1728/1728 [00:01<00:00, 1033.53it/s]


valid 0.005269989589240093


100%|██████████| 1591/1591 [00:07<00:00, 203.91it/s]


5 {'Accuracy@10': 0.20603394091766183, 'Accuracy@20': 0.19952859836580766, 'Accuracy@30': 0.19289754871150222, 'Accuracy@40': 0.1876807039597737, 'Accuracy@50': 0.1834443746071653}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:44<00:00, 205.58it/s]


train 0.004920227115360232


100%|██████████| 1728/1728 [00:01<00:00, 1033.79it/s]


valid 0.005128154123202689


100%|██████████| 1591/1591 [00:08<00:00, 198.04it/s]


6 {'Accuracy@10': 0.1935889377749843, 'Accuracy@20': 0.18799497171590196, 'Accuracy@30': 0.18328095537397862, 'Accuracy@40': 0.178441231929604, 'Accuracy@50': 0.17504714016341924}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:46<00:00, 199.08it/s]


train 0.00482914651833791


100%|██████████| 1728/1728 [00:01<00:00, 1040.78it/s]


valid 0.005027869139109714


100%|██████████| 1591/1591 [00:07<00:00, 203.35it/s]


7 {'Accuracy@10': 0.18227529855436833, 'Accuracy@20': 0.17771841609050912, 'Accuracy@30': 0.17314058244290803, 'Accuracy@40': 0.1695945945945946, 'Accuracy@50': 0.1668887492143306}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:44<00:00, 206.09it/s]


train 0.004769624077725196


100%|██████████| 1728/1728 [00:01<00:00, 1043.87it/s]


valid 0.0049549855487254135


100%|██████████| 1591/1591 [00:07<00:00, 199.86it/s]


8 {'Accuracy@10': 0.17636706473915775, 'Accuracy@20': 0.16923318667504714, 'Accuracy@30': 0.1643620364550597, 'Accuracy@40': 0.1612350722815839, 'Accuracy@50': 0.15844123192960402}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [00:44<00:00, 205.68it/s]


train 0.004728403922649465


100%|██████████| 1728/1728 [00:01<00:00, 1045.04it/s]


valid 0.00490222619827735


100%|██████████| 1591/1591 [00:07<00:00, 200.17it/s]


9 {'Accuracy@10': 0.16857322438717787, 'Accuracy@20': 0.16046511627906976, 'Accuracy@30': 0.15675675675675677, 'Accuracy@40': 0.1535512256442489, 'Accuracy@50': 0.1501068510370836}
saved model to cache/model/movielens/doc2vec.pt
ModelConfig(d_model=64, init_embedding_std=0.2, max_embedding_norm=1, window_size=5, negative_sample_size=5, lr=5e-05, weight_decay=1e-08, use_weight_tying=True, use_meta=True, use_attention=True)
TrainerConfig(model_name='doc2vec', dataset_name='movielens', epochs=10, batch_size=64, verbose=False, ignore_saved_model=True, load_model=False, save_model=True, load_dataset=False, save_dataset=False, model_dir='cache/model/', dataset_dir='cache/dataset/', device='cpu')
dataset_manager does not exist at: cache/dataset/movielens.pickle, create dataset
num_seq: 6040, num_item: 3706, num_item_meta: 28, num_seq_meta: 30, num_item_meta_types: 3, num_seq_meta_types: 3
to_sequential_data start
to_sequential_data end
train: 589192 valid: 110547


100%|██████████| 9207/9207 [01:04<00:00, 142.89it/s]


train 0.006462391784688371


100%|██████████| 1728/1728 [00:02<00:00, 857.98it/s]


valid 0.006440928828261338


100%|██████████| 1591/1591 [00:08<00:00, 194.00it/s]


0 {'Accuracy@10': 0.06844751728472659, 'Accuracy@20': 0.0758642363293526, 'Accuracy@30': 0.07578043159438508, 'Accuracy@40': 0.0757071024512885, 'Accuracy@50': 0.07516027655562539}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:05<00:00, 140.09it/s]


train 0.006169865508208261


100%|██████████| 1728/1728 [00:02<00:00, 843.52it/s]


valid 0.006119930132098519


100%|██████████| 1591/1591 [00:08<00:00, 195.18it/s]


1 {'Accuracy@10': 0.1896291640477687, 'Accuracy@20': 0.18799497171590196, 'Accuracy@30': 0.18342761366017182, 'Accuracy@40': 0.17946260213702075, 'Accuracy@50': 0.17585166561910748}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:06<00:00, 138.24it/s]


train 0.005610314865149594


100%|██████████| 1728/1728 [00:02<00:00, 859.65it/s]


valid 0.005673075398639625


100%|██████████| 1591/1591 [00:07<00:00, 198.97it/s]


2 {'Accuracy@10': 0.21125078566939032, 'Accuracy@20': 0.2071024512884978, 'Accuracy@30': 0.2027027027027027, 'Accuracy@40': 0.19622878692646134, 'Accuracy@50': 0.19182903834066625}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:11<00:00, 128.20it/s]


train 0.005184951065867035


100%|██████████| 1728/1728 [00:02<00:00, 842.37it/s]


valid 0.005343797864482307


100%|██████████| 1591/1591 [00:08<00:00, 191.78it/s]


3 {'Accuracy@10': 0.20553111250785666, 'Accuracy@20': 0.20034569453174106, 'Accuracy@30': 0.19624973811020321, 'Accuracy@40': 0.19186046511627908, 'Accuracy@50': 0.18765556253928348}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:09<00:00, 132.27it/s]


train 0.004948209680372468


100%|██████████| 1728/1728 [00:02<00:00, 819.31it/s]


valid 0.005134582277639472


100%|██████████| 1591/1591 [00:08<00:00, 195.57it/s]


4 {'Accuracy@10': 0.19088623507228158, 'Accuracy@20': 0.1865179132620993, 'Accuracy@30': 0.18135344646972554, 'Accuracy@40': 0.17807982401005656, 'Accuracy@50': 0.17456945317410433}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:12<00:00, 126.67it/s]


train 0.004819949637358842


100%|██████████| 1728/1728 [00:02<00:00, 860.34it/s]


valid 0.005002682506873845


100%|██████████| 1591/1591 [00:08<00:00, 195.90it/s]


5 {'Accuracy@10': 0.1791326209930861, 'Accuracy@20': 0.17162162162162162, 'Accuracy@30': 0.16748376283260005, 'Accuracy@40': 0.16356065367693276, 'Accuracy@50': 0.16075424261470772}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:07<00:00, 135.67it/s]


train 0.004745135977791435


100%|██████████| 1728/1728 [00:02<00:00, 861.96it/s]


valid 0.004915661255039957


100%|██████████| 1591/1591 [00:08<00:00, 194.81it/s]


6 {'Accuracy@10': 0.16939032055311126, 'Accuracy@20': 0.16040226272784414, 'Accuracy@30': 0.1545568824638592, 'Accuracy@40': 0.15158705216844753, 'Accuracy@50': 0.14913890634820867}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:09<00:00, 133.26it/s]


train 0.004698908680539521


100%|██████████| 1728/1728 [00:02<00:00, 844.37it/s]


valid 0.004856131985044907


100%|██████████| 1591/1591 [00:08<00:00, 195.84it/s]


7 {'Accuracy@10': 0.15719673161533626, 'Accuracy@20': 0.15012570710245127, 'Accuracy@30': 0.1446260213702074, 'Accuracy@40': 0.14156191074795726, 'Accuracy@50': 0.13885606536769327}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:06<00:00, 138.11it/s]


train 0.00466903728774145


100%|██████████| 1728/1728 [00:02<00:00, 807.40it/s]


valid 0.004814594496087705


100%|██████████| 1591/1591 [00:08<00:00, 194.86it/s]


8 {'Accuracy@10': 0.14776869893148964, 'Accuracy@20': 0.141169076052797, 'Accuracy@30': 0.1357636706473916, 'Accuracy@40': 0.13328095537397863, 'Accuracy@50': 0.13057196731615336}
saved model to cache/model/movielens/doc2vec.pt


100%|██████████| 9207/9207 [01:08<00:00, 135.26it/s]


train 0.004648101541862188


100%|██████████| 1728/1728 [00:02<00:00, 845.04it/s]


valid 0.0047853116543910175


100%|██████████| 1591/1591 [00:08<00:00, 193.84it/s]

9 {'Accuracy@10': 0.13708359522313013, 'Accuracy@20': 0.13199245757385292, 'Accuracy@30': 0.12880787764508694, 'Accuracy@40': 0.12575424261470775, 'Accuracy@50': 0.12359522313010686}
saved model to cache/model/movielens/doc2vec.pt





In [7]:
top_k = list(range(10, 51, 10))
data = {}
for method, result in results.items():
    data[method] = []
    for k in top_k:
        a = max(map(lambda r: r[f"Accuracy@{k}"], result["accuracy"]))
        data[method].append(a)

In [10]:
df = pd.DataFrame(data).T
df.columns = [f"Accuracy@{k}" for k in top_k]
df

Unnamed: 0,Accuracy@10,Accuracy@20,Accuracy@30,Accuracy@40,Accuracy@50
User2Vec (d=16),0.186298,0.184381,0.180264,0.178158,0.175336
User2Vec (d=32),0.209805,0.204211,0.198617,0.192489,0.187404
User2Vec (d=64),0.211251,0.207102,0.202703,0.196229,0.191829
AU2V (d=16),0.215713,0.209585,0.20616,0.200629,0.19506
AU2V (d=32),0.226776,0.217788,0.209428,0.202263,0.195789
AU2V (d=64),0.229038,0.22115,0.212717,0.204824,0.198366
AU2V (wo weight-tying),0.226776,0.221307,0.210455,0.202923,0.197285
AU2V (wo attention),0.229101,0.221999,0.211544,0.202216,0.19726
AU2V (wo meta),0.22797,0.219233,0.209638,0.202043,0.195223
