-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_epsilon_utility_efficiency.py
102 lines (90 loc) · 4.05 KB
/
test_epsilon_utility_efficiency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import math
import pandas as pd
import numpy as np
from plotting_tools import draw_epsilon_MSE, draw_epsilon_construction_time, draw_epsilon_query_time
from PM import piecewise_l1_l2_kernel_kde
from GI import gi_l2kernel_kde
from RACE import count_race_l2
from DM import duchi_l1_l2_kernel_kde
from SW import square_wave_l1_l2_kernel_kde
from mLDP_KDE import mldp_kde_l2kernel_kde
from evaluation import MSE
from kde_tools import l2kernel_kde
from parameters import dataset_parameters
''' Select dataset '''
datasets = ['CodRNA', 'CovType', 'RCV1', 'Yelp', 'SYN']
selected_flag = 0 # 0: CodRNA, 1:CovType, 2:RCV1, 3: Yelp, 4: SYN
nearest_flag = 100
''' Initialize '''
params = dataset_parameters[datasets[selected_flag]]
r_set = params['r_set']
m = params['m']
n = params['n'] - 100
omega = params['omega']
seed_l2lsh = params['seed_l2lsh']
seed_grr_rehash = params['seed_grr_rehash']
L_R_set = params[f'L_R_set_{nearest_flag}nearest']
const_file = f"small_datasets/{datasets[selected_flag]}_const.csv"
query_file = f"small_datasets/{datasets[selected_flag]}_query.csv"
const_data = pd.read_csv(const_file, sep=',', lineterminator='\n', header=None)
const_data = const_data.values
query_data = pd.read_csv(query_file, sep=',', lineterminator='\n', header=None)
query_data = query_data.values
N = const_data.shape[0]
epsilon = [1, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20]
''' accurate kde values'''
acc_kde_vals = l2kernel_kde(query_data, const_data, N, omega)
''' RACE '''
race_kde_value_set = []
race_ctime_set = []
race_qtime_set = []
for temp_seed in seed_l2lsh:
race_kde_value, ctime, qtime = count_race_l2(query_data, const_data, m, omega, N, temp_seed, 1000, 100)
race_kde_value_set.append(race_kde_value)
race_ctime_set.append(ctime)
race_qtime_set.append(qtime)
mse_set = [MSE(acc_kde_vals, race_kde_value_set[i]) for i in range(len(seed_l2lsh))]
race_mse = np.average(mse_set)
race_ctime = np.average(race_ctime_set)
race_qtime = np.average(race_qtime_set)
''' DM-KDE, PM-KDE, SW-KDE, GI-KDE '''
def calc_kde_values(epsilon, kde_function, *args):
mse_vals = []
ctime_vals = []
qtime_vals = []
for e in epsilon:
kde_val, ctime, qtime = kde_function(e, *args)
mse_vals.append(MSE(acc_kde_vals, kde_val))
ctime_vals.append(ctime)
qtime_vals.append(qtime)
return mse_vals, ctime_vals, qtime_vals
# DM-KDE
dm_mse, dm_ctime, dm_qtime = calc_kde_values(epsilon, duchi_l1_l2_kernel_kde, query_data, const_data, m, N, omega)
# PM-KDE
pm_mse, pm_ctime, pm_qtime = calc_kde_values(epsilon, piecewise_l1_l2_kernel_kde, query_data, const_data, m, N, omega)
# SW-KDE
sw_mse, sw_ctime, sw_qtime = calc_kde_values(epsilon, square_wave_l1_l2_kernel_kde, query_data, const_data, m, N, omega)
# GI-KDE
gi_mse, gi_ctime, gi_qtime = calc_kde_values(epsilon, gi_l2kernel_kde, query_data, const_data, m, N, omega)
''' mLDP-KDE '''
mldp_kde_mse = []
mldp_kde_ctime = []
mldp_kde_qtime = []
for index, e in enumerate(epsilon):
L = L_R_set[index][0]
R = L_R_set[index][1]
mldp_kde_mse_sum = 0
mldp_kde_ctime_sum = 0
mldp_kde_qtime_sum = 0
for temp_seed_l2lsh, temp_seed_grr_rehash in zip(seed_l2lsh, seed_grr_rehash):
l2lsh_race_kde, ctime, qtime, _ = mldp_kde_l2kernel_kde(query_data, e, const_data, L, R, m, omega, N, r_set[int(math.log10(nearest_flag))],
temp_seed_l2lsh, temp_seed_grr_rehash)
mldp_kde_mse_sum += MSE(acc_kde_vals, l2lsh_race_kde)
mldp_kde_ctime_sum += ctime
mldp_kde_qtime_sum += qtime
mldp_kde_mse.append(mldp_kde_mse_sum / len(seed_l2lsh))
mldp_kde_ctime.append(mldp_kde_ctime_sum / len(seed_l2lsh))
mldp_kde_qtime.append(mldp_kde_qtime_sum / len(seed_l2lsh))
draw_epsilon_MSE(epsilon, race_mse, pm_mse, dm_mse, sw_mse, gi_mse, mldp_kde_mse, datasets[selected_flag])
draw_epsilon_construction_time(epsilon, mldp_kde_ctime, race_ctime, gi_ctime, pm_ctime, dm_ctime, sw_ctime, datasets[selected_flag])
draw_epsilon_query_time(epsilon, mldp_kde_qtime, race_qtime, gi_qtime, pm_qtime, dm_qtime, sw_qtime, datasets[selected_flag])