In [1]:
import sys
sys.path.append('../implementation/')
import ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.special as sp
from tqdm import tqdm
import time
from zhou_analytic_focus import AnalyticFocusModel
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Loading the STL Crimes underlying data and user interaction data
underlying_data = pd.read_csv('../data/stl_crimes/dots.csv')
underlying_data.set_index('id', drop=True, inplace=True)
output_file_path = '../output/stl/stl_map_results_af.pkl'

interaction_data = pd.read_csv('../data/stl_crimes/stl_combined_interactions.csv')
interaction_data['interaction_session'] = interaction_data.apply(lambda row: ast.literal_eval(row.interaction_session), axis=1)
interaction_data['interaction_type_session'] = interaction_data.apply(lambda row: ast.literal_eval(row.interaction_type_session), axis=1)
ks = [1, 5, 10, 20, 50, 100]

In [3]:
af_results = pd.DataFrame()

for participant_index, row in interaction_data.iterrows():
    print(f'Processing user {row.user} task {row.task}')
    results = {'participant_id': row.user, 'task': row.task}
    model = AnalyticFocusModel(underlying_data, [['x','y']], ['type'])
    predicted = pd.DataFrame()
    rank_predicted = []
    for i in tqdm(range(len(interaction_data.iloc[participant_index].interaction_session))):
        interaction = interaction_data.iloc[participant_index].interaction_session[i]
        model.update(interaction)

        if i < len(interaction_data.iloc[participant_index].interaction_session) - 1:
            probability_of_next_point = model.predict()
            next_point = interaction_data.iloc[participant_index].interaction_session[i+1]
            predicted_next_dict = {}
            for k in ks:
                predicted_next_dict[k] = (next_point in probability_of_next_point.nlargest(k).index.values)
            predicted = predicted.append(predicted_next_dict, ignore_index=True)
            sorted_prob = probability_of_next_point.sort_values(ascending=False)
            rank, = np.where(sorted_prob.index.values == next_point)
            rank_predicted.append(rank[0] + 1)
            
    ncp = predicted.sum()/len(predicted)
    for col in ncp.index:
        results[f'ncp-{col}'] = ncp[col]
    results['rank'] = rank_predicted    
    af_results = af_results.append(results, ignore_index=True)
    
af_results.to_pickle(output_file_path)

Processing user 15 task geo-based


100%|██████████| 54/54 [00:00<00:00, 67.53it/s]


Processing user 14 task geo-based


100%|██████████| 46/46 [00:00<00:00, 81.98it/s]


Processing user 28 task geo-based


100%|██████████| 47/47 [00:00<00:00, 79.51it/s]


Processing user 16 task geo-based


100%|██████████| 44/44 [00:00<00:00, 77.12it/s]


Processing user 17 task geo-based


100%|██████████| 50/50 [00:00<00:00, 84.63it/s]


Processing user 13 task geo-based


100%|██████████| 31/31 [00:00<00:00, 82.19it/s]


Processing user 12 task geo-based


100%|██████████| 34/34 [00:00<00:00, 85.06it/s]


Processing user 10 task geo-based


100%|██████████| 34/34 [00:00<00:00, 79.89it/s]


Processing user 11 task geo-based


100%|██████████| 28/28 [00:00<00:00, 77.69it/s]


Processing user 9 task geo-based


100%|██████████| 29/29 [00:00<00:00, 80.41it/s]


Processing user 8 task geo-based


100%|██████████| 41/41 [00:00<00:00, 79.05it/s]


Processing user 5 task geo-based


100%|██████████| 22/22 [00:00<00:00, 71.87it/s]


Processing user 4 task geo-based


100%|██████████| 32/32 [00:00<00:00, 71.92it/s]


Processing user 6 task geo-based


100%|██████████| 31/31 [00:00<00:00, 66.23it/s]


Processing user 7 task geo-based


100%|██████████| 35/35 [00:00<00:00, 78.62it/s]


Processing user 3 task geo-based


100%|██████████| 32/32 [00:00<00:00, 66.99it/s]


Processing user 2 task geo-based


100%|██████████| 33/33 [00:00<00:00, 68.65it/s]


Processing user 1 task geo-based


100%|██████████| 28/28 [00:00<00:00, 69.39it/s]


Processing user 20 task geo-based


100%|██████████| 46/46 [00:00<00:00, 73.97it/s]


