# Individual Shap Values

This notebook presentd the usage of the `estimate_individual_shapley_values()` function. It is based on the algorithm described in [interpretable-ml-book](https://christophm.github.io/interpretable-ml-book/shapley.html#estimating-the-shapley-value) and the implementation presented [here](https://medium.com/mlearning-ai/machine-learning-interpretability-shapley-values-with-pyspark-16ffd87227e3).

## Session Setup

In [1]:
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession

from pyspark_ds_toolbox.ml.shap_values import estimate_shap_values



## lendo o dataset base

In [2]:
def read_data(file): 
    return pd.read_stata("https://raw.github.com/scunning1975/mixtape/master/" + file)

df = read_data('nsw_mixtape.dta')
df = pd.concat((df, read_data('cps_mixtape.dta')))
df['treat'] = df['treat'].astype(int)
#df.reset_index(level=0, inplace=True)
#df.rename(columns={'index':'id'}, inplace=True)
df['id'] = np.arange(len(df))

df.head()

Unnamed: 0,data_id,treat,age,educ,black,hisp,marr,nodegree,re74,re75,re78,id
0,Dehejia-Wahba Sample,1,37.0,11.0,1.0,0.0,1.0,1.0,0.0,0.0,9930.045898,0
1,Dehejia-Wahba Sample,1,22.0,9.0,0.0,1.0,0.0,1.0,0.0,0.0,3595.894043,1
2,Dehejia-Wahba Sample,1,30.0,12.0,1.0,0.0,0.0,0.0,0.0,0.0,24909.449219,2
3,Dehejia-Wahba Sample,1,27.0,11.0,1.0,0.0,0.0,1.0,0.0,0.0,7506.145996,3
4,Dehejia-Wahba Sample,1,33.0,8.0,1.0,0.0,0.0,1.0,0.0,0.0,289.789886,4


In [3]:
spark = SparkSession.builder\
                .appName('Spark-Toolbox') \
                .master('local[1]') \
                .config('spark.executor.memory', '3G') \
                .config('spark.driver.memory', '3G') \
                .config('spark.memory.offHeap.enabled', 'true') \
                .config('spark.memory.offHeap.size', '3G') \
                .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/12/14 14:56:07 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
sdf = spark.createDataFrame(df)
sdf.show(3)



+--------------------+-----+----+----+-----+----+----+--------+----+----+--------+---+
|             data_id|treat| age|educ|black|hisp|marr|nodegree|re74|re75|    re78| id|
+--------------------+-----+----+----+-----+----+----+--------+----+----+--------+---+
|Dehejia-Wahba Sample|    1|37.0|11.0|  1.0| 0.0| 1.0|     1.0| 0.0| 0.0|9930.046|  0|
|Dehejia-Wahba Sample|    1|22.0| 9.0|  0.0| 1.0| 0.0|     1.0| 0.0| 0.0|3595.894|  1|
|Dehejia-Wahba Sample|    1|30.0|12.0|  1.0| 0.0| 0.0|     0.0| 0.0| 0.0|24909.45|  2|
+--------------------+-----+----+----+-----+----+----+--------+----+----+--------+---+
only showing top 3 rows





## Regression

In [5]:
shap_values = estimate_shap_values(
    sdf=sdf,
    id_col='id',
    target_col='re78',
    cat_features = ['data_id'],
    sort_metric='rmse',
    problem_type='regression',
    subset_size = 1000,
    max_mem_size = '2G',
    max_models=8,
    max_runtime_secs=15,
    nfolds=5,
    seed=90
)

In [6]:
shap_values.show(20)

Checking whether there is an H2O instance running at http://localhost:54321 . connected.
--------------------------  ------------------------------------------------------------------
H2O_cluster_uptime:         1 day 4 hours 46 mins
H2O_cluster_timezone:       America/Sao_Paulo
H2O_data_parsing_timezone:  UTC
H2O_cluster_version:        3.34.0.3
H2O_cluster_version_age:    2 months and 7 days
H2O_cluster_name:           H2O_from_python_CBSSDIGITAL_07_000504_cdq2v9
H2O_cluster_total_nodes:    1
H2O_cluster_free_memory:    731 Mb
H2O_cluster_total_cores:    2
H2O_cluster_allowed_cores:  2
H2O_cluster_status:         locked, healthy
H2O_connection_url:         http://localhost:54321
H2O_connection_proxy:       {"http": null, "https": null}
H2O_internal_security:      False
H2O_API_Extensions:         Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4
Python_version:             3.7.10 final
--------------------------  ------------------------------------------------------

+-----+--------+----------+
|   id|variable|shap_value|
+-----+--------+----------+
| 9867| data_id|       0.0|
| 4726| data_id|       0.0|
| 5161| data_id|      -0.0|
|11229| data_id|       0.0|
| 7084| data_id|      -0.0|
| 8062| data_id|       0.0|
| 1931| data_id|      -0.0|
| 2968| data_id|       0.0|
|14679| data_id|       0.0|
| 3297| data_id|       0.0|
|10615| data_id|       0.0|
|   12| data_id|      -0.0|
|  668| data_id|       0.0|
| 8825| data_id|      -0.0|
|13684| data_id|       0.0|
|10169| data_id|       0.0|
| 7341| data_id|       0.0|
| 5887| data_id|      -0.0|
| 6552| data_id|      -0.0|
|12638| data_id|      -0.0|
+-----+--------+----------+
only showing top 20 rows



████████████████████████████████████████████████████████| (done) 100%
Checking whether there is an H2O instance running at http://localhost:54321 . connected.
--------------------------  ------------------------------------------------------------------
H2O_cluster_uptime:         1 day 4 hours 47 mins
H2O_cluster_timezone:       America/Sao_Paulo
H2O_data_parsing_timezone:  UTC
H2O_cluster_version:        3.34.0.3
H2O_cluster_version_age:    2 months and 7 days
H2O_cluster_name:           H2O_from_python_CBSSDIGITAL_07_000504_cdq2v9
H2O_cluster_total_nodes:    1
H2O_cluster_free_memory:    726 Mb
H2O_cluster_total_cores:    2
H2O_cluster_allowed_cores:  2
H2O_cluster_status:         locked, healthy
H2O_connection_url:         http://localhost:54321
H2O_connection_proxy:       {"http": null, "https": null}
H2O_internal_security:      False
H2O_API_Extensions:         Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4
Python_version:             3.7.10 final
------------

## Classification

In [8]:
shap_values_classification = estimate_shap_values(
    sdf=sdf,
    id_col='id',
    target_col='treat',
    cat_features = ['data_id'],
    sort_metric='aucpr',
    problem_type='classification',
    subset_size = 1000,
    max_mem_size = '2G',
    max_models=8,
    max_runtime_secs=15,
    nfolds=5,
    seed=90
)

In [10]:
shap_values_classification.show(20)

███████████████████████████████████████████████████████████████| (done) 100%
contributions progress: |████████████████████████████████████████████████████████| (done) 100%
Checking whether there is an H2O instance running at http://localhost:54321 . connected.
--------------------------  ------------------------------------------------------------------
H2O_cluster_uptime:         1 day 4 hours 41 mins
H2O_cluster_timezone:       America/Sao_Paulo
H2O_data_parsing_timezone:  UTC
H2O_cluster_version:        3.34.0.3
H2O_cluster_version_age:    2 months and 7 days
H2O_cluster_name:           H2O_from_python_CBSSDIGITAL_07_000504_cdq2v9
H2O_cluster_total_nodes:    1
H2O_cluster_free_memory:    673 Mb
H2O_cluster_total_cores:    2
H2O_cluster_allowed_cores:  2
H2O_cluster_status:         locked, healthy
H2O_connection_url:         http://localhost:54321
H2O_connection_proxy:       {"http": null, "https": null}
H2O_internal_security:      False
H2O_API_Extensions:         Amazon S3, XGBoost

+-----+--------+-----------+
|   id|variable| shap_value|
+-----+--------+-----------+
| 1512| data_id| 0.27528387|
|11332| data_id| 0.27509025|
| 4248| data_id| 0.27486047|
| 4676| data_id| 0.27495986|
|10536| data_id| 0.27442735|
|12749| data_id|  0.2752797|
|12705| data_id|   0.274982|
| 6138| data_id| 0.57456595|
| 9056| data_id| 0.38522583|
| 2410| data_id| 0.37190032|
| 3088| data_id| 0.27481568|
| 4666| data_id|  0.2752847|
|12412| data_id|  0.2742324|
| 7740| data_id| 0.27516872|
|  824| data_id|  0.2751758|
|  274| data_id|  1.1417135|
| 3164| data_id| 0.27519616|
| 6195| data_id| 0.27522728|
|11439| data_id|  0.5480533|
| 1792| data_id| 0.23286474|
| 1139| data_id| 0.27528587|
| 8773| data_id|  0.2756709|
|  345| data_id|  1.0180364|
| 6134| data_id| 0.27568513|
| 5064| data_id| 0.27498662|
|13394| data_id| 0.27528393|
|12820| data_id| 0.26185644|
| 1304| data_id| 0.27517202|
|12305| data_id| 0.27295837|
| 8457| data_id| 0.27570117|
|10965| data_id|  0.2752859|
| 7330| data_i

████████████████████████████████████████████████████████| (done) 100%
 connected.
--------------------------  ------------------------------------------------------------------
H2O_cluster_uptime:         1 day 4 hours 41 mins
H2O_cluster_timezone:       America/Sao_Paulo
H2O_data_parsing_timezone:  UTC
H2O_cluster_version:        3.34.0.3
H2O_cluster_version_age:    2 months and 7 days
H2O_cluster_name:           H2O_from_python_CBSSDIGITAL_07_000504_cdq2v9
H2O_cluster_total_nodes:    1
H2O_cluster_free_memory:    669 Mb
H2O_cluster_total_cores:    2
H2O_cluster_allowed_cores:  2
H2O_cluster_status:         locked, healthy
H2O_connection_url:         http://localhost:54321
H2O_connection_proxy:       {"http": null, "https": null}
H2O_internal_security:      False
H2O_API_Extensions:         Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4
Python_version:             3.7.10 final
--------------------------  -------------------------------------------------------------

In [15]:
shap_values.count()

Checking whether there is an H2O instance running at http://localhost:54321 . connected.
--------------------------  ------------------------------------------------------------------
H2O_cluster_uptime:         1 day 5 hours 23 mins
H2O_cluster_timezone:       America/Sao_Paulo
H2O_data_parsing_timezone:  UTC
H2O_cluster_version:        3.34.0.3
H2O_cluster_version_age:    2 months and 7 days
H2O_cluster_name:           H2O_from_python_CBSSDIGITAL_07_000504_cdq2v9
H2O_cluster_total_nodes:    1
H2O_cluster_free_memory:    723 Mb
H2O_cluster_total_cores:    2
H2O_cluster_allowed_cores:  2
H2O_cluster_status:         locked, healthy
H2O_connection_url:         http://localhost:54321
H2O_connection_proxy:       {"http": null, "https": null}
H2O_internal_security:      False
H2O_API_Extensions:         Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4
Python_version:             3.7.10 final
--------------------------  ------------------------------------------------------

180807