In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd

In [3]:
df = pd.read_csv('../data/adult.csv')

In [4]:
df.groupby('education').agg(ratio=('high_income', 'mean'), cnt=('high_income', 'count')).sort_values('ratio')

Unnamed: 0_level_0,ratio,cnt
education,Unnamed: 1_level_1,Unnamed: 2_level_1
Preschool,0.012048,83
1st-4th,0.032389,247
11th,0.050773,1812
5th-6th,0.053045,509
9th,0.054233,756
10th,0.062635,1389
7th-8th,0.064921,955
12th,0.073059,657
HS-grad,0.158578,15784
Some-college,0.189649,10878


In [105]:
import sys
sys.path.append('../')

import pandas as pd
import pprint

from privex.components.basic import Schema, Dataset, GroupbyQuery, Question
from privex.components.utils import generate_explanation_predicates
from privex.framework.solution import ExplanationSession

import logging
logger = logging.getLogger(__name__)

In [106]:
df = pd.read_csv('../data/adult.csv')
schema = Schema.from_json('../data/adult.json')
dataset = Dataset(df, schema)
gamma = 0.95
attributes = ['education', 'occupation', 'age', 'relationship', 'race', 'workclass', 'sex', 'native-country']
predicates = generate_explanation_predicates(attributes, schema, strategy='1-way marginal')
#predicates = predicates[:10]
#predicates += generate_explanation_predicates(attributes, schema, strategy='2-way marginal')
es = ExplanationSession(dataset, gamma, predicates)

In [114]:
# Phase 1
groupby_query = GroupbyQuery(
    agg = 'AVG',
    attr_agg = 'high_income',
    predicate = None,
    attr_group = 'marital-status',
    schema = schema
)
rho_query = 0.1
es.phase_1_submit_query(groupby_query, rho_query, random_seed = 152636)
print(f'submiited queries with rho = {rho_query}')
nr = es.phase_1_show_query_results()
nr['group'] = nr['group'].apply(lambda row: row[0])
gt = df.groupby('marital-status').agg(answer=('high_income', 'mean')).reset_index()
gt = gt.rename(columns={'marital-status': 'group'})
print(nr.merge(gt, on='group').rename(columns={'answer_x':'answer', 'answer_y': 'truth (hidden)'}).sort_values('truth (hidden)'))

submiited queries with rho = 0.1
                   group    answer  truth (hidden)
0          Never-married  0.045511        0.045480
1              Separated  0.064712        0.064706
2                Widowed  0.082854        0.084321
3  Married-spouse-absent  0.089988        0.092357
4               Divorced  0.101578        0.101161
6      Married-AF-spouse  0.463193        0.378378
5     Married-civ-spouse  0.446021        0.446133


In [117]:
# Phase 2
#question = Question.from_group_comparison(groupby_query, 'Prof-school', 'Doctorate')
question = Question.from_group_comparison(groupby_query, ('Married-AF-spouse',), ('Married-civ-spouse',))
es.phase_2_submit_question(question)
es.phase_2_prepare_question_ci()
ci = es.phase_2_show_question_ci()
point = es.phase_2_show_question_point()
print('question: ', question.to_natural_language())
print('The noisy group difference is ', point)
print(f'The {gamma*100:.0f}% confidence interval of the difference is ', ci)

