In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

import pandas as pd
import pprint

from privex.components.basic import Schema, Dataset, Predicate, GroupbyQuery, Question
from privex.components.utils import generate_explanation_predicates
from privex.framework.solution import ExplanationSession

import logging
logger = logging.getLogger(__name__)

In [5]:
df = pd.read_csv('../data/taxi/taxi_speed.csv')
schema = Schema.from_json('../data/taxi/taxi.json')
dataset = Dataset(df, schema)
gamma = 0.95
attributes = ['PU_Zone', 'PU_Hour', 'PU_WeekDay', 'DO_Zone', 'DO_Hour', 'DO_WeekDay']
predicates = generate_explanation_predicates(attributes, schema, strategy='1-way marginal')
es = ExplanationSession(dataset, gamma, predicates)

In [6]:
# Phase 1
groupby_query = GroupbyQuery(
    agg = 'AVG',
    attr_agg = 'trip_speed',
    predicate = Predicate('PU_Borough == "Manhattan"'),
    attr_group = 'PU_Zone',
    schema = schema
)
rho_query = 0.1
es.phase_1_submit_query(groupby_query, rho_query)
print(f'submiited queries with rho = {rho_query}')
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 
'display.max_colwidth', 1000):
    print(es.phase_1_show_query_results())

submiited queries with rho = 0.1
                                                group         answer
80                            (Central Harlem North,)   -4319.896087
123                     (Spuyten Duyvil/Kingsbridge,)   -3480.387191
35                                  (West Concourse,)   -2150.345142
162                          (Soundview/Castle Hill,)   -1652.713143
252                                  (Arden Heights,)   -1527.558107
150                      (Springfield Gardens North,)   -1425.793523
233                             (Van Cortlandt Park,)   -1394.898351
31                               (LaGuardia Airport,)   -1232.355116
83                                (Prospect Heights,)   -1109.785122
38                                 (Lenox Hill West,)   -1032.422344
111                                   (North Corona,)    -647.273940
250                                    (Westerleigh,)    -549.418110
194                                     (Mount Hope,)    -530.148910
1

In [7]:
# Phase 2
def weight_mapping(group):
    if group == ('SoHo', ):
        return -1
    if group == ('Financial District North',):
        return 1
    return 0
group_weights = {
    group: weight_mapping(group)
    for group in groupby_query.groups
}
question = Question.from_group_weights(groupby_query, group_weights)
es.phase_2_submit_question(question)
es.phase_2_prepare_question_ci()
ci = es.phase_2_show_question_ci()
print('question: ', question.to_natural_language())
print(f'The {gamma*100:.0f}% confidence interval of the difference is ', ci)

question:  Why AVG(trip_speed) WHERE (PU_Borough == "Manhattan") and (`PU_Zone` == "Financial District North") >= AVG(trip_speed) WHERE (PU_Borough == "Manhattan") and (`PU_Zone` == "SoHo")?
The 95% confidence interval of the difference is  (4.401676510348343, 4.426830840709455)


In [10]:
# Phase 3
k = 5
logger.debug(f'Length of predicates is {len(predicates)}')
rho_expl = 2.0
es.phase_3_submit_explanation_request()
es.phase_3_prepare_explanation(k, rho_expl)
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 
'display.max_colwidth', 1000, 'display.width', None, 'display.expand_frame_repr', False):
    t1, t2 = es.phase_3_show_explanation_table()
    print(t1)
    print(t2)

2022-04-09 18:05:54,324 DEBUG    [<ipython-input-10-e94347c3c07d>:3] Length of predicates is 547
2022-04-09 18:07:09,490 DEBUG    [meta_explanation_session.py:143] 547 predicates and their influences & scores have been loaded.
2022-04-09 18:07:09,497 DEBUG    [meta_explanation_session.py:147] 
        influence         score
count  547.000000    547.000000
mean    -0.000037     -4.169963
std      0.028455   3167.496555
min     -0.246512 -27441.011196
25%      0.000000      0.000000
50%      0.000000      0.000000
75%      0.000278     30.905995
max      0.162620  18102.379998
2022-04-09 18:07:09,497 DEBUG    [meta_explanation_session.py:185] total rho_expl is 2.0
2022-04-09 18:07:09,497 DEBUG    [meta_explanation_session.py:186] rho_topk is 0.05
2022-04-09 18:07:09,498 DEBUG    [meta_explanation_session.py:187] rho_ci is 0.05
2022-04-09 18:07:09,498 DEBUG    [meta_explanation_session.py:188] rho_rank is 1.9
2022-04-09 18:07:09,498 INFO     [meta_explanation_session.py:193] computing to

100%|██████████| 5/5 [00:23<00:00,  4.61s/it]

2022-04-09 18:07:32,570 DEBUG    [meta_explanation_session.py:214] [(0.06439433691150634, 0.13460381909394473), (-0.04341607149777124, 0.018468472513169278), (-0.04257728064275706, 0.016147274849365397), (-0.018473810318142686, 0.048265919102153616), (0.15204410473907315, 0.20593061808693097)]
2022-04-09 18:07:32,571 INFO     [meta_explanation_session.py:217] computing rank ci



100%|██████████| 5/5 [00:00<00:00, 5253.39it/s]

2022-04-09 18:07:32,573 DEBUG    [meta_explanation_session.py:229] [(1, 513), (1, 547), (1, 547), (1, 547), (1, 539)]
                                        predicates Rel Inf 90-CI L Rel Inf 90-CI R  Rnk 95-CI L  Rnk 95-CI R
0                          `DO_Hour` == "[20, 24)"           3.45%           4.67%            1          539
1                       `DO_Zone` == "JFK Airport"           1.46%           3.05%            1          513
2  `PU_Zone` == "East Concourse/Concourse Village"          -0.42%           1.09%            1          547
3          `PU_Zone` == "Williamsbridge/Olinville"          -0.98%           0.42%            1          547
4                           `DO_Zone` == "Bayside"          -0.96%           0.37%            1          547
                                        predicates  Inf 95-CI L  Inf 95-CI R  Rnk 95-CI L  Rnk 95-CI R
0                          `DO_Hour` == "[20, 24)"     0.152044     0.205931            1          539
1                     




In [9]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 
'display.max_colwidth', 1000, 'display.width', None, 'display.expand_frame_repr', False):
    t1, t2 = es.phase_3_show_explanation_table()
    print(t1)
    print(t2)

                               predicates Rel Inf 90-CI L Rel Inf 90-CI R  Rnk 95-CI L  Rnk 95-CI R
0             `DO_Zone` == "Borough Park"          -0.56%           1.46%            1          547
1             `PU_Zone` == "Forest Hills"          -0.49%           1.23%            1          547
2          `DO_Zone` == "Oakland Gardens"          -1.17%           0.78%            1          547
3             `PU_Zone` == "West Village"          -1.39%           0.64%            1          547
4  `PU_Zone` == "Bloomfield/Emerson Hill"          -2.38%          -0.36%            1          547
                               predicates  Inf 95-CI L  Inf 95-CI R  Rnk 95-CI L  Rnk 95-CI R
0             `DO_Zone` == "Borough Park"    -0.024636     0.064529            1          547
1             `PU_Zone` == "Forest Hills"    -0.021484     0.054467            1          547
2          `DO_Zone` == "Oakland Gardens"    -0.051708     0.034498            1          547
3             `PU_Zone` 