diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 60702824acb46..208813768e264 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1328,7 +1328,7 @@ setMethod("write.df", jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { - options[['path']] = path + options[['path']] <- path } callJMethod(df@sdf, "save", source, jmode, options) }) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 78c7a3037ffac..6f772158ddfe8 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -36,9 +36,9 @@ connectBackend <- function(hostname, port, timeout = 6000) { determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { - sparkSubmitBinName = "spark-submit" + sparkSubmitBinName <- "spark-submit" } else { - sparkSubmitBinName = "spark-submit.cmd" + sparkSubmitBinName <- "spark-submit.cmd" } sparkSubmitBinName } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d961bbc383688..7d1f6b0819ed0 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -23,6 +23,7 @@ # Int -> integer # String -> character # Boolean -> logical +# Float -> double # Double -> double # Long -> double # Array[Byte] -> raw diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 8f1c68f7c4d28..576ac72f40fc0 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -87,7 +87,7 @@ setMethod("count", setMethod("agg", signature(x = "GroupedData"), function(x, ...) { - cols = list(...) + cols <- list(...) stopifnot(length(cols) > 0) if (is.character(cols[[1]])) { cols <- varargsToEnv(...) @@ -97,7 +97,7 @@ setMethod("agg", if (!is.null(ns)) { for (n in ns) { if (n != "") { - cols[[n]] = alias(cols[[n]], n) + cols[[n]] <- alias(cols[[n]], n) } } } diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 15e2bdbd55d79..06df430687682 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -123,6 +123,7 @@ structField.character <- function(x, type, nullable = TRUE) { } options <- c("byte", "integer", + "float", "double", "numeric", "character", diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index ea629a64f7158..950ba74dbe017 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -41,8 +41,8 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, if (isInstanceOf(obj, "scala.Tuple2")) { # JavaPairRDD[Array[Byte], Array[Byte]]. - keyBytes = callJMethod(obj, "_1") - valBytes = callJMethod(obj, "_2") + keyBytes <- callJMethod(obj, "_1") + valBytes <- callJMethod(obj, "_2") res <- list(unserialize(keyBytes), unserialize(valBytes)) } else { diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index ccaea18ecab2a..f2452ed97d2ea 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -20,7 +20,7 @@ context("functions on binary files") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { fileName1 <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index 3be8c65a6c1a0..dca0657c57e0d 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -76,7 +76,7 @@ test_that("zipPartitions() on RDDs", { expect_equal(actual, list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index b79692873cec3..6c3aaab8c711e 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -447,7 +447,7 @@ test_that("zipRDD() on RDDs", { expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) @@ -483,7 +483,7 @@ test_that("cartesian() on RDDs", { actual <- collect(cartesian(rdd, emptyRdd)) expect_equal(actual, list()) - mockFile = c("Spark is pretty.", "Spark is awesome.") + mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index b0ea38854304e..76f74f80834a9 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -108,6 +108,32 @@ test_that("create DataFrame from RDD", { expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + + df <- jsonFile(sqlContext, jsonPathNa) + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") + insertInto(df, "people") + expect_equal(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"), c(16)) + expect_equal(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"), c(176.5)) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + df2 <- createDataFrame(sqlContext, df.toRDD, schema) + expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + + localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), height=c(164.10, 181.4, 173.7)) + df <- createDataFrame(sqlContext, localDF, schema) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) + expect_equal(columns(df), c("name", "age", "height")) + expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(collect(where(df, df$name == "John")), c("John", 19, 164.10)) }) test_that("convert NAs to null type in DataFrames", { diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 58318dfef71ab..a9cf83dbdbdb1 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -20,7 +20,7 @@ context("the textFile() function") # JavaSparkContext handle sc <- sparkR.init() -mockFile = c("Spark is pretty.", "Spark is awesome.") +mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { fileName <- tempfile(pattern="spark-test", fileext=".tmp") diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index aa0d2a66b9082..12df4cf4f65b7 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -119,7 +119,7 @@ test_that("cleanClosure on R functions", { # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L) - t = 4 # Override base::t in .GlobalEnv. + t <- 4 # Override base::t in .GlobalEnv. f <- function(x) { x > t } newF <- cleanClosure(f) env <- environment(newF) diff --git a/core/pom.xml b/core/pom.xml index 558cc3fb9f2f3..73f7a75cab9d3 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -372,6 +372,11 @@ junit-interface test + + org.apache.curator + curator-test + test + net.razorvine pyrolite diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 0c50b4002cf7b..648bcfe28cad2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.TimeUnit import scala.collection.mutable +import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} @@ -211,7 +212,16 @@ private[spark] class ExecutorAllocationManager( listenerBus.addListener(listener) val scheduleTask = new Runnable() { - override def run(): Unit = Utils.logUncaughtExceptions(schedule()) + override def run(): Unit = { + try { + schedule() + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logWarning(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + } + } } executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index d09e17dea0911..248339148d9b7 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -32,7 +32,20 @@ object TaskContext { */ def get(): TaskContext = taskContext.get - private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] + /** + * Returns the partition id of currently active TaskContext. It will return 0 + * if there is no active TaskContext for cases like local execution. + */ + def getPartitionId(): Int = { + val tc = taskContext.get() + if (tc == null) { + 0 + } else { + tc.partitionId() + } + } + + private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext] // Note: protected[spark] instead of private[spark] to prevent the following two from // showing up in JavaDoc. diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 56adc857d4ce0..d5b4260bf4529 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -179,6 +179,7 @@ private[spark] object SerDe { // Int -> integer // String -> character // Boolean -> logical + // Float -> double // Double -> double // Long -> double // Array[Byte] -> raw @@ -215,6 +216,9 @@ private[spark] object SerDe { case "long" | "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) case "double" | "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index f459ed5b3a1a1..aa379d4cd61e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -21,9 +21,8 @@ import java.io._ import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.Logging +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer} import org.apache.spark.util.Utils @@ -32,11 +31,11 @@ import org.apache.spark.util.Utils * Files are deleted when applications and workers are removed. * * @param dir Directory to store files. Created if non-existent (but not recursively). - * @param serialization Used to serialize our objects. + * @param serializer Used to serialize our objects. */ private[master] class FileSystemPersistenceEngine( val dir: String, - val serialization: Serialization) + val serializer: Serializer) extends PersistenceEngine with Logging { new File(dir).mkdir() @@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine( private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) + val fileOut = new FileOutputStream(file) + var out: SerializationStream = null Utils.tryWithSafeFinally { - out.write(serialized) + out = serializer.newInstance().serializeStream(fileOut) + out.writeObject(value) } { - out.close() + fileOut.close() + if (out != null) { + out.close() + } } } private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { - val fileData = new Array[Byte](file.length().asInstanceOf[Int]) - val dis = new DataInputStream(new FileInputStream(file)) + val fileIn = new FileInputStream(file) + var in: DeserializationStream = null try { - dis.readFully(fileData) + in = serializer.newInstance().deserializeStream(fileIn) + in.readObject[T]() } finally { - dis.close() + fileIn.close() + if (in != null) { + in.close() + } } - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 245b047e7dfbd..4615febf17d24 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.language.postfixOps import scala.util.Random -import akka.serialization.Serialization -import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, @@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} @@ -58,9 +56,6 @@ private[master] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") - // TODO Remove it once we don't use akka.serialization.Serialization - private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -161,20 +156,21 @@ private[master] class Master( masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler) + val serializer = new JavaSerializer(conf) val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new ZooKeeperRecoveryModeFactory(conf, serializer) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) + new FileSystemRecoveryModeFactory(conf, serializer) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory")) - val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(actorSystem)) + val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer]) + .newInstance(conf, serializer) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -213,7 +209,7 @@ private[master] class Master( override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { - val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE } else { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index a03d460509e03..58a00bceee6af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEnv import scala.reflect.ClassTag @@ -80,8 +81,11 @@ abstract class PersistenceEngine { * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + final def readPersistedData( + rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + rpcEnv.deserialize { () => + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } } def close() {} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala index 351db8fab2041..c4c3283fb73f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -17,10 +17,9 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer /** * ::DeveloperApi:: @@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi * */ @DeveloperApi -abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) { /** * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) @@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual * recovery is made by restoring from filesystem. */ -private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") @@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: } } -private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) +private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer) extends StandaloneRecoveryModeFactory(conf, serializer) { def createPersistenceEngine(): PersistenceEngine = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 328d95a7a0c68..563831cc6b8dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -import akka.serialization.Serialization +import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.serializer.Serializer -private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) +private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer) extends PersistenceEngine with Logging { @@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat } private def serializeIntoFile(path: String, value: AnyRef) { - val serializer = serialization.findSerializerFor(value) - val serialized = serializer.toBinary(value) - zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) + val serialized = serializer.newInstance().serialize(value) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes) } private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) - val clazz = m.runtimeClass.asInstanceOf[Class[T]] - val serializer = serialization.serializerFor(clazz) try { - Some(serializer.fromBinary(fileData).asInstanceOf[T]) + Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { case e: Exception => { logWarning("Exception while reading persisted file, deleting", e) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index c9fcc7a36cc04..29debe8081308 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -139,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * creating it manually because different [[RpcEnv]] may have different formats. */ def uriOf(systemName: String, address: RpcAddress, endpointName: String): String + + /** + * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object + * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. + */ + def deserialize[T](deserializationAction: () => T): T } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f2d87f68341af..fc17542abf81d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import com.google.common.util.concurrent.MoreExecutors +import akka.serialization.JavaSerializer import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ @@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] ( } override def toString: String = s"${getClass.getSimpleName}($actorSystem)" + + override def deserialize[T](deserializationAction: () => T): T = { + JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { + deserializationAction() + } + } } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef( override def toString: String = s"${getClass.getSimpleName}($actorRef)" + final override def equals(that: Any): Boolean = that match { + case other: AkkaRpcEndpointRef => actorRef == other.actorRef + case _ => false + } + + final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() } /** diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ff0a339a39c65..27b82aaddd2e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -692,7 +692,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val gettingResultTime = getGettingResultTime(info, currentTime) val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} + val accumulatorsReadable = maybeAccumulators.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") + } val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") diff --git a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala index f4e56632e426a..8c96b0e71dfdd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/CustomRecoveryModeFactory.scala @@ -19,18 +19,19 @@ // when they are outside of org.apache.spark. package other.supplier +import java.nio.ByteBuffer + import scala.collection.mutable import scala.reflect.ClassTag -import akka.serialization.Serialization - import org.apache.spark.SparkConf import org.apache.spark.deploy.master._ +import org.apache.spark.serializer.Serializer class CustomRecoveryModeFactory( conf: SparkConf, - serialization: Serialization -) extends StandaloneRecoveryModeFactory(conf, serialization) { + serializer: Serializer +) extends StandaloneRecoveryModeFactory(conf, serializer) { CustomRecoveryModeFactory.instantiationAttempts += 1 @@ -40,7 +41,7 @@ class CustomRecoveryModeFactory( * */ override def createPersistenceEngine(): PersistenceEngine = - new CustomPersistenceEngine(serialization) + new CustomPersistenceEngine(serializer) /** * Create an instance of LeaderAgent that decides who gets elected as master. @@ -53,7 +54,7 @@ object CustomRecoveryModeFactory { @volatile var instantiationAttempts = 0 } -class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine { +class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine { val data = mutable.HashMap[String, Array[Byte]]() CustomPersistenceEngine.lastInstance = Some(this) @@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def persist(name: String, obj: Object): Unit = { CustomPersistenceEngine.persistAttempts += 1 - serialization.serialize(obj) match { - case util.Success(bytes) => data += name -> bytes - case util.Failure(cause) => throw new RuntimeException(cause) - } + val serialized = serializer.newInstance().serialize(obj) + val bytes = new Array[Byte](serialized.remaining()) + serialized.get(bytes) + data += name -> bytes } /** @@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE */ override def read[T: ClassTag](prefix: String): Seq[T] = { CustomPersistenceEngine.readAttempts += 1 - val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] val results = for ((name, bytes) <- data; if name.startsWith(prefix)) - yield serialization.deserialize(bytes, clazz) - - results.find(_.isFailure).foreach { - case util.Failure(cause) => throw new RuntimeException(cause) - } - - results.flatMap(_.toOption).toSeq + yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + results.toSeq } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 9cb6dd43bac47..a8fbaf1d9da0a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { persistenceEngine.addDriver(driverToPersist) persistenceEngine.addWorker(workerToPersist) - val (apps, drivers, workers) = persistenceEngine.readPersistedData() + val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv) apps.map(_.id) should contain(appToPersist.id) drivers.map(_.id) should contain(driverToPersist.id) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala new file mode 100644 index 0000000000000..11e87bd1dd8eb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.deploy.master + +import java.net.ServerSocket + +import org.apache.commons.lang3.RandomUtils +import org.apache.curator.test.TestingServer + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} +import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.util.Utils + +class PersistenceEngineSuite extends SparkFunSuite { + + test("FileSystemPersistenceEngine") { + val dir = Utils.createTempDir() + try { + val conf = new SparkConf() + testPersistenceEngine(conf, serializer => + new FileSystemPersistenceEngine(dir.getAbsolutePath, serializer) + ) + } finally { + Utils.deleteRecursively(dir) + } + } + + test("ZooKeeperPersistenceEngine") { + val conf = new SparkConf() + // TestingServer logs the port conflict exception rather than throwing an exception. + // So we have to find a free port by ourselves. This approach cannot guarantee always starting + // zkTestServer successfully because there is a time gap between finding a free port and + // starting zkTestServer. But the failure possibility should be very low. + val zkTestServer = new TestingServer(findFreePort(conf)) + try { + testPersistenceEngine(conf, serializer => { + conf.set("spark.deploy.zookeeper.url", zkTestServer.getConnectString) + new ZooKeeperPersistenceEngine(conf, serializer) + }) + } finally { + zkTestServer.stop() + } + } + + private def testPersistenceEngine( + conf: SparkConf, persistenceEngineCreator: Serializer => PersistenceEngine): Unit = { + val serializer = new JavaSerializer(conf) + val persistenceEngine = persistenceEngineCreator(serializer) + persistenceEngine.persist("test_1", "test_1_value") + assert(Seq("test_1_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.persist("test_2", "test_2_value") + assert(Set("test_1_value", "test_2_value") === persistenceEngine.read[String]("test_").toSet) + persistenceEngine.unpersist("test_1") + assert(Seq("test_2_value") === persistenceEngine.read[String]("test_")) + persistenceEngine.unpersist("test_2") + assert(persistenceEngine.read[String]("test_").isEmpty) + + // Test deserializing objects that contain RpcEndpointRef + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + try { + // Create a real endpoint so that we can test RpcEndpointRef deserialization + val workerEndpoint = rpcEnv.setupEndpoint("worker", new RpcEndpoint { + override val rpcEnv: RpcEnv = rpcEnv + }) + + val workerToPersist = new WorkerInfo( + id = "test_worker", + host = "127.0.0.1", + port = 10000, + cores = 0, + memory = 0, + endpoint = workerEndpoint, + webUiPort = 0, + publicAddress = "" + ) + + persistenceEngine.addWorker(workerToPersist) + + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) + + assert(storedApps.isEmpty) + assert(storedDrivers.isEmpty) + + // Check deserializing WorkerInfo + assert(storedWorkers.size == 1) + val recoveryWorkerInfo = storedWorkers.head + assert(workerToPersist.id === recoveryWorkerInfo.id) + assert(workerToPersist.host === recoveryWorkerInfo.host) + assert(workerToPersist.port === recoveryWorkerInfo.port) + assert(workerToPersist.cores === recoveryWorkerInfo.cores) + assert(workerToPersist.memory === recoveryWorkerInfo.memory) + assert(workerToPersist.endpoint === recoveryWorkerInfo.endpoint) + assert(workerToPersist.webUiPort === recoveryWorkerInfo.webUiPort) + assert(workerToPersist.publicAddress === recoveryWorkerInfo.publicAddress) + } finally { + rpcEnv.shutdown() + rpcEnv.awaitTermination() + } + } + + private def findFreePort(conf: SparkConf): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, conf)._2 + } +} diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 5f2671a6e5053..e462302f28423 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -129,3 +129,12 @@ yongtang - Yong Tang ypcat - Pei-Lun Lee zhichao-li - Zhichao Li zzcclp - Zhichao Zhang +979969786 - Yuming Wang +Rosstin - Rosstin Murphy +ameyc - Amey Chaugule +animeshbaranawal - Animesh Baranawal +cafreeman - Chris Freeman +lee19 - Lee +lockwobr - Brian Lockwood +navis - Navis Ryu +pparkkin - Paavo Parkkinen diff --git a/dev/lint-python b/dev/lint-python index 0c3586462cb37..e02dff220eb87 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -21,12 +21,14 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" -PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" +PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" +PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" +PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" cd "$SPARK_ROOT_DIR" # compileall: https://docs.python.org/2/library/compileall.html -python -B -m compileall -q -l $PATHS_TO_CHECK > "$PYTHON_LINT_REPORT_PATH" +python -B -m compileall -q -l $PATHS_TO_CHECK > "$PEP8_REPORT_PATH" compile_status="${PIPESTATUS[0]}" # Get pep8 at runtime so that we don't rely on it being installed on the build server. @@ -47,11 +49,36 @@ if [ ! -e "$PEP8_SCRIPT_PATH" ]; then fi fi +# Easy install pylint in /dev/pylint. To easy_install into a directory, the PYTHONPATH should +# be set to the directory. +# dev/pylint should be appended to the PATH variable as well. +# Jenkins by default installs the pylint3 version, so for now this just checks the code quality +# of python3. +export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" +export "PYLINT_HOME=$PYTHONPATH" +export "PATH=$PYTHONPATH:$PATH" + +if [ ! -d "$PYLINT_HOME" ]; then + mkdir "$PYLINT_HOME" + # Redirect the annoying pylint installation output. + easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" + easy_install_status="$?" + + if [ "$easy_install_status" -ne 0 ]; then + echo "Unable to install pylint locally in \"$PYTHONPATH\"." + cat "$PYLINT_INSTALL_INFO" + exit "$easy_install_status" + fi + + rm "$PYLINT_INSTALL_INFO" + +fi + # There is no need to write this output to a file #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PYTHON_LINT_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" --ignore=E402,E731,E241,W503,E226 $PATHS_TO_CHECK >> "$PEP8_REPORT_PATH" pep8_status="${PIPESTATUS[0]}" if [ "$compile_status" -eq 0 -a "$pep8_status" -eq 0 ]; then @@ -61,13 +88,27 @@ else fi if [ "$lint_status" -ne 0 ]; then - echo "Python lint checks failed." - cat "$PYTHON_LINT_REPORT_PATH" + echo "PEP8 checks failed." + cat "$PEP8_REPORT_PATH" +else + echo "PEP8 checks passed." +fi + +rm "$PEP8_REPORT_PATH" + +for to_be_checked in "$PATHS_TO_CHECK" +do + pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" +done + +if [ "${PIPESTATUS[0]}" -ne 0 ]; then + lint_status=1 + echo "Pylint checks failed." + cat "$PYLINT_REPORT_PATH" else - echo "Python lint checks passed." + echo "Pylint checks passed." fi -# rm "$PEP8_SCRIPT_PATH" -rm "$PYTHON_LINT_REPORT_PATH" +rm "$PYLINT_REPORT_PATH" exit "$lint_status" diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c74cb1f1ef8ea..8c46adf256a9a 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,6 +3,24 @@ layout: global title: Spark ML Programming Guide --- +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. @@ -154,6 +172,19 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. +# Algorithm Guides + +There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. + +**Pipelines API Algorithm Guides** + +* [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Ensembles](ml-ensembles.html) + +**Algorithms in `spark.ml`** + +* [Linear methods with elastic net regularization](ml-linear-methods.html) + # Code Examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md new file mode 100644 index 0000000000000..1ac83d94c9e81 --- /dev/null +++ b/docs/ml-linear-methods.md @@ -0,0 +1,129 @@ +--- +layout: global +title: Linear Methods - ML +displayTitle: ML - Linear Methods +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: +`\[ +\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. +\]` +By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. + +**Examples** + +
+ +
+ +{% highlight scala %} + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.mllib.util.MLUtils + +// Load training data +val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for logistic regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} + +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LogisticRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Logistic Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + + // Fit the model + LogisticRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for logistic regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + } +} +{% endhighlight %} +
+ +
+ +{% highlight python %} + +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Load training data +training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for logistic regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) +{% endhighlight %} + +
+ +
+ +### Optimization + +The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d72dc20a5ad6e..0fc7036bffeaf 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -33,6 +33,7 @@ guaranteed to find a globally optimal solution, and when run multiple times on a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +* *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed. **Examples** diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3927d65fbf8fb..07655baa414b5 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -10,7 +10,7 @@ displayTitle: MLlib - Linear Methods `\[ \newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} +\newcommand{\E}{\mathbb{E}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\wv}{\mathbf{w}} @@ -18,10 +18,10 @@ displayTitle: MLlib - Linear Methods \newcommand{\bv}{\mathbf{b}} \newcommand{\N}{\mathbb{N}} \newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} \newcommand{\zero}{\mathbf{0}} \]` @@ -29,7 +29,7 @@ displayTitle: MLlib - Linear Methods Many standard *machine learning* methods can be formulated as a convex optimization problem, i.e. the task of finding a minimizer of a convex function `$f$` that depends on a variable vector -`$\wv$` (called `weights` in the code), which has `$d$` entries. +`$\wv$` (called `weights` in the code), which has `$d$` entries. Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} @@ -39,7 +39,7 @@ the objective function is of the form \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and -`$y_i\in\R$` are their corresponding labels, which we want to predict. +`$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. @@ -99,6 +99,9 @@ regularizers in MLlib: L1$\|\wv\|_1$$\mathrm{sign}(\wv)$ + + elastic net$\alpha \|\wv\|_1 + (1-\alpha)\frac{1}{2}\|\wv\|_2^2$$\alpha \mathrm{sign}(\wv) + (1-\alpha) \wv$ + @@ -107,7 +110,7 @@ of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. -It is not recommended to train models without any regularization, +[Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization) is a combination of L1 and L2 regularization. It is not recommended to train models without any regularization, especially when the number of training examples is small. ### Optimization @@ -531,7 +534,7 @@ sameModel = LogisticRegressionModel.load(sc, "myModelPath") ### Linear least squares, Lasso, and ridge regression -Linear least squares is the most common formulation for regression problems. +Linear least squares is the most common formulation for regression problems. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the squared loss: `\[ @@ -539,8 +542,8 @@ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` Various related regression methods are derived by using different types of regularization: -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is @@ -552,7 +555,7 @@ known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_erro
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -614,7 +617,7 @@ public class LinearRegression { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); JavaSparkContext sc = new JavaSparkContext(conf); - + // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD data = sc.textFile(path); @@ -634,7 +637,7 @@ public class LinearRegression { // Building the model int numIterations = 100; - final LinearRegressionModel model = + final LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); // Evaluate model on training examples and compute training error @@ -665,7 +668,7 @@ public class LinearRegression {
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. -The example then uses LinearRegressionWithSGD to build a simple linear model to predict label +The example then uses LinearRegressionWithSGD to build a simple linear model to predict label values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). @@ -706,8 +709,8 @@ a dependency. ###Streaming linear regression -When data arrive in a streaming fashion, it is useful to fit regression models online, -updating the parameters of the model as new data arrives. MLlib currently supports +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports streaming linear regression using ordinary least squares. The fitting is similar to that performed offline, except fitting occurs on each batch of data, so that the model continually updates to reflect the data from the stream. @@ -722,7 +725,7 @@ online to the first stream, and make predictions on the second stream.
-First, we import the necessary classes for parsing our input data and creating the model. +First, we import the necessary classes for parsing our input data and creating the model. {% highlight scala %} @@ -734,7 +737,7 @@ import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD Then we make input streams for training and testing data. We assume a StreamingContext `ssc` has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) -for more info. For this example, we use labeled points in training and testing streams, +for more info. For this example, we use labeled points in training and testing streams, but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} @@ -754,7 +757,7 @@ val model = new StreamingLinearRegressionWithSGD() {% endhighlight %} -Now we register the streams for training and testing and start the job. +Now we register the streams for training and testing and start the job. Printing predictions alongside true labels lets us easily see the result. {% highlight scala %} @@ -764,14 +767,14 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() - + {% endhighlight %} We can now save text files with data to the training or testing folders. -Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label -and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` -the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. -As you feed more data to the training directory, the predictions +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions will get better!
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala new file mode 100644 index 0000000000000..d9a36bda386b3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the transforms required for fitting a dataset against an R model formula. Currently + * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula + * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +@Experimental +class RFormula(override val uid: String) + extends Transformer with HasFeaturesCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("rFormula")) + + /** + * R formula parameter. The formula is provided in string form. + * @group setParam + */ + val formula: Param[String] = new Param(this, "formula", "R model formula") + + private var parsedFormula: Option[ParsedRFormula] = None + + /** + * Sets the formula to use for this transformer. Must be called before use. + * @group setParam + * @param value an R formula in string form (e.g. "y ~ x + z") + */ + def setFormula(value: String): this.type = { + parsedFormula = Some(RFormulaParser.parse(value)) + set(formula, value) + this + } + + /** @group getParam */ + def getFormula: String = $(formula) + + /** @group getParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group getParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def transformSchema(schema: StructType): StructType = { + checkCanTransform(schema) + val withFeatures = transformFeatures.transformSchema(schema) + if (hasLabelCol(schema)) { + withFeatures + } else { + val nullable = schema(parsedFormula.get.label).dataType match { + case _: NumericType | BooleanType => false + case _ => true + } + StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } + } + + override def transform(dataset: DataFrame): DataFrame = { + checkCanTransform(dataset.schema) + transformLabel(transformFeatures.transform(dataset)) + } + + override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + + override def toString: String = s"RFormula(${get(formula)})" + + private def transformLabel(dataset: DataFrame): DataFrame = { + if (hasLabelCol(dataset.schema)) { + dataset + } else { + val labelName = parsedFormula.get.label + dataset.schema(labelName).dataType match { + case _: NumericType | BooleanType => + dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) + // TODO(ekl) add support for string-type labels + case other => + throw new IllegalArgumentException("Unsupported type for label: " + other) + } + } + } + + private def transformFeatures: Transformer = { + // TODO(ekl) add support for non-numeric features and feature interactions + new VectorAssembler(uid) + .setInputCols(parsedFormula.get.terms.toArray) + .setOutputCol($(featuresCol)) + } + + private def checkCanTransform(schema: StructType) { + require(parsedFormula.isDefined, "Must call setFormula() first.") + val columnNames = schema.map(_.name) + require(!columnNames.contains($(featuresCol)), "Features column already exists.") + require( + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, + "Label column already exists and is not of type DoubleType.") + } + + private def hasLabelCol(schema: StructType): Boolean = { + schema.map(_.name).contains($(labelCol)) + } +} + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: String, terms: Seq[String]) + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r + + def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } + + def formula: Parser[ParsedRFormula] = + (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 9f83c2ee16178..086917fa680f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -116,7 +116,7 @@ class VectorAssembler(override val uid: String) if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } - StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) + StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true)) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala new file mode 100644 index 0000000000000..5fdf878a3df72 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.{Accumulator, SparkContext} + +/** + * Abstract class for stopwatches. + */ +private[spark] abstract class Stopwatch extends Serializable { + + @transient private var running: Boolean = false + private var startTime: Long = _ + + /** + * Name of the stopwatch. + */ + val name: String + + /** + * Starts the stopwatch. + * Throws an exception if the stopwatch is already running. + */ + def start(): Unit = { + assume(!running, "start() called but the stopwatch is already running.") + running = true + startTime = now + } + + /** + * Stops the stopwatch and returns the duration of the last session in milliseconds. + * Throws an exception if the stopwatch is not running. + */ + def stop(): Long = { + assume(running, "stop() called but the stopwatch is not running.") + val duration = now - startTime + add(duration) + running = false + duration + } + + /** + * Checks whether the stopwatch is running. + */ + def isRunning: Boolean = running + + /** + * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch + * is running. + */ + def elapsed(): Long + + /** + * Gets the current time in milliseconds. + */ + protected def now: Long = System.currentTimeMillis() + + /** + * Adds input duration to total elapsed time. + */ + protected def add(duration: Long): Unit +} + +/** + * A local [[Stopwatch]]. + */ +private[spark] class LocalStopwatch(override val name: String) extends Stopwatch { + + private var elapsedTime: Long = 0L + + override def elapsed(): Long = elapsedTime + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A distributed [[Stopwatch]] using Spark accumulator. + * @param sc SparkContext + */ +private[spark] class DistributedStopwatch( + sc: SparkContext, + override val name: String) extends Stopwatch { + + private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") + + override def elapsed(): Long = elapsedTime.value + + override protected def add(duration: Long): Unit = { + elapsedTime += duration + } +} + +/** + * A multiple stopwatch that contains local and distributed stopwatches. + * @param sc SparkContext + */ +private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable { + + private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty + + /** + * Adds a local stopwatch. + * @param name stopwatch name + */ + def addLocal(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new LocalStopwatch(name) + this + } + + /** + * Adds a distributed stopwatch. + * @param name stopwatch name + */ + def addDistributed(name: String): this.type = { + require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.") + stopwatches(name) = new DistributedStopwatch(sc, name) + this + } + + /** + * Gets a stopwatch. + * @param name stopwatch name + */ + def apply(name: String): Stopwatch = stopwatches(name) + + override def toString: String = { + stopwatches.values.toArray.sortBy(_.name) + .map(c => s" ${c.name}: ${c.elapsed()}ms") + .mkString("{\n", ",\n", "\n}") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e628059c4af8e..c58a64001d9a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -502,6 +502,39 @@ private[python] class PythonMLLibAPI extends Serializable { new MatrixFactorizationModelWrapper(model) } + /** + * Java stub for Python mllib LDA.run() + */ + def trainLDAModel( + data: JavaRDD[java.util.List[Any]], + k: Int, + maxIterations: Int, + docConcentration: Double, + topicConcentration: Double, + seed: java.lang.Long, + checkpointInterval: Int, + optimizer: String): LDAModel = { + val algo = new LDA() + .setK(k) + .setMaxIterations(maxIterations) + .setDocConcentration(docConcentration) + .setTopicConcentration(topicConcentration) + .setCheckpointInterval(checkpointInterval) + .setOptimizer(optimizer) + + if (seed != null) algo.setSeed(seed) + + val documents = data.rdd.map(_.asScala.toArray).map { r => + r(0) match { + case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector]) + case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector]) + case _ => throw new IllegalArgumentException("input values contains invalid type value.") + } + } + algo.run(documents) + } + + /** * Java stub for Python mllib FPGrowth.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0f8d6a399682d..68297130a7b03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -156,6 +156,21 @@ class KMeans private ( this } + // Initial cluster centers can be provided as a KMeansModel object rather than using the + // random or k-means|| initializationMode + private var initialModel: Option[KMeansModel] = None + + /** + * Set the initial starting point, bypassing the random initialization or k-means|| + * The condition model.k == this.k must be met, failure results + * in an IllegalArgumentException. + */ + def setInitialModel(model: KMeansModel): this.type = { + require(model.k == k, "mismatched cluster count") + initialModel = Some(model) + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. @@ -193,20 +208,34 @@ class KMeans private ( val initStartTime = System.nanoTime() - val centers = if (initializationMode == KMeans.RANDOM) { - initRandom(data) + // Only one run is allowed when initialModel is given + val numRuns = if (initialModel.nonEmpty) { + if (runs > 1) logWarning("Ignoring runs; one run is allowed when initialModel is given.") + 1 } else { - initKMeansParallel(data) + runs } + val centers = initialModel match { + case Some(kMeansCenters) => { + Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) + } + case None => { + if (initializationMode == KMeans.RANDOM) { + initRandom(data) + } else { + initKMeansParallel(data) + } + } + } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + " seconds.") - val active = Array.fill(runs)(true) - val costs = Array.fill(runs)(0.0) + val active = Array.fill(numRuns)(true) + val costs = Array.fill(numRuns)(0.0) - var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) + var activeRuns = new ArrayBuffer[Int] ++ (0 until numRuns) var iteration = 0 val iterationStartTime = System.nanoTime() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index e577bf87f885e..408847afa800d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -53,14 +53,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend ) summary } + private lazy val SSerr = math.pow(summary.normL2(1), 2) + private lazy val SStot = summary.variance(0) * (summary.count - 1) + private lazy val SSreg = { + val yMean = summary.mean(0) + predictionAndObservations.map { + case (prediction, _) => math.pow(prediction - yMean, 2) + }.sum() + } /** - * Returns the explained variance regression score. - * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) - * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + * Returns the variance explained by regression. + * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n + * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] */ def explainedVariance: Double = { - 1 - summary.variance(1) / summary.variance(0) + SSreg / summary.count } /** @@ -76,8 +84,7 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * expected value of the squared error loss or quadratic loss. */ def meanSquaredError: Double = { - val rmse = summary.normL2(1) / math.sqrt(summary.count) - rmse * rmse + SSerr / summary.count } /** @@ -85,14 +92,14 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * the mean squared error. */ def rootMeanSquaredError: Double = { - summary.normL2(1) / math.sqrt(summary.count) + math.sqrt(this.meanSquaredError) } /** - * Returns R^2^, the coefficient of determination. - * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * Returns R^2^, the unadjusted coefficient of determination. + * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] */ def r2: Double = { - 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1)) + 1 - SSerr / SStot } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 39c48b084e550..7ead6327486cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -17,58 +17,49 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable + import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental /** - * - * :: Experimental :: - * * Calculate all patterns of a projected database in local. */ -@Experimental private[fpm] object LocalPrefixSpan extends Logging with Serializable { /** * Calculate all patterns of a projected database. * @param minCount minimum count * @param maxPatternLength maximum pattern length - * @param prefix prefix - * @param projectedDatabase the projected dabase + * @param prefixes prefixes in reversed order + * @param database the projected database * @return a set of sequential pattern pairs, - * the key of pair is sequential pattern (a list of items), + * the key of pair is sequential pattern (a list of items in reversed order), * the value of pair is the pattern's count. */ def run( minCount: Long, maxPatternLength: Int, - prefix: Array[Int], - projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { - val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) - val frequentPatternAndCounts = frequentPrefixAndCounts - .map(x => (prefix ++ Array(x._1), x._2)) - val prefixProjectedDatabases = getPatternAndProjectedDatabase( - prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) - - val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength - if (continueProcess) { - val nextPatterns = prefixProjectedDatabases - .map(x => run(minCount, maxPatternLength, x._1, x._2)) - .reduce(_ ++ _) - frequentPatternAndCounts ++ nextPatterns - } else { - frequentPatternAndCounts + prefixes: List[Int], + database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty + val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) + val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) + frequentItemAndCounts.iterator.flatMap { case (item, count) => + val newPrefixes = item :: prefixes + val newProjected = project(filteredDatabase, item) + Iterator.single((newPrefixes, count)) ++ + run(minCount, maxPatternLength, newPrefixes, newProjected) } } /** - * calculate suffix sequence following a prefix in a sequence - * @param prefix prefix - * @param sequence sequence + * Calculate suffix sequence immediately after the first occurrence of an item. + * @param item item to get suffix after + * @param sequence sequence to extract suffix from * @return suffix sequence */ - def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = { - val index = sequence.indexOf(prefix) + def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = { + val index = sequence.indexOf(item) if (index == -1) { Array() } else { @@ -76,38 +67,28 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } + def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + database + .map(getSuffix(prefix, _)) + .filter(_.nonEmpty) + } + /** * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences sequences data - * @return array of item and count pair + * @param minCount the minimum count for an item to be frequent + * @param database database of sequences + * @return freq item to count map */ private def getFreqItemAndCounts( minCount: Long, - sequences: Array[Array[Int]]): Array[(Int, Long)] = { - sequences.flatMap(_.distinct) - .groupBy(x => x) - .mapValues(_.length.toLong) - .filter(_._2 >= minCount) - .toArray - } - - /** - * Get the frequent prefixes' projected database. - * @param prePrefix the frequent prefixes' prefix - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPatternAndProjectedDatabase( - prePrefix: Array[Int], - frequentPrefixes: Array[Int], - sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = { - val filteredProjectedDatabase = sequences - .map(x => x.filter(frequentPrefixes.contains(_))) - frequentPrefixes.map { x => - val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty) - (prePrefix ++ Array(x), sub) - }.filter(x => x._2.nonEmpty) + database: Array[Array[Int]]): mutable.Map[Int, Long] = { + // TODO: use PrimitiveKeyOpenHashMap + val counts = mutable.Map[Int, Long]().withDefaultValue(0L) + database.foreach { sequence => + sequence.distinct.foreach { item => + counts(item) += 1L + } + } + counts.filter(_._2 >= minCount) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index aed7e30033b8a..139b2f6952fb8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -203,8 +203,12 @@ class PrefixSpan private ( private def getPatternsInLocal( minCount: Long, data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(ArrayBuffer[Int], Long)] = { - data - .flatMap { x => LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) } - .map { case (pattern, count) => (pattern.to[ArrayBuffer], count) } + data.flatMap { + case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) + .map { case (pattern: List[Int], count: Long) => + (pattern.toArray.reverse.to[ArrayBuffer], count) + } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala new file mode 100644 index 0000000000000..c8d065f37a605 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite + +class RFormulaParserSuite extends SparkFunSuite { + private def checkParse(formula: String, label: String, terms: Seq[String]) { + val parsed = RFormulaParser.parse(formula) + assert(parsed.label == label) + assert(parsed.terms == terms) + } + + test("parse simple formulas") { + checkParse("y ~ x", "y", Seq("x")) + checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala new file mode 100644 index 0000000000000..fa8611b243a9f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RFormula()) + } + + test("transform numeric data") { + val formula = new RFormula().setFormula("id ~ v1 + v2") + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val result = formula.transform(original) + val resultSchema = formula.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0), + (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0)) + ).toDF("id", "v1", "v2", "features", "label") + // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } + + test("features column already exists") { + val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + formula.transform(original) + } + } + + test("label column already exists") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val resultSchema = formula.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(resultSchema.toString == formula.transform(original).schema.toString) + } + + test("label column already exists but is not double type") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") + val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + intercept[IllegalArgumentException] { + formula.transformSchema(original.schema) + } + intercept[IllegalArgumentException] { + formula.transform(original) + } + } + +// TODO(ekl) enable after we implement string label support +// test("transform string label") { +// val formula = new RFormula().setFormula("name ~ id") +// val original = sqlContext.createDataFrame( +// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name") +// val result = formula.transform(original) +// val resultSchema = formula.transformSchema(original.schema) +// val expected = sqlContext.createDataFrame( +// Seq( +// (1, "foo", Vectors.dense(Array(1.0)), 1.0), +// (2, "bar", Vectors.dense(Array(2.0)), 0.0), +// (3, "bar", Vectors.dense(Array(3.0)), 0.0)) +// ).toDF("id", "name", "features", "label") +// assert(result.schema.toString == resultSchema.toString) +// assert(result.collect().toSeq == expected.collect().toSeq) +// } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala new file mode 100644 index 0000000000000..8df6617fe0228 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def testStopwatchOnDriver(sw: Stopwatch): Unit = { + assert(sw.name === "sw") + assert(sw.elapsed() === 0L) + assert(!sw.isRunning) + intercept[AssertionError] { + sw.stop() + } + sw.start() + Thread.sleep(50) + val duration = sw.stop() + assert(duration >= 50 && duration < 100) // using a loose upper bound + val elapsed = sw.elapsed() + assert(elapsed === duration) + sw.start() + Thread.sleep(50) + val duration2 = sw.stop() + assert(duration2 >= 50 && duration2 < 100) + val elapsed2 = sw.elapsed() + assert(elapsed2 === duration + duration2) + sw.start() + assert(sw.isRunning) + intercept[AssertionError] { + sw.start() + } + } + + test("LocalStopwatch") { + val sw = new LocalStopwatch("sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on driver") { + val sw = new DistributedStopwatch(sc, "sw") + testStopwatchOnDriver(sw) + } + + test("DistributedStopwatch on executors") { + val sw = new DistributedStopwatch(sc, "sw") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw.start() + Thread.sleep(50) + sw.stop() + } + assert(!sw.isRunning) + val elapsed = sw.elapsed() + assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound + } + + test("MultiStopwatch") { + val sw = new MultiStopwatch(sc) + .addLocal("local") + .addDistributed("spark") + assert(sw("local").name === "local") + assert(sw("spark").name === "spark") + intercept[NoSuchElementException] { + sw("some") + } + assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}") + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("local").stop() + Thread.sleep(50) + sw("spark").stop() + val localElapsed = sw("local").elapsed() + val sparkElapsed = sw("spark").elapsed() + assert(localElapsed >= 50 && localElapsed < 100) + assert(sparkElapsed >= 100 && sparkElapsed < 200) + assert(sw.toString === + s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") + val rdd = sc.parallelize(0 until 4, 4) + rdd.foreach { i => + sw("local").start() + sw("spark").start() + Thread.sleep(50) + sw("spark").stop() + sw("local").stop() + } + val localElapsed2 = sw("local").elapsed() + assert(localElapsed2 === localElapsed) + val sparkElapsed2 = sw("spark").elapsed() + assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0dbbd7127444f..3003c62d9876c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -278,6 +278,28 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("Initialize using given cluster centers") { + val points = Seq( + Vectors.dense(0.0, 0.0), + Vectors.dense(1.0, 0.0), + Vectors.dense(0.0, 1.0), + Vectors.dense(1.0, 1.0) + ) + val rdd = sc.parallelize(points, 3) + // creating an initial model + val initialModel = new KMeansModel(Array(points(0), points(2))) + + val returnModel = new KMeans() + .setK(2) + .setMaxIterations(0) + .setInitialModel(initialModel) + .run(rdd) + // comparing the returned model and the initial model + assert(returnModel.clusterCenters(0) === initialModel.clusterCenters(0)) + assert(returnModel.clusterCenters(1) === initialModel.clusterCenters(1)) + } + } object KMeansSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 9de2bdb6d7246..4b7f1be58f99b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { - test("regression metrics") { + test("regression metrics for unbiased (includes intercept term) predictor") { + /* Verify results in R: + preds = c(2.25, -0.25, 1.75, 7.75) + obs = c(3.0, -0.5, 2.0, 7.0) + + SStot = sum((obs - mean(obs))^2) + SSreg = sum((preds - mean(obs))^2) + SSerr = sum((obs - preds)^2) + + explainedVariance = SSreg / length(obs) + explainedVariance + > [1] 8.796875 + meanAbsoluteError = mean(abs(preds - obs)) + meanAbsoluteError + > [1] 0.5 + meanSquaredError = mean((preds - obs)^2) + meanSquaredError + > [1] 0.3125 + rmse = sqrt(meanSquaredError) + rmse + > [1] 0.559017 + r2 = 1 - SSerr / SStot + r2 + > [1] 0.9571734 + */ + val predictionAndObservations = sc.parallelize( + Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + } + + test("regression metrics for biased (no intercept term) predictor") { + /* Verify results in R: + preds = c(2.5, 0.0, 2.0, 8.0) + obs = c(3.0, -0.5, 2.0, 7.0) + + SStot = sum((obs - mean(obs))^2) + SSreg = sum((preds - mean(obs))^2) + SSerr = sum((obs - preds)^2) + + explainedVariance = SSreg / length(obs) + explainedVariance + > [1] 8.859375 + meanAbsoluteError = mean(abs(preds - obs)) + meanAbsoluteError + > [1] 0.5 + meanSquaredError = mean((preds - obs)^2) + meanSquaredError + > [1] 0.375 + rmse = sqrt(meanSquaredError) + rmse + > [1] 0.6123724 + r2 = 1 - SSerr / SStot + r2 + > [1] 0.9486081 + */ val predictionAndObservations = sc.parallelize( Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, + assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5, "explained variance regression score mismatch") assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch") } test("regression metrics with complete fitting") { val predictionAndObservations = sc.parallelize( Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, + assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5, "explained variance regression score mismatch") assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 413436d3db85f..9f107c89f6d80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.rdd.RDD -class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { +class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { test("PrefixSpan using Integer type") { @@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { def compareResult( expectedValue: Array[(Array[Int], Long)], actualValue: Array[(Array[Int], Long)]): Boolean = { - val sortedExpectedValue = expectedValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - val sortedActualValue = actualValue.sortWith{ (x, y) => - x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 - } - sortedExpectedValue.zip(sortedActualValue) - .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2) - .reduce(_&&_) + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } val prefixspan = new PrefixSpan() diff --git a/pom.xml b/pom.xml index 370c95dd03632..aa49e2ab7294b 100644 --- a/pom.xml +++ b/pom.xml @@ -748,6 +748,12 @@ curator-framework ${curator.version} + + org.apache.curator + curator-test + ${curator.version} + test + org.apache.hadoop hadoop-client diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000000000..061775960393b --- /dev/null +++ b/pylintrc @@ -0,0 +1,404 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=pyspark.heapq3 + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +extension-pkg-whitelist= + +# Allow optimization of some AST trees. This will activate a peephole AST +# optimizer, which will apply various small optimizations. For instance, it can +# be used to obtain the result of joining multiple strings with the addition +# operator. Joining a lot of strings can lead to a maximum recursion error in +# Pylint and this flag can prevent that. It has one side effect, the resulting +# AST will be different than the one from reality. +optimize-ast=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" + +# These errors are arranged in order of number of warning given in pylint. +# If you would like to improve the code quality of pyspark, remove any of these disabled errors +# run ./dev/lint-python and see if the errors raised by pylint can be fixed. + +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=no + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions= + +# Good variable names which should always be accepted, separated by a comma +good-names=i,j,k,ex,Run,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names=baz,toto,tutu,tata + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Regular expression matching correct function names +function-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for function names +function-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct variable names +variable-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for variable names +variable-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct constant names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Naming hint for constant names +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression matching correct attribute names +attr-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for attribute names +attr-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct argument names +argument-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for argument names +argument-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Naming hint for class attribute names +class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ + +# Naming hint for inline iteration names +inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=[A-Z_][a-zA-Z0-9]+$ + +# Naming hint for class names +class-name-hint=[A-Z_][a-zA-Z0-9]+$ + +# Regular expression matching correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Naming hint for module names +module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression matching correct method names +method-rgx=[a-z_][a-z0-9_]{2,30}$ + +# Naming hint for method names +method-name-hint=[a-z_][a-z0-9_]{2,30}$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=__.*__ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=100 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + +# List of optional constructs for which whitespace checking is disabled +no-space-check=trailing-comma,dict-separator + +# Maximum number of lines in a module +max-module-lines=1000 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=_$|dummy + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis +ignored-modules= + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index bc088e4c29e26..595124726366d 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -444,7 +444,7 @@ class DecisionTreeParams(Params): minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def __init__(self): super(DecisionTreeParams, self).__init__() @@ -460,7 +460,7 @@ def __init__(self): self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index ed4d78a2c6788..8a92f6911c24b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -31,13 +31,15 @@ from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector +from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.stat.distribution import MultivariateGaussian from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable from pyspark.streaming import DStream __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', 'PowerIterationClusteringModel', 'PowerIterationClustering', - 'StreamingKMeans', 'StreamingKMeansModel'] + 'StreamingKMeans', 'StreamingKMeansModel', + 'LDA', 'LDAModel'] @inherit_doc @@ -563,6 +565,68 @@ def predictOnValues(self, dstream): return dstream.mapValues(lambda x: self._model.predict(x)) +class LDAModel(JavaModelWrapper): + + """ A clustering model derived from the LDA method. + + Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + Terminology + - "word" = "term": an element of the vocabulary + - "token": instance of a term appearing in a document + - "topic": multinomial distribution over words representing some concept + References: + - Original LDA paper (journal version): + Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + + >>> from pyspark.mllib.linalg import Vectors + >>> from numpy.testing import assert_almost_equal + >>> data = [ + ... [1, Vectors.dense([0.0, 1.0])], + ... [2, SparseVector(2, {0: 1.0})], + ... ] + >>> rdd = sc.parallelize(data) + >>> model = LDA.train(rdd, k=2) + >>> model.vocabSize() + 2 + >>> topics = model.topicsMatrix() + >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) + >>> assert_almost_equal(topics, topics_expect, 1) + """ + + def topicsMatrix(self): + """Inferred topics, where each topic is represented by a distribution over terms.""" + return self.call("topicsMatrix").toArray() + + def vocabSize(self): + """Vocabulary size (number of terms or terms in the vocabulary)""" + return self.call("vocabSize") + + +class LDA(object): + + @classmethod + def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0, + topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"): + """Train a LDA model. + + :param rdd: RDD of data points + :param k: Number of clusters you want + :param maxIterations: Number of iterations. Default to 20 + :param docConcentration: Concentration parameter (commonly named "alpha") + for the prior placed on documents' distributions over topics ("theta"). + :param topicConcentration: Concentration parameter (commonly named "beta" or "eta") + for the prior placed on topics' distributions over terms. + :param seed: Random Seed + :param checkpointInterval: Period (in iterations) between checkpoints. + :param optimizer: LDAOptimizer used to perform the actual calculation. + Currently "em", "online" are supported. Default to "em". + """ + model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations, + docConcentration, topicConcentration, seed, + checkpointInterval, optimizer) + return LDAModel(model) + + def _test(): import doctest import pyspark.mllib.clustering diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index f21403707e12a..4398ca86f2ec2 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -82,7 +82,7 @@ class RegressionMetrics(JavaModelWrapper): ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) >>> metrics.explainedVariance - 0.95... + 8.859... >>> metrics.meanAbsoluteError 0.5... >>> metrics.meanSquaredError diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dca39fa833435..e0816b3e654bc 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -39,6 +39,8 @@ 'coalesce', 'countDistinct', 'explode', + 'format_number', + 'length', 'log2', 'md5', 'monotonicallyIncreasingId', @@ -47,7 +49,6 @@ 'sha1', 'sha2', 'sparkPartitionId', - 'strlen', 'struct', 'udf', 'when'] @@ -506,14 +507,28 @@ def sparkPartitionId(): @ignore_unicode_prefix @since(1.5) -def strlen(col): - """Calculates the length of a string expression. +def length(col): + """Calculates the length of a string or binary expression. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.strlen(_to_java_column(col))) + return Column(sc._jvm.functions.length(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(1.5) +def format_number(col, d): + """Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + and returns the result as a string. + :param col: the column name of the numeric value to be formatted + :param d: the N decimal places + >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + [Row(v=u'5.0000')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.format_number(_to_java_column(col), d)) @ignore_unicode_prefix diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c5c0add49d02c..21225016805bc 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -893,7 +893,8 @@ def test_pipe_functions(self): self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) result = rdd.pipe('cat').collect() result.sort() - [self.assertEqual(x, y) for x, y in zip(data, result)] + for x, y in zip(data, result): + self.assertEqual(x, y) self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed69c42dcb825..e0beafe710079 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -114,8 +115,10 @@ object FunctionRegistry { expression[Log2]("log2"), expression[Pow]("pow"), expression[Pow]("power"), + expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), + expression[Round]("round"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), expression[ShiftRightUnsigned]("shiftrightunsigned"), @@ -149,11 +152,12 @@ object FunctionRegistry { expression[Base64]("base64"), expression[Encode]("encode"), expression[Decode]("decode"), - expression[StringInstr]("instr"), + expression[FormatNumber]("format_number"), expression[Lower]("lcase"), expression[Lower]("lower"), - expression[StringLength]("length"), + expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8cb71995eb818..50db7d21f01ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -214,19 +214,6 @@ object HiveTypeCoercion { } Union(newLeft, newRight) - - // Also widen types for BinaryOperator. - case q: LogicalPlan => q transformExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - b.makeCopy(Array(newLeft, newRight)) - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. - } } } @@ -439,6 +426,12 @@ object HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + // When we compare 2 decimal types with different precisions, cast them to the smallest // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), @@ -672,20 +665,44 @@ object HiveTypeCoercion { } /** - * Casts types according to the expected input types for Expressions that have the trait - * [[ExpectsInputTypes]]. + * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => + findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + + case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => // If we cannot do the implicit cast, just use the original input. implicitCast(in, expected).getOrElse(in) } e.withNewChildren(children) + + case e: ExpectsInputTypes if e.inputTypes.nonEmpty => + // Convert NullType into some specific target type for ExpectsInputTypes that don't do + // general implicit casting. + val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => + if (in.dataType == NullType && !expected.acceptsType(NullType)) { + Literal.create(null, expected.defaultConcreteType) + } else { + in + } + } + e.withNewChildren(children) } /** @@ -702,27 +719,22 @@ object HiveTypeCoercion { @Nullable val ret: Expression = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.isSameType(inType) => e + case _ if expectedType.acceptsType(inType) => e // Cast null type (usually from null literals) into target types case (NullType, target) => Cast(e, target.defaultConcreteType) - // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is - // already a number, leave it as is. - case (_: NumericType, NumericType) => e - // If the function accepts any numeric type and the input is a string, we follow the hive // convention and cast that input into a double case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) - // Implicit cast among numeric types + // Implicit cast among numeric types. When we reach here, input type is not acceptable. + // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to unlimited precision decimal. - case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => - Cast(e, DecimalType.Unlimited) + case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long - case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) - case (_: NumericType, target: NumericType) => e + case (_: NumericType, target: NumericType) => Cast(e, target) // Implicit cast between date time types case (DateType, TimestampType) => Cast(e, TimestampType) @@ -736,15 +748,9 @@ object HiveTypeCoercion { case (StringType, BinaryType) => Cast(e, BinaryType) case (any, StringType) if any != StringType => Cast(e, StringType) - // Type collection. - // First see if we can find our input type in the type collection. If we can, then just - // use the current expression; otherwise, find the first one we can implicitly cast. - case (_, TypeCollection(types)) => - if (types.exists(_.isSameType(inType))) { - e - } else { - types.flatMap(implicitCast(e, _)).headOption.orNull - } + // When we reach here, input type is not acceptable for any types in this type collection, + // try to find the first one we can implicitly cast. + case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull // Else, just return the same input expression case _ => null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 3eb0eb195c80d..ded89e85dea79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -19,10 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types.AbstractDataType - +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts /** * An trait that gets mixin to define the expected input types of an expression. + * + * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define + * expected input types without any implicit casting. + * + * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead. */ trait ExpectsInputTypes { self: Expression => @@ -40,7 +45,7 @@ trait ExpectsInputTypes { self: Expression => val mismatches = children.zip(inputTypes).zipWithIndex.collect { case ((child, expected), idx) if !expected.acceptsType(child.dataType) => s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " + - s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." + s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}." } if (mismatches.isEmpty) { @@ -50,3 +55,11 @@ trait ExpectsInputTypes { self: Expression => } } } + + +/** + * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]]. + */ +trait ImplicitCastInputTypes extends ExpectsInputTypes { self: Expression => + // No other methods +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 54ec10444c4f3..a655cc8e48ae1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -24,8 +24,20 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines the basic expression abstract classes in Catalyst, including: +// Expression: the base expression abstract class +// LeafExpression +// UnaryExpression +// BinaryExpression +// BinaryOperator +// +// For details, see their classdocs. +//////////////////////////////////////////////////////////////////////////////////////////////////// /** + * An expression in Catalyst. + * * If an expression wants to be exposed in the function registry (so users can call it with * "name(arguments...)", the concrete implementation must be a case class whose constructor * arguments are all Expressions types. @@ -49,9 +61,15 @@ abstract class Expression extends TreeNode[Expression] { def foldable: Boolean = false /** - * Returns true when the current expression always return the same result for fixed input values. + * Returns true when the current expression always return the same result for fixed inputs from + * children. + * + * Note that this means that an expression should be considered as non-deterministic if: + * - if it relies on some mutable internal state, or + * - if it relies on some implicit input that is not part of the children expression list. + * + * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. */ - // TODO: Need to define explicit input values vs implicit input values. def deterministic: Boolean = true def nullable: Boolean @@ -169,8 +187,10 @@ abstract class Expression extends TreeNode[Expression] { /** * A leaf expression, i.e. one without any child expressions. */ -abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { +abstract class LeafExpression extends Expression { self: Product => + + def children: Seq[Expression] = Nil } @@ -178,9 +198,13 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] * An expression with one input and one output. The output is by default evaluated to null * if the input is evaluated to null. */ -abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { +abstract class UnaryExpression extends Expression { self: Product => + def child: Expression + + override def children: Seq[Expression] = child :: Nil + override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable @@ -253,9 +277,14 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio * An expression with two inputs and one output. The output is by default evaluated to null * if any input is evaluated to null. */ -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { +abstract class BinaryExpression extends Expression { self: Product => + def left: Expression + def right: Expression + + override def children: Seq[Expression] = Seq(left, right) + override def foldable: Boolean = left.foldable && right.foldable override def nullable: Boolean = left.nullable || right.nullable @@ -335,15 +364,39 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express /** - * An expression that has two inputs that are expected to the be same type. If the two inputs have - * different types, the analyzer will find the tightest common type and do the proper type casting. + * A [[BinaryExpression]] that is an operator, with two properties: + * + * 1. The string representation is "x symbol y", rather than "funcName(x, y)". + * 2. Two inputs are expected to the be same type. If the two inputs have different types, + * the analyzer will find the tightest common type and do the proper type casting. */ -abstract class BinaryOperator extends BinaryExpression { +abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { self: Product => + /** + * Expected input type from both left/right child expressions, similar to the + * [[ImplicitCastInputTypes]] trait. + */ + def inputType: AbstractDataType + def symbol: String override def toString: String = s"($left $symbol $right)" + + override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) + + override def checkInputDataTypes(): TypeCheckResult = { + // First check whether left and right have the same type, then check if the type is acceptable. + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") + } else if (!inputType.acceptsType(left.dataType)) { + TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," + + s" not ${left.dataType.simpleString}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 6fb3343bb63f2..22687acd68a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes { + inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8476af4a5d8d6..382cbe3b84a07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,23 +18,19 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -abstract class UnaryArithmetic extends UnaryExpression { - self: Product => + +case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def dataType: DataType = child.dataType -} -case class UnaryMinus(child: Expression) extends UnaryArithmetic { override def toString: String = s"-$child" - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "operator -") - private lazy val numeric = TypeUtils.getNumeric(dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { @@ -45,9 +41,13 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { protected override def nullSafeEval(input: Any): Any = numeric.negate(input) } -case class UnaryPositive(child: Expression) extends UnaryArithmetic { +case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def prettyName: String = "positive" + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def dataType: DataType = child.dataType + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -57,9 +57,11 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { /** * A function that get the absolute value of the numeric value. */ -case class Abs(child: Expression) extends UnaryArithmetic { - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function abs") +case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def dataType: DataType = child.dataType private lazy val numeric = TypeUtils.getNumeric(dataType) @@ -71,18 +73,6 @@ abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in ${this.getClass.getSimpleName} " + - s"(${left.dataType} and ${right.dataType}).") - } else { - checkTypesInternal(dataType) - } - } - - protected def checkTypesInternal(t: DataType): TypeCheckResult - /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") @@ -104,62 +94,61 @@ private[sql] object BinaryArithmetic { } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "+" override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "-" override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "*" override def decimalMethod: String = "$times" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "/" override def decimalMethod: String = "$div" - override def nullable: Boolean = true override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot @@ -215,17 +204,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { + + override def inputType: AbstractDataType = NumericType + override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForNumericExpr(t, "operator " + symbol) - private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] @@ -281,10 +269,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = left.nullable && right.nullable + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(t, "function maxOf") + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def nullable: Boolean = left.nullable && right.nullable private lazy val ordering = TypeUtils.getOrdering(dataType) @@ -331,14 +320,14 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "max" - override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { - override def nullable: Boolean = left.nullable && right.nullable + // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least. - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(t, "function minOf") + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def nullable: Boolean = left.nullable && right.nullable private lazy val ordering = TypeUtils.getOrdering(dataType) @@ -385,5 +374,98 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } override def symbol: String = "min" - override def prettyName: String = symbol +} + +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { + + override def toString: String = s"pmod($left, $right)" + + override def symbol: String = "pmod" + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "pmod") + + override def inputType: AbstractDataType = NumericType + + protected override def nullSafeEval(left: Any, right: Any) = + dataType match { + case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int]) + case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long]) + case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short]) + case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte]) + case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float]) + case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double]) + case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + dataType match { + case dt: DecimalType => + val decimalAdd = "$plus" + s""" + ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); + if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { + ${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2); + } else { + ${ev.primitive} = r; + } + """ + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + s""" + ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); + if (r < 0) { + ${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); + } else { + ${ev.primitive} = r; + } + """ + case _ => + s""" + ${ctx.javaType(dataType)} r = $eval1 % $eval2; + if (r < 0) { + ${ev.primitive} = (r + $eval2) % $eval2; + } else { + ${ev.primitive} = r; + } + """ + } + }) + } + + private def pmod(a: Int, n: Int): Int = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Long, n: Long): Long = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Byte, n: Byte): Byte = { + val r = a % n + if (r < 0) {((r + n) % n).toByte} else r.toByte + } + + private def pmod(a: Double, n: Double): Double = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Short, n: Short): Short = { + val r = a % n + if (r < 0) {((r + n) % n).toShort} else r.toShort + } + + private def pmod(a: Float, n: Float): Float = { + val r = a % n + if (r < 0) {(r + n) % n} else r + } + + private def pmod(a: Decimal, n: Decimal): Decimal = { + val r = a % n + if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index 2d47124d247e7..a1e48c4210877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -29,10 +27,10 @@ import org.apache.spark.sql.types._ * Code generation inherited from BinaryArithmetic. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "&" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "&" private lazy val and: (Any, Any) => Any = dataType match { case ByteType => @@ -54,10 +52,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * Code generation inherited from BinaryArithmetic. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "|" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "|" private lazy val or: (Any, Any) => Any = dataType match { case ByteType => @@ -79,10 +77,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * Code generation inherited from BinaryArithmetic. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - override def symbol: String = "^" - protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "^" private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => @@ -101,11 +99,13 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryArithmetic { - override def toString: String = s"~$child" +case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) + + override def dataType: DataType = child.dataType + + override def toString: String = s"~$child" private lazy val not: (Any) => Any = dataType match { case ByteType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9f6329bbda4ec..328d635de8743 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -56,6 +56,18 @@ class CodeGenContext { */ val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + /** + * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a + * 3-tuple: java type, variable name, code to init it. + * They will be kept as member variables in generated classes like `SpecificProjection`. + */ + val mutableStates: mutable.ArrayBuffer[(String, String, String)] = + mutable.ArrayBuffer.empty[(String, String, String)] + + def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = { + mutableStates += ((javaType, variableName, initialValue)) + } + val stringType: String = classOf[UTF8String].getName val decimalType: String = classOf[Decimal].getName @@ -203,7 +215,10 @@ class CodeGenContext { def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) } - +/** + * A wrapper for generated class, defines a `generate` method so that we can pass extra objects + * into generated class. + */ abstract class GeneratedClass { def generate(expressions: Array[Expression]): Any } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index addb8023d9c0b..71e47d4f9b620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -46,6 +46,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public Object generate($exprType[] expr) { return new SpecificProjection(expr); @@ -55,6 +58,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu private $exprType[] expressions = null; private $mutableRowType mutableRow = null; + $mutableStates public SpecificProjection($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index d05dfc108e63a..856ff9f1f96f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -46,30 +46,47 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { val ctx = newCodeGenContext() - val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = order.child.gen(ctx) - val evalB = order.child.gen(ctx) + val comparisons = ordering.map { order => + val eval = order.child.gen(ctx) val asc = order.direction == Ascending + val isNullA = ctx.freshName("isNullA") + val primitiveA = ctx.freshName("primitiveA") + val isNullB = ctx.freshName("isNullB") + val primitiveB = ctx.freshName("primitiveB") s""" i = a; - ${evalA.code} + boolean $isNullA; + ${ctx.javaType(order.child.dataType)} $primitiveA; + { + ${eval.code} + $isNullA = ${eval.isNull}; + $primitiveA = ${eval.primitive}; + } i = b; - ${evalB.code} - if (${evalA.isNull} && ${evalB.isNull}) { + boolean $isNullB; + ${ctx.javaType(order.child.dataType)} $primitiveB; + { + ${eval.code} + $isNullB = ${eval.isNull}; + $primitiveB = ${eval.primitive}; + } + if ($isNullA && $isNullB) { // Nothing - } else if (${evalA.isNull}) { + } else if ($isNullA) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.isNull}) { + } else if ($isNullB) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { - int comp = ${ctx.genComp(order.child.dataType, evalA.primitive, evalB.primitive)}; + int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)}; if (comp != 0) { return ${if (asc) "comp" else "-comp"}; } } """ }.mkString("\n") - + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); @@ -78,6 +95,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; + $mutableStates public SpecificOrdering($exprType[] expr) { expressions = expr; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 274a42cb69087..9e5a745d512e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,6 +40,9 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") val code = s""" public SpecificPredicate generate($exprType[] expr) { return new SpecificPredicate(expr); @@ -47,6 +50,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; + $mutableStates public SpecificPredicate($exprType[] expr) { expressions = expr; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3c7ee9cc16599..3e5ca308dc31d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -151,6 +151,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") + val code = s""" public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); @@ -158,6 +162,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; + $mutableStates public SpecificProjection($exprType[] expr) { expressions = expr; @@ -165,65 +170,65 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { - return new SpecificRow(expressions, (InternalRow) r); + return new SpecificRow((InternalRow) r); } - } - final class SpecificRow extends ${classOf[MutableRow].getName} { + final class SpecificRow extends ${classOf[MutableRow].getName} { - $columns + $columns - public SpecificRow($exprType[] expressions, InternalRow i) { - $initColumns - } + public SpecificRow(InternalRow i) { + $initColumns + } - public int length() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } + public int length() { return ${expressions.length};} + protected boolean[] nullBits = new boolean[${expressions.length}]; + public void setNullAt(int i) { nullBits[i] = true; } + public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases + public Object get(int i) { + if (isNullAt(i)) return null; + switch (i) { + $getCases + } + return null; } - return null; - } - public void update(int i, Object value) { - if (value == null) { - setNullAt(i); - return; + public void update(int i, Object value) { + if (value == null) { + setNullAt(i); + return; + } + nullBits[i] = false; + switch (i) { + $updateCases + } } - nullBits[i] = false; - switch (i) { - $updateCases + $specificAccessorFunctions + $specificMutatorFunctions + + @Override + public int hashCode() { + int result = 37; + $hashUpdates + return result; } - } - $specificAccessorFunctions - $specificMutatorFunctions - - @Override - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - @Override - public boolean equals(Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; + @Override + public boolean equals(Object other) { + if (other instanceof SpecificRow) { + SpecificRow row = (SpecificRow) other; + $columnChecks + return true; + } + return super.equals(other); } - return super.equals(other); - } - @Override - public InternalRow copy() { - Object[] arr = new Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); + @Override + public InternalRow copy() { + Object[] arr = new Object[${expressions.length}]; + ${copyColumns} + return new ${classOf[GenericInternalRow].getName}(arr); + } } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index c7f039ede26b3..9162b73fe56eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + + s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).") } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index c31890e27fb54..a7ad452ef4943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -55,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends UnaryExpression with Serializable with ImplicitCastInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -89,7 +91,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with ImplicitCastInputTypes { self: Product => override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -174,7 +176,7 @@ object Factorial { ) } -case class Factorial(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -251,7 +253,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with ExpectsInputTypes { + extends UnaryExpression with Serializable with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType @@ -285,7 +287,7 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = @@ -329,7 +331,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -416,7 +418,7 @@ case class Pow(left: Expression, right: Expression) * @param right number of bits to left shift. */ case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -442,7 +444,7 @@ case class ShiftLeft(left: Expression, right: Expression) * @param right number of bits to left shift. */ case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -468,7 +470,7 @@ case class ShiftRight(left: Expression, right: Expression) * @param right the number of bits to right shift. */ case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression) """ } } + +/** + * Round the `child`'s result to `scale` decimal place when `scale` >= 0 + * or round at integral part when `scale` < 0. + * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * + * Child of IntegralType would eval to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * + * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], + * which leads to scale update in DecimalType's [[PrecisionInfo]] + * + * @param child expr to be round, all [[NumericType]] is allowed as Input + * @param scale new scale to be round to, this should be a constant int at runtime + */ +case class Round(child: Expression, scale: Expression) + extends BinaryExpression with ExpectsInputTypes { + + import BigDecimal.RoundingMode.HALF_UP + + def this(child: Expression) = this(child, Literal(0)) + + override def left: Expression = child + override def right: Expression = scale + + // round of Decimal would eval to null if it fails to `changePrecision` + override def nullable: Boolean = true + + override def foldable: Boolean = child.foldable + + override lazy val dataType: DataType = child.dataType match { + // if the new scale is bigger which means we are scaling up, + // keep the original scale as `Decimal` does + case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) + case t => t + } + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckSuccess => + if (scale.foldable) { + TypeCheckSuccess + } else { + TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + } + case f => f + } + } + + // Avoid repeated evaluation since `scale` is a constant int, + // avoid unnecessary `child` evaluation in both codegen and non-codegen eval + // by checking if scaleV == null as well. + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] + + override def eval(input: InternalRow): Any = { + if (scaleV == null) { // if scale is null, no need to eval its child at all + null + } else { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + nullSafeEval(evalE) + } + } + } + + // not overriding since _scale is a constant int at runtime + def nullSafeEval(input1: Any): Any = { + child.dataType match { + case _: DecimalType => + val decimal = input1.asInstanceOf[Decimal] + if (decimal.changePrecision(decimal.precision, _scale)) decimal else null + case ByteType => + BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte + case ShortType => + BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort + case IntegerType => + BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt + case LongType => + BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong + case FloatType => + val f = input1.asInstanceOf[Float] + if (f.isNaN || f.isInfinite) { + f + } else { + BigDecimal(f).setScale(_scale, HALF_UP).toFloat + } + case DoubleType => + val d = input1.asInstanceOf[Double] + if (d.isNaN || d.isInfinite) { + d + } else { + BigDecimal(d).setScale(_scale, HALF_UP).toDouble + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val ce = child.gen(ctx) + + val evaluationCode = child.dataType match { + case _: DecimalType => + s""" + if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) { + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.isNull} = true; + }""" + case ByteType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case ShortType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case IntegerType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case LongType => + if (_scale < 0) { + s""" + ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" + } else { + s"${ev.primitive} = ${ce.primitive};" + } + case FloatType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + }""" + } + case DoubleType => // if child eval to NaN or Infinity, just return it. + if (_scale == 0) { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + }""" + } else { + s""" + if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ + ${ev.primitive} = ${ce.primitive}; + } else { + ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}). + setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + }""" + } + } + + if (scaleV == null) { // if scale is null, no need to eval its child at all + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + s""" + ${ce.code} + boolean ${ev.isNull} = ${ce.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluationCode + } + """ + } + } + + override def prettyName: String = "round" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 3b59cd431b871..a269ec4a1e6dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ -case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -55,7 +55,7 @@ case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ExpectsInputTypes { + extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -118,7 +118,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -138,7 +138,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputType * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ -case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f74fd04619714..aa6c30e2f79f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -33,12 +33,17 @@ object InterpretedPredicate { } } + +/** + * An [[Expression]] that returns a boolean value. + */ trait Predicate extends Expression { self: Product => override def dataType: DataType = BooleanType } + trait PredicateHelper { protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { condition match { @@ -70,7 +75,10 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } -case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { + +case class Not(child: Expression) + extends UnaryExpression with Predicate with ImplicitCastInputTypes { + override def toString: String = s"NOT $child" override def inputTypes: Seq[DataType] = Seq(BooleanType) @@ -82,6 +90,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } + /** * Evaluates to `true` if `list` contains `value`. */ @@ -97,6 +106,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } + /** * Optimized version of In clause, when all filter values of In clause are * static. @@ -112,12 +122,12 @@ case class InSet(child: Expression, hset: Set[Any]) } } -case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { - override def toString: String = s"($left && $right)" +case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { + + override def inputType: AbstractDataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def symbol: String = "&&" override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -161,12 +171,12 @@ case class And(left: Expression, right: Expression) } } -case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { - override def toString: String = s"($left || $right)" +case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { - override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputType: AbstractDataType = BooleanType + + override def symbol: String = "||" override def eval(input: InternalRow): Any = { val input1 = left.eval(input) @@ -210,21 +220,10 @@ case class Or(left: Expression, right: Expression) } } + abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => - override def checkInputDataTypes(): TypeCheckResult = { - if (left.dataType != right.dataType) { - TypeCheckResult.TypeCheckFailure( - s"differing types in ${this.getClass.getSimpleName} " + - s"(${left.dataType} and ${right.dataType}).") - } else { - checkTypesInternal(dataType) - } - } - - protected def checkTypesInternal(t: DataType): TypeCheckResult - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isPrimitiveType(left.dataType)) { // faster version @@ -235,10 +234,12 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } } + private[sql] object BinaryComparison { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } + /** An extractor that matches both standard 3VL equality and null-safe equality. */ private[sql] object Equality { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { @@ -248,10 +249,12 @@ private[sql] object Equality { } } + case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + override def inputType: AbstractDataType = AnyDataType + + override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { if (left.dataType != BinaryType) input1 == input2 @@ -263,13 +266,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } + case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { + + override def inputType: AbstractDataType = AnyDataType + override def symbol: String = "<=>" override def nullable: Boolean = false - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess - override def eval(input: InternalRow): Any = { val input1 = left.eval(input) val input2 = right.eval(input) @@ -298,44 +303,48 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } + case class LessThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = "<" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } + case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "<=" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = "<=" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } + case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = ">" private lazy val ordering = TypeUtils.getOrdering(left.dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } + case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = ">=" - override protected def checkTypesInternal(t: DataType) = - TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) + override def inputType: AbstractDataType = TypeCollection.Ordered + + override def symbol: String = ">=" private lazy val ordering = TypeUtils.getOrdering(left.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 6cdc3000382e2..e10ba55396664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -38,11 +39,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize it. */ - @transient protected lazy val partitionId = TaskContext.get() match { - case null => 0 - case _ => TaskContext.get().partitionId() - } - @transient protected lazy val rng = new XORShiftRandom(seed + partitionId) + @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) override def deterministic: Boolean = false @@ -61,6 +58,17 @@ case class Rand(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + val className = classOf[XORShiftRandom].getCanonicalName + ctx.addMutableState(className, rngTerm, + s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); + """ + } } /** Generate a random column with i.i.d. gaussian random distribution. */ @@ -73,4 +81,15 @@ case class Randn(seed: Long) extends RDG(seed) { case IntegerLiteral(s) => s case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rngTerm = ctx.freshName("rng") + val className = classOf[XORShiftRandom].getCanonicalName + ctx.addMutableState(className, rngTerm, + s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index f64899c1ed84c..c64afe7b3f19a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.text.DecimalFormat import java.util.Locale import java.util.regex.Pattern -import org.apache.commons.lang3.StringUtils - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -29,7 +28,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ExpectsInputTypes { +trait StringRegexExpression extends ImplicitCastInputTypes { self: BinaryExpression => def escape(v: String): String @@ -105,7 +104,7 @@ case class RLike(left: Expression, right: Expression) override def toString: String = s"$left RLIKE $right" } -trait String2StringExpression extends ExpectsInputTypes { +trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -142,7 +141,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ExpectsInputTypes { +trait StringComparison extends ImplicitCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -241,7 +240,7 @@ case class StringTrimRight(child: Expression) * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = substr @@ -265,7 +264,7 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) @@ -306,7 +305,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -344,7 +343,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -413,7 +412,7 @@ case class StringFormat(children: Expression*) extends Expression { * Returns the string which repeat the given string value n times. */ case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = times @@ -447,7 +446,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ -case class StringSpace(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -467,7 +466,7 @@ case class StringSpace(child: Expression) extends UnaryExpression with ExpectsIn * Splits str around pat (pattern is a regular expression). */ case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = pattern @@ -488,7 +487,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } /** - * A function that return the length of the given string expression. + * A function that return the length of the given string or binary expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType - override def inputTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) - protected override def nullSafeEval(string: Any): Any = - string.asInstanceOf[UTF8String].numChars + protected override def nullSafeEval(value: Any): Any = child.dataType match { + case StringType => value.asInstanceOf[UTF8String].numChars + case BinaryType => value.asInstanceOf[Array[Byte]].length + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"($c).numChars()") + child.dataType match { + case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") + case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") + } } override def prettyName: String = "length" @@ -573,7 +577,7 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI * A function that return the Levenshtein distance between the two given strings. */ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ExpectsInputTypes { + with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -591,7 +595,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * Returns the numeric value of the first character of str. */ -case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -608,7 +612,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTyp /** * Converts the argument from binary to a base 64 string. */ -case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -622,7 +626,7 @@ case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTy /** * Converts the argument from a base 64 string to BINARY. */ -case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -636,7 +640,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput * If either argument is null, the result will also be null. */ case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = bin override def right: Expression = charset @@ -655,7 +659,7 @@ case class Decode(bin: Expression, charset: Expression) * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = value override def right: Expression = charset @@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression) } } +/** + * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, + * and returns the result as a string. If D is 0, the result has no decimal point or + * fractional part. + */ +case class FormatNumber(x: Expression, d: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = x + override def right: Expression = d + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + + // Associated with the pattern, for the last d value, and we will update the + // pattern (DecimalFormat) once the new coming d value differ with the last one. + @transient + private var lastDValue: Int = -100 + + // A cached DecimalFormat, for performance concern, we will change it + // only if the d value changed. + @transient + private val pattern: StringBuffer = new StringBuffer() + + @transient + private val numberFormat: DecimalFormat = new DecimalFormat("") + + override def eval(input: InternalRow): Any = { + val xObject = x.eval(input) + if (xObject == null) { + return null + } + + val dObject = d.eval(input) + + if (dObject == null || dObject.asInstanceOf[Int] < 0) { + return null + } + val dValue = dObject.asInstanceOf[Int] + + if (dValue != lastDValue) { + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length()) + pattern.append("#,###,###,###,###,###,##0") + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + val dFormat = new DecimalFormat(pattern.toString()) + lastDValue = dValue; + numberFormat.applyPattern(dFormat.toPattern()) + } + + x.dataType match { + case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte])) + case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short])) + case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float])) + case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int])) + case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long])) + case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double])) + case _: DecimalType => + UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal)) + } + } + + override def prettyName: String = "format_number" +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e911b907e8536..d7077a0ec907a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -291,6 +291,11 @@ abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] { /** * A logical plan node with a left and right child. */ -abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] { +abstract class BinaryNode extends LogicalPlan { self: Product => + + def left: LogicalPlan + def right: LogicalPlan + + override def children: Seq[LogicalPlan] = Seq(left, right) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 09f6c6b0ec423..16844b2f4b680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -453,15 +453,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } } -/** - * A [[TreeNode]] that has two children, [[left]] and [[right]]. - */ -trait BinaryNode[BaseType <: TreeNode[BaseType]] { - def left: BaseType - def right: BaseType - - def children: Seq[BaseType] = Seq(left, right) -} /** * A [[TreeNode]] with no children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 3148309a2166f..0103ddcf9cfb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -32,14 +32,6 @@ object TypeUtils { } } - def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { - if (t.isInstanceOf[IntegralType] || t == NullType) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") - } - } - def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[AtomicType] || t == NullType) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 32f87440b4e37..076d7b5a5118d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType { private[sql] def defaultConcreteType: DataType /** - * Returns true if this data type is the same type as `other`. This is different that equality - * as equality will also consider data type parametrization, such as decimal precision. + * Returns true if `other` is an acceptable input type for a function that expects this, + * possibly abstract DataType. * * {{{ * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) - * - * // this should return false - * NumericType.isSameType(DecimalType(10, 2)) - * }}} - */ - private[sql] def isSameType(other: DataType): Boolean - - /** - * Returns true if `other` is an acceptable input type for a function that expectes this, - * possibly abstract, DataType. - * - * {{{ - * // this should return true - * DecimalType.isSameType(DecimalType(10, 2)) + * DecimalType.acceptsType(DecimalType(10, 2)) * * // this should return true as well * NumericType.acceptsType(DecimalType(10, 2)) * }}} */ - private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) + private[sql] def acceptsType(other: DataType): Boolean /** Readable string representation for the type. */ private[sql] def simpleString: String @@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = - types.exists(_.isSameType(other)) + types.exists(_.acceptsType(other)) override private[sql] def simpleString: String = { types.map(_.simpleString).mkString("(", " or ", ")") @@ -96,6 +80,17 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) private[sql] object TypeCollection { + /** + * Types that can be ordered/compared. In the long run we should probably make this a trait + * that can be mixed into each data type, and perhaps create an [[AbstractDataType]]. + */ + val Ordered = TypeCollection( + BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType, + TimestampType, DateType, + StringType, BinaryType) + def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { @@ -105,6 +100,21 @@ private[sql] object TypeCollection { } +/** + * An [[AbstractDataType]] that matches any concrete data types. + */ +protected[sql] object AnyDataType extends AbstractDataType { + + // Note that since AnyDataType matches any concrete types, defaultConcreteType should never + // be invoked. + override private[sql] def defaultConcreteType: DataType = throw new UnsupportedOperationException + + override private[sql] def simpleString: String = "any" + + override private[sql] def acceptsType(other: DataType): Boolean = true +} + + /** * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. */ @@ -148,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType { override private[sql] def simpleString: String = "numeric" - override private[sql] def isSameType(other: DataType): Boolean = false - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] } -private[sql] object IntegralType { +private[sql] object IntegralType extends AbstractDataType { /** * Enables matching against IntegralType for expressions: * {{{ @@ -163,6 +171,12 @@ private[sql] object IntegralType { * }}} */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] + + override private[sql] def defaultConcreteType: DataType = IntegerType + + override private[sql] def simpleString: String = "integral" + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 76ca7a84c1d1a..5094058164b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[ArrayType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index da83a7f0ba379..2d133eea19fe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this - override private[sql] def isSameType(other: DataType): Boolean = this == other + override private[sql] def acceptsType(other: DataType): Boolean = this == other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index f5bd068d60dc4..a85af9e04aedb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.math.{MathContext, RoundingMode} - import org.apache.spark.annotation.DeveloperApi /** @@ -138,14 +136,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { } def toBigDecimal: BigDecimal = { - if (decimalVal.ne(null)) { - decimalVal(MathContext.UNLIMITED) - } else { - BigDecimal(longVal, _scale)(MathContext.UNLIMITED) - } - } - - def toLimitedBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { decimalVal } else { @@ -273,15 +263,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) - def / (that: Decimal): Decimal = { - if (that.isZero) { - null - } else { - // To avoid non-terminating decimal expansion problem, we get scala's BigDecimal with limited - // precision and scala. - Decimal(toLimitedBigDecimal / that.toLimitedBigDecimal) - } - } + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) def % (that: Decimal): Decimal = if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index a1cafeab1704d..377c75f6e85a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = Unlimited - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[DecimalType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ddead10bc2171..ac34b642827ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -71,7 +71,7 @@ object MapType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType) - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[MapType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b8097403ec3cc..2ef97a427c37e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -307,7 +307,7 @@ object StructType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = new StructType - override private[sql] def isSameType(other: DataType): Boolean = { + override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[StructType] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 9d0c69a2451d1..f0f17103991ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ case class TestFunction( children: Seq[Expression], - inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes { + inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def dataType: DataType = StringType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8e0551b23eea6..ed0d20e7de80e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{TypeCollection, StringType} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -49,13 +49,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { def assertErrorForDifferingTypes(expr: Expression): Unit = { assertError(expr, - s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + s"differing types in '${expr.prettyString}'") } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "operator - accepts numeric type") - assertError(Abs('stringField), "function abs accepts numeric type") - assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + assertError(UnaryMinus('stringField), "expected to be of type numeric") + assertError(Abs('stringField), "expected to be of type numeric") + assertError(BitwiseNot('stringField), "expected to be of type integral") } test("check types for binary arithmetic") { @@ -78,18 +78,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") - assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") - assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") - assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") - assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + assertError(Add('booleanField, 'booleanField), "accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") - assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") - assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") - assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type") - assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") - assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + assertError(MaxOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(MinOf('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") } test("check types for predicates") { @@ -105,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertError(EqualTo('intField, 'complexField), "differing types") - assertError(EqualNullSafe('intField, 'complexField), "differing types") - + assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError( - LessThan('complexField, 'complexField), "operator < accepts non-complex type") - assertError( - LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") - assertError( - GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") - assertError( - GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + assertError(LessThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(LessThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThan('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") + assertError(GreaterThanOrEqual('complexField, 'complexField), + s"accepts ${TypeCollection.Ordered.simpleString} type") - assertError( - If('intField, 'stringField, 'stringField), + assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) @@ -171,4 +171,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), "Odd position only allow foldable and not-null StringType expressions") } + + test("check types for ROUND") { + assertSuccess(Round(Literal(null), Literal(null))) + assertSuccess(Round('intField, Literal(1))) + + assertError(Round('intField, 'intField), "Only foldable Expression is allowed") + assertError(Round('intField, 'booleanField), "expected to be of type int") + assertError(Round('intField, 'complexField), "expected to be of type int") + assertError(Round('booleanField, 'intField), "expected to be of type numeric") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index acb9a433de903..d0fd033b981c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -194,6 +194,30 @@ class HiveTypeCoercionSuite extends PlanTest { Project(Seq(Alias(transformed, "a")()), testRelation)) } + test("cast NullType for expresions that implement ExpectsInputTypes") { + import HiveTypeCoercionSuite._ + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + AnyTypeUnaryExpression(Literal.create(null, NullType)), + AnyTypeUnaryExpression(Literal.create(null, NullType))) + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + NumericTypeUnaryExpression(Literal.create(null, NullType)), + NumericTypeUnaryExpression(Literal.create(null, DoubleType))) + } + + test("cast NullType for binary operators") { + import HiveTypeCoercionSuite._ + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) + + ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), + NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) + } + test("coalesce casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) @@ -302,3 +326,33 @@ class HiveTypeCoercionSuite extends PlanTest { ) } } + + +object HiveTypeCoercionSuite { + + case class AnyTypeUnaryExpression(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def dataType: DataType = NullType + } + + case class NumericTypeUnaryExpression(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def dataType: DataType = NullType + } + + case class AnyTypeBinaryOperator(left: Expression, right: Expression) + extends BinaryOperator { + override def dataType: DataType = NullType + override def inputType: AbstractDataType = AnyDataType + override def symbol: String = "anytype" + } + + case class NumericTypeBinaryOperator(left: Expression, right: Expression) + extends BinaryOperator { + override def dataType: DataType = NullType + override def inputType: AbstractDataType = NumericType + override def symbol: String = "numerictype" + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6c93698f8017b..e7e5231d32c9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.Decimal - class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { /** @@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), Array(1.toByte, 2.toByte)) } + + test("pmod") { + testNumericDataTypes { convert => + val left = Literal(convert(7)) + val right = Literal(convert(3)) + checkEvaluation(Pmod(left, right), convert(1)) + checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null) + checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + checkEvaluation(Pmod(-7, 3), 2) + checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) + checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) + checkEvaluation(Pmod(2L, Long.MaxValue), 2) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 7ca9e30b2bcd5..52a874a9d89ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math.BigDecimal.RoundingMode + import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite @@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { null, create_row(null)) } + + test("round") { + val domain = -6 to 6 + val doublePi: Double = math.Pi + val shortPi: Short = 31415 + val intPi: Int = 314159265 + val longPi: Long = 31415926535897932L + val bdPi: BigDecimal = BigDecimal(31415927L, 7) + + val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, + 3.1416, 3.14159, 3.141593) + + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ + Seq.fill[Short](7)(31415) + + val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159270) ++ Seq.fill(7)(314159265) + + val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, + 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ + Seq.fill(7)(31415926535897932L) + + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) + + domain.zipWithIndex.foreach { case (scale, i) => + checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) + checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + } + + // round_scale > current_scale would result in precision increase + // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + } + (8 to 10).foreach { scale => + checkEvaluation(Round(bdPi, scale), null, EmptyRow) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index b19f4ee37a109..5d7763bedf6bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} +import org.apache.spark.sql.types._ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("length for string") { - val a = 'a.string.at(0) - checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) - checkEvaluation(StringLength(a), 5, create_row("abdef")) - checkEvaluation(StringLength(a), 0, create_row("")) - checkEvaluation(StringLength(a), null, create_row(null)) - checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) - } - test("ascii for string") { val a = 'a.string.at(0) checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) @@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) } + + test("length for string / binary") { + val a = 'a.string.at(0) + val b = 'b.binary.at(0) + val bytes = Array[Byte](1, 2, 3, 1, 2) + val string = "abdef" + + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(Length(Literal("a花花c")), 4, create_row(string)) + // scalastyle:on + checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]())) + + checkEvaluation(Length(a), 5, create_row(string)) + checkEvaluation(Length(b), 5, create_row(bytes)) + + checkEvaluation(Length(a), 0, create_row("")) + checkEvaluation(Length(b), 0, create_row(Array[Byte]())) + + checkEvaluation(Length(a), null, create_row(null)) + checkEvaluation(Length(b), null, create_row(null)) + + checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string)) + checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) + } + + test("number format") { + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000") + checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null) + checkEvaluation( + FormatNumber( + Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)), + "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) + checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index 030bb6d21b18b..1d297beb3868d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester import scala.language.postfixOps class DecimalSuite extends SparkFunSuite with PrivateMethodTester { - test("creating decimals") { - /** Check that a Decimal has the given string representation, precision and scale */ - def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { - assert(d.toString === string) - assert(d.precision === precision) - assert(d.scale === scale) - } + /** Check that a Decimal has the given string representation, precision and scale */ + private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) @@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) } + test("creating decimals with negative scale") { + checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3) + checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9) + checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10) + checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10) + checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10) + } + test("double and long values") { /** Check that a Decimal converts to the given double and long values */ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { @@ -162,22 +171,4 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } - - test("accurate precision after multiplication") { - val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal - assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") - } - - test("fix non-terminating decimal expansion problem") { - val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) - // The difference between decimal should not be more than 0.001. - assert(decimal.toDouble - 0.333 < 0.001) - } - - test("fix loss of precision/scale when doing division operation") { - val a = Decimal(2) / Decimal(3) - assert(a.toDouble < 1.0 && a.toDouble > 0.6) - val b = Decimal(1) / Decimal(8) - assert(b.toDouble === 0.125) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 43b62f0e822f8..92861ab038f19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -47,6 +47,7 @@ private[r] object SQLUtils { dataType match { case "byte" => org.apache.spark.sql.types.ByteType case "integer" => org.apache.spark.sql.types.IntegerType + case "float" => org.apache.spark.sql.types.FloatType case "double" => org.apache.spark.sql.types.DoubleType case "numeric" => org.apache.spark.sql.types.DoubleType case "character" => org.apache.spark.sql.types.StringType @@ -68,7 +69,7 @@ private[r] object SQLUtils { def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { val num = schema.fields.size - val rowRDD = rdd.map(bytesToRow) + val rowRDD = rdd.map(bytesToRow(_, schema)) sqlContext.createDataFrame(rowRDD, schema) } @@ -76,12 +77,20 @@ private[r] object SQLUtils { df.map(r => rowToRBytes(r)) } - private[this] def bytesToRow(bytes: Array[Byte]): Row = { + private[this] def doConversion(data: Object, dataType: DataType): Object = { + data match { + case d: java.lang.Double if dataType == FloatType => + new java.lang.Float(d) + case _ => data + } + } + + private[this] def bytesToRow(bytes: Array[Byte], schema: StructType): Row = { val bis = new ByteArrayInputStream(bytes) val dis = new DataInputStream(bis) val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => - SerDe.readObject(dis) + doConversion(SerDe.readObject(dis), schema.fields(i).dataType) }.toSeq) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4d7d8626a0ecc..9dc7879fa4a1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -247,6 +247,11 @@ private[sql] trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { override def outputPartitioning: Partitioning = child.outputPartitioning } -private[sql] trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { +private[sql] trait BinaryNode extends SparkPlan { self: Product => + + def left: SparkPlan + def right: SparkPlan + + override def children: Seq[SparkPlan] = Seq(left, right) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 437d143e53f3f..fec403fe2d348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{LongType, DataType} /** @@ -40,6 +41,10 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L + @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + + override def deterministic: Boolean = false + override def nullable: Boolean = false override def dataType: DataType = LongType @@ -47,6 +52,20 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { override def eval(input: InternalRow): Long = { val currentCount = count count += 1 - (TaskContext.get().partitionId().toLong << 33) + currentCount + partitionMask + currentCount + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val countTerm = ctx.freshName("count") + val partitionMaskTerm = ctx.freshName("partitionMask") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, "0L") + ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, + "((long) org.apache.spark.TaskContext.getPartitionId()) << 33") + + ev.isNull = "false" + s""" + final ${ctx.javaType(dataType)} ${ev.primitive} = $partitionMaskTerm + $countTerm; + $countTerm++; + """ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 822d3d8c9108d..7c790c549a5d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{IntegerType, DataType} @@ -28,9 +29,20 @@ import org.apache.spark.sql.types.{IntegerType, DataType} */ private[sql] case object SparkPartitionID extends LeafExpression { + override def deterministic: Boolean = false + override def nullable: Boolean = false override def dataType: DataType = IntegerType - override def eval(input: InternalRow): Int = TaskContext.get().partitionId() + @transient private lazy val partitionId = TaskContext.getPartitionId() + + override def eval(input: InternalRow): Int = partitionId + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val idTerm = ctx.freshName("partitionId") + ctx.addMutableState(ctx.JAVA_INT, idTerm, "org.apache.spark.TaskContext.getPartitionId()") + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0d4e160ed8057..d6da284a4c788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1371,6 +1371,23 @@ object functions { */ def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + + /** + * Returns the positive value of dividend mod divisor. + * + * @group math_funcs + * @since 1.5.0 + */ + def pmod(dividendColName: String, divisorColName: String): Column = + pmod(Column(dividendColName), Column(divisorColName)) + /** * Returns the double value that is closest in value to the argument and * is equal to a mathematical integer. @@ -1389,6 +1406,38 @@ object functions { */ def rint(columnName: String): Column = rint(Column(columnName)) + /** + * Returns the value of the column `e` rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column): Column = round(e.expr, 0) + + /** + * Returns the value of the given column rounded to 0 decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String): Column = round(Column(columnName), 0) + + /** + * Returns the value of `e` rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + + /** + * Returns the value of the given column rounded to `scale` decimal places. + * + * @group math_funcs + * @since 1.5.0 + */ + def round(columnName: String, scale: Int): Column = round(Column(columnName), scale) + /** * Shift the the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. @@ -1636,20 +1685,44 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Computes the length of a given string value. + * Computes the length of a given string / binary value. + * + * @group string_funcs + * @since 1.5.0 + */ + def length(e: Column): Column = Length(e.expr) + + /** + * Computes the length of a given string / binary column. + * + * @group string_funcs + * @since 1.5.0 + */ + def length(columnName: String): Column = length(Column(columnName)) + + /** + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. * * @group string_funcs * @since 1.5.0 */ - def strlen(e: Column): Column = StringLength(e.expr) + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) /** - * Computes the length of a given string column. + * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, + * and returns the result as a string. + * If d is 0, the result has no decimal point or fractional part. + * If d < 0, the result will be null. * * @group string_funcs * @since 1.5.0 */ - def strlen(columnName: String): Column = strlen(Column(columnName)) + def format_number(columnXName: String, d: Int): Column = { + format_number(Column(columnXName), d) + } /** * Computes the Levenshtein distance of the two given strings. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6cebec95d2850..6dccdd857b453 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest { Row(2743272264L, 2180413220L)) } - test("string length function") { - val df = Seq(("abc", "")).toDF("a", "b") - checkAnswer( - df.select(strlen($"a"), strlen("b")), - Row(3, 0)) - - checkAnswer( - df.selectExpr("length(a)", "length(b)"), - Row(3, 0)) - } - test("Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) @@ -403,4 +392,121 @@ class DataFrameFunctionsSuite extends QueryTest { Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } + + test("pmod") { + val intData = Seq((7, 3), (-7, 3)).toDF("a", "b") + checkAnswer( + intData.select(pmod('a, 'b)), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod('a, lit(3))), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.select(pmod(lit(-7), 'b)), + Seq(Row(2), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, b)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(a, 3)"), + Seq(Row(1), Row(2)) + ) + checkAnswer( + intData.selectExpr("pmod(-7, b)"), + Seq(Row(2), Row(2)) + ) + val doubleData = Seq((7.2, 4.1)).toDF("a", "b") + checkAnswer( + doubleData.select(pmod('a, 'b)), + Seq(Row(3.1000000000000005)) // same as hive + ) + checkAnswer( + doubleData.select(pmod(lit(2), lit(Int.MaxValue))), + Seq(Row(2)) + ) + } + + test("string / binary length function") { + val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") + checkAnswer( + df.select(length($"a"), length("a"), length($"b"), length("b")), + Row(3, 3, 4, 4)) + + checkAnswer( + df.selectExpr("length(a)", "length(b)"), + Row(3, 4)) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("length(c)"), // int type of the argument is unacceptable + Row("5.0000")) + } + } + + test("number format function") { + val tuple = + ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], + 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) + val df = + Seq(tuple) + .toDF( + "a", // string "aa" + "b", // byte 1 + "c", // short 2 + "d", // float 3.13223f + "e", // integer 4 + "f", // long 5L + "g", // double 6.48173d + "h") // decimal 7.128381 + + checkAnswer( + df.select( + format_number($"f", 4), + format_number("f", 4)), + Row("5.0000", "5.0000")) + + checkAnswer( + df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + Row("1.0000")) + + checkAnswer( + df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + Row("2.0000")) + + checkAnswer( + df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + Row("3.1322")) + + checkAnswer( + df.selectExpr("format_number(e, e)"), // not convert anything + Row("4.0000")) + + checkAnswer( + df.selectExpr("format_number(f, e)"), // not convert anything + Row("5.0000")) + + checkAnswer( + df.selectExpr("format_number(g, e)"), // not convert anything + Row("6.4817")) + + checkAnswer( + df.selectExpr("format_number(h, e)"), // not convert anything + Row("7.1284")) + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable + Row("5.0000")) + } + + intercept[AnalysisException] { + checkAnswer( + df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable + Row("5.0000")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 24bef21b999ea..087126bb2e513 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -198,6 +198,21 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(rint, math.rint) } + test("round") { + val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") + checkAnswer( + df.select(round('a), round('a, -1), round('a, -2)), + Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) + + val pi = 3.1415 + checkAnswer( + ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), + Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } @@ -375,6 +390,5 @@ class MathExpressionsSuite extends QueryTest { val df = Seq((1, -1, "abc")).toDF("a", "b", "c") checkAnswer(df.selectExpr("positive(a)"), Row(1)) checkAnswer(df.selectExpr("positive(b)"), Row(-1)) - checkAnswer(df.selectExpr("positive(c)"), Row("abc")) } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index c884c399281a8..4ada64bc21966 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -221,9 +221,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_when", "udf_case", - // Needs constant object inspectors - "udf_round", - // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive // is src(key STRING, value STRING), and in the reflect.q, it failed in // Integer.valueOf, which expect the first argument passed as STRING type not INT. @@ -918,8 +915,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_regexp_replace", "udf_repeat", "udf_rlike", - "udf_round", - // "udf_round_3", TODO: FIX THIS failed due to cast exception + // "udf_round", turn this on after we figure out null vs nan vs infinity + "udf_round_3", "udf_rpad", "udf_rtrim", "udf_second", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5bdf68c83fca7..4b7a782c805a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -301,9 +301,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) - // We're converting the entire table into ParquetRelation, so predicates to Hive metastore - // are empty. - val partitions = metastoreRelation.getHiveQlPartitions().map { p => + val partitions = metastoreRelation.hiveQlPartitions.map { p => val location = p.getLocation val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) @@ -646,6 +644,32 @@ private[hive] case class MetastoreRelation new Table(tTable) } + @transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p => + val tPartition = new org.apache.hadoop.hive.metastore.api.Partition + tPartition.setDbName(databaseName) + tPartition.setTableName(tableName) + tPartition.setValues(p.values) + + val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() + tPartition.setSd(sd) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + + sd.setLocation(p.storage.location) + sd.setInputFormat(p.storage.inputFormat) + sd.setOutputFormat(p.storage.outputFormat) + + val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo + sd.setSerdeInfo(serdeInfo) + serdeInfo.setSerializationLib(p.storage.serde) + + val serdeParameters = new java.util.HashMap[String, String]() + serdeInfo.setParameters(serdeParameters) + table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } + + new Partition(hiveQlTable, tPartition) + } + @transient override lazy val statistics: Statistics = Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) @@ -666,34 +690,6 @@ private[hive] case class MetastoreRelation } ) - def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { - table.getPartitions(predicates).map { p => - val tPartition = new org.apache.hadoop.hive.metastore.api.Partition - tPartition.setDbName(databaseName) - tPartition.setTableName(tableName) - tPartition.setValues(p.values) - - val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() - tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) - - sd.setLocation(p.storage.location) - sd.setInputFormat(p.storage.inputFormat) - sd.setOutputFormat(p.storage.outputFormat) - - val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo - sd.setSerdeInfo(serdeInfo) - serdeInfo.setSerializationLib(p.storage.serde) - - val serdeParameters = new java.util.HashMap[String, String]() - serdeInfo.setParameters(serdeParameters) - table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) } - - new Partition(hiveQlTable, tPartition) - } - } - /** Only compare database and tablename, not alias. */ override def sameResult(plan: LogicalPlan): Boolean = { plan match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index a357bb39ca7fd..d08c594151654 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -27,7 +27,6 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9638a8201e190..ed359620a5f7f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -125,7 +125,7 @@ private[hive] trait HiveStrategies { InterpretedPredicate.create(castedPredicate) } - val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part => + val partitions = relation.hiveQlPartitions.filter { part => val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { @@ -213,7 +213,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 1656587d14835..0a1d761a52f88 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -21,7 +21,6 @@ import java.io.PrintStream import java.util.{Map => JMap} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException} -import org.apache.spark.sql.catalyst.expressions.Expression private[hive] case class HiveDatabase( name: String, @@ -72,12 +71,7 @@ private[hive] case class HiveTable( def isPartitioned: Boolean = partitionColumns.nonEmpty - def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = { - predicates match { - case Nil => client.getAllPartitions(this) - case _ => client.getPartitionsByFilter(this, predicates) - } - } + def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) // Hive does not support backticks when passing names to the client. def qualifiedName: String = s"$database.$name" @@ -138,9 +132,6 @@ private[hive] trait ClientInterface { /** Returns all partitions for the given table. */ def getAllPartitions(hTable: HiveTable): Seq[HivePartition] - /** Returns partitions filtered by predicates for the given table. */ - def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition] - /** Loads a static partition into an existing table. */ def loadPartition( loadPath: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 8adda54754230..53f457ad4f3cc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -17,21 +17,25 @@ package org.apache.spark.sql.hive.client -import java.io.{File, PrintStream} -import java.util.{Map => JMap} +import java.io.{BufferedReader, InputStreamReader, File, PrintStream} +import java.net.URI +import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConversions._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.metastore.api.Database import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.metastore.{TableType => HTableType} +import org.apache.hadoop.hive.metastore.api +import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.hadoop.hive.ql.metadata import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Driver, metadata} +import org.apache.hadoop.hive.ql.processors._ +import org.apache.hadoop.hive.ql.Driver import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.Expression @@ -312,13 +316,6 @@ private[hive] class ClientWrapper( shim.getAllPartitions(client, qlTable).map(toHivePartition) } - override def getPartitionsByFilter( - hTable: HiveTable, - predicates: Seq[Expression]): Seq[HivePartition] = withHiveState { - val qlTable = toQlTable(hTable) - shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition) - } - override def listTables(dbName: String): Seq[String] = withHiveState { client.getAllTables(dbName) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index d12778c7583df..1fa9d278e2a57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -31,11 +31,6 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.serde.serdeConstants - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, IntegralType} /** * A shim that defines the interface between ClientWrapper and the underlying Hive library used to @@ -66,8 +61,6 @@ private[client] sealed abstract class Shim { def getAllPartitions(hive: Hive, table: Table): Seq[Partition] - def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition] - def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor def getDriverResults(driver: Driver): Seq[String] @@ -116,7 +109,7 @@ private[client] sealed abstract class Shim { } -private[client] class Shim_v0_12 extends Shim with Logging { +private[client] class Shim_v0_12 extends Shim { private lazy val startMethod = findStaticMethod( @@ -203,17 +196,6 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq - override def getPartitionsByFilter( - hive: Hive, - table: Table, - predicates: Seq[Expression]): Seq[Partition] = { - // getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12. - // See HIVE-4888. - logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " + - "Please use Hive 0.13 or higher.") - getAllPartitions(hive, table) - } - override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor] @@ -285,12 +267,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { classOf[Hive], "getAllPartitionsOf", classOf[Table]) - private lazy val getPartitionsByFilterMethod = - findMethod( - classOf[Hive], - "getPartitionsByFilter", - classOf[Table], - classOf[String]) private lazy val getCommandProcessorMethod = findStaticMethod( classOf[CommandProcessorFactory], @@ -312,52 +288,6 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq - /** - * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. - * a string that represents partition predicates like "str_key=\"value\" and int_key=1 ...". - * - * Unsupported predicates are skipped. - */ - def convertFilters(table: Table, filters: Seq[Expression]): String = { - // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys - .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) - .map(col => col.getName).toSet - - filters.collect { - case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => - s"${a.name} ${op.symbol} $v" - case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => - s"$v ${op.symbol} ${a.name}" - - case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) - if !varcharKeys.contains(a.name) => - s"""${a.name} ${op.symbol} "$v"""" - case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) - if !varcharKeys.contains(a.name) => - s""""$v" ${op.symbol} ${a.name}""" - }.mkString(" and ") - } - - override def getPartitionsByFilter( - hive: Hive, - table: Table, - predicates: Seq[Expression]): Seq[Partition] = { - - // Hive getPartitionsByFilter() takes a string that represents partition - // predicates like "str_key=\"value\" and int_key=1 ..." - val filter = convertFilters(table, predicates) - val partitions = - if (filter.isEmpty) { - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]] - } else { - logDebug(s"Hive metastore filter is '$filter'.") - getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] - } - - partitions.toSeq - } - override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index ba7eb15a1c0c6..d33da8242cc1d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -44,7 +44,7 @@ private[hive] case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, - partitionPruningPred: Seq[Expression])( + partitionPruningPred: Option[Expression])( @transient val context: HiveContext) extends LeafNode { @@ -56,7 +56,7 @@ case class HiveTableScan( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private[this] val boundPruningPred = partitionPruningPred.map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -133,8 +133,7 @@ case class HiveTableScan( protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { hadoopReader.makeRDDForTable(relation.hiveQlTable) } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index c4828c4717643..741a3cd31c603 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -61,7 +61,9 @@ public void setUp() throws IOException { @After public void tearDown() throws IOException { // Clean up tables. - hc.sql("DROP TABLE IF EXISTS window_table"); + if (hc != null) { + hc.sql("DROP TABLE IF EXISTS window_table"); + } } @Test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala deleted file mode 100644 index 0efcf80bd4ea7..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.client - -import scala.collection.JavaConversions._ - -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.serde.serdeConstants - -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/** - * A set of tests for the filter conversion logic used when pushing partition pruning into the - * metastore - */ -class FiltersSuite extends SparkFunSuite with Logging { - private val shim = new Shim_v0_13 - - private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") - private val varCharCol = new FieldSchema() - varCharCol.setName("varchar") - varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) - testTable.setPartCols(varCharCol :: Nil) - - filterTest("string filter", - (a("stringcol", StringType) > Literal("test")) :: Nil, - "stringcol > \"test\"") - - filterTest("string filter backwards", - (Literal("test") > a("stringcol", StringType)) :: Nil, - "\"test\" > stringcol") - - filterTest("int filter", - (a("intcol", IntegerType) === Literal(1)) :: Nil, - "intcol = 1") - - filterTest("int filter backwards", - (Literal(1) === a("intcol", IntegerType)) :: Nil, - "1 = intcol") - - filterTest("int and string filter", - (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil, - "1 = intcol and \"a\" = strcol") - - filterTest("skip varchar", - (Literal("") === a("varchar", StringType)) :: Nil, - "") - - private def filterTest(name: String, filters: Seq[Expression], result: String) = { - test(name){ - val converted = shim.convertFilters(testTable, filters) - if (converted != result) { - fail( - s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") - } - } - } - - private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)() -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 3eb127e23d486..d52e162acbd04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -20,9 +20,7 @@ package org.apache.spark.sql.hive.client import java.io.File import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils /** @@ -153,12 +151,6 @@ class VersionsSuite extends SparkFunSuite with Logging { client.getAllPartitions(client.getTable("default", "src_part")) } - test(s"$version: getPartitionsByFilter") { - client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo( - AttributeReference("key", IntegerType, false)(NamedExpression.newExprId), - Literal(1)))) - } - test(s"$version: loadPartition") { client.loadPartition( emptyDir, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index e83a7dc77e329..de6a41ce5bfcb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) val partValues = if (relation.table.isPartitioned) { - p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) + p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) } else { Seq.empty } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index eb7475e9df869..905ea0b7b878c 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -62,6 +62,7 @@ public static Interval fromString(String s) { if (s == null) { return null; } + s = s.trim(); Matcher m = p.matcher(s); if (!m.matches() || s.equals("interval")) { return null; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index 44a949a371f2b..1832d0bc65551 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -75,6 +75,12 @@ public void fromStringTest() { Interval result = new Interval(-5 * 12 + 23, 0); assertEquals(Interval.fromString(input), result); + input = "interval -5 years 23 month "; + assertEquals(Interval.fromString(input), result); + + input = " interval -5 years 23 month "; + assertEquals(Interval.fromString(input), result); + // Error cases input = "interval 3month 1 hour"; assertEquals(Interval.fromString(input), null);