2022-04-25 21:38:27,706 INFO     [image.py:20] answers: [9982.926929253796, 22382.183312771773, 15.765188211211795, 34.03588831851936]
2022-04-25 21:38:27,707 INFO     [image.py:21] sigmas: [3.162277660168379, 3.162277660168379, 3.162277660168379, 3.162277660168379]
2022-04-25 21:38:27,708 INFO     [image.py:22] bounds: [(9975.028491030382, 9990.82536747721), (22374.284874548357, 22390.08175099519), (7.866749987797281, 23.66362643462631), (26.137450095104846, 41.93432654193388)]
2022-04-25 21:38:27,708 INFO     [image.py:23] manual ci: (-0.4598422551149035, 0.2589346986687564)
2022-04-25 21:38:27,708 INFO     [image.py:24] fun ci l: 0.4598422551149035
2022-04-25 21:38:27,709 INFO     [image.py:25] fun ci u: -0.2589346986687564
question:  Why AVG(high_income) WHERE `marital-status` == "Married-AF-spouse" >= AVG(high_income) WHERE `marital-status` == "Married-civ-spouse"?
The noisy group difference is  0.01717194739104899
The 95% confidence interval of the difference is  (-0.258934698668

In [118]:
# Phase 2
#question = Question.from_group_comparison(groupby_query, 'Prof-school', 'Doctorate')
question = Question.from_group_comparison(groupby_query, ('Married-civ-spouse',), ('Never-married',))
es.phase_2_submit_question(question)
es.phase_2_prepare_question_ci()
ci = es.phase_2_show_question_ci()
point = es.phase_2_show_question_point()
print('question: ', question.to_natural_language())
print('The noisy group difference is ', point)
print(f'The {gamma*100:.0f}% confidence interval of the difference is ', ci)

2022-04-25 21:39:27,605 INFO     [image.py:20] answers: [733.5750581061126, 16118.723030353407, 9982.926929253796, 22382.183312771773]
2022-04-25 21:39:27,606 INFO     [image.py:21] sigmas: [3.162277660168379, 3.162277660168379, 3.162277660168379, 3.162277660168379]
2022-04-25 21:39:27,606 INFO     [image.py:22] bounds: [(725.6766198826981, 741.473496329527), (16110.824592129993, 16126.621468576821), (9975.028491030382, 9990.82536747721), (22374.284874548357, 22390.08175099519)]
2022-04-25 21:39:27,607 INFO     [image.py:23] manual ci: (-0.4015329300800539, -0.39948772348837247)
2022-04-25 21:39:27,607 INFO     [image.py:24] fun ci l: 0.4015329300800539
2022-04-25 21:39:27,608 INFO     [image.py:25] fun ci u: 0.39948772348837247
question:  Why AVG(high_income) WHERE `marital-status` == "Married-civ-spouse" >= AVG(high_income) WHERE `marital-status` == "Never-married"?
The noisy group difference is  0.4005103977536667
The 95% confidence interval of the difference is  (0.3994877234883724

In [121]:
# Phase 3
k = 5
logger.debug(f'Length of predicates is {len(predicates)}')
rho_expl = 2.0
es.phase_3_submit_explanation_request()
es.phase_3_prepare_explanation(k, rho_expl, split_factor = 0.9, random_seed = 12532) # 12535, 12534
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 
'display.max_colwidth', 1000, 'display.width', 1000):
    t1, t2 = es.phase_3_show_explanation_table()
    print(t1)
    print(t2)

2022-04-25 21:46:48,546 DEBUG    [<ipython-input-121-44759e5bb3a5>:3] Length of predicates is 103
(`marital-status` == "Never-married")or(`marital-status` == "Married-civ-spouse")
2022-04-25 21:46:48,553 INFO     [influence_function.py:24] Dataset relative to the question has length 38496
2022-04-25 21:46:51,921 DEBUG    [meta_explanation_session.py:169] 103 predicates and their influences & scores have been loaded.
2022-04-25 21:46:51,927 DEBUG    [meta_explanation_session.py:173] 
        influence       score
count  103.000000  103.000000
mean     0.000325    7.266483
std      0.006527  146.071226
min     -0.026839 -600.661181
25%     -0.000606  -13.551499
50%     -0.000017   -0.385618
75%      0.000368    8.235890
max      0.024788  554.765178
2022-04-25 21:46:51,927 DEBUG    [meta_explanation_session.py:228] total rho_expl is 2.0
2022-04-25 21:46:51,928 DEBUG    [meta_explanation_session.py:229] rho_topk is 0.05
2022-04-25 21:46:51,928 DEBUG    [meta_explanation_session.py:230] rh

  0%|          | 0/5 [00:00<?, ?it/s]

2022-04-25 21:46:52,14 INFO     [image.py:20] answers: [722.2537090314613, 16090.066967612667, 9983.092740040667, 22385.536862739926, 570.7425159572181, 14862.130647988655, 7533.290233347423, 18813.938333082, 14853.887366327654, 22378.394185048815]
2022-04-25 21:46:52,15 INFO     [image.py:21] sigmas: [22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898]
2022-04-25 21:46:52,15 INFO     [image.py:22] bounds: [(659.4865258199197, 785.0208922430029), (16027.299784401126, 16152.834150824208), (9920.325556829126, 10045.859923252208), (22322.769679528385, 22448.304045951467), (507.9753327456765, 633.5096991687597), (14799.363464777114, 14924.897831200196), (7470.523050135881, 7596.057416558965), (18751.171149870457, 18876.70551629354), (14791.120183116112, 14916.654549539195), (22315.627001837274, 22441.161368260357)]


 20%|██        | 1/5 [00:00<00:00,  7.37it/s]

2022-04-25 21:46:52,147 INFO     [image.py:20] answers: [714.952487135132, 16112.439500285365, 10006.407637080138, 22350.816775206546, 542.8807557459877, 15008.161092614846, 6586.3357587604505, 16156.944597402773, 14998.446838380078, 22368.766453377506]
2022-04-25 21:46:52,148 INFO     [image.py:21] sigmas: [22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898]
2022-04-25 21:46:52,148 INFO     [image.py:22] bounds: [(652.1853039235904, 777.7196703466735), (16049.672317073824, 16175.206683496906), (9943.640453868597, 10069.174820291679), (22288.049591995004, 22413.583958418087), (480.1135725344461, 605.6479389575293), (14945.393909403305, 15070.928275826387), (6523.5685755489085, 6649.102941971993), (16094.177414191232, 16219.711780614314), (14935.679655168537, 15061.214021591619), (22305.999270165965, 22431.533636589047)]


 40%|████      | 2/5 [00:00<00:00,  7.37it/s]

2022-04-25 21:46:52,282 INFO     [image.py:20] answers: [735.9024554116946, 16159.269727329565, 10010.713837861236, 22390.1349666724, 621.0875429000112, 9371.666433988665, 9977.673136607982, 22209.508422796538, 9408.004970579279, 22419.314631940488]
2022-04-25 21:46:52,282 INFO     [image.py:21] sigmas: [22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898]
2022-04-25 21:46:52,282 INFO     [image.py:22] bounds: [(673.135272200153, 798.6696386232362), (16096.502544118024, 16222.036910541106), (9947.946654649695, 10073.481021072777), (22327.36778346086, 22452.902149883943), (558.3203596884696, 683.8547261115527), (9308.899250777124, 9434.433617200206), (9914.905953396441, 10040.440319819523), (22146.741239584997, 22272.27560600808), (9345.237787367738, 9470.77215379082), (22356.547448728947, 22482.08181515203)]


 60%|██████    | 3/5 [00:00<00:00,  7.48it/s]

2022-04-25 21:46:52,409 INFO     [image.py:20] answers: [751.0709606117048, 16134.671690991414, 9993.622225730329, 22410.34973659599, 482.6007130785633, 14250.721422348328, 7716.128869683152, 19180.735739132382, 14263.85816210034, 22394.267978125274]
2022-04-25 21:46:52,410 INFO     [image.py:21] sigmas: [22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898]
2022-04-25 21:46:52,410 INFO     [image.py:22] bounds: [(688.3037774001632, 813.8381438232464), (16071.904507779873, 16197.438874202955), (9930.855042518788, 10056.38940894187), (22347.582553384447, 22473.11691980753), (419.8335298670217, 545.3678962901049), (14187.954239136787, 14313.488605559869), (7653.36168647161, 7778.896052894694), (19117.96855592084, 19243.502922343923), (14201.090978888798, 14326.62534531188), (22331.500794913733, 22457.035161336815)]


 80%|████████  | 4/5 [00:00<00:00,  7.21it/s]

2022-04-25 21:46:52,562 INFO     [image.py:20] answers: [719.5966867503388, 16133.512999445606, 9964.637048679951, 22388.82877164617, 475.26853064895295, 13425.161234748586, 7228.242476701742, 18229.164277836004, 13436.090666158836, 22363.02084102227]
2022-04-25 21:46:52,563 INFO     [image.py:21] sigmas: [22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898, 22.360679774997898]
2022-04-25 21:46:52,563 INFO     [image.py:22] bounds: [(656.8295035387972, 782.3638699618804), (16070.745816234064, 16196.280182657147), (9901.86986546841, 10027.404231891493), (22326.06158843463, 22451.595954857712), (412.50134743741137, 538.0357138604945), (13362.394051537045, 13487.928417960127), (7165.4752934902, 7291.009659913284), (18166.397094624463, 18291.931461047545), (13373.323482947295, 13498.857849370377), (22300.25365781073, 22425.78802423381)]


100%|██████████| 5/5 [00:00<00:00,  7.14it/s]

2022-04-25 21:46:52,633 DEBUG    [meta_explanation_session.py:269] [(0.01441871580997527, 0.037601467620713515), (0.00925922316268419, 0.03361094316693102), (-0.0003134454673210007, 0.01608551026730573), (0.008665998566900706, 0.030942732893514582), (0.012872808380316134, 0.034573826883309074)]
2022-04-25 21:46:52,633 INFO     [meta_explanation_session.py:275] computing rank ci



100%|██████████| 5/5 [00:00<00:00, 7601.13it/s]

2022-04-25 21:46:52,635 DEBUG    [meta_explanation_session.py:288] [(1, 6), (1, 9), (1, 92), (1, 16), (1, 8)]
                          predicates Rel Inf 90-CI L Rel Inf 90-CI R  Rnk 95-CI L  Rnk 95-CI R
0  `occupation` == "Exec-managerial"           3.60%           9.39%            1            6
1         `education` == "Bachelors"           3.21%           8.63%            1            8
2                `age` == "(40, 50]"           2.31%           8.39%            1            9
3   `occupation` == "Prof-specialty"           2.16%           7.73%            1           16
4      `relationship` == "Own-child"          -0.08%           4.02%            1           92
                          predicates  Inf 95-CI L  Inf 95-CI R  Rnk 95-CI L  Rnk 95-CI R
0  `occupation` == "Exec-managerial"     0.014419     0.037601            1            6
1         `education` == "Bachelors"     0.012873     0.034574            1            8
2                `age` == "(40, 50]"     0.009259    




In [122]:
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 
'display.max_colwidth', 1000, 'display.width', 1000):
    t1, t2 = es.phase_3_show_explanation_table()
    print(t1)
    print(t2)

                          predicates Rel Inf 90-CI L Rel Inf 90-CI R  Rnk 95-CI L  Rnk 95-CI R
