# Clustering

In this example we will do some simple cell clustering based on multiband imagery.

## Setup 

First some setup:

In [None]:
from pyrasterframes import *
from pyrasterframes.rasterfunctions import *
import pyspark
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans
from pyspark.ml import Pipeline
from pathlib import Path

spark = SparkSession.builder. \
    master("local[*]"). \
    appName("RasterFrames"). \
    config("spark.ui.enabled", "false"). \
    getOrCreate(). \
    withRasterFrames()
    
resource_dir = Path('./samples').resolve()
# Utility for reading imagery from our test data set
filenamePattern = "L8-B{}-Elkton-VA.tiff"
bandNumbers = range(1, 8)
bandColNames = list(map(lambda n: 'band_{}'.format(n), bandNumbers))

## Loading Data

The first step is to load multiple bands of imagery and construct a single RasterFrame from them.

In [None]:
from functools import reduce
joinedRF = reduce(lambda rf1, rf2: rf1.asRF().spatialJoin(rf2.drop('bounds').drop('metadata')),
                  map(lambda bf: spark.read.geotiff(bf[1]) \
                      .withColumnRenamed('tile', 'band_{}'.format(bf[0])),
                  map(lambda b: (b, resource_dir.joinpath(filenamePattern.format(b)).as_uri()), bandNumbers)))

We should see a single `spatial_key` column along with 4 columns of tiles.

In [None]:
joinedRF.printSchema()

## ML Pipeline 

SparkML requires that each observation be in its own row, and those
observations be packed into a single `Vector`. The first step is to
"explode" the tiles into a single row per cell/pixel.

In [None]:
exploder = TileExploder()

To "vectorize" the the band columns, as required by SparkML, we use the SparkML 
`VectorAssembler`. We then configure our algorithm, create the transformation pipeline,
and train our model. (Note: the selected value of *K* below is arbitrary.) 

In [None]:
assembler = VectorAssembler() \
    .setInputCols(bandColNames) \
    .setOutputCol("features")

# Configure our clustering algorithm
k = 5
kmeans = KMeans().setK(k)

# Combine the two stages
pipeline = Pipeline().setStages([exploder, assembler, kmeans])

# Compute clusters
model = pipeline.fit(joinedRF)

## Model Evaluation

At this point the model can be saved off for later use, or used immediately on the same
data we used to compute the model. First we run the data through the model to assign 
cluster IDs to each cell.

In [None]:
clustered = model.transform(joinedRF)
clustered.show(8)

If we want to inspect the model statistics, the SparkML API requires us to go
through this unfortunate contortion:

In [None]:
clusterResults = list(filter(lambda x: str(x).startswith('KMeans'), model.stages))[0]

Compute sum of squared distances of points to their nearest center:

In [None]:
metric = clusterResults.computeCost(clustered)
print("Within set sum of squared errors: %s" % metric)

## Visualizing Results

The predictions are in a DataFrame with each row representing a separate pixel. 
To assemble a raster to visualize the cluster assignments, we have to go through a
multi-stage process to get the data back in tile form, and from there to combined
raster form.

First, we get the DataFrame back into RasterFrame form:

In [None]:
tlm = joinedRF.tileLayerMetadata()
layout = tlm['layoutDefinition']['tileLayout']

retiled = clustered.groupBy('spatial_key').agg(
    assembleTile('column_index', 'row_index', 'prediction',
        layout['tileCols'], layout['tileRows'], 'int8')
)

In [None]:
rf = retiled.asRF('spatial_key', tlm)
rf.printSchema()
rf.show()

spark.stop()

To render our visualization, we convert to a raster first, and then use an
`IndexedColorMap` to assign each discrete cluster a different color, and finally
rendering to a PNG file.

In [None]:
# TODO:
# toRaster DNE (or at least I can't find it)

raster = rf.toRaster('prediction', 186, 169)

clusterColors = IndexedColorMap.fromColorMap(
  ColorRamps.Viridis.toColorMap([range(0, k)])
)

raster.tile.renderPng(clusterColors).write("target/scala-2.11/tut/ml/clustered.png")

| Color Composite    | Cluster Assignments |
| ------------------ | ------------------- |
| ![](L8-RGB-VA.png) | ![](clustered.png)  |

In [None]:
spark.stop()