In [1]:
import sys
import collections
import torch
from dqn_agent_pytorch import DQNAgent
import numpy as np
import os
import random
import time
from copy import deepcopy
import logging
import torch.nn as nn
from env.hgnn import hgnn_env

def seed(random_seed):
    torch.manual_seed(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    print(random_seed)

def get_logger(logger_name, log_file, level=logging.INFO):
    l = logging.getLogger(logger_name)
    formatter = logging.Formatter('%(asctime)s : %(message)s', "%Y-%m-%d %H:%M:%S")
    fileHandler = logging.FileHandler(log_file, mode='a')
    fileHandler.setFormatter(formatter)
    l.setLevel(level)
    l.addHandler(fileHandler)
    return logging.getLogger(logger_name)

def use_pretrain(env, dataset='yelp_data'):
    if dataset == 'yelp_data':
        print('./data/yelp_data/embedding/user.embedding_' + str(env.data.entity_dim))
        fr1 = open('./data/yelp_data/embedding/user.embedding_' + str(env.data.entity_dim), 'r')
        fr2 = open('./data/yelp_data/embedding/item.embedding_' + str(env.data.entity_dim), 'r')
    elif dataset == 'douban_movie':
        print('./data/douban_movie/embedding/user.embedding_' + str(env.data.entity_dim))
        fr1 = open('./data/douban_movie/embedding/user.embedding_' + str(env.data.entity_dim), 'r')
        fr2 = open('./data/douban_movie/embedding/item.embedding_' + str(env.data.entity_dim), 'r')
    else:
        print('./data/' + dataset + '/embedding/user.embedding_' + str(env.data.entity_dim))
        fr1 = open('./data/' + dataset + '/embedding/user.embedding_' + str(env.data.entity_dim), 'r')
        fr2 = open('./data/' + dataset + '/embedding/item.embedding_' + str(env.data.entity_dim), 'r')

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    emb = env.train_data.x
    emb.requires_grad = False

    for line in fr1.readlines():
        embeddings = line.strip().split()
        id, embedding = int(embeddings[0]), embeddings[1:]
        embedding = list(map(float, embedding))
        emb[id] = torch.tensor(embedding)

    for line in fr2.readlines():
        embeddings = line.strip().split()
        id, embedding = int(embeddings[0]), embeddings[1:]
        embedding = list(map(float, embedding))
        emb[id] = torch.tensor(embedding)

    env.train_data.x = emb.to(device)
    

In [2]:
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Run HGNN.")

    parser.add_argument('--local_rank', type=int, default=0,
                        help='Local rank for using multi GPUs.')

    parser.add_argument('--seed', type=int, default=123,
                        help='Random seed.')

    parser.add_argument('--task', nargs='?', default='rec',
                        help='Choose a task from {rec, herec, mcrec, classification}')

    parser.add_argument('--data_name', nargs='?', default='TCL',
                        help='Choose a dataset from {yelp_data, douban_movie, TCL}')
    parser.add_argument('--data_dir', nargs='?', default='data/',
                        help='Input data path.')

    parser.add_argument('--use_pretrain', type=int, default=0,
                        help='0: No pretrain, 1: Pretrain with the learned embeddings, 2: Pretrain with stored model.')
    parser.add_argument('--pretrain_embedding_dir', nargs='?', default='datasets/pretrain/',
                        help='Path of learned embeddings.')
    parser.add_argument('--pretrain_model_path', nargs='?', default='trained_model/model.pth',
                        help='Path of stored model.')

    parser.add_argument('--cf_batch_size', type=int, default=90000,
                        help='CF batch size.')
    parser.add_argument('--kg_batch_size', type=int, default=10000,
                        help='KG batch size.')
    parser.add_argument('--nd_batch_size', type=int, default=5000,
                        help='node sampling batch size.')
    parser.add_argument('--rl_batch_size', type=int, default=1,
                        help='RL training batch size.')
    parser.add_argument('--train_batch_size', type=int, default=2000,
                        help='Eval batch size (the user number to test every batch).')
    parser.add_argument('--test_batch_size', type=int, default=20000,
                        help='Test batch size (the user number to test every batch).')

    parser.add_argument('--entity_dim', type=int, default=64,
                        help='User / entity Embedding size.')
    parser.add_argument('--relation_dim', type=int, default=32,
                        help='Relation Embedding size.')

    parser.add_argument('--aggregation_type', nargs='?', default='bi-interaction',
                        help='Specify the type of the aggregation layer from {gcn, graphsage, bi-interaction}.')
    parser.add_argument('--conv_dim_list', nargs='?', default='[64, 32, 16]',
                        help='Output sizes of every aggregation layer.')
    parser.add_argument('--mess_dropout', nargs='?', default='[0.1, 0.1, 0.1]',
                        help='Dropout probability w.r.t. message dropout for each deep layer. 0: no dropout.')

    parser.add_argument('--kg_l2loss_lambda', type=float, default=1e-5,
                        help='Lambda when calculating KG l2 loss.')
    parser.add_argument('--cf_l2loss_lambda', type=float, default=1e-5,
                        help='Lambda when calculating CF l2 loss.')

    parser.add_argument('--lr', type=float, default=0.01,
                        help='Learning rate.')
    parser.add_argument('--n_epoch', type=int, default=1000,
                        help='Number of epoch.')
    parser.add_argument('--stopping_steps', type=int, default=10,
                        help='Number of epoch for early stopping')

    parser.add_argument('--limit', type=float, default=1000,
                        help='Time Limit.')

    parser.add_argument('--cf_print_every', type=int, default=1,
                        help='Iter interval of printing CF loss.')
    parser.add_argument('--kg_print_every', type=int, default=1,
                        help='Iter interval of printing KG loss.')
    parser.add_argument('--evaluate_every', type=int, default=1,
                        help='Epoch interval of evaluating CF.')

    parser.add_argument('--K', type=int, default=20,
                        help='Calculate metric@K when evaluating.')
    parser.add_argument('--episode', type=int, default=20,
                        help='episode')
    parser.add_argument('--feats-type', type=int, default=2,
                        help='Type of the node features used. ' +
                             '0 - loaded features; ' +
                             '1 - only target node features (zero vec for others); ' +
                             '2 - only target node features (id vec for others); ' +
                             '3 - all id vec. Default is 2.')
    parser.add_argument('--layers', type=int, default=2, help='Number of layers. Default is 2.')
    parser.add_argument('--hidden-dim', type=int, default=64, help='Dimension of the node hidden state. Default is 64.')
    parser.add_argument('--num-heads', type=list, default=[4], help='Number of the attention heads. Default is 8.')
    parser.add_argument('--attn-vec-dim', type=int, default=128,
                        help='Dimension of the attention vector. Default is 128.')
    parser.add_argument('--rnn-type', default='RotatE0', help='Type of the aggregator. Default is RotatE0.')
    parser.add_argument('--epoch', type=int, default=100, help='Number of epochs. Default is 100.')
    parser.add_argument('--patience', type=int, default=10, help='Patience. Default is 10.')
    parser.add_argument('--repeat', type=int, default=1,
                        help='Repeat the training and testing for N times. Default is 1.')
    parser.add_argument('--log', default='',
                        help='Name in log')
    parser.add_argument('--mpset', default="[[['2', '1']], [['1', '2']]]",
                        help='Meta-path Set.')
    parser.add_argument('--init', default="RL",
                        help='Meta-path Set initialization method.')

    args = parser.parse_args()

    save_dir = 'trained_model/HGNN/{}/entitydim{}_relationdim{}_{}_{}_lr{}_pretrain{}/'.format(
        args.data_name, args.entity_dim, args.relation_dim, args.aggregation_type,
        '-'.join([str(i) for i in eval(args.conv_dim_list)]), args.lr, args.use_pretrain)
    args.save_dir = save_dir

    return args

In [3]:
def train_agent(args):
    seed(0)
    print("Seed set to 0\n")
    tim1 = time.time()
    torch.backends.cudnn.deterministic = True

    dataset = args.data_name
    max_timesteps = 4 if args.task == 'rec' else 3
    if args.task == 'mcrec':
        max_timesteps = 5

    infor = 'rl_' + str(args.data_name) + '_' + str(args.task) + '_' + str(args.log)
    model_name = 'model/' + infor + '.pth'

    episode = int(args.episode)
    u_max_episodes = episode
    i_max_episodes = episode

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    logger1 = get_logger('log', 'log/logger_' + infor + '.log')
    logger2 = get_logger('log2', 'log/logger2_' + infor + '.log')

    env = hgnn_env(logger1, logger2, model_name, args, dataset=dataset)
    use_pretrain(env, dataset)

    user_agent = DQNAgent(scope='dqn',
                          action_num=env.action_num,
                          replay_memory_size=int(1e4),
                          replay_memory_init_size=500,
                          norm_step=1,
                          batch_size=1,
                          state_shape=env.obs.shape,
                          mlp_layers=[32, 64, 32],
                          learning_rate=0.001,
                          device=torch.device(device)
                          )
    env.user_policy = user_agent
    best_user_val = 0.0
    best_user_i = 0

    for i_episode in range(1, u_max_episodes + 1):
        env.reset_past_performance()
        loss, reward, (val_acc, reward) = user_agent.user_learn(logger1, logger2, env, max_timesteps)
        logger2.info("Generated meta-path set: %s" % str(env.etypes_lists))
        print("Generated meta-path set: %s\n" % str(env.etypes_lists))
        if val_acc > best_user_val:
            best_user_policy = deepcopy(user_agent)
            best_user_val = val_acc
            best_user_i = i_episode
        logger2.info("Training User Meta-policy: %d    Val_Acc: %.5f    Avg_reward: %.5f    Best_Acc:  %.5f    Best_i: %d "
                     % (i_episode, val_acc, reward, best_user_val, best_user_i))
        logger2.info("")  # 添加空行
        print("")  # 添加空行
        for i in range(4):
            user_agent.train()

    tim_1 = time.time()
    for i in range(50):
        user_agent.train()

    if args.task != 'mcrec':
        item_agent = DQNAgent(scope='dqn',
                              action_num=env.action_num,
                              replay_memory_size=int(1e4),
                              replay_memory_init_size=500,
                              norm_step=1,
                              batch_size=1,
                              state_shape=env.obs.shape,
                              mlp_layers=[32, 64, 32],
                              learning_rate=0.001,
                              device=torch.device(device)
                              )
        env.item_policy = item_agent
        best_item_val = 0.0
        best_item_i = 0

        logger2.info("Training Meta-policy on Validation Set")

        for i_episode in range(1, i_max_episodes + 1):
            env.reset_past_performance()
            loss, reward, (val_acc, reward) = item_agent.item_learn(logger1, logger2, env, max_timesteps)
            logger2.info("Generated meta-path set: %s" % str(env.etypes_lists))
            print("Generated meta-path set: %s\n" % str(env.etypes_lists))
            if val_acc > best_item_val:
                best_item_policy = deepcopy(item_agent)
                best_item_val = val_acc
                best_item_i = i_episode
            logger2.info("Training Item Meta-policy: %d    Val_Acc: %.5f    Avg_reward: %.5f    Best_Acc:  %.5f    Best_i: %d "
                         % (i_episode, val_acc, reward, best_item_val, best_item_i))
            logger2.info("")  # 添加空行
            print("")  # 添加空行
            for i in range(4):
                item_agent.train()
        for i in range(50):
            item_agent.train()

    print('Reinforced training time: ', time.time() - tim_1, 's\n')

    tim2 = time.time()

    print("RL agent training time: ", (tim2 - tim1) / 60, "min\n")

    os.makedirs('model', exist_ok=True)
    torch.save({'q_estimator_qnet_state_dict': env.user_policy.q_estimator.qnet.state_dict(),
                'target_estimator_qnet_state_dict': env.user_policy.target_estimator.qnet.state_dict(),
                'Val': best_user_val},
               'model/a-best-user-' + str(best_user_val) + '-' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '.pth.tar')

    if args.task != 'mcrec':
        os.makedirs('model', exist_ok=True)
        torch.save({'q_estimator_qnet_state_dict': env.item_policy.qnet.state_dict(),
                    'target_estimator_qnet_state_dict': env.item_policy.target_estimator.qnet.state_dict(),
                    'Val': best_item_val},
                   'model/a-best-item-' + str(best_item_val) + '-' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '.pth.tar')


In [4]:
def test_model(args):
    # 这里实现测试逻辑，加载模型并在测试数据集上进行评估
    pass


In [5]:
if __name__ == '__main__':
    try:
        args = parse_args()
        best_user_val, best_item_val = train_model(args)
        print(f"Best User Val: {best_user_val}, Best Item Val: {best_item_val}")
    except Exception as e:
        print(e)


usage: ipykernel_launcher.py [-h] [--local_rank LOCAL_RANK] [--seed SEED]
                             [--task [TASK]] [--data_name [DATA_NAME]]
                             [--data_dir [DATA_DIR]]
                             [--use_pretrain USE_PRETRAIN]
                             [--pretrain_embedding_dir [PRETRAIN_EMBEDDING_DIR]]
                             [--pretrain_model_path [PRETRAIN_MODEL_PATH]]
                             [--cf_batch_size CF_BATCH_SIZE]
                             [--kg_batch_size KG_BATCH_SIZE]
                             [--nd_batch_size ND_BATCH_SIZE]
                             [--rl_batch_size RL_BATCH_SIZE]
                             [--train_batch_size TRAIN_BATCH_SIZE]
                             [--test_batch_size TEST_BATCH_SIZE]
                             [--entity_dim ENTITY_DIM]
                             [--relation_dim RELATION_DIM]
                             [--aggregation_type [AGGREGATION_TYPE]]
                            

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
 ################################################################
    #max_timesteps = 4 if args.task == 'rec' else 3
   # if args.task == 'mcrec':
     #   max_timesteps = 5
        
        
    # 调试阶段，使用较小的时间步数和训练周期数
    max_timesteps = 2 if args.task == 'rec' else 1
    if args.task == 'mcrec':
        max_timesteps = 3

    episode = int(args.episode) // 10  # 将训练周期数减少10倍
    u_max_episodes = episode
    i_max_episodes = episode
