diff --git a/README.md b/README.md index 7a4806d..d4583ab 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,21 @@ val (x, y) = mushrooms.unzipInt sparkgscv(x, y, 5, Seq(new Accuracy().asInstanceOf[ClassificationMeasure]): _*) { (x, y) => knn(x, y, 3) } ``` +**From Spark Dataset to SMILE SparseDataset** + +```scala +import org.apache.spark.smile.implicits._ + +val spark = SparkSession.builder().master("local[*]").getOrCreate() + +val mushrooms = spark.read.format("libsvm").load("data/mushrooms.svm") + +val (x,y) = mushrooms.toSmileDataset("features","label").unzipInt + +val res = cv(x, y, 5, Seq(new Accuracy().asInstanceOf[ClassificationMeasure]): _*) { (x, y) => knn(x, y, 3) } +``` + + ## Contributing Feel free to open an issue or make a pull request to contribute to the repository. diff --git a/src/main/scala/org/apache/spark/smile/implicits/package.scala b/src/main/scala/org/apache/spark/smile/implicits/package.scala new file mode 100644 index 0000000..8fa641b --- /dev/null +++ b/src/main/scala/org/apache/spark/smile/implicits/package.scala @@ -0,0 +1,58 @@ +package org.apache.spark.smile + +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{Row, Dataset => SparkDataset} +import smile.data.{NominalAttribute, SparseDataset} + +package object implicits { + + implicit class BetterSmileDataset(dataset: SparkDataset[_]) { + + def toSmileDataset( + featuresColName: String = "features", + labelColName: String = "label", + weightColName: String = "weight"): SparseDataset = { + val classification = + dataset.select(labelColName).take(20).forall { case Row(x: Double) => (x % 1) == 0 } + val minClass = + if (classification) + dataset.select(labelColName).agg(min(labelColName)).head.getDouble(0).toInt + else 0 + val res = + if (classification) new SparseDataset(new NominalAttribute("class")) + else new SparseDataset(new NominalAttribute("response")) + if (dataset.columns.contains(weightColName)) { + dataset + .select(Array(featuresColName, labelColName, weightColName).map(col): _*) + .collect() + .toList + .zipWithIndex + .foreach { + case (Row(features: Vector, label: Double, weight: Double), i: Int) => + features.toArray.toList.zipWithIndex.foreach { + case (x: Double, j: Int) => res.set(i, j, x) + } + if (classification) res.set(i, label.toInt - minClass, weight) + else res.set(i, label, weight) + } + } else { + dataset + .select(Array(featuresColName, labelColName).map(col): _*) + .collect() + .zipWithIndex + .foreach { + case (Row(features: Vector, label: Double), i: Int) => + features.toArray.zipWithIndex.foreach { + case (x: Double, j: Int) => res.set(i, j, x) + } + if (classification) res.set(i, label.toInt - minClass) else res.set(i, label) + } + } + res + + } + + } + +} diff --git a/src/main/scala/smile/tuning/SparkOperators.scala b/src/main/scala/smile/tuning/Operators.scala similarity index 97% rename from src/main/scala/smile/tuning/SparkOperators.scala rename to src/main/scala/smile/tuning/Operators.scala index bcdad50..8809f5e 100644 --- a/src/main/scala/smile/tuning/SparkOperators.scala +++ b/src/main/scala/smile/tuning/Operators.scala @@ -8,7 +8,7 @@ import scala.reflect.ClassTag case class SerializableClassificationMeasure(@transient measure: ClassificationMeasure) -trait SparkOperators { +trait Operators { def sparkgscv[T <: Object: ClassTag]( x: Array[T], diff --git a/src/main/scala/smile/tuning/package.scala b/src/main/scala/smile/tuning/package.scala index 5d0bce0..4b20fdf 100644 --- a/src/main/scala/smile/tuning/package.scala +++ b/src/main/scala/smile/tuning/package.scala @@ -1,3 +1,3 @@ package smile -package object tuning extends SparkOperators {} +package object tuning extends Operators {} diff --git a/src/test/scala/org/apache/spark/smile/ImplicitSuite.scala b/src/test/scala/org/apache/spark/smile/ImplicitSuite.scala new file mode 100644 index 0000000..555387b --- /dev/null +++ b/src/test/scala/org/apache/spark/smile/ImplicitSuite.scala @@ -0,0 +1,25 @@ +package org.apache.spark.smile + +import com.holdenkarau.spark.testing.DatasetSuiteBase +import org.scalatest.FunSuite +import smile.classification.knn +import smile.data._ +import smile.validation.{Accuracy, ClassificationMeasure, _} + +class ImplicitSuite extends FunSuite with DatasetSuiteBase { + + test("toSmileDataset") { + + import org.apache.spark.smile.implicits._ + + val mushrooms = spark.read.format("libsvm").load("data/mushrooms.svm") + + val (x,y) = mushrooms.toSmileDataset().unzipInt + + val res = cv(x, y, 5, Seq(new Accuracy().asInstanceOf[ClassificationMeasure]): _*) { (x, y) => knn(x, y, 3) } + + assert(res(0) == 1) + + } + +} diff --git a/src/test/scala/smile/tuning/SparkOperatorsSuite.scala b/src/test/scala/smile/tuning/OperatorsSuite.scala similarity index 76% rename from src/test/scala/smile/tuning/SparkOperatorsSuite.scala rename to src/test/scala/smile/tuning/OperatorsSuite.scala index fb5180f..46b67dc 100644 --- a/src/test/scala/smile/tuning/SparkOperatorsSuite.scala +++ b/src/test/scala/smile/tuning/OperatorsSuite.scala @@ -3,16 +3,16 @@ package smile.tuning import com.holdenkarau.spark.testing.DatasetSuiteBase import org.apache.spark.sql.SparkSession import org.scalatest.FunSuite -import smile.classification.{Classifier, _} +import smile.classification._ import smile.data._ import smile.read import smile.validation._ -class SparkOperatorsSuite extends FunSuite with DatasetSuiteBase { +class OperatorsSuite extends FunSuite with DatasetSuiteBase { test("sparkgscv") { - implicit val sparkImplicit = spark + implicit val sparkImplicit: SparkSession = spark val mushrooms = read.libsvm("data/mushrooms.svm") val (x, y) = mushrooms.unzipInt