# Individual Shap Values

This notebook presentd the usage of the `estimate_individual_shapley_values()` function. It is based on the algorithm described in [interpretable-ml-book](https://christophm.github.io/interpretable-ml-book/shapley.html#estimating-the-shapley-value) and the implementation presented [here](https://medium.com/mlearning-ai/machine-learning-interpretability-shapley-values-with-pyspark-16ffd87227e3).

## Session Setup

In [1]:
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession

from pyspark_ds_toolbox.ml.shap_values import H2OWrapper

## lendo o dataset base

In [2]:
def read_data(file): 
    return pd.read_stata("https://raw.github.com/scunning1975/mixtape/master/" + file)

df = read_data('nsw_mixtape.dta')
df = pd.concat((df, read_data('cps_mixtape.dta')))
df['treat'] = df['treat'].astype(int)
#df.reset_index(level=0, inplace=True)
#df.rename(columns={'index':'id'}, inplace=True)
df['id'] = np.arange(len(df))

df.head()

Unnamed: 0,data_id,treat,age,educ,black,hisp,marr,nodegree,re74,re75,re78,id
0,Dehejia-Wahba Sample,1,37.0,11.0,1.0,0.0,1.0,1.0,0.0,0.0,9930.045898,0
1,Dehejia-Wahba Sample,1,22.0,9.0,0.0,1.0,0.0,1.0,0.0,0.0,3595.894043,1
2,Dehejia-Wahba Sample,1,30.0,12.0,1.0,0.0,0.0,0.0,0.0,0.0,24909.449219,2
3,Dehejia-Wahba Sample,1,27.0,11.0,1.0,0.0,0.0,1.0,0.0,0.0,7506.145996,3
4,Dehejia-Wahba Sample,1,33.0,8.0,1.0,0.0,0.0,1.0,0.0,0.0,289.789886,4


In [3]:
spark = SparkSession.builder\
                .appName('Spark-Toolbox') \
                .master('local[1]') \
                .config('spark.executor.memory', '3G') \
                .config('spark.driver.memory', '3G') \
                .config('spark.memory.offHeap.enabled', 'true') \
                .config('spark.memory.offHeap.size', '3G') \
                .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/12/13 15:10:37 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:



sdf = spark.createDataFrame(df)
sdf.show(3)



+--------------------+-----+----+----+-----+----+----+--------+----+----+--------+---+
|             data_id|treat| age|educ|black|hisp|marr|nodegree|re74|re75|    re78| id|
+--------------------+-----+----+----+-----+----+----+--------+----+----+--------+---+
|Dehejia-Wahba Sample|    1|37.0|11.0|  1.0| 0.0| 1.0|     1.0| 0.0| 0.0|9930.046|  0|
|Dehejia-Wahba Sample|    1|22.0| 9.0|  0.0| 1.0| 0.0|     1.0| 0.0| 0.0|3595.894|  1|
|Dehejia-Wahba Sample|    1|30.0|12.0|  1.0| 0.0| 0.0|     0.0| 0.0| 0.0|24909.45|  2|
+--------------------+-----+----+----+-----+----+----+--------+----+----+--------+---+
only showing top 3 rows





## Regression

In [5]:
model = H2OWrapper(
    max_mem_size='1G',
    df=df,
    id_col='id',
    target_col='re78',
    cat_features=['data_id'],
    sort_metric='rmse',
    problem_type='regression',
    max_models=8,
    max_runtime_secs=30,
    nfolds=5,
    seed=90
)


Checking whether there is an H2O instance running at http://localhost:54321 . connected.


0,1
H2O_cluster_uptime:,4 hours 29 mins
H2O_cluster_timezone:,America/Sao_Paulo
H2O_data_parsing_timezone:,UTC
H2O_cluster_version:,3.34.0.3
H2O_cluster_version_age:,2 months and 6 days
H2O_cluster_name:,H2O_from_python_CBSSDIGITAL_07_000504_cdq2v9
H2O_cluster_total_nodes:,1
H2O_cluster_free_memory:,734 Mb
H2O_cluster_total_cores:,2
H2O_cluster_allowed_cores:,2


Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%
AutoML progress: |███████████████████████████████████████████████████████████████| (done) 100%
contributions progress: |████████████████████████████████████████████████████████| (done) 100%


In [6]:
model.shap_values.sort_values(['id', 'variable'])

Unnamed: 0,id,variable,shap_value
32874,0,age,0.046105
65748,0,black,0.020949
0,0,data_id,-0.041034
49311,0,educ,0.002018
82185,0,hisp,0.000215
...,...,...,...
115058,16436,marr,0.281929
131495,16436,nodegree,0.072110
147932,16436,re74,0.194097
164369,16436,re75,0.173252


## Classification

In [11]:
model = H2OWrapper(
    max_mem_size='1G',
    df=df,
    id_col='index',
    cat_features=['data_id'],
    target_col='treat',
    sort_metric='aucpr',
    problem_type='classification',
    train_size=0.7,
    max_models=8,
    max_runtime_secs=30,
    nfolds=5,
    seed=90
)


Checking whether there is an H2O instance running at http://localhost:54321 . connected.


0,1
H2O_cluster_uptime:,17 mins 13 secs
H2O_cluster_timezone:,America/Sao_Paulo
H2O_data_parsing_timezone:,UTC
H2O_cluster_version:,3.34.0.3
H2O_cluster_version_age:,2 months and 5 days
H2O_cluster_name:,H2O_from_python_CBSSDIGITAL_07_000504_cdq2v9
H2O_cluster_total_nodes:,1
H2O_cluster_free_memory:,802 Mb
H2O_cluster_total_cores:,2
H2O_cluster_allowed_cores:,2


Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%
AutoML progress: |███████████████████████████████████████████████████████████████| (done) 100%
contributions progress: |████████████████████████████████████████████████████████| (done) 100%


In [12]:
model.shap_values.head()

Unnamed: 0,index,age,educ,black,hisp,marr,nodegree,re74,re75,re78,data_id_Dehejia-Wahba Sample,MeanDeviance,BiasTerm
0,0,0.0,-0.061009,0.0,0.028771,0.0,0.0,0.02608,0.041803,-0.036995,1.00135,2.883605,-4.499061
1,1,0.0,-0.013179,0.0,-0.234024,0.0,0.0,0.038064,-0.016943,-0.006626,1.232709,1.996,-4.499061
2,2,0.0,0.00693,0.0,0.002849,0.0,0.0,0.008442,0.111856,0.253269,0.616653,8.908318,-4.499061
3,3,0.0,-0.061009,0.0,0.028771,0.0,0.0,0.010704,0.026427,-0.020808,1.015916,2.883605,-4.499061
4,4,0.0,0.151695,0.0,0.019543,0.0,0.0,0.00886,0.066559,0.024747,0.728596,4.823765,-4.499061



## Rascunho

In [5]:
from pyspark.sql import functions as F, types as T, Window as W

In [13]:
sdf\
        .withColumn('rand', F.rand())\
        .withColumn('qcut', F.ntile(2).over(W.partitionBy().orderBy(F.col('rand'))))\
        .withColumn('qcut', F.col('qcut').cast(T.StringType())).printSchema()

root
 |-- data_id: string (nullable = true)
 |-- treat: long (nullable = true)
 |-- age: float (nullable = true)
 |-- educ: float (nullable = true)
 |-- black: float (nullable = true)
 |-- hisp: float (nullable = true)
 |-- marr: float (nullable = true)
 |-- nodegree: float (nullable = true)
 |-- re74: float (nullable = true)
 |-- re75: float (nullable = true)
 |-- re78: float (nullable = true)
 |-- id: long (nullable = true)
 |-- rand: double (nullable = false)
 |-- qcut: string (nullable = false)



In [29]:
sdf.schema['id'].dataType == T.LongType()

True

In [25]:
schema = T.StructType([
  T.StructField('id', sdf.schema['id'].dataType),
  T.StructField('variable', T.StringType()),
  T.StructField('shap_value', T.FloatType()),
])

def estimate_shap_values(
    sdf,
    max_mem_size='1G',
    id_col='id',
    target_col='re78',
    cat_features=['data_id'],
    sort_metric='rmse',
    problem_type='regression',
    max_models=8,
    max_runtime_secs=30,
    nfolds=5,
    seed=90
):
    sdf = sdf\
        .withColumn('rand', F.rand())\
        .withColumn('qcut', F.ntile(2).over(W.partitionBy().orderBy(F.col('rand'))))\
        .withColumn('qcut', F.col('qcut').cast(T.StringType()))

    #@F.pandas_udf(schema, F.PandasUDFType.GROUPED_MAP)
    def fn(pdf):
        model = H2OWrapper(
            df=pdf,
            max_mem_size=max_mem_size,
            id_col=id_col,
            target_col=target_col,
            cat_features=cat_features,
            sort_metric=sort_metric,
            problem_type=problem_type,
            max_models=max_models,
            max_runtime_secs=max_runtime_secs,
            nfolds=nfolds,
            seed=seed
        )
        return model.shap_values
    return sdf.groupBy('qcut').applyInPandas(fn, schema=schema)

In [30]:
tt = estimate_shap_values(
    sdf=sdf,
    max_mem_size='1G',
    id_col='id',
    target_col='re78',
    cat_features=['data_id'],
    sort_metric='rmse',
    problem_type='regression',
    max_models=8,
    max_runtime_secs=10,
    nfolds=5,
    seed=90
)

In [32]:
tt.orderBy('id', 'variable').show(20)

Checking whether there is an H2O instance running at http://localhost:54321 . connected.
--------------------------  ------------------------------------------------------------------
H2O_cluster_uptime:         5 hours 16 mins
H2O_cluster_timezone:       America/Sao_Paulo
H2O_data_parsing_timezone:  UTC
H2O_cluster_version:        3.34.0.3
H2O_cluster_version_age:    2 months and 6 days
H2O_cluster_name:           H2O_from_python_CBSSDIGITAL_07_000504_cdq2v9
H2O_cluster_total_nodes:    1
H2O_cluster_free_memory:    698 Mb
H2O_cluster_total_cores:    2
H2O_cluster_allowed_cores:  2
H2O_cluster_status:         locked, healthy
H2O_connection_url:         http://localhost:54321
H2O_connection_proxy:       {"http": null, "https": null}
H2O_internal_security:      False
H2O_API_Extensions:         Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4
Python_version:             3.7.10 final
--------------------------  ------------------------------------------------------------

+---+--------------------+-------------+
| id|            variable|   shap_value|
+---+--------------------+-------------+
|  0|                 age|   0.07760226|
|  0|               black|  0.025549281|
|  0|             data_id|         -0.0|
|  0|                educ| 5.4665393E-4|
|  0|                hisp| -3.008009E-5|
|  0|                marr| 0.0022601434|
|  0|            nodegree| -0.002892988|
|  0|                rand|  0.017775178|
|  0|                re74|    0.2688882|
|  0|                re75|    0.6103014|
|  0|               treat|         -0.0|
|  1|                 age| 0.0066296067|
|  1|               black| -0.017312199|
|  1|        data_id.CPS1|   0.04577717|
|  1|data_id.Dehejia-W...|  0.018195676|
|  1| data_id.missing(NA)|         -0.0|
|  1|                educ|   0.08427536|
|  1|                hisp| -0.005700647|
|  1|                marr|-0.0024202915|
|  1|            nodegree|  0.030252375|
+---+--------------------+-------------+
only showing top



In [5]:


schema = T.StructType([
  T.StructField('c1', T.StringType()),
  T.StructField('c2', T.StringType()),
  T.StructField('c3', T.IntegerType()),
])

def fn_wrapper(df, val):

  @F.pandas_udf(schema, F.PandasUDFType.GROUPED_MAP)
  def fn(pdf):
    pdf['c3'] = pdf.shape[0] + val
    return pdf

  return df.groupby('c1', 'c2').apply(fn)

fn_wrapper(test, 7).show()

Unnamed: 0,id
0,0
1,1
2,2
3,3
4,4
...,...
15987,16432
15988,16433
15989,16434
15990,16435


In [None]:
from pyspark.sql.functions import pandas_udf, PandasUDFType


def estima(df, by="id", column="v", value=1.0):
    schema = "{} long, {} double".format(by, column)

    @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
    def subtract_value(pdf):
        # pdf is a pandas.DataFrame
        v = pdf[column]
        g = pdf[by]
        return pdf.assign(v = v - g * value)

    return df.groupby(by).apply(subtract_value)

my_function(df, by="id", column="v", value=2.0).show()