Skip to content
This repository has been archived by the owner on Nov 15, 2020. It is now read-only.

Commit

Permalink
add implicit method to convert spark dataset to smile sparsedataset
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrenodet committed Jul 26, 2019
1 parent 5bd982b commit 0771da7
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 5 deletions.
15 changes: 15 additions & 0 deletions README.md
Expand Up @@ -47,6 +47,21 @@ val (x, y) = mushrooms.unzipInt
sparkgscv(x, y, 5, Seq(new Accuracy().asInstanceOf[ClassificationMeasure]): _*) { (x, y) => knn(x, y, 3) }
```

**From Spark Dataset to SMILE SparseDataset**

```scala
import org.apache.spark.smile.implicits._

val spark = SparkSession.builder().master("local[*]").getOrCreate()

val mushrooms = spark.read.format("libsvm").load("data/mushrooms.svm")

val (x,y) = mushrooms.toSmileDataset("features","label").unzipInt

val res = cv(x, y, 5, Seq(new Accuracy().asInstanceOf[ClassificationMeasure]): _*) { (x, y) => knn(x, y, 3) }
```


## Contributing

Feel free to open an issue or make a pull request to contribute to the repository.
Expand Down
58 changes: 58 additions & 0 deletions src/main/scala/org/apache/spark/smile/implicits/package.scala
@@ -0,0 +1,58 @@
package org.apache.spark.smile

import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Row, Dataset => SparkDataset}
import smile.data.{NominalAttribute, SparseDataset}

package object implicits {

implicit class BetterSmileDataset(dataset: SparkDataset[_]) {

def toSmileDataset(
featuresColName: String = "features",
labelColName: String = "label",
weightColName: String = "weight"): SparseDataset = {
val classification =
dataset.select(labelColName).take(20).forall { case Row(x: Double) => (x % 1) == 0 }
val minClass =
if (classification)
dataset.select(labelColName).agg(min(labelColName)).head.getDouble(0).toInt
else 0
val res =
if (classification) new SparseDataset(new NominalAttribute("class"))
else new SparseDataset(new NominalAttribute("response"))
if (dataset.columns.contains(weightColName)) {
dataset
.select(Array(featuresColName, labelColName, weightColName).map(col): _*)
.collect()
.toList
.zipWithIndex
.foreach {
case (Row(features: Vector, label: Double, weight: Double), i: Int) =>
features.toArray.toList.zipWithIndex.foreach {
case (x: Double, j: Int) => res.set(i, j, x)
}
if (classification) res.set(i, label.toInt - minClass, weight)
else res.set(i, label, weight)
}
} else {
dataset
.select(Array(featuresColName, labelColName).map(col): _*)
.collect()
.zipWithIndex
.foreach {
case (Row(features: Vector, label: Double), i: Int) =>
features.toArray.zipWithIndex.foreach {
case (x: Double, j: Int) => res.set(i, j, x)
}
if (classification) res.set(i, label.toInt - minClass) else res.set(i, label)
}
}
res

}

}

}
Expand Up @@ -8,7 +8,7 @@ import scala.reflect.ClassTag

case class SerializableClassificationMeasure(@transient measure: ClassificationMeasure)

trait SparkOperators {
trait Operators {

def sparkgscv[T <: Object: ClassTag](
x: Array[T],
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/smile/tuning/package.scala
@@ -1,3 +1,3 @@
package smile

package object tuning extends SparkOperators {}
package object tuning extends Operators {}
25 changes: 25 additions & 0 deletions src/test/scala/org/apache/spark/smile/ImplicitSuite.scala
@@ -0,0 +1,25 @@
package org.apache.spark.smile

import com.holdenkarau.spark.testing.DatasetSuiteBase
import org.scalatest.FunSuite
import smile.classification.knn
import smile.data._
import smile.validation.{Accuracy, ClassificationMeasure, _}

class ImplicitSuite extends FunSuite with DatasetSuiteBase {

test("toSmileDataset") {

import org.apache.spark.smile.implicits._

val mushrooms = spark.read.format("libsvm").load("data/mushrooms.svm")

val (x,y) = mushrooms.toSmileDataset().unzipInt

val res = cv(x, y, 5, Seq(new Accuracy().asInstanceOf[ClassificationMeasure]): _*) { (x, y) => knn(x, y, 3) }

assert(res(0) == 1)

}

}
Expand Up @@ -3,16 +3,16 @@ package smile.tuning
import com.holdenkarau.spark.testing.DatasetSuiteBase
import org.apache.spark.sql.SparkSession
import org.scalatest.FunSuite
import smile.classification.{Classifier, _}
import smile.classification._
import smile.data._
import smile.read
import smile.validation._

class SparkOperatorsSuite extends FunSuite with DatasetSuiteBase {
class OperatorsSuite extends FunSuite with DatasetSuiteBase {

test("sparkgscv") {

implicit val sparkImplicit = spark
implicit val sparkImplicit: SparkSession = spark

val mushrooms = read.libsvm("data/mushrooms.svm")
val (x, y) = mushrooms.unzipInt
Expand Down

0 comments on commit 0771da7

Please sign in to comment.