0  `occupation` == "Exec-managerial"           3.60%           9.39%            1            6
1                `age` == "(40, 50]"           2.31%           8.39%            1            9
2      `relationship` == "Own-child"          -0.08%           4.02%            1           92
3   `occupation` == "Prof-specialty"           2.16%           7.73%            1           16
4         `education` == "Bachelors"           3.21%           8.63%            1            8
                          predicates  Inf 95-CI L  Inf 95-CI R  Rnk 95-CI L  Rnk 95-CI R
0  `occupation` == "Exec-managerial"     0.014419     0.037601            1            6
1                `age` == "(40, 50]"     0.009259     0.033611            1            9
2      `relationship` == "Own-child"    -0.000313     0.016086            1           92
3   `occupation` == "Prof-specialty"     0.008666     0.030943            

In [104]:
k = 5
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 
'display.max_colwidth', 1000, 'display.width', 1000):
    print(es.phase_3_true_top_k(k))

Positive Influences:  13
                                topk rel-influence
0  `occupation` == "Exec-managerial"         6.19%
1         `education` == "Bachelors"         6.11%
2                `age` == "(40, 50]"         5.59%
3   `occupation` == "Prof-specialty"         4.84%
4           `education` == "Masters"         2.81%


In [None]:
for p in es.topk_explanation_predicates:
    x = es.predicates_with_influences_and_scores[p]
    print(f"{x['score']:.0f},")
    print(x['score'] / x['influence'])

