# Classification

In this example we will do some simple cell classification based on multiband imagery and a
target/label raster. As a part of the process we'll explore the cross-validation support in
SparkML.

## Setup

First some setup:

In [3]:
from pyrasterframes import *
from pyrasterframes.rasterfunctions import *
from pyrasterframes.types import NoDataFilter
import pyspark
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline
from pathlib import Path

spark = SparkSession.builder. \
    master("local[*]"). \
    appName("RasterFrames"). \
    config("spark.ui.enabled", "false"). \
    getOrCreate(). \
    withRasterFrames()

# Utility for reading imagery from our test data set
resource_dir = Path('./samples').resolve()
# Utility for reading imagery from our test data set
filenamePattern = "L8-B{}-Elkton-VA.tiff"
bandNumbers = range(1, 8)
bandColNames = list(map(lambda n: 'band_{}'.format(n), bandNumbers))

def readTiff(name):
    return resource_dir.joinpath(filenamePattern.format(name)).as_uri()

## Loading Data

The first step is to load multiple bands of imagery and construct a single RasterFrame from them.
To do this we:

1. Identify the GeoTIFF filename. 
2. Read the TIFF raster
3. Convert to a raster frame of `tileSize` sized tiles, with an appropriate column name
4. Use the RasterFrames `spatialJoin` function to create a new RasterFrame with a column for each band
 

In [4]:
from functools import reduce
joinedRF = reduce(lambda rf1, rf2: rf1.asRF().spatialJoin(rf2.drop('bounds').drop('metadata')),
                  map(lambda bf: spark.read.geotiff(bf[1]) \
                      .withColumnRenamed('tile', 'band_{}'.format(bf[0])),
                  map(lambda b: (b, readTiff(b)), bandNumbers)))

We should see a single `spatial_key` column along with 6 columns of tiles.

In [5]:
joinedRF.printSchema()

root
 |-- spatial_key: struct (nullable = false)
 |    |-- col: integer (nullable = false)
 |    |-- row: integer (nullable = false)
 |-- bounds: polygon (nullable = true)
 |-- metadata: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = false)
 |-- band_1: rf_tile (nullable = false)
 |-- band_2: rf_tile (nullable = false)
 |-- band_3: rf_tile (nullable = false)
 |-- band_4: rf_tile (nullable = false)
 |-- band_5: rf_tile (nullable = false)
 |-- band_6: rf_tile (nullable = false)
 |-- band_7: rf_tile (nullable = false)



Similarly pull we pull in the target label data.

In [6]:
targetCol = "target"

target = spark.read.geotiff(resource_dir.joinpath("L8-Labels-Elkton-VA.tiff").as_uri()).withColumnRenamed('tile', targetCol)

Take a peek at what kind of label data we have to work with.

In [7]:
target.select(aggStats("target")).show(1, False)

+----------------------------------------------------+
|aggStats(target)                                    |
+----------------------------------------------------+
|[1478,0.0,2.0,0.8017591339648173,0.2780212626872066]|
+----------------------------------------------------+



Join the target label RasterFrame with the band tiles to create our analytics base table

In [22]:
abt = joinedRF.spatialJoin(target).drop('bounds', 'metadata').asRF()
abt_1_2 = abt.withColumn("band_8", localAdd("band_1", "band_2"))
abt_1_2.printSchema()

root
 |-- spatial_key: struct (nullable = false)
 |    |-- col: integer (nullable = false)
 |    |-- row: integer (nullable = false)
 |-- band_1: rf_tile (nullable = false)
 |-- band_2: rf_tile (nullable = false)
 |-- band_3: rf_tile (nullable = false)
 |-- band_4: rf_tile (nullable = false)
 |-- band_5: rf_tile (nullable = false)
 |-- band_6: rf_tile (nullable = false)
 |-- band_7: rf_tile (nullable = false)
 |-- target: rf_tile (nullable = false)
 |-- band_8: rf_tile (nullable = true)



In [7]:
def append(lst, elem):
    return tuple(lst + [elem])

append(bandColNames, targetCol)

('band_1',
 'band_2',
 'band_3',
 'band_4',
 'band_5',
 'band_6',
 'band_7',
 'target')

## ML Pipeline

The data preparation modeling pipeline is next. SparkML requires that each observation be in 
its own row, and those observations be packed into a single `Vector` type. The first step is 
to "explode" the tiles into a single row per cell/pixel. Then we filter out any rows that
have `NoData` values (which will cause an error during training). Finally we use the
SparkML `VectorAssembler` to create that `Vector`. 

In [27]:
exploder = TileExploder()

def append(lst, elem):
    return (lst + [elem])

noDataFilter = NoDataFilter()
noDataFilter.setInputCols(append(bandColNames, targetCol))

assembler = VectorAssembler()
assembler.setInputCols(bandColNames). \
  setOutputCol("features")

VectorAssembler_46158468bf8cd2503df3

We are going to use a decision tree for classification. You can swap out one of the other multi-class
classification algorithms if you like. With the algorithm selected we can assemble our modeling pipeline.

In [72]:
classifier = DecisionTreeClassifier()

classifier.setLabelCol("target"). \
  setFeaturesCol(assembler.getOutputCol())

pipeline = Pipeline() 
pipeline.setStages([exploder, noDataFilter, assembler, classifier])

Pipeline_422ab243e3a21c1c71bc

## Cross Validation

