In [1]:
import sys
sys.path.append('../implementation/')
import ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.special as sp
from tqdm import tqdm
import time
from healey_adaboost_naive_bayes import AdaBoostNB
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Loading the STL Crimes underlying data and user interaction data
underlying_data = pd.read_csv('../data/stl_crimes/dots.csv')
underlying_data.set_index('id', drop=True, inplace=True)

# this is the main trick; not sure why, but categories should be in the range [0, n-1], not [1, n]
underlying_data['type'] = underlying_data['type'] - 1
output_file_path = '../output/stl/stl_map_results_ada_nb.pkl'

interaction_data = pd.read_csv('../data/stl_crimes/stl_combined_interactions.csv')
interaction_data['interaction_session'] = interaction_data.apply(lambda row: ast.literal_eval(row.interaction_session), axis=1)
interaction_data['interaction_type_session'] = interaction_data.apply(lambda row: ast.literal_eval(row.interaction_type_session), axis=1)
ks = [1, 5, 10, 20, 50, 100]

In [3]:
stl_map_results = pd.DataFrame()

for participant_index, row in interaction_data.iterrows():
    print(f'Processing user {row.user} task {row.task}')
    results = {'participant_id': row.user, 'task': row.task}
    model = AdaBoostNB(underlying_data, [['x','y']], ['type'])
    predicted = pd.DataFrame()
    rank_predicted = []
    for i in tqdm(range(len(interaction_data.iloc[participant_index].interaction_session))):
        interaction = interaction_data.iloc[participant_index].interaction_session[i]
        model.update(interaction)

        if i < len(interaction_data.iloc[participant_index].interaction_session) - 1:
            probability_of_next_point = model.predict()
            next_point = interaction_data.iloc[participant_index].interaction_session[i+1]
            predicted_next_dict = {}
            for k in ks:
                predicted_next_dict[k] = (next_point in probability_of_next_point.nlargest(k).index.values)
            predicted = predicted.append(predicted_next_dict, ignore_index=True)
            sorted_prob = probability_of_next_point.sort_values(ascending=False)
            rank, = np.where(sorted_prob.index.values == next_point)
            rank_predicted.append(rank[0] + 1)
            
    ncp = predicted.sum()/len(predicted)
    for col in ncp.index:
        results[f'ncp-{col}'] = ncp[col]
    results['rank'] = rank_predicted    
    stl_map_results = stl_map_results.append(results, ignore_index=True)
    
stl_map_results.to_pickle(output_file_path)

Processing user 15 task geo-based


100%|██████████| 54/54 [00:04<00:00, 11.00it/s]


Processing user 14 task geo-based


100%|██████████| 46/46 [00:04<00:00,  9.77it/s]


Processing user 28 task geo-based


100%|██████████| 47/47 [00:04<00:00,  9.60it/s]


Processing user 16 task geo-based


100%|██████████| 44/44 [00:05<00:00,  8.63it/s]


Processing user 17 task geo-based


100%|██████████| 50/50 [00:05<00:00,  8.70it/s]


Processing user 13 task geo-based


100%|██████████| 31/31 [00:03<00:00,  8.12it/s]


Processing user 12 task geo-based


100%|██████████| 34/34 [00:04<00:00,  8.23it/s]


Processing user 10 task geo-based


100%|██████████| 34/34 [00:04<00:00,  8.03it/s]


Processing user 11 task geo-based


100%|██████████| 28/28 [00:03<00:00,  8.40it/s]


Processing user 9 task geo-based


100%|██████████| 29/29 [00:03<00:00,  8.29it/s]


Processing user 8 task geo-based


100%|██████████| 41/41 [00:04<00:00,  8.64it/s]


Processing user 5 task geo-based


100%|██████████| 22/22 [00:02<00:00,  9.02it/s]


Processing user 4 task geo-based


100%|██████████| 32/32 [00:03<00:00,  8.66it/s]


Processing user 6 task geo-based


100%|██████████| 31/31 [00:03<00:00,  8.17it/s]


Processing user 7 task geo-based


100%|██████████| 35/35 [00:04<00:00,  7.87it/s]


Processing user 3 task geo-based


100%|██████████| 32/32 [00:03<00:00,  8.10it/s]


Processing user 2 task geo-based