In [138]:
for p in es.topk_explanation_predicates:
    influ = es.predicates_with_influences_and_scores[p]['influence']
    rank = es.sorted_predicates.tolist().index(p) + 1
    print(influ, rank)

0.024788435119158524 1
0.02239452141248067 3
0.010037808294482172 6
0.019403818136007892 4
0.02445980262099675 2


In [137]:
x = [(659.4865258199197, 785.0208922430029), (16027.299784401126, 16152.834150824208), (9920.325556829126, 10045.859923252208), (22322.769679528385, 22448.304045951467), (507.9753327456765, 633.5096991687597), (14799.363464777114, 14924.897831200196), (7470.523050135881, 7596.057416558965), (18751.171149870457, 18876.70551629354), (14791.120183116112, 14916.654549539195), (22315.627001837274, 22441.161368260357)]
for i, c in enumerate(x):
    print(f'$\I_{{{i+1}}} = ({c[0]:.0f}, {c[1]:.0f})$, ')

$\I_{1} = (659, 785)$, 
$\I_{2} = (16027, 16153)$, 
$\I_{3} = (9920, 10046)$, 
$\I_{4} = (22323, 22448)$, 
$\I_{5} = (508, 634)$, 
$\I_{6} = (14799, 14925)$, 
$\I_{7} = (7471, 7596)$, 
$\I_{8} = (18751, 18877)$, 
$\I_{9} = (14791, 14917)$, 
$\I_{10} = (22316, 22441)$, 
