<h1>K-means with Rheem <div style="float:right; z-index:1"><img src="rheem.png" width="100px" /></div></h1>

This notebook demonstrates how to run k-means, a popular clustering algorithm. In particular, iterations and DB access are being demonstrated here. To run this notebook, you will need the [Jupyter Scala kernel](https://github.com/alexarchambault/jupyter-scala).

<h3>Rheem plan (abstract)</h3>
<img src="img/k-means-abstract.png" style="float:center" width="500"/>

At first, we intialize Rheem.

In [1]:
// Disable logging.
import $ivy.`org.slf4j:slf4j-nop:1.7.12`
org.slf4j.LoggerFactory.getLogger("root").info("Enforcing slf4j-nop...")

[32mimport [39m[36m$ivy.$                           
[39m

In [2]:
// Load dependencies into the kernel.
import $ivy.`org.qcri.rheem::rheem-api:0.3.0`,
    $ivy.`org.qcri.rheem:rheem-basic:0.3.0`,
    $ivy.`org.qcri.rheem:rheem-java:0.3.0`,
    $ivy.`org.qcri.rheem::rheem-spark:0.3.0`,
    $ivy.`org.qcri.rheem:rheem-sqlite3:0.3.0`,
    $ivy.`org.apache.spark::spark-core:1.6.0`,
    $ivy.`org.apache.spark::spark-graphx:1.6.0`,
    $ivy.`de.hpi.isg:profiledb-store:0.1.1`,
    $ivy.`com.github.sekruse::spark-summit-demo:1.0-SNAPSHOT`

// Do the relevant imports.
import org.qcri.rheem.api._
import org.qcri.rheem.core.api._
import org.qcri.rheem.core.function._
import org.qcri.rheem.core.optimizer.ProbabilisticDoubleInterval
import org.qcri.rheem.java.Java, org.qcri.rheem.spark.Spark, org.qcri.rheem.sqlite3.Sqlite3
import de.hpi.isg.profiledb.store.model._
import com.github.sekruse.spark_summit_demo._
import scala.collection.JavaConversions._

// Set up a Rheem context.
val localDir = new java.io.File(".").getAbsoluteFile
val config = new Configuration(s"file://$localDir/rheem.properties")

[32mimport [39m[36m$ivy.$                                ,
    $ivy.$                                 ,
    $ivy.$                                ,
    $ivy.$                                  ,
    $ivy.$                                   ,
    $ivy.$                                   ,
    $ivy.$                                     ,
    $ivy.$                                 ,
    $ivy.$                                                   

// Do the relevant imports.
[39m
[32mimport [39m[36morg.qcri.rheem.api._
[39m
[32mimport [39m[36morg.qcri.rheem.core.api._
[39m
[32mimport [39m[36morg.qcri.rheem.core.function._
[39m
[32mimport [39m[36morg.qcri.rheem.core.optimizer.ProbabilisticDoubleInterval
[39m
[32mimport [39m[36morg.qcri.rheem.java.Java, org.qcri.rheem.spark.Spark, org.qcri.rheem.sqlite3.Sqlite3
[39m
[32mimport [39m[36mde.hpi.isg.profiledb.store.model._
[39m
[32mimport [39m[36mcom.github.sekruse.spark_summit_demo._
[39m
[32mimport [39m[36mscala

Next, we generate an input database.

In [3]:
locally {
    import java.io.File
    
    val file = new File("data/locations.db")
    if (!file.exists) {
        file.getParentFile.mkdirs()
        generateKMeansData(path = file.getPath, k = 20, points = 10000)
    }
    
    config.setProperty("rheem.sqlite3.jdbc.url", s"jdbc:sqlite:$file")
}

If this notebook is run in an offline environment, run the `run-webserver.sh` script to provide the required JS libraries.

In [4]:
val offline = true
if (offline) {
    addModule("plotly", "http://localhost:8888/files/js/plotly-latest.min")
    addModule("d3", "http://localhost:8888/files/js/d3.v4.min")
    config.setProperty("spark.driver.host", "localhost")
}

[36moffline[39m: [32mBoolean[39m = [32mtrue[39m

Now, we can run k-means.

In [10]:
locally {
    val experiment = new Experiment("my-exp", new Subject("k-means", "1.0"))
    val k = 15
    val iterations = 20
    val rheemCtx = new RheemContext(config)
        //.withPlugin(Java.basicPlugin)
        .withPlugin(Java.channelConversionPlugin)
        .withPlugin(Spark.basicPlugin)
        .withPlugin(Sqlite3.plugin)
    
    // Define data types to handle k-means neatly.
    trait PointLike {
      def x: Double
      def y: Double
    }

    case class Point(x: Double, y: Double) extends PointLike {
      def distanceTo(that: PointLike) = {
        val dx = this.x - that.x
        val dy = this.y - that.y
        math.sqrt(dx * dx + dy * dy)
      }
      override def toString: String = f"($x%.2f, $y%.2f)"
    }

    case class TaggedPoint(x: Double, y: Double, centroidId: Int) extends PointLike {
      def toPoint = Point(x, y)
    }

    case class TaggedPointCounter(x: Double, y: Double, centroidId: Int, count: Int = 1) extends PointLike {
      def this(point: PointLike, centroidId: Int, count: Int) = this(point.x, point.y, centroidId, count)
      def +(that: TaggedPointCounter) = TaggedPointCounter(
          this.x + that.x, this.y + that.y, this.centroidId, this.count + that.count
      )
      def average = TaggedPoint(x / count, y / count, centroidId)
    }

    // Set up a new plan.
    val planBuilder = new PlanBuilder(rheemCtx)
        .withJobName("k-means")
        .withUdfJarsOf(this.getClass)
        .withExperiment(experiment)
    
    // Read and parse the input file(s).
    import org.qcri.rheem.sqlite3.operators._
    val points = planBuilder
        .readTable(new Sqlite3TableSource("locations", "lat", "lon", "description"))
        .withName("Read table")
    
        .projectRecords(Seq("lat", "lon"))
        .withName("Project coordinates")
    
        .map(record => Point(record.getDouble(0), record.getDouble(1)))
        .withName("Create points")

    // Create initial centroids.
    def createRandomCentroids(n: Int): Seq[TaggedPoint] =
        for (i <- 0 until k) yield TaggedPoint(math.random * 180, math.random * 180, i)

    val initialCentroids = planBuilder
        .loadCollection(createRandomCentroids(k))
        .withName("Load random centroids")

    // Do the k-means loop.
    class SelectNearestCentroid extends FunctionDescriptor.ExtendedSerializableFunction[Point, TaggedPointCounter] {

      var centroids: Iterable[TaggedPoint] = _
      override def open(executionCtx: ExecutionContext) = {
        centroids = executionCtx.getBroadcast[TaggedPoint]("centroids")
      }

      override def apply(point: Point): TaggedPointCounter = {
        var minDistance = Double.PositiveInfinity
        var nearestCentroidId = -1
        for (centroid <- centroids) {
          val distance = point.distanceTo(centroid)
          if (distance < minDistance) {
            minDistance = distance
            nearestCentroidId = centroid.centroidId
          }
        }
        new TaggedPointCounter(point, nearestCentroidId, 1)
      }
    }
    
    val finalCentroids = initialCentroids.repeat(iterations, { currentCentroids =>
        points
            .mapJava(new SelectNearestCentroid)
            .withBroadcast(currentCentroids, "centroids").withName("Find nearest centroid")
        
            .reduceByKey(_.centroidId, _ + _).withName("Add up points")
            .withCardinalityEstimator(k)
                    
            .map(_.average)
            .withName("Average points")
                                                                                                                    
            .keyBy(_.centroidId).coGroup(currentCentroids.keyBy(_.centroidId))
            .withName("Co-group with old centroids")
            .map(coGroup => if (coGroup.field0.isEmpty) coGroup.field1.head else coGroup.field0.head)
            .withName("Re-insert lost centroids")
            
    }).withName("Loop")

    // Collect the result.
    val c = finalCentroids
      .map(_.toPoint).withName("Strip centroid names")
      .collect()
    publish.html(s"<h1>Result</h1> Collected ${c.head} and ${c.size - 1} more centroids.")
    
    publish.html("<h1>Execution plan</h1>")
    plotExecutionPlan(experiment)
}