# Prediction gap test

Głównym celem tego notatnika jest przetestowanie naszego algorytmu liczącego prediction gap.

Robimy to dla zadania regresji, ponieważ dla zadania binarnej regresji logistycznej jest taki problem, że na koniec do wyników drzew przykładana jest funkcja sigmoid, co uniemożliwia szybkie analityczne obliczenie wyniku. Przy zadaniu regresji ten problem nie występuje.


## Konfiguracja

In [1]:
import os
while "notebooks" in os.getcwd():
    os.chdir("../")


In [2]:
from pathlib import Path
import pandas as pd

from src.decision_tree.tree import load_trees
from src.decision_tree.prediction_gap import (
    NormalPredictionGap,
    prediction_gap_on_single_feature_perturbation,
    prediction_gap_by_random_sampling,
    prediction_gap_by_exact_calc
)


In [3]:
models_path = Path("models")
data_path = Path("data")
wine_model_name = "winequality_red"
wine_test_data_path = data_path / "wine_quality/test_winequality_red_scaled.csv"
housing_model_name = "housing"
housing_test_data_path = data_path / "housing_data/test_housing_scaled.csv"


In [4]:
stddev = 0.3


## Wczytanie danych i modelu

In [5]:
wine_trees = load_trees(models_path, wine_model_name)


In [6]:
wine_data = pd.read_csv(wine_test_data_path)
wine_data


Unnamed: 0,fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality
0,-0.413454,0.123905,-0.313113,-0.240375,-0.349975,-0.848716,-0.561586,-0.183745,-0.201591,-0.638220,-0.678644,5
1,1.310138,-0.937525,1.638205,-0.240375,1.371576,-0.944346,-0.865676,0.982285,-1.756618,2.312434,-0.960246,5
2,-1.160343,2.526090,-1.340122,-0.382271,-0.647527,-0.083669,-0.409542,-0.989366,1.871778,-1.169337,0.729364,6
3,-0.068735,0.598756,-0.877968,-0.311323,-0.307468,0.872638,0.411500,-0.194345,-0.136798,0.542042,0.447763,6
4,-1.217796,-0.490607,0.611196,-0.027532,-0.222453,-0.944346,-0.987312,-0.634256,1.288643,0.187963,0.541630,6
...,...,...,...,...,...,...,...,...,...,...,...,...
315,0.103624,-0.714066,0.662546,2.668484,-0.796303,-1.231239,-1.108948,-0.575955,-0.201591,-0.579207,1.480302,4
316,1.195232,0.626688,-0.159061,0.185312,0.372651,1.255161,0.198638,1.618302,-0.460762,0.069937,-0.490910,5
317,0.103624,-0.323013,-0.005010,-0.453218,-0.626274,0.203223,-0.257497,-0.830361,-0.979104,1.132173,0.635497,6
318,-0.585813,0.347364,-0.056360,-0.382271,-0.158692,0.107592,1.749495,-0.480552,-0.201591,-0.815259,-0.490910,5


In [7]:
housing_trees = load_trees(models_path, housing_model_name)


In [8]:
housing_data = pd.read_csv(housing_test_data_path)
housing_data


Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value
0,1.307575,-0.862335,-0.686477,-0.108529,0.019293,0.265378,0.111059,-1.045477,77700.0
1,-1.452618,0.987002,1.856182,-0.215333,-0.272609,-0.312139,-0.338825,0.137031,314300.0
2,-0.993418,1.581599,-0.527561,-0.341390,-0.469583,-0.355408,-0.532380,-0.600266,99100.0
3,-0.878618,1.370915,-0.924851,0.445661,0.342046,0.237120,0.435395,-0.294598,109400.0
4,-1.302879,0.982320,1.141059,-0.570583,-0.581123,-0.503761,-0.561151,-0.711385,76400.0
...,...,...,...,...,...,...,...,...,...
4123,-1.123192,0.809091,0.425936,-0.367518,-0.381775,0.021655,-0.398984,-0.079942,161500.0
4124,0.259400,-0.141327,-0.845393,4.616525,5.095534,4.930548,5.368429,-0.487884,87200.0
4125,0.638740,-0.768697,1.141059,-0.336348,-0.331939,-0.240611,-0.357134,-1.017579,112900.0
4126,0.598809,-0.675060,-0.765935,0.116539,0.671918,0.686594,0.537403,-0.632217,185100.0


