In [10]:
import json
import pandas as pd
import torch
from torch.utils.data.dataloader import DataLoader
import numpy as np
from models.sasrec_base import SASRec
from utils import *
from tqdm import tqdm
import argparse

def test_variance(model, train_ds, test_ds, split_dict):
    all_splits = set(list(split_dict.values()))
    ndcg_list = []
    ht_list = []
    for split in all_splits:
        _, lora_test = get_lora_train_test_ds(split_dict, train_ds, test_ds, split)
        test_loader  = DataLoader(lora_test, batch_size = 384, shuffle = False, collate_fn = collate_test)
        model.eval()
        ndcg, ht = 0, 0
        with torch.no_grad():
            for test_batch in test_loader:
                u, seq, pos, test_items, mask = test_batch
                batch_ndcg, batch_ht = eval_step(model, u, seq, pos, test_items, mask, topK = 10)
                ndcg += batch_ndcg
                ht += batch_ht
        ndcg /= len(test_ds)
        ht /= len(test_ds)
        ndcg_list.append(ndcg)
        ht_list.append(ht)
    return np.std(ndcg_list) * 1000



num_u, num_i = get_usr_itm_num('ml-1m')
train, test = load_train_test_data_num(load_txt_file('ml-1m'), num_i)
model = SASRec(user_num = num_u, item_num = num_i, maxlen = 200, num_blocks = 2, num_heads = 1, hidden_units = 50, dropout_rate = 0.2, device = 'cpu')
model.load_state_dict(torch.load(f'checkpoints/{'ml-1m'}-base.pth', map_location=torch.device('cpu')))
model = model.to('cpu')

split_dict_pop = json.load(open('config/popularity.json'))
split_dict_tmp = json.load(open('config/temperature.json'))


Processing Users: 100%|██████████| 6040/6040 [00:46<00:00, 130.28it/s]


**Group by Popularity**

In [11]:
test_variance(model, train, test, split_dict_pop)

2.291597902089782

**Group by Sequence Length**

In [12]:
test_variance(model, train, test, split_dict_tmp)

3.2364433656873937

**Group By Random**

In [30]:
test_variance(model, train, test, dict(zip(np.arange(1, 6041), ['split_' + str(i) for i in np.random.permutation(np.array([[np.arange(1,11)] * 604]).flatten())])))

0.8288052669234953

**Group By Intention Shift**

In [69]:
df = pd.DataFrame(load_txt_file('ml-1m').items(), columns=["userID", "itemID"])
df = df.explode("itemID").reset_index(drop=True)
item_pop = df['itemID'].value_counts()
low, mid, high = item_pop.quantile([0.33, 0.66, 1.0])
labels = [-1, 0, 1]
item_pop= pd.cut(item_pop, bins=[-1, low, mid, high], labels=labels, include_lowest=True)
df = df.groupby('userID')['itemID'].agg(pop = lambda x: np.var(np.array([item_pop[i] for i in x])))
df["split"] = pd.qcut(df["pop"], q=10, labels=[f"split_{i}" for i in range(1, 11)])
intention_shift_dict = {str(k):v for k,v in df["split"].to_dict().items()}

In [71]:
test_variance(model, train, test, intention_shift_dict)

2.0848147916681903