# 准备

In [1]:
import torch
import csv

from model import SASRec
from utils import *

import pandas as pd
from tqdm import tqdm

from datetime import datetime

# 加载
请根据日志文件夹中的args.txt文件配置args，避免形状不匹配的情况

In [2]:
class args():
    def __init__(self):
        self.dataset = 'goodbooks'
        self.batch_size = 512
        self.lr = 0.001
        self.maxlen = 200
        self.hidden_units = 200
        self.num_blocks = 2
        self.num_epochs = 60
        self.num_heads  = 4
        self.dropout_rate = 0.2
        self.l2_emb = 0.0
        self.inference_only = False
        self.state_dict_path = None
        self.fixed_position_encode = True
        self.device = 'cuda'

args = args()

此处需要选择恰当的模型权重路径

In [3]:
model = SASRec(53424, 10000, args).to(args.device)
model.load_state_dict(torch.load('./goodbooks_2024-07-20_14-34/SASRec.lr=0.001.layer=2.head=4.hidden=200.maxlen=200.block=2.pth', map_location=torch.device(args.device)))
model.eval()

SASRec(
  (item_emb): Embedding(10001, 200, padding_idx=0)
  (pos_emb): Embedding(200, 200)
  (emb_dropout): Dropout(p=0.2, inplace=False)
  (attention_layernorms): ModuleList(
    (0-1): 2 x LayerNorm((200,), eps=1e-08, elementwise_affine=True)
  )
  (attention_layers): ModuleList(
    (0-1): 2 x MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=200, out_features=200, bias=True)
    )
  )
  (forward_layernorms): ModuleList(
    (0-1): 2 x LayerNorm((200,), eps=1e-08, elementwise_affine=True)
  )
  (forward_layers): ModuleList(
    (0-1): 2 x PointWiseFeedForward(
      (conv1): Conv1d(200, 200, kernel_size=(1,), stride=(1,))
      (dropout1): Dropout(p=0.2, inplace=False)
      (relu): ReLU()
      (conv2): Conv1d(200, 200, kernel_size=(1,), stride=(1,))
      (dropout2): Dropout(p=0.2, inplace=False)
    )
  )
  (last_layernorm): LayerNorm((200,), eps=1e-08, elementwise_affine=True)
)

In [4]:
dataset = data_partition(args.dataset)
t_test = evaluate(model, dataset, args)
print('test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1]))

....................................................................................................test (NDCG@10: 0.6373, HR@10: 0.8526)


In [5]:
usernum = 0
itemnum = 0
User = defaultdict(list) # User存储比赛给出的所有训练数据
user_train = {}
user_valid = {}
user_test = {}
# assume user/item index starting from 1
f = open('../data/%s.txt' % 'goodbooks', 'r')
for line in f:
    u, i = line.rstrip().split(',')
    u = int(u)
    i = int(i)
    usernum = max(u, usernum)
    itemnum = max(i, itemnum)
    User[u].append(i)

In [6]:
user_test = defaultdict(list)
for i in tqdm(range(1, usernum+1)):
    j = list(range(1, itemnum+1))
    # 将没有在训练数据User中出现的item加入到user_test中
    user_test[i].append(list(set(j).difference(User[i])))
    user_test[i] = user_test[i][0]

100%|██████████| 53424/53424 [00:46<00:00, 1145.33it/s]


In [7]:
with open('./submission'+datetime.now().strftime("%Y-%m-%d_%H-%M")+'.csv', 'ab') as f:
    f.write('user_id,item_id\n'.encode())
    # 训练集与测试集的最大长度
    maxlen_te = max(len(user_test[i]) for i in range(1, usernum+1))
    maxlen_tr = max(len(User[i]) for i in range(1, usernum+1))
    for i in tqdm(range(1, usernum+1)):
        seq = np.zeros([maxlen_tr], dtype=np.int32)
        idx = maxlen_tr - 1
        # 将给定的训练数据倒序存储到交互序列seq中
        for j in reversed(User[i]):
            seq[idx] = j
            idx -= 1
            if idx == -1: break
        # 待预测的item列表，存储在之前处理好的user_test列表中
        item_idx = user_test[i]
        p = [np.array(l) for l in [[i], [seq], item_idx]]
        # 计算所有待预测item的得分，取前10个item作为推荐结果
        predictions = -model.predict(*p)
        predictions = predictions[0]
        a = predictions.argsort()[:10]
        a = a.cpu().numpy()
        r = np.array(item_idx)
        s = r[a]
        u = np.full(shape=10, fill_value=i, dtype=np.int)
        # 输出，记得我们之前把用户和item的索引都加了1，现在要减回去
        pre = np.c_[u-1,s-1]
        np.savetxt(f, pre, delimiter=',', fmt='%i')
# df = pd.read_csv('./sub_f.csv',header=None,names=['user_id', 'item_id'])
# df.to_csv('./submission_1.csv',index=False)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  u = np.full(shape=10, fill_value=i, dtype=np.int)
100%|██████████| 53424/53424 [05:20<00:00, 166.71it/s]