In [35]:
sample_housing_data = housing_data.sample(n=200, random_state=42)
sample_housing_data


Unnamed: 0,longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value
949,-0.065035,-0.567377,1.141059,-0.382186,0.045398,0.056977,0.100597,-0.629375,247900.0
3168,-0.883609,1.108731,0.267020,0.568967,0.353912,0.191202,0.385698,0.091447,129200.0
2080,-1.477574,1.075958,0.664310,-0.418857,-0.441105,-0.686553,-0.425140,0.133873,310300.0
1210,0.688653,-0.792107,1.299975,-0.350558,-0.396015,-0.097557,-0.307438,-0.265595,160800.0
2553,0.379192,-0.632923,-0.686477,0.477748,0.346792,0.403134,0.508632,0.152665,196800.0
...,...,...,...,...,...,...,...,...,...
3816,-0.119939,0.570316,-0.924851,-0.135574,-0.398388,-0.222067,-0.270819,0.189353,94400.0
1659,1.342514,-0.796789,0.664310,-0.522911,-0.258370,-0.494931,-0.412062,-1.395886,55000.0
3230,-1.168113,0.462633,1.856182,-0.648967,-0.642826,-0.827842,-0.613464,-0.065940,243800.0
4099,0.279366,0.205131,1.220517,-0.085151,0.088115,-0.084311,0.022129,-1.189809,50900.0


## Prediction gap dla każdego featura osobno

In [9]:
predgap = NormalPredictionGap(stddev)


In [10]:
%%time
single_abs_predgaps = prediction_gap_on_single_feature_perturbation(
    predgap, wine_trees, wine_data, squared=False)


Starting predgap calculation for citric_acid.
Starting predgap calculation for density.
Starting predgap calculation for total_sulfur_dioxide.
Starting predgap calculation for alcohol.
Starting predgap calculation for sulphates.
Starting predgap calculation for pH.
Starting predgap calculation for volatile_acidity.
Starting predgap calculation for residual_sugar.
Starting predgap calculation for fixed_acidity.
Starting predgap calculation for free_sulfur_dioxide.
Starting predgap calculation for chlorides.
CPU times: user 1min 37s, sys: 2.22 s, total: 1min 39s
Wall time: 1min 37s


In [11]:
single_abs_predgaps


Unnamed: 0,Feature,PredGap
10,chlorides,11.300655
8,fixed_acidity,11.295127
2,total_sulfur_dioxide,11.290612
5,pH,11.288976
9,free_sulfur_dioxide,11.288638
6,volatile_acidity,11.288001
0,citric_acid,11.287782
4,sulphates,11.285678
7,residual_sugar,11.284041
1,density,11.283049


In [12]:
%%time
single_sqr_predgaps = prediction_gap_on_single_feature_perturbation(
    predgap, wine_trees, wine_data, squared=True)


Starting predgap calculation for citric_acid.
Starting predgap calculation for density.
Starting predgap calculation for total_sulfur_dioxide.
Starting predgap calculation for alcohol.
Starting predgap calculation for sulphates.
Starting predgap calculation for pH.
Starting predgap calculation for volatile_acidity.
Starting predgap calculation for residual_sugar.
Starting predgap calculation for fixed_acidity.
Starting predgap calculation for free_sulfur_dioxide.
Starting predgap calculation for chlorides.
CPU times: user 1min 43s, sys: 1.76 s, total: 1min 45s
Wall time: 1min 43s


In [13]:
single_sqr_predgaps


Unnamed: 0,Feature,PredGap
10,chlorides,37.796323
8,fixed_acidity,37.712785
2,total_sulfur_dioxide,37.674693
6,volatile_acidity,37.654698
9,free_sulfur_dioxide,37.641536
5,pH,37.639914
0,citric_acid,37.629953
4,sulphates,37.629519
7,residual_sugar,37.60895
3,alcohol,37.590154