Processing user 21 task geo-based


100%|██████████| 50/50 [00:00<00:00, 79.88it/s]


Processing user 23 task geo-based


100%|██████████| 47/47 [00:00<00:00, 80.78it/s]


Processing user 22 task geo-based


100%|██████████| 49/49 [00:00<00:00, 58.18it/s]


Processing user 26 task geo-based


100%|██████████| 46/46 [00:00<00:00, 66.03it/s]


Processing user 27 task geo-based


100%|██████████| 46/46 [00:00<00:00, 68.78it/s]


Processing user 25 task geo-based


100%|██████████| 57/57 [00:00<00:00, 66.03it/s]


Processing user 19 task geo-based


100%|██████████| 47/47 [00:00<00:00, 73.11it/s]


Processing user 18 task geo-based


100%|██████████| 50/50 [00:00<00:00, 70.14it/s]


Processing user 24 task geo-based


100%|██████████| 48/48 [00:00<00:00, 64.52it/s]


Processing user 15 task mixed


100%|██████████| 34/34 [00:00<00:00, 51.16it/s]


Processing user 14 task mixed


100%|██████████| 39/39 [00:00<00:00, 68.83it/s]


Processing user 16 task mixed


100%|██████████| 38/38 [00:00<00:00, 67.58it/s]


Processing user 17 task mixed


100%|██████████| 35/35 [00:00<00:00, 69.89it/s]


Processing user 13 task mixed


100%|██████████| 57/57 [00:01<00:00, 56.09it/s]


Processing user 12 task mixed


100%|██████████| 121/121 [00:01<00:00, 70.19it/s]


Processing user 10 task mixed


100%|██████████| 99/99 [00:01<00:00, 60.36it/s]


Processing user 11 task mixed


100%|██████████| 96/96 [00:01<00:00, 68.55it/s]


Processing user 9 task mixed


100%|██████████| 89/89 [00:01<00:00, 60.51it/s]


Processing user 8 task mixed


100%|██████████| 94/94 [00:01<00:00, 70.15it/s]


Processing user 5 task mixed


100%|██████████| 101/101 [00:01<00:00, 71.64it/s]


Processing user 4 task mixed


100%|██████████| 118/118 [00:01<00:00, 70.07it/s]


Processing user 6 task mixed


100%|██████████| 99/99 [00:01<00:00, 64.75it/s]


Processing user 7 task mixed


100%|██████████| 108/108 [00:01<00:00, 71.07it/s]


Processing user 3 task mixed


100%|██████████| 87/87 [00:01<00:00, 62.88it/s]


Processing user 2 task mixed


100%|██████████| 90/90 [00:01<00:00, 65.52it/s]


Processing user 1 task mixed


100%|██████████| 75/75 [00:01<00:00, 66.78it/s]


Processing user 20 task mixed


100%|██████████| 47/47 [00:00<00:00, 68.40it/s]


Processing user 21 task mixed


100%|██████████| 43/43 [00:00<00:00, 65.89it/s]


Processing user 23 task mixed


100%|██████████| 17/17 [00:00<00:00, 59.74it/s]


Processing user 22 task mixed


100%|██████████| 37/37 [00:00<00:00, 65.61it/s]


Processing user 26 task mixed


100%|██████████| 38/38 [00:00<00:00, 61.92it/s]


Processing user 27 task mixed


100%|██████████| 43/43 [00:00<00:00, 66.57it/s]


Processing user 25 task mixed


100%|██████████| 39/39 [00:00<00:00, 66.41it/s]


Processing user 19 task mixed


100%|██████████| 38/38 [00:00<00:00, 69.04it/s]


Processing user 18 task mixed


100%|██████████| 40/40 [00:00<00:00, 63.98it/s]


Processing user 24 task mixed


100%|██████████| 48/48 [00:00<00:00, 68.53it/s]


Processing user 15 task type-based


100%|██████████| 10/10 [00:00<00:00, 63.27it/s]


Processing user 14 task type-based


100%|██████████| 6/6 [00:00<00:00, 69.58it/s]


Processing user 16 task type-based


100%|██████████| 19/19 [00:00<00:00, 62.81it/s]


Processing user 17 task type-based


100%|██████████| 11/11 [00:00<00:00, 48.89it/s]


Processing user 13 task type-based


