# Classification

In this example we will do some simple cell classification based on multiband imagery and a
target/label raster. As a part of the process we'll explore the cross-validation support in
SparkML.

## Setup

First some setup:

In [18]:
from pyrasterframes import *
from pyrasterframes.rasterfunctions import *
import pyspark
from pyspark.sql import SparkSession
from pathlib import Path

spark = SparkSession.builder. \
    master("local[*]"). \
    appName("RasterFrames"). \
    config("spark.ui.enabled", "false"). \
    getOrCreate(). \
    withRasterFrames()

# Utility for reading imagery from our test data set
resource_dir = Path('./samples').resolve()
# Utility for reading imagery from our test data set
filenamePattern = "L8-B{}-Elkton-VA.tiff"
bandNumbers = range(1, 8)
bandColNames = list(map(lambda n: 'band_{}'.format(n), bandNumbers))

def readTiff(name):
    return resource_dir.joinpath(filenamePattern.format(name)).as_uri()

## Loading Data

The first step is to load multiple bands of imagery and construct a single RasterFrame from them.
To do this we:

1. Identify the GeoTIFF filename. 
2. Read the TIFF raster
3. Convert to a raster frame of `tileSize` sized tiles, with an appropriate column name
4. Use the RasterFrames `spatialJoin` function to create a new RasterFrame with a column for each band
 

In [8]:
from functools import reduce
joinedRF = reduce(lambda rf1, rf2: rf1.asRF().spatialJoin(rf2.drop('bounds').drop('metadata')),
                  map(lambda bf: spark.read.geotiff(bf[1]) \
                      .withColumnRenamed('tile', 'band_{}'.format(bf[0])),
                  map(lambda b: (b, resource_dir.joinpath(filenamePattern.format(b)).as_uri()), bandNumbers)))

We should see a single `spatial_key` column along with 6 columns of tiles.

In [9]:
joinedRF.printSchema()

root
 |-- spatial_key: struct (nullable = false)
 |    |-- col: integer (nullable = false)
 |    |-- row: integer (nullable = false)
 |-- bounds: polygon (nullable = true)
 |-- metadata: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = false)
 |-- band_1: rf_tile (nullable = false)
 |-- band_2: rf_tile (nullable = false)
 |-- band_3: rf_tile (nullable = false)
 |-- band_4: rf_tile (nullable = false)
 |-- band_5: rf_tile (nullable = false)
 |-- band_6: rf_tile (nullable = false)
 |-- band_7: rf_tile (nullable = false)



Similarly pull we pull in the target label data. When load the target label raster we have 
to convert the cell type to `Double` to meet expectations of SparkML. 

In [19]:
targetCol = "target"

target = readTiff("L8-%s-Elkton-VA.tiff".format("Labels")). \
  mapTile(_.convert(DoubleConstantNoDataCellType)). \
  projectedRaster. \
  toRF(tileSize, tileSize, targetCol) \
;

AttributeError: 'str' object has no attribute 'mapTile'

Take a peek at what kind of label data we have to work with.

In [None]:
target.select(aggStats(target(targetCol))).show

Join the target label RasterFrame with the band tiles to create our analytics base table

In [None]:
abt = joinedRF.spatialJoin(target)
;

## ML Pipeline

The data preparation modeling pipeline is next. SparkML requires that each observation be in 
its own row, and those observations be packed into a single `Vector` type. The first step is 
to "explode" the tiles into a single row per cell/pixel. Then we filter out any rows that
have `NoData` values (which will cause an error during training). Finally we use the
SparkML `VectorAssembler` to create that `Vector`. 

In [None]:
exploder = new TileExploder()

noDataFilter = new NoDataFilter().
  setInputCols(bandColNames :+ targetCol)

assembler = new VectorAssembler().
  setInputCols(bandColNames).
  setOutputCol("features")
;

We are going to use a decision tree for classification. You can swap out one of the other multi-class
classification algorithms if you like. With the algorithm selected we can assemble our modeling pipeline.

In [None]:
classifier = new DecisionTreeClassifier().
  setLabelCol(targetCol).
  setFeaturesCol(assembler.getOutputCol)

pipeline = new Pipeline().
  setStages(Array(exploder, noDataFilter, assembler, classifier))
;

## Cross Validation

To extend the sophistication of the example we are going to use the SparkML support for 
cross-validation and hyper-parameter tuning. The first step is to configure how we're 
going to evaluate our model's performance. Then we define the hyperparmeter(s) we're going to 
vary and evaluate. Finally we configure the cross validator. 

In [None]:
evaluator = new MulticlassClassificationEvaluator().
  setLabelCol(targetCol).
  setPredictionCol("prediction").
  setMetricName("accuracy")

paramGrid = new ParamGridBuilder().
  addGrid(classifier.maxDepth, Array(2, 3, 4)).
  build()

trainer = new CrossValidator().
  setEstimator(pipeline).
  setEvaluator(evaluator).
  setEstimatorParamMaps(paramGrid).
  setNumFolds(4)
;

Push the "go" button:

In [None]:
val model = trainer.fit(abt)

## Model Evaluation

To view the model's performance we format the `paramGrid` settings used for each model and 
render the parameter/performance association.

In [None]:
metrics = model.getEstimatorParamMaps.
  map(_.toSeq.map(p ⇒ s"${p.param.name} = ${p.value}")).
  map(_.mkString(", ")).
  zip(model.avgMetrics)
;

In [None]:
metrics.toSeq.toDF("params", "metric").show(false)

Finally, we score the original data set (including the cells without target values) and 
add up class membership results.

In [None]:
scored = model.bestModel.transform(joinedRF)

scored.groupBy($"prediction" as "class").count().show

## Visualizing Results

The predictions are in a DataFrame with each row representing a separate pixel. 
To assemble a raster to visualize the class assignments, we have to go through a
multi-stage process to get the data back in tile form, and from there to combined
raster form.

First, we get the DataFrame back into RasterFrame form:

In [None]:
tlm = joinedRF.tileLayerMetadata.left.get

retiled = scored.groupBy($"spatial_key").agg(
  assembleTile(
    $"column_index", $"row_index", $"prediction",
    tlm.tileCols, tlm.tileRows, ByteConstantNoDataCellType
  )
)

rf = retiled.asRF($"spatial_key", tlm)
;

To render our visualization, we convert to a raster first, and then use an
`IndexedColorMap` to assign each discrete class a different color, and finally
rendering to a PNG file.

In [None]:
raster = rf.toRaster($"prediction", 186, 169)

clusterColors = IndexedColorMap.fromColorMap(
  ColorRamps.Viridis.toColorMap((0 until 3).toArray)
)

raster.tile.renderPng(clusterColors).write("target/scala-2.11/tut/ml/classified.png")
;

| Color Composite    | Target Labels          | Class Assignments   |
| ------------------ | ---------------------- | ------------------- |
| ![](L8-RGB-VA.png) | ![](target-labels.png) | ![](classified.png) |

In [1]:
```tut:invisible
raster = SinglebandGeoTiff("../core/src/test/resources/L8-Labels-Elkton-VA.tiff").raster

k = raster.findMinMax._2

clusterColors = IndexedColorMap.fromColorMap(
  ColorRamps.Viridis.toColorMap((0 to k).toArray)
)

raster.tile.renderPng(clusterColors).write("target/scala-2.11/tut/ml/target-labels.png")

spark.stop()
```

SyntaxError: invalid syntax (<ipython-input-1-4f87b3a3d579>, line 1)