In [1]:
import sys 
sys.path.append('..')

#Dependencies

import os
from tqdm import tqdm
import pandas as pd
import numpy as np

import json 
import pickle

from src.environment.ml_env import SimulatedEnv, SimulatedFairEnv
from src.model.recommender import DRRAgent, FairRecAgent

In [2]:
dataset_path = "../data/movie_lens_100k_output_path.json"
with open(dataset_path) as json_file:
    _dataset_path = json.load(json_file)


dataset = {}
with open(os.path.join("..", _dataset_path["eval_users_dict"]), "rb") as pkl_file:
    dataset["eval_users_dict"] = pickle.load(pkl_file)

with open(os.path.join("..", _dataset_path["eval_users_dict_positive_items"]), "rb") as pkl_file:
    dataset["eval_users_dict_positive_items"] = pickle.load(pkl_file)

with open(os.path.join("..", _dataset_path["eval_users_history_lens"]), "rb") as pkl_file:
    dataset["eval_users_history_lens"] = pickle.load(pkl_file)

with open(os.path.join("..", _dataset_path["users_history_lens"]), "rb") as pkl_file:
    dataset["users_history_lens"] = pickle.load(pkl_file)

with open(os.path.join("..", _dataset_path["movies_groups"]), "rb") as pkl_file:
    dataset["movies_groups"] = pickle.load(pkl_file)

In [3]:
ENV = dict(drr=SimulatedEnv, fairrec=SimulatedFairEnv)
AGENT = dict(drr=DRRAgent, fairrec=FairRecAgent)

In [4]:
train_ids = [
    "movie_lens_100k_2021-10-24_01-42-57", # long training
    "movie_lens_100k_fair_2021-10-24_01-41-02", # long training
]

In [5]:
algorithm = "drr"
train_version = "movie_lens_100k"
train_id = train_ids[0]
output_path = "../model/{}/{}".format(train_version, train_id)

users_num = 943
items_num = 1682

state_size = 5
srm_size = 3

embedding_dim = 50
actor_hidden_dim = 512
actor_learning_rate = 0.0001
critic_hidden_dim = 512
critic_learning_rate = 0.001
discount_factor = 0.9
tau = 0.01
learning_starts = 100
replay_memory_size = 1000000
batch_size = 1
emb_model = "user_movie"
embedding_network_weights = "../model/pmf/emb_50_ratio_0.800000_bs_1000_e_258_wd_0.100000_lr_0.000100_trained_pmf.pt"
n_groups = 10
fairness_constraints = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]


top_k = None
done_count = 10

In [6]:
actor_checkpoint = sorted(
    [
        int((f.split("_")[1]).split(".")[0])
        for f in os.listdir(output_path)
        if f.startswith("actor_")
    ]
)[-1]
critic_checkpoint = sorted(
    [
        int((f.split("_")[1]).split(".")[0])
        for f in os.listdir(output_path)
        if f.startswith("critic_")
    ]
)[-1]

print(actor_checkpoint, critic_checkpoint)

30000 30000


In [7]:
# # with open(os.path.join("../model/movie_lens_100k/movie_lens_100k_2021-10-26_11-45-48/buffer.pkl"), "rb") as pkl_file:
# #     buffer_load = pickle.load(pkl_file)

# with open(os.path.join("../model/movie_lens_100k_fair/movie_lens_100k_fair_2021-10-26_11-45-51/buffer.pkl"), "rb") as pkl_file:
#     buffer_load = pickle.load(pkl_file)