100%|██████████| 14/14 [00:00<00:00, 58.79it/s]


Processing user 12 task type-based


100%|██████████| 11/11 [00:00<00:00, 54.39it/s]


Processing user 10 task type-based


100%|██████████| 11/11 [00:00<00:00, 56.92it/s]


Processing user 11 task type-based


100%|██████████| 5/5 [00:00<00:00, 64.91it/s]


Processing user 9 task type-based


100%|██████████| 8/8 [00:00<00:00, 58.89it/s]


Processing user 8 task type-based


100%|██████████| 8/8 [00:00<00:00, 58.28it/s]


Processing user 5 task type-based


100%|██████████| 8/8 [00:00<00:00, 61.50it/s]


Processing user 4 task type-based


100%|██████████| 6/6 [00:00<00:00, 62.63it/s]


Processing user 6 task type-based


100%|██████████| 12/12 [00:00<00:00, 60.90it/s]


Processing user 7 task type-based


100%|██████████| 6/6 [00:00<00:00, 61.49it/s]


Processing user 3 task type-based


100%|██████████| 8/8 [00:00<00:00, 58.18it/s]


Processing user 2 task type-based


100%|██████████| 12/12 [00:00<00:00, 60.73it/s]


Processing user 1 task type-based


100%|██████████| 6/6 [00:00<00:00, 67.49it/s]


Processing user 20 task type-based


100%|██████████| 13/13 [00:00<00:00, 60.54it/s]


Processing user 21 task type-based


100%|██████████| 13/13 [00:00<00:00, 56.35it/s]


Processing user 23 task type-based


100%|██████████| 16/16 [00:00<00:00, 64.35it/s]


Processing user 22 task type-based


100%|██████████| 21/21 [00:00<00:00, 64.85it/s]


Processing user 19 task type-based


100%|██████████| 11/11 [00:00<00:00, 60.49it/s]


Processing user 18 task type-based


100%|██████████| 12/12 [00:00<00:00, 60.49it/s]


In [11]:
af_results

Unnamed: 0,participant_id,task,ncp-1,ncp-5,ncp-10,ncp-20,ncp-50,ncp-100,rank
0,15,geo-based,0.037736,0.132075,0.301887,0.415094,0.584906,0.698113,"[1293, 1170, 10, 1260, 8, 8, 24, 14, 11, 0, 8,..."
1,14,geo-based,0.044444,0.155556,0.311111,0.355556,0.533333,0.733333,"[11, 1512, 3, 12, 0, 11, 1807, 1260, 12, 366, ..."
2,28,geo-based,0.021739,0.152174,0.304348,0.456522,0.608696,0.630435,"[1294, 1173, 2, 2, 5, 1812, 1264, 5, 519, 20, ..."
3,16,geo-based,0.069767,0.162791,0.255814,0.395349,0.627907,0.720930,"[1512, 0, 1808, 14, 0, 11, 1261, 14, 9, 512, 2..."
4,17,geo-based,0.020408,0.102041,0.244898,0.469388,0.673469,0.714286,"[1294, 1173, 1263, 13, 1813, 6, 10, 378, 29, 3..."
...,...,...,...,...,...,...,...,...,...
73,21,type-based,0.083333,0.250000,0.416667,0.500000,0.500000,0.500000,"[861, 1471, 599, 1866, 0, 1304, 3, 1149, 0, 6,..."
74,23,type-based,0.200000,0.533333,0.600000,0.600000,0.600000,0.600000,"[861, 599, 1471, 1866, 4, 1, 0, 0, 1, 4, 472, ..."
75,22,type-based,0.150000,0.450000,0.650000,0.750000,0.750000,0.750000,"[466, 2, 591, 0, 1822, 0, 1473, 3, 861, 9, 4, ..."
76,19,type-based,0.100000,0.300000,0.500000,0.700000,0.700000,0.700000,"[1, 1302, 0, 1147, 1473, 7, 861, 1309, 8, 2]"


In [8]:
# success rate for predicting next click in the top-k
ncp = predicted.sum()/len(predicted)
ncp

1      0.000000
5      0.022222
10     0.111111
20     0.155556
50     0.222222
100    0.400000
dtype: float64

In [None]:
model.normalizer

In [None]:
model.normalizer.fit_transform(model.underlying_data_w_probability['x_disc__importance'].to_numpy().reshape(-1, 1)).ravel()