100%|██████████| 33/33 [00:04<00:00,  7.88it/s]


Processing user 1 task geo-based


100%|██████████| 28/28 [00:04<00:00,  6.95it/s]


Processing user 20 task geo-based


100%|██████████| 46/46 [00:06<00:00,  7.02it/s]


Processing user 21 task geo-based


100%|██████████| 50/50 [00:06<00:00,  8.29it/s]


Processing user 23 task geo-based


100%|██████████| 47/47 [00:04<00:00, 10.18it/s]


Processing user 22 task geo-based


100%|██████████| 49/49 [00:04<00:00, 10.45it/s]


Processing user 26 task geo-based


100%|██████████| 46/46 [00:04<00:00, 10.44it/s]


Processing user 27 task geo-based


100%|██████████| 46/46 [00:04<00:00, 10.48it/s]


Processing user 25 task geo-based


100%|██████████| 57/57 [00:05<00:00, 10.40it/s]


Processing user 19 task geo-based


100%|██████████| 47/47 [00:04<00:00, 10.38it/s]


Processing user 18 task geo-based


100%|██████████| 50/50 [00:04<00:00, 10.44it/s]


Processing user 24 task geo-based


100%|██████████| 48/48 [00:04<00:00, 10.34it/s]


Processing user 15 task mixed


100%|██████████| 34/34 [00:03<00:00, 10.14it/s]


Processing user 14 task mixed


100%|██████████| 39/39 [00:03<00:00, 10.36it/s]


Processing user 16 task mixed


100%|██████████| 38/38 [00:03<00:00, 10.17it/s]


Processing user 17 task mixed


100%|██████████| 35/35 [00:03<00:00,  9.46it/s]


Processing user 13 task mixed


100%|██████████| 57/57 [00:05<00:00,  9.88it/s]


Processing user 12 task mixed


100%|██████████| 121/121 [00:11<00:00, 10.17it/s]


Processing user 10 task mixed


100%|██████████| 99/99 [00:09<00:00, 10.06it/s]


Processing user 11 task mixed


100%|██████████| 96/96 [00:08<00:00, 10.72it/s]


Processing user 9 task mixed


100%|██████████| 89/89 [00:08<00:00, 10.88it/s]


Processing user 8 task mixed


100%|██████████| 94/94 [00:08<00:00, 11.00it/s]


Processing user 5 task mixed


100%|██████████| 101/101 [00:09<00:00, 10.97it/s]


Processing user 4 task mixed


100%|██████████| 118/118 [00:10<00:00, 11.05it/s]


Processing user 6 task mixed


100%|██████████| 99/99 [00:08<00:00, 11.05it/s]


Processing user 7 task mixed


100%|██████████| 108/108 [00:09<00:00, 11.07it/s]


Processing user 3 task mixed


100%|██████████| 87/87 [00:07<00:00, 11.07it/s]


Processing user 2 task mixed


100%|██████████| 90/90 [00:08<00:00, 11.08it/s]


Processing user 1 task mixed


100%|██████████| 75/75 [00:06<00:00, 10.92it/s]


Processing user 20 task mixed


100%|██████████| 47/47 [00:04<00:00, 11.05it/s]


Processing user 21 task mixed


100%|██████████| 43/43 [00:03<00:00, 11.13it/s]


Processing user 23 task mixed


100%|██████████| 17/17 [00:01<00:00, 11.11it/s]


Processing user 22 task mixed


100%|██████████| 37/37 [00:03<00:00, 11.10it/s]


Processing user 26 task mixed


100%|██████████| 38/38 [00:03<00:00, 11.12it/s]


Processing user 27 task mixed


100%|██████████| 43/43 [00:03<00:00, 11.21it/s]


Processing user 25 task mixed


100%|██████████| 39/39 [00:03<00:00, 11.42it/s]


Processing user 19 task mixed


100%|██████████| 38/38 [00:03<00:00, 11.60it/s]


Processing user 18 task mixed


100%|██████████| 40/40 [00:03<00:00, 11.56it/s]


Processing user 24 task mixed


100%|██████████| 48/48 [00:04<00:00, 11.57it/s]


Processing user 15 task type-based


100%|██████████| 10/10 [00:00<00:00, 11.66it/s]


Processing user 14 task type-based


