In [1]:
import torch
import argparse
from model_exp import SASRec_Exp
from utils import *
import os
import pytorch_lightning as pl
from peft import LoraConfig, get_peft_model, PeftModel
from tqdm import tqdm
import copy
from utils import evaluate_user

In [2]:
# pl.seed_everything(3407)
# seed = 3407
# torch.manual_seed(seed)
# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
with open('checkpoints/ml1m_popularity.json', 'r') as json_file:
    user_popu = json.load(json_file)

with open('checkpoints/ml1m_temprature.json', 'r') as json_file:
    user_temp = json.load(json_file)


In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='ml-1m',)
parser.add_argument('--batch_size', default=1024, type=int)
parser.add_argument('--lr', default=1e-3, type=float)
parser.add_argument('--maxlen', default=200, type=int)
parser.add_argument('--hidden_units', default=50, type=int)
parser.add_argument('--num_blocks', default=2, type=int)
parser.add_argument('--num_epochs', default=25, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.2, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--device', default='cpu', type=str)
args, _ = parser.parse_known_args()
u2i_index, i2u_index = build_index(args.dataset)
dataset = data_partition(args.dataset)
[user_train, user_valid, user_test, usernum, itemnum] = dataset

model = SASRec_Exp(6040, 3416, args).to(args.device)
model.load_state_dict(torch.load('checkpoints/base.pth', map_location=torch.device(args.device)), strict=True)

<All keys matched successfully>

In [4]:
test_users = np.random.randint(1,6041, 50)
test_users = np.array([1226,  699, 1001,  332, 5342, 5722, 4752, 4850, 2265, 5312, 3234,
        107, 1433, 2610, 4373, 6021, 5292, 5316, 1926,  938, 2274, 3189,
       3544, 4856, 4026, 3069, 5531, 2117, 6005, 3285, 3651, 4742, 5340,
       2001, 5649, 3033, 2218, 5963, 5752, 4942, 1093, 1539, 2256, 1201,
       4489, 3186, 2436, 5454,  456, 2995])

In [5]:
ndcg_ls = list()
ht_ls = list()
for user in test_users:
    ndcg, ht = evaluate_user(model,dataset,args, test_user=user)
    ndcg_ls.append(ndcg)
    ht_ls.append(ht)

In [6]:
### BASELINE
print(np.mean(ndcg_ls))
print(np.mean(ht_ls))

0.11560899612123408
0.22


In [7]:
### 
popu_ndcg_ls = []
temp_ndcg_ls = []
popu_ht_ls = []
temp_ht_ls = []
for user in test_users[:]:
    user_split_popu = user_popu[str(user)]
    user_split_temp = user_temp[str(user)]

    popu_dir = 'lora_checkpoint_popu/' + user_split_popu + '/'
    temp_dir = 'lora_checkpoint_temp/' + user_split_temp + '/'

    max_popu = max([float(i.split('=')[1]) for i in os.listdir(popu_dir)])
    max_temp = max([float(i.split('=')[1]) for i in os.listdir(temp_dir)])

    ckpt_popu = 'lora_checkpoint_popu/'  + user_split_popu + '/' + 'ndcg=' + str(max_popu)
    ckpt_temp = 'lora_checkpoint_temp/'  + user_split_temp + '/' + 'ndcg=' + str(max_temp)
    lora_popu = PeftModel.from_pretrained(copy.deepcopy(model), ckpt_popu).merge_and_unload()
    lora_temp = PeftModel.from_pretrained(copy.deepcopy(model), ckpt_temp).merge_and_unload()
    

    with torch.no_grad():
        popu_ndcg, popu_ht = evaluate_user(lora_popu, dataset, args, test_user= user)
        temp_ndcg, temp_ht = evaluate_user(lora_temp, dataset, args, test_user= user)
    
    popu_ndcg_ls.append(popu_ndcg)
    temp_ndcg_ls.append(temp_ndcg)
    popu_ht_ls.append(popu_ht)
    temp_ht_ls.append(temp_ht)


In [8]:
print('POPU')
print(np.mean(popu_ndcg_ls))
print(np.mean(popu_ht_ls))
print('TEMP')
print(np.mean(temp_ndcg_ls))
print(np.mean(temp_ht_ls))

POPU
0.11603785297076218
0.26
TEMP
0.1087449297755429
0.22


In [9]:
### Combine
from utils import evaluate_user_multiple_model
ALPHA = 0.5
combined_ndcg_ls = list()
combined_ht_ls = list()
for user in test_users[:]:
    user_split_popu = user_popu[str(user)]
    user_split_temp = user_temp[str(user)]

    popu_dir = 'lora_checkpoint_popu/' + user_split_popu + '/'
    temp_dir = 'lora_checkpoint_temp/' + user_split_temp + '/'

    max_popu = max([float(i.split('=')[1]) for i in os.listdir(popu_dir)])
    max_temp = max([float(i.split('=')[1]) for i in os.listdir(temp_dir)])

    ckpt_popu = 'lora_checkpoint_popu/'  + user_split_popu + '/' + 'ndcg=' + str(max_popu)
    ckpt_temp = 'lora_checkpoint_temp/'  + user_split_temp + '/' + 'ndcg=' + str(max_temp)
    lora_popu = PeftModel.from_pretrained(copy.deepcopy(model), ckpt_popu).merge_and_unload()
    lora_temp = PeftModel.from_pretrained(copy.deepcopy(model), ckpt_temp).merge_and_unload()
    
    ndcg, ht = evaluate_user_multiple_model([lora_popu, lora_temp], dataset, args, test_user=user, alpha=ALPHA)

    combined_ndcg_ls.append(ndcg)
    combined_ht_ls.append(ht)
    


print('Combined with Alpha = ', ALPHA)
print(np.mean(combined_ndcg_ls))
print(np.mean(combined_ht_ls))

Combined with Alpha =  0.5
0.12250258583351335
0.26