In [8]:
_precision = []
_propfair = []
_reward = []
_ufg = []
for i in range(10):
    sum_precision = 0
    sum_propfair = 0
    sum_reward = 0

    env = ENV[algorithm](
        users_dict=dataset["eval_users_dict"],
        users_history_lens=dataset["eval_users_history_lens"],
        n_groups=n_groups,
        movies_groups=dataset["movies_groups"],
        state_size=state_size,
        done_count=done_count,
        fairness_constraints=fairness_constraints,
    )
    available_users = env.available_users

    recommender = AGENT[algorithm](
        env=env,
        users_num=users_num,
        items_num=items_num,
        genres_num=0,
        movies_genres_id={}, 
        srm_size=srm_size,
        state_size=state_size,
        train_version=train_version,
        is_test=True,
        embedding_dim=embedding_dim,
        actor_hidden_dim=actor_hidden_dim,
        actor_learning_rate=actor_learning_rate,
        critic_hidden_dim=critic_hidden_dim,
        critic_learning_rate=critic_learning_rate,
        discount_factor=discount_factor,
        tau=tau,
        learning_starts=learning_starts,
        replay_memory_size=replay_memory_size,
        batch_size=batch_size,
        model_path=output_path,
        emb_model=emb_model,
        embedding_network_weights_path=embedding_network_weights,
        n_groups=n_groups,
        fairness_constraints=fairness_constraints,
    )

    for user_id in tqdm(available_users):
        recommender.load_model(
            os.path.join(output_path, "actor_{}.h5".format(actor_checkpoint)),
            os.path.join(
                output_path, "critic_{}.h5".format(critic_checkpoint)
            ),
        )
        # recommender.buffer = buffer_load

        eval_env = ENV[algorithm](
            users_dict=dataset["eval_users_dict"],
            users_history_lens=dataset["eval_users_history_lens"],
            n_groups=n_groups,
            movies_groups=dataset["movies_groups"],
            state_size=state_size,
            done_count=done_count,
            fairness_constraints=fairness_constraints,
            fix_user_id=user_id,
            reward_model=recommender.reward_model,
            device=recommender.device,
        )

        recommender.env = eval_env

        precision, ndcg, propfair, reward = recommender.train(
            max_episode_num=1, top_k=top_k
        )

        sum_precision += precision
        sum_propfair += propfair
        sum_reward += reward

        del eval_env


    _precision.append(sum_precision / len(dataset["eval_users_dict"]))
    _propfair.append(sum_propfair / len(dataset["eval_users_dict"]))
    _ufg.append((sum_propfair / len(dataset["eval_users_dict"]))
        / (1 - (sum_precision / len(dataset["eval_users_dict"]))))
    _reward.append(sum_reward)

    print("---------- Evaluation")
    print("- sum reward: ", round(sum_reward, 2))
    print("- precision@: ", round(sum_precision / len(dataset["eval_users_dict"]), 4))
    print("- propfair: ", round(sum_propfair / len(dataset["eval_users_dict"]), 4)), 
    print(
        "- ufg: ",
        round((sum_propfair / len(dataset["eval_users_dict"]))
        / (1 - (sum_precision / len(dataset["eval_users_dict"]))), 4)
    )
    print()


print(round(np.mean(_propfair), 4))
print(round(np.mean(_precision), 4))
print(round(np.mean(_reward), 4))
print(round(np.mean(_ufg), 4))