To extend the sophistication of the example we are going to use the SparkML support for 
cross-validation and hyper-parameter tuning. The first step is to configure how we're 
going to evaluate our model's performance. Then we define the hyperparmeter(s) we're going to 
vary and evaluate. Finally we configure the cross validator. 

In [73]:
evaluator = MulticlassClassificationEvaluator()
evaluator.setLabelCol("target"). \
  setPredictionCol("prediction"). \
  setMetricName("accuracy") 

paramGrid = ParamGridBuilder().addGrid(classifier.maxDepth, [2, 3, 4]).build()

trainer = CrossValidator()
trainer.setEstimator(pipeline). \
  setEvaluator(evaluator). \
  setEstimatorParamMaps(paramGrid). \
  setNumFolds(4)

CrossValidator_4a9791d71e79fc24daa4

Push the "go" button:

In [74]:
model = trainer.fit(abt)

Py4JJavaError: An error occurred while calling o10370.fit.
: java.lang.NullPointerException: Value at index 0 is null
	at org.apache.spark.sql.Row$class.getAnyValAs(Row.scala:472)
	at org.apache.spark.sql.Row$class.getDouble(Row.scala:248)
	at org.apache.spark.sql.catalyst.expressions.GenericRow.getDouble(rows.scala:165)
	at org.apache.spark.ml.classification.Classifier.getNumClasses(Classifier.scala:115)
	at org.apache.spark.ml.classification.DecisionTreeClassifier.train(DecisionTreeClassifier.scala:102)
	at org.apache.spark.ml.classification.DecisionTreeClassifier.train(DecisionTreeClassifier.scala:45)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:118)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:82)
	at sun.reflect.GeneratedMethodAccessor122.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:748)


In [83]:
abt.show(0, False)

+-----------+------+------+------+------+------+------+------+---------+
|spatial_key|band_1|band_2|band_3|band_4|band_5|band_6|band_7|targetCol|
+-----------+------+------+------+------+------+------+------+---------+
+-----------+------+------+------+------+------+------+------+---------+
only showing top 0 rows



## DeleteThisHeader 
Model Evaluation

To view the model's performance we format the `paramGrid` settings used for each model and 
render the parameter/performance association.

In [None]:
metrics = model.getEstimatorParamMaps.
  map(_.toSeq.map(p ⇒ s"${p.param.name} = ${p.value}")).
  map(_.mkString(", ")).
  zip(model.avgMetrics)
;

In [None]:
metrics.toSeq.toDF("params", "metric").show(false)

Finally, we score the original data set (including the cells without target values) and 
add up class membership results.

In [None]:
scored = model.bestModel.transform(joinedRF)

scored.groupBy("prediction" as "class").count().show

## Visualizing Results

The predictions are in a DataFrame with each row representing a separate pixel. 
To assemble a raster to visualize the class assignments, we have to go through a
multi-stage process to get the data back in tile form, and from there to combined
raster form.

First, we get the DataFrame back into RasterFrame form:

In [None]:
tlm = joinedRF.tileLayerMetadata.left.get

retiled = scored.groupBy($"spatial_key").agg(
  assembleTile(
    $"column_index", $"row_index", $"prediction",
    tlm.tileCols, tlm.tileRows, ByteConstantNoDataCellType
  )
)

rf = retiled.asRF($"spatial_key", tlm)

To render our visualization, we convert to a raster first, and then use an
`IndexedColorMap` to assign each discrete class a different color, and finally
rendering to a PNG file.

In [None]:
raster = rf.toRaster($"prediction", 186, 169)

clusterColors = IndexedColorMap.fromColorMap(
  ColorRamps.Viridis.toColorMap((0 until 3).toArray)
)

raster.tile.renderPng(clusterColors).write("target/scala-2.11/tut/ml/classified.png")
;

| Color Composite    | Target Labels          | Class Assignments   |
| ------------------ | ---------------------- | ------------------- |
| ![](L8-RGB-VA.png) | ![](target-labels.png) | ![](classified.png) |

In [None]:
raster = SinglebandGeoTiff("../core/src/test/resources/L8-Labels-Elkton-VA.tiff").raster

k = raster.findMinMax._2

clusterColors = IndexedColorMap.fromColorMap(
  ColorRamps.Viridis.toColorMap((0 to k).toArray)
)

raster.tile.renderPng(clusterColors).write("target/scala-2.11/tut/ml/target-labels.png")

In [1]:
spark.stop()

SyntaxError: invalid syntax (<ipython-input-1-4f87b3a3d579>, line 1)

In [20]:
noDataFilter = NoDataFilter()
noDataFilter.setInputCols(['val'])

In [68]:
vars(trainer)

{'_defaultParamMap': {Param(parent='CrossValidator_4d5dac9f1699990affc4', name='numFolds', doc='number of folds for cross validation'): 3,
  Param(parent='CrossValidator_4d5dac9f1699990affc4', name='seed', doc='random seed.'): 5002524630791503526},
 '_input_kwargs': {},
 '_paramMap': {Param(parent='CrossValidator_4d5dac9f1699990affc4', name='estimator', doc='estimator to be cross-validated'): Pipeline_48d7abb4dc822d1a4300,
  Param(parent='CrossValidator_4d5dac9f1699990affc4', name='estimatorParamMaps', doc='estimator param maps'): [{Param(parent='DecisionTreeClassifier_4c1fa0e265b3e1159afd', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 2},
   {Param(parent='DecisionTreeClassifier_4c1fa0e265b3e1159afd', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 3},
   {Param(parent='DecisionTreeClassifier_4c1fa0e265b3