In [1]:
!which python

/Users/vini/Dev-Files/Poetry/virtualenvs/pyspark-ds-toolbox-H0pw_EKR-py3.8/bin/python


In [1]:
import pandas as pd
from pyspark.sql.window import Window
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import FloatType, StructField, StructType, StringType
from pyspark.ml.linalg import VectorUDT

import pyspark.ml.feature as FF
from pyspark.ml import Pipeline
from pyspark.ml.classification import GBTClassifier


In [2]:
from pyspark_ds_toolbox.ml.eval import get_p1

from pyspark_ds_toolbox.ml.eval import calculate_shapley_values

In [3]:
spark = SparkSession.builder\
                .appName('Ml-Pipes') \
                .master('local[1]') \
                .config('spark.executor.memory', '3G') \
                .config('spark.driver.memory', '3G') \
                .config('spark.memory.offHeap.enabled', 'true') \
                .config('spark.memory.offHeap.size', '3G') \
                .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

21/12/04 08:42:02 WARN Utils: Your hostname, matrix.local resolves to a loopback address: 127.0.0.1; using 10.0.0.105 instead (on interface en0)
21/12/04 08:42:02 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
21/12/04 08:42:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [4]:
def read_data(file): 
    return pd.read_stata("https://raw.github.com/scunning1975/mixtape/master/" + file)

df = read_data('nsw_mixtape.dta')
df = pd.concat((df, read_data('cps_mixtape.dta')))
df.reset_index(level=0, inplace=True)

df = spark.createDataFrame(df.drop(columns=['data_id']))\
    .withColumn('age2', F.col('age')**2)\
    .withColumn('age3', F.col('age')**3)\
    .withColumn('educ2', F.col('educ')**2)\
    .withColumn('educ_re74', F.col('educ')*F.col('re74'))\
    .withColumn('u74', F.when(F.col('re74')==0, 1).otherwise(0))\
    .withColumn('u75', F.when(F.col('re75')==0, 1).otherwise(0))

features=['age', 'age2', 'age3', 'educ', 'educ2', 'marr', 'nodegree', 'black', 'hisp', 're74', 're75', 'u74', 'u75', 'educ_re74']
assembler = FF.VectorAssembler(inputCols=features, outputCol='features')
pipeline = Pipeline(stages = [assembler])
df_assembled = pipeline.fit(df).transform(df)

In [5]:
train_size=0.8
train, test = df_assembled.randomSplit([train_size, (1-train_size)], seed=12345)

model = GBTClassifier(labelCol='treat')
p = Pipeline(stages=[model])
p_fitted = p.fit(train)

df_predicted = p_fitted.transform(test).withColumn('probability', get_p1(F.col('probability')))

df_predicted.printSchema()
v = df_predicted.filter(F.col('index')==3).select('probability').collect()[0][0]
m = df_predicted.select('probability').toPandas().probability.mean()