100%|██████████| 6/6 [00:00<00:00, 11.68it/s]


Processing user 16 task type-based


100%|██████████| 19/19 [00:01<00:00, 11.44it/s]


Processing user 17 task type-based


100%|██████████| 11/11 [00:01<00:00, 10.99it/s]


Processing user 13 task type-based


100%|██████████| 14/14 [00:01<00:00, 11.38it/s]


Processing user 12 task type-based


100%|██████████| 11/11 [00:00<00:00, 11.36it/s]


Processing user 10 task type-based


100%|██████████| 11/11 [00:00<00:00, 11.45it/s]


Processing user 11 task type-based


100%|██████████| 5/5 [00:00<00:00, 11.81it/s]


Processing user 9 task type-based


100%|██████████| 8/8 [00:00<00:00, 11.65it/s]


Processing user 8 task type-based


100%|██████████| 8/8 [00:00<00:00, 11.69it/s]


Processing user 5 task type-based


100%|██████████| 8/8 [00:00<00:00, 11.70it/s]


Processing user 4 task type-based


100%|██████████| 6/6 [00:00<00:00, 11.55it/s]


Processing user 6 task type-based


100%|██████████| 12/12 [00:01<00:00, 11.49it/s]


Processing user 7 task type-based


100%|██████████| 6/6 [00:00<00:00, 11.72it/s]


Processing user 3 task type-based


100%|██████████| 8/8 [00:00<00:00, 11.51it/s]


Processing user 2 task type-based


100%|██████████| 12/12 [00:01<00:00, 11.59it/s]


Processing user 1 task type-based


100%|██████████| 6/6 [00:00<00:00, 11.54it/s]


Processing user 20 task type-based


100%|██████████| 13/13 [00:01<00:00, 11.44it/s]


Processing user 21 task type-based


100%|██████████| 13/13 [00:01<00:00, 11.67it/s]


Processing user 23 task type-based


100%|██████████| 16/16 [00:01<00:00, 11.59it/s]


Processing user 22 task type-based


100%|██████████| 21/21 [00:01<00:00, 11.65it/s]


Processing user 19 task type-based


100%|██████████| 11/11 [00:00<00:00, 11.68it/s]


Processing user 18 task type-based


100%|██████████| 12/12 [00:01<00:00, 11.56it/s]


In [11]:
stl_map_results

Unnamed: 0,participant_id,task,ncp-1,ncp-5,ncp-10,ncp-20,ncp-50,ncp-100,rank
0,15,geo-based,0.0,0.037736,0.150943,0.188679,0.245283,0.471698,"[105, 166, 32, 225, 10, 11, 35, 28, 11, 7, 10,..."
1,14,geo-based,0.0,0.000000,0.044444,0.177778,0.288889,0.400000,"[7, 311, 20, 15, 13, 11, 188, 27, 11, 128, 31,..."
2,28,geo-based,0.0,0.021739,0.043478,0.217391,0.369565,0.586957,"[4, 89, 10, 15, 14, 182, 21, 1, 50, 56, 57, 51..."
3,16,geo-based,0.0,0.046512,0.093023,0.093023,0.186047,0.395349,"[37, 4, 491, 31, 23, 9, 41, 37, 10, 230, 48, 4..."
4,17,geo-based,0.0,0.020408,0.061224,0.102041,0.306122,0.489796,"[61, 249, 184, 0, 26, 4, 1, 154, 54, 48, 48, 2..."
...,...,...,...,...,...,...,...,...,...
73,21,type-based,0.0,0.166667,0.583333,0.916667,0.916667,1.000000,"[7, 53, 3, 9, 4, 13, 7, 9, 8, 10, 13, 12]"
74,23,type-based,0.0,0.333333,0.733333,0.866667,0.933333,1.000000,"[3, 85, 3, 5, 4, 11, 6, 4, 2, 9, 2, 8, 13, 12, 8]"
75,22,type-based,0.1,0.350000,0.750000,0.950000,0.950000,0.950000,"[10, 1, 4, 4, 5, 5, 148, 8, 12, 10, 2, 2, 1, 5..."
76,19,type-based,0.1,0.400000,0.600000,0.900000,0.900000,0.900000,"[1, 2, 3, 4, 161, 6, 11, 11, 10, 6]"