In [36]:
%%time
single_abs_predgaps = prediction_gap_on_single_feature_perturbation(
    predgap, housing_trees, sample_housing_data, squared=False)


Starting predgap calculation for latitude.
Starting predgap calculation for total_rooms.
Starting predgap calculation for households.
Starting predgap calculation for total_bedrooms.
Starting predgap calculation for housing_median_age.
Starting predgap calculation for population.
Starting predgap calculation for median_income.
Starting predgap calculation for longitude.
CPU times: user 1min 28s, sys: 768 ms, total: 1min 29s
Wall time: 1min 28s


In [37]:
single_abs_predgaps


Unnamed: 0,Feature,PredGap
4,housing_median_age,411572.453911
6,median_income,410685.98412
5,population,409862.009713
2,households,409086.467455
1,total_rooms,408040.844113
0,latitude,402778.569877
7,longitude,402142.112315
3,total_bedrooms,


In [38]:
%%time
single_sqr_predgaps = prediction_gap_on_single_feature_perturbation(
    predgap, housing_trees, sample_housing_data, squared=True)


Starting predgap calculation for latitude.
Starting predgap calculation for total_rooms.
Starting predgap calculation for households.
Starting predgap calculation for total_bedrooms.
Starting predgap calculation for housing_median_age.
Starting predgap calculation for population.
Starting predgap calculation for median_income.
Starting predgap calculation for longitude.
CPU times: user 1min 28s, sys: 594 ms, total: 1min 29s
Wall time: 1min 28s


In [39]:
single_sqr_predgaps


Unnamed: 0,Feature,PredGap
0,latitude,43889830000.0
7,longitude,42391730000.0
6,median_income,42211010000.0
4,housing_median_age,42171200000.0
5,population,42097570000.0
2,households,41230780000.0
1,total_rooms,41157500000.0
3,total_bedrooms,


### Tu patrzymy na błędne wyniki

In [14]:
small_test_data = wine_data[:10]
small_test_data


Unnamed: 0,fixed_acidity,volatile_acidity,citric_acid,residual_sugar,chlorides,free_sulfur_dioxide,total_sulfur_dioxide,density,pH,sulphates,alcohol,quality
0,-0.413454,0.123905,-0.313113,-0.240375,-0.349975,-0.848716,-0.561586,-0.183745,-0.201591,-0.63822,-0.678644,5
1,1.310138,-0.937525,1.638205,-0.240375,1.371576,-0.944346,-0.865676,0.982285,-1.756618,2.312434,-0.960246,5
2,-1.160343,2.52609,-1.340122,-0.382271,-0.647527,-0.083669,-0.409542,-0.989366,1.871778,-1.169337,0.729364,6
3,-0.068735,0.598756,-0.877968,-0.311323,-0.307468,0.872638,0.4115,-0.194345,-0.136798,0.542042,0.447763,6
4,-1.217796,-0.490607,0.611196,-0.027532,-0.222453,-0.944346,-0.987312,-0.634256,1.288643,0.187963,0.54163,6
5,1.597403,-0.434742,2.357111,0.469103,-0.456244,-0.944346,-0.74404,0.982285,-0.914312,0.010924,0.729364,6
6,1.310138,-0.937525,1.535504,-0.169427,-0.009916,-0.944346,-0.804858,0.00706,-1.10869,0.365003,0.635497,7
7,0.161077,-0.434742,0.200392,0.043416,-0.031169,0.490115,0.107411,0.77028,0.381544,1.486251,-0.49091,6
8,0.390889,-0.714066,0.713897,-0.382271,-0.626274,-0.083669,-0.196679,-1.381576,-0.590348,-0.343154,1.668037,6
9,-0.643266,1.101539,-1.13472,1.497846,-0.031169,-1.135608,-1.078539,-0.289747,0.640715,-1.582429,0.447763,5


In [40]:
baseline_preds = wine_trees.eval_on_multiple_rows(small_test_data)
baseline_preds