root
 |-- index: long (nullable = true)
 |-- treat: double (nullable = true)
 |-- age: double (nullable = true)
 |-- educ: double (nullable = true)
 |-- black: double (nullable = true)
 |-- hisp: double (nullable = true)
 |-- marr: double (nullable = true)
 |-- nodegree: double (nullable = true)
 |-- re74: double (nullable = true)
 |-- re75: double (nullable = true)
 |-- re78: double (nullable = true)
 |-- age2: double (nullable = true)
 |-- age3: double (nullable = true)
 |-- educ2: double (nullable = true)
 |-- educ_re74: double (nullable = true)
 |-- u74: integer (nullable = false)
 |-- u75: integer (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: float (nullable = true)
 |-- prediction: double (nullable = false)



In [7]:
import os
# from psutil import virtual_memory
from pyspark import SparkConf
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql import functions as F, SparkSession, types as T, Window

import operator

import pyspark
from pyspark.sql.types import StructType,StructField, StringType, IntegerType, FloatType

In [37]:
def calculate_shapley_values(
        spark,
        df,
        model,
        row_of_interest,
        feature_names,
        features_col='features',
        column_to_examine='anomalyScore',
        print_shap_values=False
):
    """
    # Based on the algorithm described here:
    # https://christophm.github.io/interpretable-ml-book/shapley.html#estimating-the-shapley-value
    # And on Baskerville's implementation for IForest/ AnomalyModel here:
    # https://github.com/equalitie/baskerville/blob/develop/src/baskerville/util/model_interpretation/helpers.py#L235
    """

    schema = StructType([
        StructField('id', IntegerType(), True),
        StructField('feature', StringType(), True),
        StructField('shap', FloatType(), True)
    ])

    results = spark.createDataFrame(spark.sparkContext.emptyRDD(),schema)

    features_perm_col = 'features_permutations'
    marginal_contribution_filter = F.avg('marginal_contribution').alias('shap_value')
    # broadcast the row of interest and ordered feature names
    ROW_OF_INTEREST_BROADCAST = spark.sparkContext.broadcast(row_of_interest)
    ORDERED_FEATURE_NAMES = spark.sparkContext.broadcast(feature_names)

    # persist before continuing with calculations
    if not df.is_cached:
        df = df.persist()

    # get permutations
    # Creates a column for the ordered features and then shuffles it.
    # The result is a dataframe with a column `output_col` that contains:
    # [feat2, feat4, feat3, feat1],
    # [feat3, feat4, feat2, feat1],
    # [feat1, feat2, feat4, feat3],
    # ...
    features_df = df.withColumn(
        'features_permutations',
        F.shuffle(
            F.array(*[F.lit(f) for f in feature_names])
        )
    )

    # set up the udf - x-j and x+j need to be calculated for every row
    def calculate_x(
            feature_j, z_features, curr_feature_perm
    ):
        """
        The instance  x+j is the instance of interest,
        but all values in the order before feature j are
        replaced by feature values from the sample z
        The instance  x−j is the same as  x+j, but in addition
        has feature j replaced by the value for feature j from the sample z
        """
        x_interest = ROW_OF_INTEREST_BROADCAST.value
        ordered_features = ORDERED_FEATURE_NAMES.value
        x_minus_j = list(z_features).copy()
        x_plus_j = list(z_features).copy()
        f_i = curr_feature_perm.index(feature_j)
        after_j = False
        for f in curr_feature_perm[f_i:]:
            # replace z feature values with x of interest feature values
            # iterate features in current permutation until one before j
            # x-j = [z1, z2, ... zj-1, xj, xj+1, ..., xN]
            # we already have zs because we go row by row with the udf,
            # so replace z_features with x of interest
            f_index = ordered_features.index(f)
            new_value = x_interest[f_index]
            x_plus_j[f_index] = new_value
            if after_j:
                x_minus_j[f_index] = new_value
            after_j = True

        # minus must be first because of lag
        return Vectors.dense(x_minus_j), Vectors.dense(x_plus_j)

    udf_calculate_x = F.udf(calculate_x, T.ArrayType(VectorUDT()))

    # persist before processing
    features_df = features_df.persist()

    for f in feature_names:
        # x column contains x-j and x+j in this order.
        # Because lag is calculated this way:
        # F.col('anomalyScore') - (F.col('anomalyScore') one row before)
        # x-j needs to be first in `x` column array so we should have:
        # id1, [x-j row i,  x+j row i]
        # ...
        # that with explode becomes:
        # id1, x-j row i
        # id1, x+j row i
        # ...
        # to give us (x+j - x-j) when we calculate marginal contribution
        # Note that with explode, x-j and x+j for the same row have the same id
        # This gives us the opportunity to use lag with
        # a window partitioned by id
        x_df = features_df.withColumn('x', udf_calculate_x(
            F.lit(f), features_col, features_perm_col
        )).persist()

        # Calculating SHAP values for f
        x_df = x_df.selectExpr(
            'id', f'explode(x) as {features_col}'
        ).cache()
        x_df = model.transform(x_df).withColumn('probability', get_p1(F.col('probability')))

        # marginal contribution is calculated using a window and a lag of 1.
        # the window is partitioned by id because x+j and x-j for the same row
        # will have the same id
        x_df = x_df.withColumn(
            'marginal_contribution',
            F.col(column_to_examine) - F.lag(F.col(column_to_examine), 1).over(Window.partitionBy('id').orderBy('id'))
        )
        # calculate the average
        x_df = x_df.filter(x_df.marginal_contribution.isNotNull())
        
        feat_shap_value = pd.DataFrame.from_dict({
            'id': [row_of_interest['id']],
            'feature': [f],
            'shap_value': [x_df.select(marginal_contribution_filter).first().shap_value]
        })
        feat_shap_value = spark.createDataFrame(feat_shap_value)
        if print_shap_values:
            print(f'Marginal Contribution for feature: {f} = {x_df.select(marginal_contribution_filter).first().shap_value}')
        
        results = results.union(feat_shap_value)
        break
        
    return (results, x_df)

In [7]:
a, r = calculate_shapley_values(
    spark=spark,
    df = df_predicted,
    id_column='index',
    model = p_fitted,
    row_of_interest = df_predicted.filter(F.col('index')==3).first(),
    feature_names = features,
    features_col='features',
    column_to_examine='probability',
    print_shap_values=True
)
type(a)



Marginal Contribution for feature: age = -0.003291975620736412




pyspark.sql.dataframe.DataFrame

In [8]:
r.show()

+-----+--------------------+--------------------+-----------+----------+---------------------+
|index|            features|       rawPrediction|probability|prediction|marginal_contribution|
+-----+--------------------+--------------------+-----------+----------+---------------------+
| 1697|[3.0,1156.0,48.0,...|[1.53137828714336...| 0.04466992|       0.0|         -0.012413949|
| 2250|[3.0,0.0,91125.0,...|[1.72242596217247...|0.030922757|       0.0|        -0.0135901775|
| 2509|[3.0,400.0,8000.0...|[1.55068795525105...|0.043050535|       0.0|        -2.1320954E-4|
| 2927|[3.0,0.0,48.0,6.0...|[1.82098708167501...|0.025531625|       0.0|         -0.004834337|
| 5556|[3.0,1156.0,39304...|[1.55689123243229...| 0.04254231|       0.0|         -0.002766002|
| 8484|[3.0,1521.0,48.0,...|[1.66132311714578...|0.034802407|       0.0|         -0.009710528|
|10959|[3.0,361.0,6859.0...|[1.55626474703821...|0.042593375|       0.0|        -0.0052779615|
|11625|[3.0,841.0,24389....|[1.14831830698900...| 

In [9]:
r.select('id').distinct().count()

AnalysisException: cannot resolve '`id`' given input columns: [features, index, marginal_contribution, prediction, probability, rawPrediction];
'Project ['id]
+- Filter isnotnull(marginal_contribution#1768)
   +- Project [index#0L, features#1369, rawPrediction#1729, probability#1762, prediction#1747, marginal_contribution#1768]
      +- Project [index#0L, features#1369, rawPrediction#1729, probability#1762, prediction#1747, _we0#1769, (probability#1762 - _we0#1769) AS marginal_contribution#1768]
         +- Window [lag(probability#1762, -1, null) windowspecdefinition(index#0L, index#0L ASC NULLS FIRST, specifiedwindowframe(RowFrame, -1, -1)) AS _we0#1769], [index#0L], [index#0L ASC NULLS FIRST]
            +- Project [index#0L, features#1369, rawPrediction#1729, probability#1762, prediction#1747]
               +- Project [index#0L, features#1369, rawPrediction#1729, <lambda>(probability#1736) AS probability#1762, prediction#1747]
                  +- Project [index#0L, features#1369, rawPrediction#1729, probability#1736, UDF(rawPrediction#1729) AS prediction#1747]
                     +- Project [index#0L, features#1369, rawPrediction#1729, UDF(rawPrediction#1729) AS probability#1736]
                        +- Project [index#0L, features#1369, UDF(features#1369) AS rawPrediction#1729]
                           +- Project [index#0L, features#1369]
                              +- Generate explode(x#898), false, [features#1369]
                                 +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, u74#80, u75#97, features#118, rawPrediction#166, probability#263, prediction#216, features_permutations#449, calculate_x(age, features#118, features_permutations#449) AS x#898]
                                    +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, u74#80, u75#97, features#118, rawPrediction#166, probability#263, prediction#216, shuffle(array(age, age2, age3, educ, educ2, marr, nodegree, black, hisp, re74, re75, u74, u75, educ_re74), Some(-2092057485210354459)) AS features_permutations#449]
                                       +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, u74#80, u75#97, features#118, rawPrediction#166, <lambda>(probability#189) AS probability#263, prediction#216]
                                          +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, u74#80, u75#97, features#118, rawPrediction#166, probability#189, UDF(rawPrediction#166) AS prediction#216]
                                             +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, u74#80, u75#97, features#118, rawPrediction#166, UDF(rawPrediction#166) AS probability#189]
                                                +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, u74#80, u75#97, features#118, UDF(features#118) AS rawPrediction#166]
                                                   +- Sample 0.8, 1.0, false, 12345
                                                      +- Sort [index#0L ASC NULLS FIRST, treat#1 ASC NULLS FIRST, age#2 ASC NULLS FIRST, educ#3 ASC NULLS FIRST, black#4 ASC NULLS FIRST, hisp#5 ASC NULLS FIRST, marr#6 ASC NULLS FIRST, nodegree#7 ASC NULLS FIRST, re74#8 ASC NULLS FIRST, re75#9 ASC NULLS FIRST, re78#10 ASC NULLS FIRST, age2#22 ASC NULLS FIRST, age3#35 ASC NULLS FIRST, educ2#49 ASC NULLS FIRST, educ_re74#64 ASC NULLS FIRST, u74#80 ASC NULLS FIRST, u75#97 ASC NULLS FIRST, features#118 ASC NULLS FIRST], false
                                                         +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, u74#80, u75#97, UDF(struct(age, age#2, age2, age2#22, age3, age3#35, educ, educ#3, educ2, educ2#49, marr, marr#6, nodegree, nodegree#7, black, black#4, hisp, hisp#5, re74, re74#8, re75, re75#9, u74_double_VectorAssembler_e8cf33225a47, cast(u74#80 as double), ... 4 more fields)) AS features#118]
                                                            +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, u74#80, CASE WHEN (re75#9 = cast(0 as double)) THEN 1 ELSE 0 END AS u75#97]
                                                               +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, educ_re74#64, CASE WHEN (re74#8 = cast(0 as double)) THEN 1 ELSE 0 END AS u74#80]
                                                                  +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, educ2#49, (educ#3 * re74#8) AS educ_re74#64]
                                                                     +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, age3#35, POWER(educ#3, cast(2 as double)) AS educ2#49]
                                                                        +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, age2#22, POWER(age#2, cast(3 as double)) AS age3#35]
                                                                           +- Project [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10, POWER(age#2, cast(2 as double)) AS age2#22]
                                                                              +- LogicalRDD [index#0L, treat#1, age#2, educ#3, black#4, hisp#5, marr#6, nodegree#7, re74#8, re75#9, re78#10], false


In [None]:
r.filter('id=3').show()

+---+--------------------+--------------------+-----------+----------+---------------------+
| id|            features|       rawPrediction|probability|prediction|marginal_contribution|
+---+--------------------+--------------------+-----------+----------+---------------------+
|  3|[3.0,2304.0,11059...|[1.54350200272500...|0.043646522|       0.0|                  0.0|
+---+--------------------+--------------------+-----------+----------+---------------------+



In [36]:
df_predicted.select('index').distinct().count()

3314

In [42]:
a.select(F.sum('shap')).collect()[0][0]

0.0021847155001780854

In [19]:
print(df_predicted.select('probability').toPandas().probability.mean() + a.select(F.sum('shap')).collect()[0][0])

print(f'{v}')

0.041614942312058444
0.043646521866321564


In [None]:
print(df_predicted.select('probability').toPandas().probability.mean() + a.select(F.sum('shap')).collect()[0][0])

print(f'{v}')



0.04829068469926582
0.04364077374339104




In [None]:
print(df_predicted.select('probability').toPandas().probability.mean() + a.select(F.sum('shap')).collect()[0][0])

print(f'{v}')



0.052278824448785434
0.043646443635225296




In [84]:
v

0.043646443635225296