100%|██████████| 1/1 [00:00<00:00,  8.62it/s]
100%|██████████| 1/1 [00:00<00:00, 46.69it/s]
100%|██████████| 1/1 [00:00<00:00, 29.09it/s]
100%|██████████| 1/1 [00:00<00:00, 36.20it/s]
100%|██████████| 1/1 [00:00<00:00, 29.86it/s]
100%|██████████| 1/1 [00:00<00:00, 34.33it/s]
100%|██████████| 1/1 [00:00<00:00, 62.33it/s]
100%|██████████| 1/1 [00:00<00:00, 29.32it/s]
100%|██████████| 1/1 [00:00<00:00, 48.02it/s]
100%|██████████| 1/1 [00:00<00:00, 10.52it/s]
100%|██████████| 1/1 [00:00<00:00,  9.81it/s]
100%|██████████| 1/1 [00:00<00:00, 10.88it/s]
100%|██████████| 1/1 [00:00<00:00,  9.46it/s]
100%|██████████| 1/1 [00:00<00:00, 20.46it/s]
100%|██████████| 1/1 [00:00<00:00,  9.41it/s]
100%|██████████| 1/1 [00:00<00:00, 10.02it/s]
100%|██████████| 1/1 [00:00<00:00, 10.06it/s]
100%|██████████| 1/1 [00:00<00:00, 10.89it/s]
100%|██████████| 1/1 [00:00<00:00,  9.47it/s]
100%|██████████| 1/1 [00:00<00:00, 10.87it/s]
100%|██████████| 1/1 [00:00<00:00,  8.09it/s]
100%|██████████| 1/1 [00:00<00:00,

---------- Evaluation
- sum reward:  480.1
- precision@:  0.8161
- propfair:  0.8535
- ufg:  4.6397



100%|██████████| 1/1 [00:00<00:00, 45.82it/s]
100%|██████████| 1/1 [00:00<00:00, 42.83it/s]
100%|██████████| 1/1 [00:00<00:00, 31.19it/s]
100%|██████████| 1/1 [00:00<00:00, 33.60it/s]
100%|██████████| 1/1 [00:00<00:00, 40.35it/s]
100%|██████████| 1/1 [00:00<00:00, 37.54it/s]
100%|██████████| 1/1 [00:00<00:00, 80.61it/s]
100%|██████████| 1/1 [00:00<00:00, 45.04it/s]
100%|██████████| 1/1 [00:00<00:00, 36.55it/s]
100%|██████████| 1/1 [00:00<00:00, 18.68it/s]
100%|██████████| 1/1 [00:00<00:00, 10.25it/s]
100%|██████████| 1/1 [00:00<00:00, 11.98it/s]
100%|██████████| 1/1 [00:00<00:00, 10.46it/s]
100%|██████████| 1/1 [00:00<00:00, 21.63it/s]
100%|██████████| 1/1 [00:00<00:00, 10.77it/s]
100%|██████████| 1/1 [00:00<00:00, 10.22it/s]
100%|██████████| 1/1 [00:00<00:00, 11.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.39it/s]
100%|██████████| 1/1 [00:00<00:00,  7.82it/s]
100%|██████████| 1/1 [00:00<00:00,  8.56it/s]
100%|██████████| 1/1 [00:00<00:00, 10.70it/s]
100%|██████████| 1/1 [00:00<00:00,

---------- Evaluation
- sum reward:  492.33
- precision@:  0.8124
- propfair:  0.8598
- ufg:  4.5836



100%|██████████| 1/1 [00:00<00:00, 42.85it/s]
100%|██████████| 1/1 [00:00<00:00, 40.13it/s]
100%|██████████| 1/1 [00:00<00:00, 33.79it/s]
100%|██████████| 1/1 [00:00<00:00, 35.71it/s]
100%|██████████| 1/1 [00:00<00:00, 38.24it/s]
100%|██████████| 1/1 [00:00<00:00, 50.78it/s]
100%|██████████| 1/1 [00:00<00:00, 81.34it/s]
100%|██████████| 1/1 [00:00<00:00, 48.03it/s]
100%|██████████| 1/1 [00:00<00:00, 38.87it/s]
100%|██████████| 1/1 [00:00<00:00, 19.51it/s]
100%|██████████| 1/1 [00:00<00:00, 10.44it/s]
100%|██████████| 1/1 [00:00<00:00, 11.17it/s]
100%|██████████| 1/1 [00:00<00:00, 10.36it/s]
100%|██████████| 1/1 [00:00<00:00, 19.00it/s]
100%|██████████| 1/1 [00:00<00:00, 11.84it/s]
100%|██████████| 1/1 [00:00<00:00, 10.29it/s]
100%|██████████| 1/1 [00:00<00:00, 10.52it/s]
100%|██████████| 1/1 [00:00<00:00, 11.83it/s]
100%|██████████| 1/1 [00:00<00:00, 10.52it/s]
100%|██████████| 1/1 [00:00<00:00, 10.56it/s]
100%|██████████| 1/1 [00:00<00:00, 10.49it/s]
100%|██████████| 1/1 [00:00<00:00,

---------- Evaluation
- sum reward:  465.56
- precision@:  0.7974
- propfair:  0.855
- ufg:  4.2199



100%|██████████| 1/1 [00:00<00:00, 38.03it/s]
100%|██████████| 1/1 [00:00<00:00, 41.18it/s]
100%|██████████| 1/1 [00:00<00:00, 36.07it/s]
100%|██████████| 1/1 [00:00<00:00, 36.55it/s]
100%|██████████| 1/1 [00:00<00:00, 39.62it/s]
100%|██████████| 1/1 [00:00<00:00, 38.89it/s]
100%|██████████| 1/1 [00:00<00:00, 109.43it/s]
100%|██████████| 1/1 [00:00<00:00, 46.11it/s]
100%|██████████| 1/1 [00:00<00:00, 47.68it/s]
100%|██████████| 1/1 [00:00<00:00, 17.84it/s]
100%|██████████| 1/1 [00:00<00:00,  8.09it/s]
100%|██████████| 1/1 [00:00<00:00, 11.55it/s]
100%|██████████| 1/1 [00:00<00:00, 10.28it/s]
100%|██████████| 1/1 [00:00<00:00, 20.29it/s]
100%|██████████| 1/1 [00:00<00:00, 11.38it/s]
100%|██████████| 1/1 [00:00<00:00, 10.52it/s]
100%|██████████| 1/1 [00:00<00:00, 10.65it/s]
100%|██████████| 1/1 [00:00<00:00, 10.02it/s]
100%|██████████| 1/1 [00:00<00:00, 10.25it/s]
100%|██████████| 1/1 [00:00<00:00, 10.24it/s]
100%|██████████| 1/1 [00:00<00:00,  8.43it/s]
100%|██████████| 1/1 [00:00<00:00

---------- Evaluation
- sum reward:  470.71
- precision@:  0.8142
- propfair:  0.8497
- ufg:  4.5724



100%|██████████| 1/1 [00:00<00:00, 49.21it/s]
100%|██████████| 1/1 [00:00<00:00, 48.03it/s]
100%|██████████| 1/1 [00:00<00:00, 43.98it/s]
100%|██████████| 1/1 [00:00<00:00, 43.96it/s]
100%|██████████| 1/1 [00:00<00:00, 41.26it/s]
100%|██████████| 1/1 [00:00<00:00, 32.97it/s]
100%|██████████| 1/1 [00:00<00:00, 72.24it/s]
100%|██████████| 1/1 [00:00<00:00, 34.08it/s]
100%|██████████| 1/1 [00:00<00:00, 34.04it/s]
100%|██████████| 1/1 [00:00<00:00, 18.30it/s]
100%|██████████| 1/1 [00:00<00:00, 10.58it/s]
100%|██████████| 1/1 [00:00<00:00, 10.46it/s]
100%|██████████| 1/1 [00:00<00:00, 11.14it/s]
100%|██████████| 1/1 [00:00<00:00, 18.90it/s]
100%|██████████| 1/1 [00:00<00:00,  9.70it/s]
100%|██████████| 1/1 [00:00<00:00,  9.77it/s]
100%|██████████| 1/1 [00:00<00:00, 10.27it/s]
100%|██████████| 1/1 [00:00<00:00, 11.12it/s]
100%|██████████| 1/1 [00:00<00:00,  9.05it/s]
100%|██████████| 1/1 [00:00<00:00, 10.06it/s]
100%|██████████| 1/1 [00:00<00:00, 10.54it/s]
100%|██████████| 1/1 [00:00<00:00,

---------- Evaluation
- sum reward:  486.26
- precision@:  0.8096
- propfair:  0.8573
- ufg:  4.5036



100%|██████████| 1/1 [00:00<00:00, 49.65it/s]
100%|██████████| 1/1 [00:00<00:00, 51.55it/s]
100%|██████████| 1/1 [00:00<00:00, 45.12it/s]
100%|██████████| 1/1 [00:00<00:00, 45.11it/s]
100%|██████████| 1/1 [00:00<00:00, 44.09it/s]
100%|██████████| 1/1 [00:00<00:00, 39.66it/s]
100%|██████████| 1/1 [00:00<00:00, 107.33it/s]
100%|██████████| 1/1 [00:00<00:00, 36.33it/s]
100%|██████████| 1/1 [00:00<00:00, 43.66it/s]
100%|██████████| 1/1 [00:00<00:00, 17.97it/s]
100%|██████████| 1/1 [00:00<00:00,  9.48it/s]
100%|██████████| 1/1 [00:00<00:00, 10.59it/s]
100%|██████████| 1/1 [00:00<00:00, 10.67it/s]
100%|██████████| 1/1 [00:00<00:00, 22.24it/s]
100%|██████████| 1/1 [00:00<00:00, 10.02it/s]
100%|██████████| 1/1 [00:00<00:00, 10.49it/s]
100%|██████████| 1/1 [00:00<00:00, 11.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.96it/s]
100%|██████████| 1/1 [00:00<00:00, 10.39it/s]
100%|██████████| 1/1 [00:00<00:00, 10.40it/s]
100%|██████████| 1/1 [00:00<00:00, 11.96it/s]
100%|██████████| 1/1 [00:00<00:00

---------- Evaluation
- sum reward:  477.16
- precision@:  0.8179
- propfair:  0.8552
- ufg:  4.6956



100%|██████████| 1/1 [00:00<00:00, 54.01it/s]
100%|██████████| 1/1 [00:00<00:00, 44.96it/s]
100%|██████████| 1/1 [00:00<00:00, 58.29it/s]
100%|██████████| 1/1 [00:00<00:00, 47.70it/s]
100%|██████████| 1/1 [00:00<00:00, 36.37it/s]
100%|██████████| 1/1 [00:00<00:00, 43.57it/s]
100%|██████████| 1/1 [00:00<00:00, 89.23it/s]
100%|██████████| 1/1 [00:00<00:00, 44.84it/s]
100%|██████████| 1/1 [00:00<00:00, 38.95it/s]
100%|██████████| 1/1 [00:00<00:00, 17.97it/s]
100%|██████████| 1/1 [00:00<00:00,  9.07it/s]
100%|██████████| 1/1 [00:00<00:00,  9.67it/s]
100%|██████████| 1/1 [00:00<00:00, 10.60it/s]
100%|██████████| 1/1 [00:00<00:00, 24.01it/s]
100%|██████████| 1/1 [00:00<00:00, 10.30it/s]
100%|██████████| 1/1 [00:00<00:00, 10.09it/s]
100%|██████████| 1/1 [00:00<00:00, 11.50it/s]
100%|██████████| 1/1 [00:00<00:00,  9.57it/s]
100%|██████████| 1/1 [00:00<00:00, 10.31it/s]
100%|██████████| 1/1 [00:00<00:00, 10.30it/s]
100%|██████████| 1/1 [00:00<00:00, 11.59it/s]
100%|██████████| 1/1 [00:00<00:00,

---------- Evaluation
- sum reward:  486.48
- precision@:  0.8174
- propfair:  0.8551
- ufg:  4.6817



100%|██████████| 1/1 [00:00<00:00, 39.50it/s]
100%|██████████| 1/1 [00:00<00:00, 40.26it/s]
100%|██████████| 1/1 [00:00<00:00, 37.48it/s]
100%|██████████| 1/1 [00:00<00:00, 35.88it/s]
100%|██████████| 1/1 [00:00<00:00, 34.60it/s]
100%|██████████| 1/1 [00:00<00:00, 43.60it/s]
100%|██████████| 1/1 [00:00<00:00, 80.53it/s]
100%|██████████| 1/1 [00:00<00:00, 40.49it/s]
100%|██████████| 1/1 [00:00<00:00, 40.23it/s]
100%|██████████| 1/1 [00:00<00:00, 17.91it/s]
100%|██████████| 1/1 [00:00<00:00,  8.59it/s]
100%|██████████| 1/1 [00:00<00:00, 11.24it/s]
100%|██████████| 1/1 [00:00<00:00, 10.57it/s]
100%|██████████| 1/1 [00:00<00:00, 20.61it/s]
100%|██████████| 1/1 [00:00<00:00, 11.50it/s]
100%|██████████| 1/1 [00:00<00:00, 10.20it/s]
100%|██████████| 1/1 [00:00<00:00, 10.35it/s]
100%|██████████| 1/1 [00:00<00:00, 10.37it/s]
100%|██████████| 1/1 [00:00<00:00, 11.47it/s]
100%|██████████| 1/1 [00:00<00:00, 10.29it/s]
100%|██████████| 1/1 [00:00<00:00, 10.15it/s]
100%|██████████| 1/1 [00:00<00:00,

---------- Evaluation
- sum reward:  505.26
- precision@:  0.8198
- propfair:  0.8604
- ufg:  4.7754



100%|██████████| 1/1 [00:00<00:00, 52.77it/s]
100%|██████████| 1/1 [00:00<00:00, 61.60it/s]
100%|██████████| 1/1 [00:00<00:00, 61.88it/s]
100%|██████████| 1/1 [00:00<00:00, 56.71it/s]
100%|██████████| 1/1 [00:00<00:00, 57.37it/s]
100%|██████████| 1/1 [00:00<00:00, 60.48it/s]
100%|██████████| 1/1 [00:00<00:00, 120.61it/s]
100%|██████████| 1/1 [00:00<00:00, 44.74it/s]
100%|██████████| 1/1 [00:00<00:00, 32.32it/s]
100%|██████████| 1/1 [00:00<00:00, 19.03it/s]
100%|██████████| 1/1 [00:00<00:00,  9.16it/s]
100%|██████████| 1/1 [00:00<00:00,  9.65it/s]
100%|██████████| 1/1 [00:00<00:00,  9.41it/s]
100%|██████████| 1/1 [00:00<00:00, 21.95it/s]
100%|██████████| 1/1 [00:00<00:00, 11.16it/s]
100%|██████████| 1/1 [00:00<00:00,  8.68it/s]
100%|██████████| 1/1 [00:00<00:00, 10.41it/s]
100%|██████████| 1/1 [00:00<00:00, 11.24it/s]
100%|██████████| 1/1 [00:00<00:00, 10.48it/s]
100%|██████████| 1/1 [00:00<00:00, 11.60it/s]
100%|██████████| 1/1 [00:00<00:00, 10.31it/s]
100%|██████████| 1/1 [00:00<00:00

---------- Evaluation
- sum reward:  468.2
- precision@:  0.8013
- propfair:  0.8493
- ufg:  4.2745



100%|██████████| 1/1 [00:00<00:00, 42.20it/s]
100%|██████████| 1/1 [00:00<00:00, 57.70it/s]
100%|██████████| 1/1 [00:00<00:00, 47.31it/s]
100%|██████████| 1/1 [00:00<00:00, 47.13it/s]
100%|██████████| 1/1 [00:00<00:00, 43.72it/s]
100%|██████████| 1/1 [00:00<00:00, 48.37it/s]
100%|██████████| 1/1 [00:00<00:00, 83.35it/s]
100%|██████████| 1/1 [00:00<00:00, 50.02it/s]
100%|██████████| 1/1 [00:00<00:00, 47.72it/s]
100%|██████████| 1/1 [00:00<00:00, 18.52it/s]
100%|██████████| 1/1 [00:00<00:00,  9.81it/s]
100%|██████████| 1/1 [00:00<00:00, 10.51it/s]
100%|██████████| 1/1 [00:00<00:00, 10.59it/s]
100%|██████████| 1/1 [00:00<00:00, 19.68it/s]
100%|██████████| 1/1 [00:00<00:00, 10.25it/s]
100%|██████████| 1/1 [00:00<00:00, 10.02it/s]
100%|██████████| 1/1 [00:00<00:00, 11.94it/s]
100%|██████████| 1/1 [00:00<00:00, 10.52it/s]
100%|██████████| 1/1 [00:00<00:00, 11.51it/s]
100%|██████████| 1/1 [00:00<00:00,  8.28it/s]
100%|██████████| 1/1 [00:00<00:00, 10.73it/s]
100%|██████████| 1/1 [00:00<00:00,

---------- Evaluation
- sum reward:  479.03
- precision@:  0.8103
- propfair:  0.8553
- ufg:  4.5076

0.855
0.8116
481.109
4.5454





In [None]:
print(round(np.mean(_propfair), 4))
print(round(np.mean(_precision), 4))
print(round(np.mean(_reward), 4))
print(round(np.mean(_ufg), 4))

In [9]:
_propfair

[0.8534648878974421,
 0.8597976954985628,
 0.8550059617766729,
 0.8496948042935616,
 0.8573013600447137,
 0.8551669763910876,
 0.8550675940525352,
 0.8603760957875722,
 0.8493064771195712,
 0.8553115818598706]

In [None]:
_precision

In [None]:
_reward

In [None]:
_ufg

In [12]:
print(round(_propfair[3], 4))
print(round(_precision[3], 4))
print(round(_reward[3], 4))
print(round(_ufg[3], 4))

0.8497
0.8142
470.7087
4.5724