array([5.0759096, 5.015626 , 4.4223237, 5.9840064, 5.9561796, 6.2821665,
       6.762588 , 5.60282  , 6.0557013, 4.9935913], dtype=float32)

In [41]:
feature = "fixed_acidity"


In [42]:
func = predgap.prediction_gap_fixed


In [44]:
        curr_feature_total = 0.0
        for i in range(len(small_test_data)):
            x = small_test_data.iloc[i, :-1]
            y = baseline_preds[i]
            curr_datapoint_predgap = func(wine_trees, x, {feature}, y)
            print(curr_datapoint_predgap)
            curr_feature_total += curr_datapoint_predgap
        curr_feature_total /= len(small_test_data)
        

37.69552255752826
36.409101263546205
38.710570152255805
37.73896477582433
35.57383783001329
36.676068296194344
38.03687954663548
36.73549678529932
37.26499490915275
38.82277417555595


In [45]:
print(curr_feature_total)


37.366421029200566


### Tu patrzymy ile mniej więcej powinno wyjść

In [49]:
import numpy as np


In [50]:
perturbed_features = {feature}


In [51]:
rng = np.random.default_rng()


In [53]:
perturbed_df = small_test_data.copy()
perturbed_df[list(perturbed_features)] += rng.normal(loc=0.0, scale=stddev,
                                                     size=(len(small_test_data), len(perturbed_features)))


In [55]:
perturbed_preds = wine_trees.eval_on_multiple_rows(perturbed_df)
perturbed_preds


array([5.0500135, 5.0315866, 4.4223237, 5.875661 , 6.5197372, 6.262988 ,
       6.710272 , 5.6118474, 6.071811 , 4.976256 ], dtype=float32)

In [56]:
(baseline_preds - perturbed_preds) ** 2


array([6.7060656e-04, 2.5474373e-04, 0.0000000e+00, 1.1738749e-02,
       3.1759721e-01, 3.6781066e-04, 2.7369836e-03, 8.1495411e-05,
       2.5953026e-04, 3.0051661e-04], dtype=float32)

In [57]:
np.mean((baseline_preds - perturbed_preds) ** 2)


0.033400767

### Czy nadal działa dobrze dla klasyfikacji?

Nie.

In [58]:
heloc_model_name = "heloc-scaled-gbdt"
heloc_test_data_path = data_path / "heloc-scaled-test.csv"


In [59]:
heloc_trees = load_trees(models_path, wine_model_name)


In [63]:
heloc_data = pd.read_csv(heloc_test_data_path)
sample_heloc_data = heloc_data[:10]
sample_heloc_data


