In [None]:
cd drive/MyDrive/JAIST/Research/RL/movie-lens-simulator

In [None]:
!make style
!make setup

In [None]:
from src.data.movielens1M import LoadMovielens1M, PreprocessMovielens1M
from src.dataloader.movielens import MovieLensDataset
from src.dataloader.dataloader import get_MovieLensDataloaders
from src.model.embedding import Simulator, InitialRecommender
from src.trainer.rl_trainer import SAC_trainer
from src.trainer.embedding_trainer import Embedding_Trainer
from src.recommend.recommend_handler import RecommendHandler
from src.evaluation.precision_oriented import PrecisionOrientedMetrics
from src.evaluation.diversity_oriented import DiversityOrientedMetrics

from configs.config import CFG_DICT
from src.utils.file_handler import File_Handler

from src.agent.actor_critic import ActorCriticAgent
from src.agent.soft_actor_critic import SAC


import numpy as np
import torch
import pandas as pd

file_handler = File_Handler(log_dir='SAC/exp3')
logger = file_handler.make_logger_file('src')

In [None]:
load_rawML1M = LoadMovielens1M(path=CFG_DICT['DATASET']['PATH'])
raw_datas = load_rawML1M.load_data()

preprocessML1M = PreprocessMovielens1M(raw_datas)
preprocess_datas = preprocessML1M.apply_prepocess()

ratings, users, movies, histories = preprocess_datas

In [None]:
#preprocess_datas = ratings[ratings.UserID < 100], users, movies, histories
preprocess_datas = ratings, users, movies, histories

device = 'cuda'

recommend_handler = RecommendHandler(
    initial_data = preprocess_datas,
    topK = CFG_DICT['SIMULATION']['topK'],
    device = device,
    file_handler = file_handler,
    )

user_w = './log/train_usr/random_split/user_simulator.pkl'
user_simulator = Simulator(emb_size=CFG_DICT['USER_SIMULATOR']['DIM_EMB']).to(device)
user_simulator.load_state_dict(torch.load(user_w))

initial_recommender = InitialRecommender(emb_size=CFG_DICT['INITIAL_RECOMMENDER']['DIM_EMB']).to(device)

agent = SAC(device=device)

log_metrics = pd.DataFrame()


for time in np.arange(CFG_DICT['SIMULATION']['EPOCH']):

    ratings, _, _, _ = preprocess_datas

    logger.info('-'*50)
    logger.info('time : {}'.format(time))
    logger.info('data len : {}'.format(len(ratings)))


    ### recommend
    if time == 0:
        initial_recommender_datas = get_MovieLensDataloaders(
            model_type = 'INITIAL_RECOMMENDER',
            processed_data = preprocess_datas,
            data_split = 'random'
        )

        initial_recommender_trainer = Embedding_Trainer(
            model_type = 'INITIAL_RECOMMENDER',
            model = initial_recommender,
            loaders = initial_recommender_datas,
            device = device
            )

        initial_recommender = initial_recommender_trainer.train()
        file_handler.save(initial_recommender.state_dict(), 'initial_recommender_{}.pkl'.format(time))

        recommend_handler.register_models((initial_recommender, user_simulator))
        recommend_handler.recommend(time=time)
    
    else:

        agent, recommend_handler = SAC_trainer(
            recommend_handler = recommend_handler,
            agent = agent,
            device = device,
            bs = CFG_DICT['REPLAY_BUFFER']['BS'],
            epoch = CFG_DICT['RL']['EPOCH']
        )

        recommend_handler.register_models((agent, user_simulator))
        recommend_handler.recommend(time=time)

    # evaluation
    pom = PrecisionOrientedMetrics(recommend_handler.pred_ratings, topK=CFG_DICT['SIMULATION']['topK'])

    MAP = pom.MAP()
    logger.info("MAP : {}".format(MAP))

    mean_Precision = pom.mean_Precision()
    logger.info("MEAN_PRECISION : {}".format(mean_Precision))

    MRR = pom.mrr()
    logger.info("MRR : {}".format(MRR))

    dom = DiversityOrientedMetrics(recommend_handler.pred_ratings, recommend_handler.ratings, topK=CFG_DICT['SIMULATION']['topK'])
    
    DIVERSITY = dom.diversity_score()
    logger.info("DIVERSITY : {}".format(DIVERSITY))

    SERENDIPITY = dom.serendipity_score()
    logger.info("SERENDIPITY : {}".format(SERENDIPITY))

    NOVELTY = dom.novelty_score()
    logger.info("NOVELTY : {}".format(NOVELTY))

    UNIQUENESS = dom.uniqueness_score()
    logger.info("UNIQUENESS : {}".format(UNIQUENESS))

    log_metrics = log_metrics.append({
        'MAP': MAP, 
        'MEAN_PRECISION': mean_Precision, 
        'MRR': MRR,
        'DIVERSITY': DIVERSITY, 
        'SERENDIPITY': SERENDIPITY, 
        'NOVELTY': NOVELTY,
        'UNIQUENESS' : UNIQUENESS,
        }, ignore_index=True)
    
    file_handler.save(log_metrics, 'log_metrics.csv')

    # update ratings
    recommend_handler.update_ratings()
    del preprocess_datas

    preprocess_datas = (recommend_handler.ratings, users, movies, recommend_handler.history)