In [1]:
import pandas as pd
from tqdm.auto import tqdm
tqdm.pandas()

test_df: pd.DataFrame = pd.read_pickle('data/test.pkl')

test_df = test_df.groupby('id').apply(lambda x: pd.Series({
    'history': x['history'].iloc[0],
    'positives': set(x['ID'])
}), include_groups=False)
test_df

Unnamed: 0_level_0,history,positives
id,Unnamed: 1_level_1,Unnamed: 2_level_1
1005_2024-11-06,"[[26, 0], [5, 1], [16, 1]]",{6}
1008_2024-11-04,"[[2, 1], [28, 0], [18, 1]]",{6}
1008_2024-11-05,"[[2, 1], [28, 0], [18, 1], [28, 0], [6, 1], [1...","{10, 30, 14}"
1008_2024-11-06,"[[2, 1], [28, 0], [18, 1], [28, 0], [6, 1], [1...","{0, 18, 11}"
1008_2024-11-07,"[[28, 0], [6, 1], [19, 0], [30, 1], [14, 1], [...","{5, 15}"
...,...,...
957_2024-11-06,"[[1, 0], [9, 0], [19, 0], [22, 0], [0, 0], [7,...",{2}
957_2024-11-07,"[[22, 0], [0, 0], [7, 0], [21, 0], [15, 0], [2...",{2}
972_2024-11-01,"[[21, 0], [17, 0], [0, 1], [12, 0], [26, 0], [...",{1}
972_2024-11-05,"[[20, 0], [25, 0], [10, 0], [22, 0], [25, 0], ...","{2, 21}"


In [2]:
import torch
import numpy as np
from src.gru4rec import GRU4Rec, DEVICE

model: GRU4Rec = torch.load('models/gru4rec.pth')
model.eval()
missions = pd.read_csv('data/missions.csv')

def rank(x):
    history = torch.from_numpy(x['history']).view(1, -1, 2).to(DEVICE)
    with torch.no_grad():
        scores = model(history).view(-1).detach().cpu().numpy()
    return np.argsort(scores)[::-1]

def random_rank(_):
    return np.random.permutation(len(missions))

test_df['gru4rec'] = test_df.progress_apply(rank, axis=1)
test_df['random'] = test_df.progress_apply(random_rank, axis=1)
test_df

  model: GRU4Rec = torch.load('models/gru4rec.pth')


  0%|          | 0/2074 [00:00<?, ?it/s]

  0%|          | 0/2074 [00:00<?, ?it/s]

Unnamed: 0_level_0,history,positives,gru4rec,random
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1005_2024-11-06,"[[26, 0], [5, 1], [16, 1]]",{6},"[30, 2, 4, 0, 31, 24, 25, 5, 12, 26, 6, 7, 13,...","[1, 8, 16, 9, 10, 17, 22, 29, 30, 24, 26, 21, ..."
1008_2024-11-04,"[[2, 1], [28, 0], [18, 1]]",{6},"[30, 2, 4, 0, 25, 31, 24, 5, 12, 26, 13, 14, 7...","[19, 25, 30, 20, 23, 8, 7, 16, 24, 12, 11, 2, ..."
1008_2024-11-05,"[[2, 1], [28, 0], [18, 1], [28, 0], [6, 1], [1...","{10, 30, 14}","[30, 2, 4, 24, 0, 5, 25, 31, 12, 26, 6, 13, 14...","[7, 30, 13, 29, 28, 9, 6, 20, 27, 15, 21, 12, ..."
1008_2024-11-06,"[[2, 1], [28, 0], [18, 1], [28, 0], [6, 1], [1...","{0, 18, 11}","[30, 2, 31, 4, 25, 5, 0, 24, 12, 13, 6, 7, 26,...","[2, 28, 16, 9, 13, 7, 27, 18, 22, 8, 29, 10, 6..."
1008_2024-11-07,"[[28, 0], [6, 1], [19, 0], [30, 1], [14, 1], [...","{5, 15}","[30, 2, 25, 31, 4, 24, 5, 0, 12, 13, 14, 26, 1...","[7, 20, 11, 4, 1, 21, 8, 22, 16, 30, 23, 29, 2..."
...,...,...,...,...
957_2024-11-06,"[[1, 0], [9, 0], [19, 0], [22, 0], [0, 0], [7,...",{2},"[2, 30, 4, 0, 5, 12, 31, 24, 25, 6, 26, 7, 13,...","[14, 20, 22, 5, 17, 6, 2, 29, 8, 26, 18, 1, 15..."
957_2024-11-07,"[[22, 0], [0, 0], [7, 0], [21, 0], [15, 0], [2...",{2},"[2, 30, 4, 0, 5, 12, 31, 24, 25, 7, 26, 6, 13,...","[28, 14, 10, 12, 17, 13, 22, 29, 26, 5, 21, 24..."
972_2024-11-01,"[[21, 0], [17, 0], [0, 1], [12, 0], [26, 0], [...",{1},"[2, 30, 4, 0, 31, 5, 12, 25, 24, 6, 7, 26, 13,...","[9, 24, 28, 1, 7, 2, 22, 4, 8, 6, 19, 20, 29, ..."
972_2024-11-05,"[[20, 0], [25, 0], [10, 0], [22, 0], [25, 0], ...","{2, 21}","[2, 30, 4, 0, 24, 12, 5, 31, 25, 26, 6, 13, 7,...","[12, 14, 10, 21, 29, 5, 30, 17, 3, 7, 0, 25, 2..."


In [3]:
def recall(x, model, k=5):
    top_k = set(x[model][:k])
    return len(x['positives'] & top_k) / len(x['positives'])

def hit_rate(x, model, k=5):
    top_k = set(x[model][:k])
    return len(x['positives'] & top_k) > 0

pd.DataFrame({
    'gru4rec': {
        'recall@5': test_df.progress_apply(lambda x: recall(x, 'gru4rec'), axis=1).mean(),
        'recall@10': test_df.progress_apply(lambda x: recall(x, 'gru4rec', k=10), axis=1).mean(),
        'hit_rate@5': test_df.progress_apply(lambda x: hit_rate(x, 'gru4rec'), axis=1).mean(),
        'hit_rate@10': test_df.progress_apply(lambda x: hit_rate(x, 'gru4rec', k=10), axis=1).mean()
    },
    'random': {
        'recall@5': test_df.progress_apply(lambda x: recall(x, 'random'), axis=1).mean(),
        'recall@10': test_df.progress_apply(lambda x: recall(x, 'random', k=10), axis=1).mean(),
        'hit_rate@5': test_df.progress_apply(lambda x: hit_rate(x, 'random'), axis=1).mean(),
        'hit_rate@10': test_df.progress_apply(lambda x: hit_rate(x, 'random', k=10), axis=1).mean()
    }
})

  0%|          | 0/2074 [00:00<?, ?it/s]

  0%|          | 0/2074 [00:00<?, ?it/s]

  0%|          | 0/2074 [00:00<?, ?it/s]

  0%|          | 0/2074 [00:00<?, ?it/s]

  0%|          | 0/2074 [00:00<?, ?it/s]

  0%|          | 0/2074 [00:00<?, ?it/s]

  0%|          | 0/2074 [00:00<?, ?it/s]

  0%|          | 0/2074 [00:00<?, ?it/s]

Unnamed: 0,gru4rec,random
recall@5,0.34116,0.158052
recall@10,0.550852,0.320869
hit_rate@5,0.552555,0.296046
hit_rate@10,0.842334,0.535198
