In [None]:
import os
os.environ['JDBC_HOST'] = 'jrtest01-splice-hregion'

In [None]:
# setup-- 
import os
import pyspark
from splicemachine.spark.context import PySpliceContext
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
!pip install plotly

# make sure pyspark tells workers to use python3 not 2 if both are installed
os.environ['PYSPARK_PYTHON'] = '/usr/bin/python3'
jdbc_host = os.environ['JDBC_HOST']

conf = pyspark.SparkConf()
sc = pyspark.SparkContext(conf=conf)

spark = SparkSession.builder.config(conf=conf).getOrCreate()

splicejdbc=f"jdbc:splice://{jdbc_host}:1527/splicedb;user=splice;password=admin"

splice = PySpliceContext(spark, splicejdbc)


# Decision Trees
The [Decision Tree](https://spark.apache.org/docs/latest/mllib-decision-tree.html) is a greedy algorithm that performs a recursive binary partitioning of the feature space for predictive modeling. 
* Locally optimal decisions are made at each node in hopes of a globally optimal decision
* Because of it's greedy nature, it cannot guarantee the globally optimal tree

At its core (and most simplified), decision trees are simply a system of if-else statements, always taking the most optimal answer, resulting in (hopefully) the most optimal decision. See [here](http://mines.humanoriented.com/classes/2010/fall/csci568/portfolio_exports/lguo/image/decisionTree/classification.jpg)

The example below demonstrates how to load a [LIBSVM](https://github.com/apache/spark/blob/master/data/mllib/sample_libsvm_data.txt) data file, parse it as an RDD of LabeledPoint and then perform classification using a decision tree with Gini impurity as an impurity measure and a maximum tree depth of 5. The test error is calculated to measure the algorithm accuracy. For more information, check out Spark's Decision Tree [page](https://spark.apache.org/docs/latest/mllib-decision-tree.html)


In [None]:
%%scala 
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "/Users/benepstein/Desktop/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification tree model:\n" + model.toDebugString)

// Save and load model
// model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
// val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")

## And in PySpark

In [None]:
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml import Pipeline
import plotly.express as px

data = spark.createDataFrame(px.data.iris()).drop('species_id')

# Convert species column into int type
si = StringIndexer(inputCol='species', outputCol='species_vec')

# Create a vector of features
cols = [c for c in data.columns if c != 'species']
va = VectorAssembler(inputCols=cols, outputCol='features')

# Define stages of a Pipeline for Spark
pipeline = Pipeline(stages = [si, va])

data = pipeline.fit(data).transform(data)

# Show the final dataset
data.orderBy('sepal_width').show()

## Let's visualize our data with [Plotly](https://plot.ly/python/):
* X axis will be petal_length
* Y axis will be sepal_width
* Z axis will sepal_length
* Datapoint size will be petal_width
* Color will be species type (versicolor, virginica, setosa)

### In the next cell, change any of the variables in the plot function to see a new chart layout. Trying different combinations can give you new insight into the data!

In [None]:
# Hover over any datapoint to get it's exact dimensions
px.scatter_3d(data.toPandas(), x='petal_length', y='sepal_width', z='sepal_length', size='petal_width', color='species')

## Now we can create our Decision Tree to predict which species it is based on its sepal_length, sepal_width, petal_length, and petal_width

In [None]:
from pyspark.ml.classification import DecisionTreeClassifier
from splicemachine.ml.utilities import SpliceMultiClassificationEvaluator

# The data has already been preprocessed above into a feature vector called "features"
# Create the decision tree
dt = DecisionTreeClassifier(labelCol='species_vec', featuresCol='features', maxDepth=20)

# Split our dataset into training and testing
train, test = data.randomSplit([0.8,0.2])

# Train on our training data
model = dt.fit(train)
# Make predictions
predictions = model.transform(test)

predictions.select(['features','species','species_vec','prediction']).show()

# Evaluate results
e = SpliceMultiClassificationEvaluator(spark, label_column='species_vec')
e.input(predictions)
results = e.get_results(dict=True)

# That's quite a good Decision Tree! 
### Let's see if we can get a better understanding for it

In [None]:
from splicemachine.ml.utilities import DecisionTreeVisualizer as dtv
import pprint

print(dtv.visualize(model, ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'], ['versicolor', 'virginica' ,'setosa'],'First_Decision_Tree', visual=False))