In [3]:
%%configure -f
{
    "name": "synapseml",
    "conf": {
        "spark.jars.packages": "com.microsoft.azure:synapseml_2.12:0.9.5",
        "spark.jars.excludes": "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalactic:scalactic_2.12,org.scalatest:scalatest_2.12,com.azure:azure-core",
        "spark.yarn.user.classpath.first": "true"
    }
}

StatementMeta(, , , SessionStarting, )

In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
import pyspark.sql.functions as F
from pyspark.ml.evaluation import BinaryClassificationEvaluator

from synapse.ml.explainers import ICETransformer

import matplotlib.pyplot as plt

StatementMeta(, , , Waiting, )

In [None]:
df = spark.sql("SELECT * FROM RAIDEMO")
display(df)

StatementMeta(, , , Waiting, )

In [None]:
df = df.select('WarehouseID','ItemID','TotalYTDSales','CurrentInventory',
        'AmountToOrder', 'AverageWeekWastage', 'IsPerishable', 'HasFlatSellingRate', 'IsSeasonSensitive', 
        'MonthNumber','WeekNumberMonth','WeekNumberYear', 'TotalMembersInCurrentCity','AverageOnlineOrdAmtInPastWeeks')

StatementMeta(, , , Waiting, )

In [None]:
categorical_features = ['WarehouseID','ItemID', 'IsPerishable', 'HasFlatSellingRate', 'IsSeasonSensitive', 
        'MonthNumber','WeekNumberMonth','WeekNumberYear']
numeric_features = ['TotalYTDSales','CurrentInventory','AverageWeekWastage', 'TotalMembersInCurrentCity','AverageOnlineOrdAmtInPastWeeks']

StatementMeta(, , , Waiting, )

In [None]:
from synapse.ml.exploratory import AggregateBalanceMeasure, DistributionBalanceMeasure, FeatureBalanceMeasure
import pyspark.sql.functions as F
from pyspark.sql.types import StringType,IntegerType,BooleanType

df.withColumn("IsPerishable",df.IsPerishable.cast(StringType()))

cols = ["HasFlatSellingRate", "IsSeasonSensitive", "IsPerishable"]
for col2 in cols:
    df = df.withColumn(
        col2, 
        F.when(
            F.col(col2) == 'True',
            '1'
        ).when(
            F.col(col2) == 'False',
            '0'
        ).otherwise(F.col(col2).cast('string'))
    )


StatementMeta(, , , Waiting, )

In [None]:
from synapse.ml.exploratory import AggregateBalanceMeasure, DistributionBalanceMeasure, FeatureBalanceMeasure
import pyspark.sql.functions as F
from pyspark.sql.types import StringType,IntegerType,BooleanType

df.withColumn("AmountToOrder", df.AmountToOrder.cast(IntegerType()))

StatementMeta(, , , Waiting, )

In [None]:
df = df.withColumn("AmountToOrder", F.when(F.col("AmountToOrder") < 6600, F.lit(0)).otherwise(F.lit(1))).drop('id')
df.head(10)

StatementMeta(, , , Waiting, )

In [None]:
string_indexer_outputs = [feature + "_idx" for feature in categorical_features]
one_hot_encoder_outputs = [feature + "_enc" for feature in categorical_features]

pipeline = Pipeline(stages=[
    StringIndexer().setInputCol('AmountToOrder').setOutputCol("label"), #.setStringOrderType("alphabetAsc")
    StringIndexer().setInputCols(categorical_features).setOutputCols(string_indexer_outputs),
    OneHotEncoder().setInputCols(string_indexer_outputs).setOutputCols(one_hot_encoder_outputs),
    VectorAssembler(inputCols=one_hot_encoder_outputs+numeric_features, outputCol="features"),
    GBTClassifier( maxDepth=7, maxIter=10)]) #weightCol="fnlwgt",

model = pipeline.fit(df)

StatementMeta(, , , Waiting, )

In [None]:
data = model.transform(df)
display(data.select('AmountToOrder', 'probability', 'prediction'))

StatementMeta(, , , Waiting, )

In [None]:
eval_auc = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="prediction")
eval_auc.evaluate(data)

StatementMeta(, , , Waiting, )

In [None]:
pdp = ICETransformer(model=model, targetCol="probability", kind="average", targetClasses=[1],
                     categoricalFeatures=categorical_features, numericFeatures=numeric_features, numSamples = 500000)

StatementMeta(, , , Waiting, )

In [None]:
# Disable BroadcastHashJoin, so Spark will use standard SortMergeJoin. Currently, Hyperspace indexes utilize SortMergeJoin to speed up query.
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Verify that BroadcastHashJoin is set correctly 
print(spark.conf.get("spark.sql.autoBroadcastJoinThreshold"))

StatementMeta(, , , Waiting, )

In [None]:
output_pdp = pdp.transform(df)
display(output_pdp)
output_pdp.cache()

StatementMeta(, , , Waiting, )

In [None]:
# Helper functions for visualization

def get_pandas_df_from_column(df, col_name):
  keys_df = df.select(F.explode(F.map_keys(F.col(col_name)))).distinct()
  keys = list(map(lambda row: row[0], keys_df.collect()))
  key_cols = list(map(lambda f: F.col(col_name).getItem(f).alias(str(f)), keys))
  final_cols = key_cols
  pandas_df = df.select(final_cols).toPandas()
  return pandas_df

def plot_dependence_for_categorical(df, col, col_int=True, figsize=(20, 5)):
  dict_values = {}
  col_names = list(df.columns)

  for col_name in col_names:
    dict_values[col_name] = df[col_name][0].toArray()[0]
    marklist= sorted(dict_values.items(), key=lambda x: int(x[0]) if col_int else x[0]) 
    sortdict=dict(marklist)

  fig = plt.figure(figsize = figsize)
  plt.bar(sortdict.keys(), sortdict.values())

  plt.xlabel(col, size=13)
  plt.ylabel("Dependence")
  plt.show()
  
def plot_dependence_for_numeric(df, col, col_int=True, figsize=(20, 5)):
  dict_values = {}
  col_names = list(df.columns)

  for col_name in col_names:
    dict_values[col_name] = df[col_name][0].toArray()[0]
    marklist= sorted(dict_values.items(), key=lambda x: int(x[0]) if col_int else x[0]) 
    sortdict=dict(marklist)

  fig = plt.figure(figsize = figsize)

  
  plt.plot(list(sortdict.keys()), list(sortdict.values()))

  plt.xlabel(col, size=13)
  plt.ylabel("Dependence")
  plt.ylim(0.0)
  plt.show()
  

StatementMeta(, , , Waiting, )

In [None]:
df_education_num = get_pandas_df_from_column(output_pdp, 'WarehouseID_dependence')
plot_dependence_for_numeric(df_education_num, 'WarehouseID')