Unnamed: 0,ExternalRiskEstimate,MSinceOldestTradeOpen,MSinceMostRecentTradeOpen,AverageMInFile,NumSatisfactoryTrades,NumTrades60Ever2DerogPubRec,NumTrades90Ever2DerogPubRec,PercentTradesNeverDelq,MSinceMostRecentDelq,MaxDelq2PublicRecLast12M,...,MSinceMostRecentInqexcl7days,NumInqLast6M,NumInqLast6Mexcl7days,NetFractionRevolvingBurden,NetFractionInstallBurden,NumRevolvingTradesWBalance,NumInstallTradesWBalance,NumBank2NatlTradesWHighUtilization,PercentTradesWBalance,RiskPerformance
0,-0.2149,1.699313,-0.056723,2.459516,-0.363267,1.20102,1.712719,-1.564641,-1.46636,-3.498869,...,-0.48872,-0.211494,-0.187532,-1.005586,0.090931,-0.70225,-0.955119,-0.06251,-1.379531,0.0
1,1.716386,-0.361454,-0.280331,0.499039,-0.185539,-0.476838,-0.394914,0.646149,0.946849,0.75547,...,0.256357,-0.668264,-0.652592,-0.866162,-0.122189,-0.029287,-0.307994,-0.06251,0.391373,0.0
2,0.903213,0.638018,-0.578475,-0.027656,2.036062,-0.476838,-0.394914,0.646149,0.946849,0.75547,...,1.852951,1.615589,1.672709,-1.075297,1.598802,-0.365768,-0.307994,-0.740481,-1.879017,1.0
3,-0.011607,-0.309935,-0.503939,-1.139568,-0.896451,-0.476838,-0.394914,0.646149,0.946849,0.75547,...,-0.701599,0.245277,0.277528,0.946344,1.44695,0.307195,-0.307994,0.615461,0.981674,0.0
4,0.394979,0.586499,-0.503939,0.323474,1.680606,1.20102,0.658902,-0.119124,0.681985,0.147707,...,-0.062962,-0.668264,-0.652592,-0.238756,-0.096058,1.31664,-0.307994,-0.06251,-0.743822,1.0
5,-0.418194,0.48346,-0.280331,-0.203221,1.236286,0.362091,0.658902,0.391058,-1.260355,-3.498869,...,0.469236,-0.668264,-0.652592,-0.901018,-0.274042,-0.365768,0.339132,-0.06251,-1.515754,1.0
6,-0.51984,0.504068,-0.280331,0.703865,0.43651,-0.476838,-0.394914,-0.374216,-1.46636,-1.067818,...,-0.701599,-0.668264,-0.652592,1.538895,-0.425894,-0.029287,-0.307994,0.615461,0.391373,0.0
7,-0.316547,-0.474796,-0.503939,-0.788438,0.88083,-0.476838,-0.394914,0.646149,0.946849,0.75547,...,1.852951,0.245277,0.277528,0.56293,1.497567,-0.365768,0.339132,0.615461,-0.925453,0.0
8,-0.621487,0.916221,-0.280331,1.142778,1.680606,0.362091,-0.394914,-0.544276,-0.877773,-0.460055,...,-0.169401,-0.668264,-0.652592,1.538895,0.383984,2.326085,-0.307994,2.649374,0.709228,1.0
9,-0.926427,-0.392366,-0.503939,-1.578481,0.258782,-0.476838,-0.394914,0.646149,0.946849,0.75547,...,1.852951,-0.211494,-0.187532,2.305724,1.092628,0.643677,2.280509,1.971403,1.526568,0.0


In [64]:
%%time
single_sqr_predgaps = prediction_gap_on_single_feature_perturbation(
    predgap, heloc_trees, sample_heloc_data, squared=True)


Starting predgap calculation for NumInqLast6M.
Starting predgap calculation for PercentInstallTrades.
Starting predgap calculation for ExternalRiskEstimate.
Starting predgap calculation for NumTrades60Ever2DerogPubRec.
Starting predgap calculation for NumSatisfactoryTrades.
Starting predgap calculation for NumInstallTradesWBalance.
Starting predgap calculation for MaxDelqEver.
Starting predgap calculation for NumRevolvingTradesWBalance.
Starting predgap calculation for MSinceMostRecentInqexcl7days.
Starting predgap calculation for NumTradesOpeninLast12M.
Starting predgap calculation for MSinceOldestTradeOpen.
Starting predgap calculation for PercentTradesWBalance.
Starting predgap calculation for NumInqLast6Mexcl7days.
Starting predgap calculation for NumTotalTrades.
Starting predgap calculation for MaxDelq2PublicRecLast12M.
Starting predgap calculation for NumTrades90Ever2DerogPubRec.
Starting predgap calculation for NetFractionRevolvingBurden.
Starting predgap calculation for MSinceM

Wyniki powinny być z przedziału [0, 1].

In [65]:
single_sqr_predgaps


Unnamed: 0,Feature,PredGap
0,NumInqLast6M,37.386643
12,NumInqLast6Mexcl7days,37.386643
21,PercentTradesNeverDelq,37.386643
20,NumBank2NatlTradesWHighUtilization,37.386643
19,NetFractionInstallBurden,37.386643
18,AverageMInFile,37.386643
17,MSinceMostRecentTradeOpen,37.386643
16,NetFractionRevolvingBurden,37.386643
15,NumTrades90Ever2DerogPubRec,37.386643
14,MaxDelq2PublicRecLast12M,37.386643
