diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f7a8a2e4de24..a329e14f25aeb 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -12,7 +12,8 @@ export("print.jobj") # MLlib integration exportMethods("glm", - "predict") + "predict", + "summary") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 06dd6b75dff3d..f4c93d3c7dd67 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1566,7 +1566,7 @@ setMethod("fillna", #' @return a local R data.frame representing the contingency table. The first column of each row #' will be the distinct values of `col1` and the column names will be the distinct values #' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no -#' occurrences will have `null` as their counts. +#' occurrences will have zero as their counts. #' #' @rdname statfunctions #' @export diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 2fb6fae55f28c..49162838b8d1a 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) { # TODO: check the status code to output error information returnStatus <- readInt(conn) - stopifnot(returnStatus == 0) + if (returnStatus != 0) { + stop(readString(conn)) + } readObject(conn) } diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 6f772158ddfe8..c811d1dac3bd5 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -48,7 +48,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack jars <- paste("--jars", jars) } - if (packages != "") { + if (!identical(packages, "")) { packages <- paste("--packages", packages) } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 7d1f6b0819ed0..6d364f77be7ee 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -102,11 +102,11 @@ readList <- function(con) { readRaw <- function(con) { dataLen <- readInt(con) - data <- readBin(con, raw(), as.integer(dataLen), endian = "big") + readBin(con, raw(), as.integer(dataLen), endian = "big") } readRawLen <- function(con, dataLen) { - data <- readBin(con, raw(), as.integer(dataLen), endian = "big") + readBin(con, raw(), as.integer(dataLen), endian = "big") } readDeserialize <- function(con) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 836e0175c391f..a3a121058e165 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -254,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") # @rdname intersection # @export -setGeneric("intersection", function(x, other, numPartitions = 1) { - standardGeneric("intersection") }) +setGeneric("intersection", + function(x, other, numPartitions = 1) { + standardGeneric("intersection") + }) # @rdname keys # @export @@ -489,9 +491,7 @@ setGeneric("sample", #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname saveAsParquetFile #' @export @@ -553,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn #' @rdname withColumnRenamed #' @export -setGeneric("withColumnRenamed", function(x, existingCol, newCol) { - standardGeneric("withColumnRenamed") }) +setGeneric("withColumnRenamed", + function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) ###################### Column Methods ########################## diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 258e354081fc1..efddcc1d8d71c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~' and '+'. +#' operators are supported, including '~', '+', '-', and '.'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"), function(object, newData) { return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param model A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "PipelineModel"), + function(object) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", object@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", object@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index ebc6ff65e9d0f..83801d3209700 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -202,8 +202,8 @@ setMethod("partitionBy", packageNamesArr <- serialize(.sparkREnv$.packages, connection = NULL) - broadcastArr <- lapply(ls(.broadcastNames), function(name) { - get(name, .broadcastNames) }) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) jrdd <- getJRDD(x) # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 79b79d70943cb..e83104f116422 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -22,7 +22,8 @@ connExists <- function(env) { tryCatch({ exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) - }, error = function(err) { + }, + error = function(err) { return(FALSE) }) } @@ -104,16 +105,13 @@ sparkR.init <- function( return(get(".sparkRjsc", envir = .sparkREnv)) } - sparkMem <- Sys.getenv("SPARK_MEM", "1024m") jars <- suppressWarnings(normalizePath(as.character(sparkJars))) # Classpath separator is ";" on Windows # URI needs four /// as from http://stackoverflow.com/a/18522792 if (.Platform$OS.type == "unix") { - collapseChar <- ":" uriSep <- "//" } else { - collapseChar <- ";" uriSep <- "////" } @@ -156,7 +154,8 @@ sparkR.init <- function( .sparkREnv$backendPort <- backendPort tryCatch({ connectBackend("localhost", backendPort) - }, error = function(err) { + }, + error = function(err) { stop("Failed to connect JVM\n") }) @@ -267,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { stop("Spark SQL is not built with Hive support") }) diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R index 30b05c1a2afcd..8a20991f89af8 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/test_client.R @@ -30,3 +30,7 @@ test_that("no package specified doesn't add packages flag", { expect_equal(gsub("[[:space:]]", "", args), "") }) + +test_that("multiple packages don't produce a warning", { + expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) +}) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index a492763344ae6..f272de78ad4a6 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -35,8 +35,27 @@ test_that("glm and predict", { test_that("predictions match with native glm", { training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ Sepal_Length, data = training) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("dot minus and intercept vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 62fe48a5d6c7b..61c8a7ec7d837 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -112,7 +112,8 @@ test_that("create DataFrame from RDD", { df <- jsonFile(sqlContext, jsonPathNa) hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") @@ -602,7 +603,8 @@ test_that("write.df() as parquet file", { test_that("test HiveContext", { hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") @@ -1000,6 +1002,11 @@ test_that("crosstab() on a DataFrame", { expect_identical(expected, ordered) }) +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table Not Found: blah", retError), TRUE) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd06..18a1e13bdc655 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/bin/pyspark b/bin/pyspark index f9dbddfa53560..8f2a3b5a7717b 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -82,4 +82,4 @@ fi export PYSPARK_DRIVER_PYTHON export PYSPARK_DRIVER_PYTHON_OPTS -exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main "$@" +exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 45e9e3def5121..3c6169983e76b 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -35,4 +35,4 @@ set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py -call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main %* +call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %* diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 3a2a88219818f..27006e45e932b 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/pom.xml b/core/pom.xml index 95f36eb348698..202678779150b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,11 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + com.google.guava guava @@ -281,7 +286,7 @@ org.tachyonproject tachyon-client - 0.6.4 + 0.7.0 org.apache.hadoop @@ -292,36 +297,12 @@ curator-recipes - org.eclipse.jetty - jetty-jsp - - - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-servlet + org.tachyonproject + tachyon-underfs-glusterfs - junit - junit - - - org.powermock - powermock-module-junit4 - - - org.powermock - powermock-api-mockito - - - org.apache.curator - curator-test + org.tachyonproject + tachyon-underfs-s3 diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 1d460432be9ff..1aa6ba4201261 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -59,14 +59,14 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); - private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - @VisibleForTesting - static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; private final int initialSize; private final int numPartitions; + private final int pageSizeBytes; + @VisibleForTesting + final int maxRecordSizeBytes; private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; @@ -109,7 +109,10 @@ public UnsafeShuffleExternalSorter( this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - + this.pageSizeBytes = (int) Math.min( + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, + conf.getSizeAsBytes("spark.buffer.pageSize", "64m")); + this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); } @@ -272,7 +275,11 @@ void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return sorter.getMemoryUsage() + totalPageSize; } private long freeMemory() { @@ -346,23 +353,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for // BytesToBytesMap. - if (requiredSpace > PAGE_SIZE) { + if (requiredSpace > pageSizeBytes) { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPage = memoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = PAGE_SIZE; + freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 764578b181422..d47d6fc9c2ac4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -129,6 +129,11 @@ public UnsafeShuffleWriter( open(); } + @VisibleForTesting + public int maxRecordSizeBytes() { + return sorter.maxRecordSizeBytes; + } + /** * This convenience method should only be called in test code. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index bf1bc5dffba78..4d7e5b3dfba6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -17,9 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; -import com.google.common.base.Charsets; -import com.google.common.primitives.Longs; -import com.google.common.primitives.UnsignedBytes; +import com.google.common.primitives.UnsignedLongs; import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; @@ -30,81 +28,67 @@ public class PrefixComparators { private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); - public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); public static final class StringPrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - // TODO: can done more efficiently - byte[] a = Longs.toByteArray(aPrefix); - byte[] b = Longs.toByteArray(bPrefix); - for (int i = 0; i < 8; i++) { - int c = UnsignedBytes.compare(a[i], b[i]); - if (c != 0) return c; - } - return 0; + return UnsignedLongs.compare(aPrefix, bPrefix); } - public long computePrefix(byte[] bytes) { - if (bytes == null) { - return 0L; - } else { - byte[] padded = new byte[8]; - System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8)); - return Longs.fromByteArray(padded); - } - } - - public long computePrefix(String value) { - return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8)); + public static long computePrefix(UTF8String value) { + return value == null ? 0L : value.getPrefix(); } + } - public long computePrefix(UTF8String value) { - return value == null ? 0L : computePrefix(value.getBytes()); + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); } } - /** - * Prefix comparator for all integral types (boolean, byte, short, int, long). - */ - public static final class IntegralPrefixComparator extends PrefixComparator { + public static final class LongPrefixComparator extends PrefixComparator { @Override public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } + } - public final long NULL_PREFIX = Long.MIN_VALUE; + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } } - public static final class FloatPrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - float a = Float.intBitsToFloat((int) aPrefix); - float b = Float.intBitsToFloat((int) bPrefix); - return Utils.nanSafeCompareFloats(a, b); + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(float value) { - return Float.floatToIntBits(value) & 0xffffffffL; + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); } - public static final class DoublePrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparatorDesc extends PrefixComparator { @Override - public int compare(long aPrefix, long bPrefix) { + public int compare(long bPrefix, long aPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(double value) { + public static long computePrefix(double value) { return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 4d6731ee60af3..866e0b4151577 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -20,6 +20,9 @@ import java.io.IOException; import java.util.LinkedList; +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,10 +44,7 @@ public final class UnsafeExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); - private static final int PAGE_SIZE = 1 << 27; // 128 megabytes - @VisibleForTesting - static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; - + private final long pageSizeBytes; private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; private final int initialSize; @@ -91,7 +91,19 @@ public UnsafeExternalSorter( this.initialSize = initialSize; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); initializeForWriting(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addOnCompleteCallback(new AbstractFunction0() { + @Override + public BoxedUnit apply() { + freeMemory(); + return null; + } + }); } // TODO: metrics tracking + integration with shuffle write metrics @@ -147,7 +159,16 @@ public void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return sorter.getMemoryUsage() + totalPageSize; + } + + @VisibleForTesting + public int getNumberOfAllocatedPages() { + return allocatedPages.size(); } public long freeMemory() { @@ -209,23 +230,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for // BytesToBytesMap. - if (requiredSpace > PAGE_SIZE) { + if (requiredSpace > pageSizeBytes) { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPage = memoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = PAGE_SIZE; + freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); } } @@ -257,7 +278,7 @@ public void insertRecord( currentPagePosition, lengthInBytes); currentPagePosition += lengthInBytes; - + freeSpaceInCurrentPage -= totalSpaceRequired; sorter.insertRecord(recordAddress, prefix); } diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index b146f8a784127..689afea64f8db 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 3a2a88219818f..27006e45e932b 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 2f4fcac890eef..eb75f26718e19 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -341,7 +341,4 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) - - def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 443830f8d03b6..842bfdbadc948 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -24,11 +24,23 @@ package org.apache.spark private[spark] trait ExecutorAllocationClient { /** - * Express a preference to the cluster manager for a given total number of executors. - * This can result in canceling pending requests or filing additional requests. + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. * @return whether the request is acknowledged by the cluster manager. */ - private[spark] def requestTotalExecutors(numExecutors: Int): Boolean + private[spark] def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean /** * Request an additional number of executors from the cluster manager. diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 648bcfe28cad2..1877aaf2cac55 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -161,6 +161,12 @@ private[spark] class ExecutorAllocationManager( // (2) an executor idle timeout has elapsed. @volatile private var initializing: Boolean = true + // Number of locality aware tasks, used for executor placement. + private var localityAwareTasks = 0 + + // Host to possible task running on it, used for executor placement. + private var hostToLocalTaskCount: Map[String, Int] = Map.empty + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -295,7 +301,7 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { - client.requestTotalExecutors(numExecutorsTarget) + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") } @@ -349,7 +355,8 @@ private[spark] class ExecutorAllocationManager( return 0 } - val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) + val addRequestAcknowledged = testing || + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) if (addRequestAcknowledged) { val executorsString = "executor" + { if (delta > 1) "s" else "" } logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + @@ -519,6 +526,12 @@ private[spark] class ExecutorAllocationManager( // Number of tasks currently running on the cluster. Should be 0 when no stages are active. private var numRunningTasks: Int = _ + // stageId to tuple (the number of task with locality preferences, a map where each pair is a + // node and the number of tasks that would like to be scheduled on that node) map, + // maintain the executor placement hints for each stage Id used by resource framework to better + // place the executors. + private val stageIdToExecutorPlacementHints = new mutable.HashMap[Int, (Int, Map[String, Int])] + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { initializing = false val stageId = stageSubmitted.stageInfo.stageId @@ -526,6 +539,24 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks allocationManager.onSchedulerBacklogged() + + // Compute the number of tasks requested by the stage on each host + var numTasksPending = 0 + val hostToLocalTaskCountPerStage = new mutable.HashMap[String, Int]() + stageSubmitted.stageInfo.taskLocalityPreferences.foreach { locality => + if (!locality.isEmpty) { + numTasksPending += 1 + locality.foreach { location => + val count = hostToLocalTaskCountPerStage.getOrElse(location.host, 0) + 1 + hostToLocalTaskCountPerStage(location.host) = count + } + } + } + stageIdToExecutorPlacementHints.put(stageId, + (numTasksPending, hostToLocalTaskCountPerStage.toMap)) + + // Update the executor placement hints + updateExecutorPlacementHints() } } @@ -534,6 +565,10 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks -= stageId stageIdToTaskIndices -= stageId + stageIdToExecutorPlacementHints -= stageId + + // Update the executor placement hints + updateExecutorPlacementHints() // If this is the last stage with pending tasks, mark the scheduler queue as empty // This is needed in case the stage is aborted for any reason @@ -637,6 +672,29 @@ private[spark] class ExecutorAllocationManager( def isExecutorIdle(executorId: String): Boolean = { !executorIdToTaskIds.contains(executorId) } + + /** + * Update the Executor placement hints (the number of tasks with locality preferences, + * a map where each pair is a node and the number of tasks that would like to be scheduled + * on that node). + * + * These hints are updated when stages arrive and complete, so are not up-to-date at task + * granularity within stages. + */ + def updateExecutorPlacementHints(): Unit = { + var localityAwareTasks = 0 + val localityToCount = new mutable.HashMap[String, Int]() + stageIdToExecutorPlacementHints.values.foreach { case (numTasksPending, localities) => + localityAwareTasks += numTasksPending + localities.foreach { case (hostname, count) => + val updatedCount = localityToCount.getOrElse(hostname, 0) + count + localityToCount(hostname) = updatedCount + } + } + + allocationManager.localityAwareTasks = localityAwareTasks + allocationManager.hostToLocalTaskCount = localityToCount.toMap + } } /** diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index ad68512dccb79..4b9d59975bdc2 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -56,7 +56,7 @@ object Partitioner { */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse - for (r <- bySize if r.partitioner.isDefined) { + for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) { return r.partitioner.get } if (rdd.context.conf.contains("spark.default.parallelism")) { diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 6cf36fbbd6254..4161792976c7b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,11 +18,12 @@ package org.apache.spark import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet +import org.apache.avro.{SchemaNormalization, Schema} + import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -161,6 +162,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + private final val avroNamespace = "avro.schema." + + /** + * Use Kryo serialization and register the given set of Avro schemas so that the generic + * record serializer can decrease network IO + */ + def registerAvroSchemas(schemas: Schema*): SparkConf = { + for (schema <- schemas) { + set(avroNamespace + SchemaNormalization.parsingFingerprint64(schema), schema.toString) + } + this + } + + /** Gets all the avro schemas in the configuration used in the generic Avro record serializer */ + def getAvroSchema: Map[Long, String] = { + getAll.filter { case (k, v) => k.startsWith(avroNamespace) } + .map { case (k, v) => (k.substring(avroNamespace.length).toLong, v) } + .toMap + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a6b94a271cfc..ac6ac6c216767 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1382,16 +1382,29 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Express a preference to the cluster manager for a given total number of executors. - * This can result in canceling pending requests or filing additional requests. - * This is currently only supported in YARN mode. Return whether the request is received. - */ - private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] + ): Boolean = { assert(supportDynamicAllocation, "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.requestTotalExecutors(numExecutors) + b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount) case _ => logWarning("Requesting executors is only supported in coarse-grained mode") false diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 598953ac3bcc8..55e563ee968be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -207,6 +207,7 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { + TaskContext.setTaskContext(context) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index @@ -263,11 +264,6 @@ private[spark] class PythonRDD( if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } - } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() } } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index a5de10fe89c42..14dac4ed28ce3 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -69,8 +69,11 @@ private[r] class RBackendHandler(server: RBackend) case e: Exception => logError(s"Removing $objId failed", e) writeInt(dos, -1) + writeString(dos, s"Removing $objId failed: ${e.getMessage}") } - case _ => dos.writeInt(-1) + case _ => + dos.writeInt(-1) + writeString(dos, s"Error: unknown method $methodName") } } else { handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) @@ -146,8 +149,11 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed", e) + logError(s"$methodName on $objId failed") writeInt(dos, -1) + // Writing the error message of the cause for the exception. This will be returned + // to user in the R process. + writeString(dos, Utils.exceptionString(e.getCause)) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 23a470d6afcae..1cf2824f862ee 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -112,6 +112,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( partition: Int): Unit = { val env = SparkEnv.get + val taskContext = TaskContext.get() val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val stream = new BufferedOutputStream(output, bufferSize) @@ -119,6 +120,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( override def run(): Unit = { try { SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) val dataOut = new DataOutputStream(stream) dataOut.writeInt(partition) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 4615febf17d24..51b3f0dead73e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -541,6 +541,7 @@ private[master] class Master( /** * Schedule executors to be launched on the workers. + * Returns an array containing number of cores assigned to each worker. * * There are two modes of launching executors. The first attempts to spread out an application's * executors on as many workers as possible, while the second does the opposite (i.e. launch them @@ -551,39 +552,77 @@ private[master] class Master( * multiple executors from the same application may be launched on the same worker if the worker * has enough cores and memory. Otherwise, each executor grabs all the cores available on the * worker by default, in which case only one executor may be launched on each worker. + * + * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core + * at a time). Consider the following example: cluster has 4 workers with 16 cores each. + * User requests 3 executors (spark.cores.max = 48, spark.executor.cores = 16). If 1 core is + * allocated at a time, 12 cores from each worker would be assigned to each executor. + * Since 12 < 16, no executors would launch [SPARK-8881]. */ - private def startExecutorsOnWorkers(): Unit = { - // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app - // in the queue, then the second app, etc. - if (spreadOutApps) { - // Try to spread out each app among all the workers, until it has all its cores - for (app <- waitingApps if app.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && - worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1)) - .sortBy(_.coresFree).reverse - val numUsable = usableWorkers.length - val assigned = new Array[Int](numUsable) // Number of cores to give on each node - var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) - var pos = 0 - while (toAssign > 0) { - if (usableWorkers(pos).coresFree - assigned(pos) > 0) { - toAssign -= 1 - assigned(pos) += 1 + private def scheduleExecutorsOnWorkers( + app: ApplicationInfo, + usableWorkers: Array[WorkerInfo], + spreadOutApps: Boolean): Array[Int] = { + // If the number of cores per executor is not specified, then we can just schedule + // 1 core at a time since we expect a single executor to be launched on each worker + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1) + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val numUsable = usableWorkers.length + val assignedCores = new Array[Int](numUsable) // Number of cores to give to each worker + val assignedMemory = new Array[Int](numUsable) // Amount of memory to give to each worker + var coresToAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) + var freeWorkers = (0 until numUsable).toIndexedSeq + + def canLaunchExecutor(pos: Int): Boolean = { + usableWorkers(pos).coresFree - assignedCores(pos) >= coresPerExecutor && + usableWorkers(pos).memoryFree - assignedMemory(pos) >= memoryPerExecutor + } + + while (coresToAssign >= coresPerExecutor && freeWorkers.nonEmpty) { + freeWorkers = freeWorkers.filter(canLaunchExecutor) + freeWorkers.foreach { pos => + var keepScheduling = true + while (keepScheduling && canLaunchExecutor(pos) && coresToAssign >= coresPerExecutor) { + coresToAssign -= coresPerExecutor + assignedCores(pos) += coresPerExecutor + // If cores per executor is not set, we are assigning 1 core at a time + // without actually meaning to launch 1 executor for each core assigned + if (app.desc.coresPerExecutor.isDefined) { + assignedMemory(pos) += memoryPerExecutor + } + + // Spreading out an application means spreading out its executors across as + // many workers as possible. If we are not spreading out, then we should keep + // scheduling executors on this worker until we use all of its resources. + // Otherwise, just move on to the next worker. + if (spreadOutApps) { + keepScheduling = false } - pos = (pos + 1) % numUsable - } - // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable if assigned(pos) > 0) { - allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos)) } } - } else { - // Pack each app into as few workers as possible until we've assigned all its cores - for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { - for (app <- waitingApps if app.coresLeft > 0) { - allocateWorkerResourceToExecutors(app, app.coresLeft, worker) - } + } + assignedCores + } + + /** + * Schedule and launch executors on workers + */ + private def startExecutorsOnWorkers(): Unit = { + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app + // in the queue, then the second app, etc. + for (app <- waitingApps if app.coresLeft > 0) { + val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) } } } @@ -591,19 +630,22 @@ private[master] class Master( /** * Allocate a worker's resources to one or more executors. * @param app the info of the application which the executors belong to - * @param coresToAllocate cores on this worker to be allocated to this application + * @param assignedCores number of cores on this worker for this application + * @param coresPerExecutor number of cores per executor * @param worker the worker info */ private def allocateWorkerResourceToExecutors( app: ApplicationInfo, - coresToAllocate: Int, + assignedCores: Int, + coresPerExecutor: Option[Int], worker: WorkerInfo): Unit = { - val memoryPerExecutor = app.desc.memoryPerExecutorMB - val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate) - var coresLeft = coresToAllocate - while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) { - val exec = app.addExecutor(worker, coresPerExecutor) - coresLeft -= coresPerExecutor + // If the number of cores per executor is specified, we divide the cores assigned + // to this worker evenly among the executors with no remainder. + // Otherwise, we launch a single executor that grabs all the assignedCores on this worker. + val numExecutors = coresPerExecutor.map { assignedCores / _ }.getOrElse(1) + val coresToAssign = coresPerExecutor.getOrElse(assignedCores) + for (i <- 1 to numExecutors) { + val exec = app.addExecutor(worker, coresToAssign) launchExecutor(worker, exec) app.state = ApplicationState.RUNNING } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 581b40003c6c4..7bc7fce7ae8dd 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -209,16 +209,19 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() + var threwException = true val (value, accumUpdates) = try { - task.run( + val res = task.run( taskAttemptId = taskId, attemptNumber = attemptNumber, metricsSystem = env.metricsSystem) + threwException = false + res } finally { val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" - if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { throw new SparkException(errMsg) } else { logError(errMsg) @@ -310,10 +313,6 @@ private[spark] class Executor( } } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f827270ee6a44..f83a051f5da11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -128,7 +128,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -141,6 +141,12 @@ class NewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished @@ -159,18 +165,23 @@ class NewHadoopRDD[K, V]( private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + // Close reader and release it + reader.close() + reader = null + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index defdabf95ac4b..3bb9998e1db44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -133,6 +133,7 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { + TaskContext.setTaskContext(context) val out = new PrintWriter(proc.getOutputStream) // scalastyle:off println diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 394c6686cbabd..6d61d227382d7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1082,7 +1082,9 @@ abstract class RDD[T: ClassTag]( val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce // the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { + + // Don't trigger TreeAggregation when it doesn't save wall-clock time + while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) { numPartitions /= scale val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala similarity index 81% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index e1c1a6c06268f..35e44cb59c1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import org.apache.spark.{Partition => SparkPartition, _} +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,12 +31,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -62,7 +63,7 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be * folded into core. */ -private[sql] class SqlNewHadoopRDD[K, V]( +private[spark] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], @@ -128,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -147,7 +154,7 @@ private[sql] class SqlNewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -160,6 +167,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished @@ -178,18 +191,24 @@ private[sql] class SqlNewHadoopRDD[K, V]( private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + reader.close() + reader = null + + SqlNewHadoopRDD.unsetInputFileName() + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { @@ -240,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V]( } private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 552dabcfa5139..c4fa277c21254 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -790,8 +790,28 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - stage.makeNewStageAttempt(partitionsToCompute.size) outputCommitCoordinator.stageStart(stage.id) + val taskIdToLocations = try { + stage match { + case s: ShuffleMapStage => + partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap + case s: ResultStage => + val job = s.resultOfJob.get + partitionsToCompute.map { id => + val p = job.partitions(id) + (id, getPreferredLocs(stage.rdd, p)) + }.toMap + } + } catch { + case NonFatal(e) => + stage.makeNewStageAttempt(partitionsToCompute.size) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return + } + + stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -830,7 +850,7 @@ class DAGScheduler( stage match { case stage: ShuffleMapStage => partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) + val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) } @@ -840,7 +860,7 @@ class DAGScheduler( partitionsToCompute.map { id => val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) + val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) } } @@ -896,11 +916,9 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") } } } catch { @@ -927,7 +945,7 @@ class DAGScheduler( // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { - val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val attemptId = task.stageAttemptId listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b86724de2cb73..40a333a3e06b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -77,8 +77,11 @@ private[spark] abstract class Stage( private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ - def makeNewStageAttempt(numPartitionsToCompute: Int): Unit = { - _latestInfo = StageInfo.fromStage(this, nextAttemptId, Some(numPartitionsToCompute)) + def makeNewStageAttempt( + numPartitionsToCompute: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { + _latestInfo = StageInfo.fromStage( + this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences) nextAttemptId += 1 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 5d2abbc67e9d9..24796c14300b1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -34,7 +34,8 @@ class StageInfo( val numTasks: Int, val rddInfos: Seq[RDDInfo], val parentIds: Seq[Int], - val details: String) { + val details: String, + private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -70,7 +71,12 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage, attemptId: Int, numTasks: Option[Int] = None): StageInfo = { + def fromStage( + stage: Stage, + attemptId: Int, + numTasks: Option[Int] = None, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos new StageInfo( @@ -80,6 +86,7 @@ private[spark] object StageInfo { numTasks.getOrElse(stage.numTasks), rddInfos, stage.parents.map(_.id), - stage.details) + stage.details, + taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d11a00956a9a9..1978305cfefbd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{TaskContextImpl, TaskContext} +import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -86,7 +86,18 @@ private[spark] abstract class Task[T]( (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for shuffles + SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() + } + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + } + } finally { + TaskContext.unset() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 4be1eda2e9291..06f5438433b6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -86,7 +86,11 @@ private[spark] object CoarseGrainedClusterMessages { // Request executors by specifying the new total number of executors desired // This includes executors already pending or running - case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage + case class RequestExecutors( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]) + extends CoarseGrainedClusterMessage case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c65b3e517773e..bd89160af4ffa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -66,6 +66,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] + // A map to store hostname with its possible task number running on it + protected var hostToLocalTaskCount: Map[String, Int] = Map.empty + + // The number of pending tasks which is locality required + protected var localityAwareTasks = 0 + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -235,7 +241,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.executorLost(executorId, SlaveLost(reason)) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) - case None => logError(s"Asked to remove non-existent executor $executorId") + case None => logInfo(s"Asked to remove non-existent executor $executorId") } } @@ -339,6 +345,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") + numPendingExecutors += numAdditionalExecutors // Account for executors pending to be added or removed val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size @@ -346,16 +353,33 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } /** - * Express a preference to the cluster manager for a given total number of executors. This can - * result in canceling pending requests or filing additional requests. - * @return whether the request is acknowledged. + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. */ - final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized { + final override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int] + ): Boolean = synchronized { if (numExecutors < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of executor(s) " + s"$numExecutors from the cluster manager. Please specify a positive number!") } + + this.localityAwareTasks = localityAwareTasks + this.hostToLocalTaskCount = hostToLocalTaskCount + numPendingExecutors = math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) doRequestTotalExecutors(numExecutors) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index bc67abb5df446..044f6288fabdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -53,7 +53,8 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) + yarnSchedulerEndpoint.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } /** @@ -108,6 +109,8 @@ private[spark] abstract class YarnSchedulerBackend( case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala new file mode 100644 index 0000000000000..62f8aae7f2126 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import scala.collection.mutable + +import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.apache.avro.io._ +import org.apache.commons.io.IOUtils + +import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.io.CompressionCodec + +/** + * Custom serializer used for generic Avro records. If the user registers the schemas + * ahead of time, then the schema's fingerprint will be sent with each message instead of the actual + * schema, as to reduce network IO. + * Actions like parsing or compressing schemas are computationally expensive so the serializer + * caches all previously seen values as to reduce the amount of work needed to do. + * @param schemas a map where the keys are unique IDs for Avro schemas and the values are the + * string representation of the Avro schema, used to decrease the amount of data + * that needs to be serialized. + */ +private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) + extends KSerializer[GenericRecord] { + + /** Used to reduce the amount of effort to compress the schema */ + private val compressCache = new mutable.HashMap[Schema, Array[Byte]]() + private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]() + + /** Reuses the same datum reader/writer since the same schema will be used many times */ + private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]() + private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]() + + /** Fingerprinting is very expensive so this alleviates most of the work */ + private val fingerprintCache = new mutable.HashMap[Schema, Long]() + private val schemaCache = new mutable.HashMap[Long, Schema]() + + // GenericAvroSerializer can't take a SparkConf in the constructor b/c then it would become + // a member of KryoSerializer, which would make KryoSerializer not Serializable. We make + // the codec lazy here just b/c in some unit tests, we use a KryoSerializer w/out having + // the SparkEnv set (note those tests would fail if they tried to serialize avro data). + private lazy val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + + /** + * Used to compress Schemas when they are being sent over the wire. + * The compression results are memoized to reduce the compression time since the + * same schema is compressed many times over + */ + def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { + val bos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(bos) + out.write(schema.toString.getBytes("UTF-8")) + out.close() + bos.toByteArray + }) + + /** + * Decompresses the schema into the actual in-memory object. Keeps an internal cache of already + * seen values so to limit the number of times that decompression has to be done. + */ + def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { + val bis = new ByteArrayInputStream(schemaBytes.array()) + val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) + new Schema.Parser().parse(new String(bytes, "UTF-8")) + }) + + /** + * Serializes a record to the given output stream. It caches a lot of the internal data as + * to not redo work + */ + def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = { + val encoder = EncoderFactory.get.binaryEncoder(output, null) + val schema = datum.getSchema + val fingerprint = fingerprintCache.getOrElseUpdate(schema, { + SchemaNormalization.parsingFingerprint64(schema) + }) + schemas.get(fingerprint) match { + case Some(_) => + output.writeBoolean(true) + output.writeLong(fingerprint) + case None => + output.writeBoolean(false) + val compressedSchema = compress(schema) + output.writeInt(compressedSchema.length) + output.writeBytes(compressedSchema) + } + + writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) + .asInstanceOf[DatumWriter[R]] + .write(datum, encoder) + encoder.flush() + } + + /** + * Deserializes generic records into their in-memory form. There is internal + * state to keep a cache of already seen schemas and datum readers. + */ + def deserializeDatum(input: KryoInput): GenericRecord = { + val schema = { + if (input.readBoolean()) { + val fingerprint = input.readLong() + schemaCache.getOrElseUpdate(fingerprint, { + schemas.get(fingerprint) match { + case Some(s) => new Schema.Parser().parse(s) + case None => + throw new SparkException( + "Error reading attempting to read avro data -- encountered an unknown " + + s"fingerprint: $fingerprint, not sure what schema to use. This could happen " + + "if you registered additional schemas after starting your spark context.") + } + }) + } else { + val length = input.readInt() + decompress(ByteBuffer.wrap(input.readBytes(length))) + } + } + val decoder = DecoderFactory.get.directBinaryDecoder(input, null) + readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema)) + .asInstanceOf[DatumReader[GenericRecord]] + .read(null, decoder) + } + + override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit = + serializeDatum(datum, output) + + override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord = + deserializeDatum(input) +} diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 7cb6e080533ad..0ff7562e912ca 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,6 +27,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} +import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ @@ -73,6 +74,8 @@ class KryoSerializer(conf: SparkConf) .split(',') .filter(!_.isEmpty) + private val avroSchemas = conf.getAvroSchema + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { @@ -101,6 +104,9 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) + try { // scalastyle:off classforname // Use the default classloader when calling the user registrator. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 3bcc7178a3d8b..f038b722957b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,95 +19,101 @@ package org.apache.spark.shuffle import scala.collection.mutable -import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** - * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory * from this pool and release it as it spills data out. When a task ends, all its memory will be * released by the Executor. * - * This class tries to ensure that each thread gets a reasonable share of memory, instead of some - * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * This class tries to ensure that each task gets a reasonable share of memory, instead of some + * task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. */ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * Try to acquire up to numBytes memory for the current task, and return the number of bytes * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active threads) before it is forced to spill. This can - * happen if the number of threads increases but an older thread had a lot of memory already. + * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active tasks) before it is forced to spill. This can + * happen if the number of tasks increases but an older task had a lot of memory already. */ def tryToAcquire(numBytes: Long): Long = synchronized { - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - // Add this thread to the threadMemory map just so we can keep an accurate count of the number - // of active threads, to let other threads ramp down their memory in calls to tryToAcquire - if (!threadMemory.contains(threadId)) { - threadMemory(threadId) = 0L - notifyAll() // Will later cause waiting threads to wake up and check numThreads again + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) = 0L + notifyAll() // Will later cause waiting tasks to wake up and check numThreads again } // Keep looping until we're either sure that we don't want to grant this request (because this - // thread would have more than 1 / numActiveThreads of the memory) or we have enough free - // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). while (true) { - val numActiveThreads = threadMemory.keys.size - val curMem = threadMemory(threadId) - val freeMemory = maxMemory - threadMemory.values.sum + val numActiveTasks = taskMemory.keys.size + val curMem = taskMemory(taskAttemptId) + val freeMemory = maxMemory - taskMemory.values.sum - // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads; + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem)) + val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - if (curMem < maxMemory / (2 * numActiveThreads)) { - // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; - // if we can't give it this much now, wait for other threads to free up memory - // (this happens if older threads allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + if (curMem < maxMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } else { - logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + logInfo( + s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } } 0L // Never reached } - /** Release numBytes bytes for the current thread. */ + /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { - val threadId = Thread.currentThread().getId - val curMem = threadMemory.getOrElse(threadId, 0L) + val taskAttemptId = currentTaskAttemptId() + val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") } - threadMemory(threadId) -= numBytes + taskMemory(taskAttemptId) -= numBytes notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } - /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisThread(): Unit = synchronized { - val threadId = Thread.currentThread().getId - threadMemory.remove(threadId) + /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisTask(): Unit = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index ed609772e6979..6f27f00307f8c 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.TaskContext import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -43,11 +44,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Ensure only one thread is putting, and if necessary, dropping blocks at any given time private val accountingLock = new Object - // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) + // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. - // Pending unroll memory refers to the intermediate memory occupied by a thread + // Pending unroll memory refers to the intermediate memory occupied by a task // after the unroll but before the actual putting of the block in the cache. // This chunk of memory is expected to be released *as soon as* we finish // caching the corresponding block as opposed to until after the task finishes. @@ -250,21 +251,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var elementsUnrolled = 0 // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true - // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. + // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 - // Memory currently reserved by this thread for this particular unrolling operation + // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this thread, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisThread + // Previous unroll memory held by this task, for releasing later (only at the very end) + val previousMemoryReserved = currentUnrollMemoryForThisTask // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -283,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { - if (!reserveUnrollMemoryForThisThread(amountToRequest)) { + if (!reserveUnrollMemoryForThisTask(amountToRequest)) { // If the first request is not granted, try again after ensuring free space // If there is still not enough space, give up and drop the partition val spaceToEnsure = maxUnrollMemory - currentUnrollMemory @@ -291,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val result = ensureFreeSpace(blockId, spaceToEnsure) droppedBlocks ++= result.droppedBlocks } - keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest) + keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) } } // New threshold is currentSize * memoryGrowthFactor @@ -317,9 +318,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // later when the task finishes. if (keepUnrolling) { accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved - releaseUnrollMemoryForThisThread(amountToRelease) - reservePendingUnrollMemoryForThisThread(amountToRelease) + val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved + releaseUnrollMemoryForThisTask(amountToRelease) + reservePendingUnrollMemoryForThisTask(amountToRelease) } } } @@ -397,7 +398,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisThread() + releasePendingUnrollMemoryForThisTask() } ResultWithDroppedBlocks(putSuccess, droppedBlocks) } @@ -427,9 +428,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Take into account the amount of memory currently occupied by unrolling blocks // and minus the pending unroll memory for that block on current thread. - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(threadId, 0L) + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) @@ -455,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping + // This should never be null as only one task should be dropping // blocks and removing entries. However the check is still here for // future safety. if (entry != null) { @@ -482,79 +483,85 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) entries.synchronized { entries.containsKey(blockId) } } + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Reserve additional memory for unrolling blocks used by this thread. + * Reserve additional memory for unrolling blocks used by this task. * Return whether the request is granted. */ - def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { - val threadId = Thread.currentThread().getId - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory + val taskAttemptId = currentTaskAttemptId() + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } granted } } /** - * Release memory used by this thread for unrolling blocks. - * If the amount is not specified, remove the current thread's allocation altogether. + * Release memory used by this task for unrolling blocks. + * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { - val threadId = Thread.currentThread().getId + def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { if (memory < 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap.remove(taskAttemptId) } else { - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory - // If this thread claims no more unroll memory, release it completely - if (unrollMemoryMap(threadId) <= 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory + // If this task claims no more unroll memory, release it completely + if (unrollMemoryMap(taskAttemptId) <= 0) { + unrollMemoryMap.remove(taskAttemptId) } } } } /** - * Reserve the unroll memory of current unroll successful block used by this thread + * Reserve the unroll memory of current unroll successful block used by this task * until actually put the block into memory entry. */ - def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = { - val threadId = Thread.currentThread().getId + def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } } /** - * Release pending unroll memory of current unroll successful block used by this thread + * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisThread(): Unit = { - val threadId = Thread.currentThread().getId + def releasePendingUnrollMemoryForThisTask(): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap.remove(threadId) + pendingUnrollMemoryMap.remove(taskAttemptId) } } /** - * Return the amount of memory currently occupied for unrolling blocks across all threads. + * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** - * Return the amount of memory currently occupied for unrolling blocks by this thread. + * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { - unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) + def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** - * Return the number of threads currently unrolling blocks. + * Return the number of tasks currently unrolling blocks. */ - def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. @@ -566,7 +573,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo( s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 2c84e4485996e..61449847add3d 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -107,6 +107,25 @@ private[spark] abstract class WebUI( } } + /** + * Add a handler for static content. + * + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. + */ + def addStaticHandler(resourceBase: String, path: String): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + } + + /** + * Remove a static content handler. + * + * @param path Path in UI to unmount. + */ + def removeStaticHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) + } + /** Initialize all components of the server. */ def initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 2ce670ad02e97..e72547df7254b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -79,6 +79,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { case JobExecutionStatus.SUCCEEDED => "succeeded" case JobExecutionStatus.FAILED => "failed" case JobExecutionStatus.RUNNING => "running" + case JobExecutionStatus.UNKNOWN => "unknown" } // The timeline library treats contents as HTML, so we have to escape them; for the diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 7d84468f62ab1..14b1f2a17e707 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -217,10 +217,10 @@ object SizeEstimator extends Logging { var arrSize: Long = alignSize(objectSize + INT_SIZE) if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) + arrSize += alignSize(length.toLong * primitiveSize(elementClass)) state.size += arrSize } else { - arrSize += alignSize(length * pointerSize) + arrSize += alignSize(length.toLong * pointerSize) state.size += arrSize if (length <= ARRAY_SIZE_FOR_SAMPLING) { @@ -336,7 +336,7 @@ object SizeEstimator extends Logging { // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp var alignedSize = shellSize for (size <- fieldSizes if sizeCount(size) > 0) { - val count = sizeCount(size) + val count = sizeCount(size).toLong // If there are internal gaps, smaller field can fit in. alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count) shellSize += size * count diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c5816949cd360..c4012d0e83f7d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -443,11 +443,11 @@ private[spark] object Utils extends Logging { val lockFileName = s"${url.hashCode}${timestamp}_lock" val localDir = new File(getLocalDir(conf)) val lockFile = new File(localDir, lockFileName) - val raf = new RandomAccessFile(lockFile, "rw") + val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel() // Only one executor entry. // The FileLock is only used to control synchronization for executors download file, // it's always safe regardless of lock type (mandatory or advisory). - val lock = raf.getChannel().lock() + val lock = lockFileChannel.lock() val cachedFile = new File(localDir, cachedFileName) try { if (!cachedFile.exists()) { @@ -455,6 +455,7 @@ private[spark] object Utils extends Logging { } } finally { lock.release() + lockFileChannel.close() } copyFile( url, diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 10c3eedbf4b46..04fc09b323dbb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -111,7 +111,7 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); - conf = new SparkConf(); + conf = new SparkConf().set("spark.buffer.pageSize", "128m"); taskMetrics = new TaskMetrics(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); @@ -512,12 +512,12 @@ public void close() { } writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); writer.forceSorterToSpill(); // We should be able to write a record that's right _at_ the max record size - final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; new Random(42).nextBytes(atMaxRecordSize); writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); writer.forceSorterToSpill(); // Inserting a record that's larger than the max record size should fail: - final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; new Random(42).nextBytes(exceedsMaxRecordSize); Product2 hugeRecord = new Tuple2(new byte[0], exceedsMaxRecordSize); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index ea8755e21eb68..0e391b751226d 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -199,4 +199,23 @@ public void testSortingEmptyArrays() throws Exception { } } + @Test + public void testFillingPage() throws Exception { + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + byte[] record = new byte[16]; + while (sorter.getNumberOfAllocatedPages() < 2) { + sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); + } + sorter.freeMemory(); + } + } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 803e1831bb269..34caca892891c 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -751,6 +751,42 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 2) } + test("get pending task number and related locality preference") { + sc = createSparkContext(2, 5, 3) + val manager = sc.executorAllocationManager.get + + val localityPreferences1 = Seq( + Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host3")), + Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host4")), + Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host4")), + Seq.empty, + Seq.empty + ) + val stageInfo1 = createStageInfo(1, 5, localityPreferences1) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + + assert(localityAwareTasks(manager) === 3) + assert(hostToLocalTaskCount(manager) === + Map("host1" -> 2, "host2" -> 3, "host3" -> 2, "host4" -> 2)) + + val localityPreferences2 = Seq( + Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host5")), + Seq(TaskLocation("host3"), TaskLocation("host4"), TaskLocation("host5")), + Seq.empty + ) + val stageInfo2 = createStageInfo(2, 3, localityPreferences2) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo2)) + + assert(localityAwareTasks(manager) === 5) + assert(hostToLocalTaskCount(manager) === + Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2)) + + sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo1)) + assert(localityAwareTasks(manager) === 2) + assert(hostToLocalTaskCount(manager) === + Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -784,8 +820,13 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val sustainedSchedulerBacklogTimeout = 2L private val executorIdleTimeout = 3L - private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { - new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details") + private def createStageInfo( + stageId: Int, + numTasks: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { + new StageInfo( + stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences) } private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { @@ -815,6 +856,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _onSchedulerQueueEmpty = PrivateMethod[Unit]('onSchedulerQueueEmpty) private val _onExecutorIdle = PrivateMethod[Unit]('onExecutorIdle) private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy) + private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks) + private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount) private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { manager invokePrivate _numExecutorsToAdd() @@ -885,4 +928,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private def onExecutorBusy(manager: ExecutorAllocationManager, id: String): Unit = { manager invokePrivate _onExecutorBusy(id) } + + private def localityAwareTasks(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _localityAwareTasks() + } + + private def hostToLocalTaskCount(manager: ExecutorAllocationManager): Map[String, Int] = { + manager invokePrivate _hostToLocalTaskCount() + } } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index b099cd3fb7965..69cb4b44cf7ef 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -141,5 +141,30 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + test("managed memory leak error should not mask other failures (SPARK-9266") { + val conf = new SparkConf().set("spark.unsafe.exceptionOnMemoryLeak", "true") + sc = new SparkContext("local[1,1]", "test", conf) + + // If a task leaks memory but fails due to some other cause, then make sure that the original + // cause is preserved + val thrownDueToTaskFailure = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + throw new Exception("intentional task failure") + iter + }.count() + } + assert(thrownDueToTaskFailure.getMessage.contains("intentional task failure")) + + // If the task succeeded but memory was leaked, then the task should fail due to that leak + val thrownDueToMemoryLeak = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + iter + }.count() + } + assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) + } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 5a2670e4d1cf0..139b8dc25f4b4 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -182,7 +182,7 @@ class HeartbeatReceiverSuite // Adjust the target number of executors on the cluster manager side assert(fakeClusterManager.getTargetNumExecutors === 0) - sc.requestTotalExecutors(2) + sc.requestTotalExecutors(2, 0, Map.empty) assert(fakeClusterManager.getTargetNumExecutors === 2) assert(fakeClusterManager.getExecutorIdsToKill.isEmpty) @@ -241,7 +241,8 @@ private class FakeSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - clusterManagerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) + clusterManagerEndpoint.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { @@ -260,7 +261,7 @@ private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoin def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestExecutors(requestedTotal) => + case RequestExecutors(requestedTotal, _, _) => targetNumExecutors = requestedTotal context.reply(true) case KillExecutors(executorIds) => diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 3316f561a4949..aa8028792cb41 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -91,13 +91,13 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("RangePartitioner for keys that are not Comparable (but with Ordering)") { // Row does not extend Comparable, but has an implicit Ordering defined. - implicit object RowOrdering extends Ordering[Row] { - override def compare(x: Row, y: Row): Int = x.value - y.value + implicit object RowOrdering extends Ordering[Item] { + override def compare(x: Item, y: Item): Int = x.value - y.value } - val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) + val rdd = sc.parallelize(1 to 4500).map(x => (Item(x), Item(x))) val partitioner = new RangePartitioner(1500, rdd) - partitioner.getPartition(Row(100)) + partitioner.getPartition(Item(100)) } test("RangPartitioner.sketch") { @@ -252,4 +252,4 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva } -private sealed case class Row(value: Int) +private sealed case class Item(value: Int) diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index a8fbaf1d9da0a..4d7016d1e594b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -25,14 +25,15 @@ import scala.language.postfixOps import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.Matchers +import org.scalatest.{Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ +import org.apache.spark.rpc.RpcEnv -class MasterSuite extends SparkFunSuite with Matchers with Eventually { +class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester { test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) @@ -142,4 +143,196 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { } } + test("basic scheduling - spread out") { + testBasicScheduling(spreadOut = true) + } + + test("basic scheduling - no spread out") { + testBasicScheduling(spreadOut = false) + } + + test("scheduling with max cores - spread out") { + testSchedulingWithMaxCores(spreadOut = true) + } + + test("scheduling with max cores - no spread out") { + testSchedulingWithMaxCores(spreadOut = false) + } + + test("scheduling with cores per executor - spread out") { + testSchedulingWithCoresPerExecutor(spreadOut = true) + } + + test("scheduling with cores per executor - no spread out") { + testSchedulingWithCoresPerExecutor(spreadOut = false) + } + + test("scheduling with cores per executor AND max cores - spread out") { + testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut = true) + } + + test("scheduling with cores per executor AND max cores - no spread out") { + testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut = false) + } + + private def testBasicScheduling(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(1024) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + val scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 10) + } + + private def testSchedulingWithMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(1024, maxCores = Some(8)) + val appInfo2 = makeAppInfo(1024, maxCores = Some(16)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + // With spreading out, each worker should be assigned a few cores + if (spreadOut) { + assert(scheduledCores(0) === 3) + assert(scheduledCores(1) === 3) + assert(scheduledCores(2) === 2) + } else { + // Without spreading out, the cores should be concentrated on the first worker + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 0) + assert(scheduledCores(2) === 0) + } + // Now test the same thing with max cores > cores per worker + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 6) + assert(scheduledCores(1) === 5) + assert(scheduledCores(2) === 5) + } else { + // Without spreading out, the first worker should be fully booked, + // and the leftover cores should spill over to the second worker only. + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 0) + } + } + + private def testSchedulingWithCoresPerExecutor(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(1024, coresPerExecutor = Some(2)) + val appInfo2 = makeAppInfo(256, coresPerExecutor = Some(2)) + val appInfo3 = makeAppInfo(256, coresPerExecutor = Some(3)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + // Each worker should end up with 4 executors with 2 cores each + // This should be 4 because of the memory restriction on each worker + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 8) + assert(scheduledCores(2) === 8) + // Now test the same thing without running into the worker memory limit + // Each worker should now end up with 5 executors with 2 cores each + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 10) + // Now test the same thing with a cores per executor that 10 is not divisible by + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo3, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 9) + assert(scheduledCores(1) === 9) + assert(scheduledCores(2) === 9) + } + + // Sorry for the long method name! + private def testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(256, coresPerExecutor = Some(2), maxCores = Some(4)) + val appInfo2 = makeAppInfo(256, coresPerExecutor = Some(2), maxCores = Some(20)) + val appInfo3 = makeAppInfo(256, coresPerExecutor = Some(3), maxCores = Some(20)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + // We should only launch two executors, each with exactly 2 cores + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 2) + assert(scheduledCores(1) === 2) + assert(scheduledCores(2) === 0) + } else { + assert(scheduledCores(0) === 4) + assert(scheduledCores(1) === 0) + assert(scheduledCores(2) === 0) + } + // Test max cores > number of cores per worker + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 6) + } else { + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 0) + } + // Test max cores > number of cores per worker AND + // a cores per executor that is 10 is not divisible by + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo3, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 6) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 6) + } else { + assert(scheduledCores(0) === 9) + assert(scheduledCores(1) === 9) + assert(scheduledCores(2) === 0) + } + } + + // =============================== + // | Utility methods for testing | + // =============================== + + private val _scheduleExecutorsOnWorkers = PrivateMethod[Array[Int]]('scheduleExecutorsOnWorkers) + + private def makeMaster(conf: SparkConf = new SparkConf): Master = { + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 7077, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 8080, securityMgr, conf) + master + } + + private def makeAppInfo( + memoryPerExecutorMb: Int, + coresPerExecutor: Option[Int] = None, + maxCores: Option[Int] = None): ApplicationInfo = { + val desc = new ApplicationDescription( + "test", maxCores, memoryPerExecutorMb, null, "", None, None, coresPerExecutor) + val appId = System.currentTimeMillis.toString + new ApplicationInfo(0, appId, desc, new Date, null, Int.MaxValue) + } + + private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { + val workerId = System.currentTimeMillis.toString + new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, 101, "address") + } + } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index dfa102f432a02..1321ec84735b5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -282,6 +282,29 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { )) } + // See SPARK-9326 + test("cogroup with empty RDD") { + import scala.reflect.classTag + val intPairCT = classTag[(Int, Int)] + + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.emptyRDD[(Int, Int)](intPairCT) + + val joined = rdd1.cogroup(rdd2).collect() + assert(joined.size > 0) + } + + // See SPARK-9326 + test("cogroup with groupByed RDD having 0 partitions") { + import scala.reflect.classTag + val intCT = classTag[Int] + + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.emptyRDD[Int](intCT).groupBy((x) => 5) + val joined = rdd1.cogroup(rdd2).collect() + assert(joined.size > 0) + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala new file mode 100644 index 0000000000000..bc9f3708ed69d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.io.{Output, Input} +import org.apache.avro.{SchemaBuilder, Schema} +import org.apache.avro.generic.GenericData.Record + +import org.apache.spark.{SparkFunSuite, SharedSparkContext} + +class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + val schema : Schema = SchemaBuilder + .record("testRecord").fields() + .requiredString("data") + .endRecord() + val record = new Record(schema) + record.put("data", "test data") + + test("schema compression and decompression") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) + } + + test("record serialization and deserialization") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + + val outputStream = new ByteArrayOutputStream() + val output = new Output(outputStream) + genericSer.serializeDatum(record, output) + output.flush() + output.close() + + val input = new Input(new ByteArrayInputStream(outputStream.toByteArray)) + assert(genericSer.deserializeDatum(input) === record) + } + + test("uses schema fingerprint to decrease message size") { + val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) + + val output = new Output(new ByteArrayOutputStream()) + + val beginningNormalPosition = output.total() + genericSerFull.serializeDatum(record, output) + output.flush() + val normalLength = output.total - beginningNormalPosition + + conf.registerAvroSchemas(schema) + val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) + val beginningFingerprintPosition = output.total() + genericSerFinger.serializeDatum(record, output) + val fingerprintLength = output.total - beginningFingerprintPosition + + assert(fingerprintLength < normalLength) + } + + test("caches previously seen schemas") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + val compressedSchema = genericSer.compress(schema) + val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + + assert(compressedSchema.eq(genericSer.compress(schema))) + assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 96778c9ebafb1..f495b6a037958 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,26 +17,39 @@ package org.apache.spark.shuffle +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.CountDownLatch -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, TaskContext} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { + + val nextTaskAttemptId = new AtomicInteger() + /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { override def run() { - body + try { + val taskAttemptId = nextTaskAttemptId.getAndIncrement + val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) + when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) + TaskContext.setTaskContext(mockTaskContext) + body + } finally { + TaskContext.unset() + } } } thread.start() thread } - test("single thread requesting memory") { + test("single task requesting memory") { val manager = new ShuffleMemoryManager(1000L) assert(manager.tryToAcquire(100L) === 100L) @@ -50,7 +63,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(manager.tryToAcquire(300L) === 300L) assert(manager.tryToAcquire(300L) === 200L) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() assert(manager.tryToAcquire(1000L) === 1000L) assert(manager.tryToAcquire(100L) === 0L) } @@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } - test("threads cannot grow past 1 / N") { - // Two threads request 250 bytes first, wait for each other to get it, and then request + test("tasks cannot grow past 1 / N") { + // Two tasks request 250 bytes first, wait for each other to get it, and then request // 500 more; we should only grant 250 bytes to each of them on this second request val manager = new ShuffleMemoryManager(1000L) @@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(state.t2Result2 === 250L) } - test("threads can block to get at least 1 / 2N memory") { + test("tasks can block to get at least 1 / 2N memory") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases 250 bytes, which should then be granted to t2. Further requests // by t2 will return false right away because it now has 1 / 2N of the memory. @@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("releaseMemoryForThisThread") { + test("releaseMemoryForThisTask") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases all its memory. t2 should now be able to grab all the memory. @@ -251,9 +264,9 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise + // sure the other task blocks for some time otherwise Thread.sleep(300) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() } val t2 = startThread("t2") { @@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { t2.join() } - // Both threads should've been able to acquire their memory; the second one will have waited + // Both tasks should've been able to acquire their memory; the second one will have waited // until the first one acquired 1000 bytes and then released all of it state.synchronized { assert(state.t1Result === 1000L, "t1 could not allocate memory") @@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("threads should not be granted a negative size") { + test("tasks should not be granted a negative size") { val manager = new ShuffleMemoryManager(1000L) manager.tryToAcquire(700L) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bcee901f5dd5f..f480fd107a0c2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1004,32 +1004,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Reserve - memoryStore.reserveUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 100) - memoryStore.reserveUnrollMemoryForThisThread(200) - assert(memoryStore.currentUnrollMemoryForThisThread === 300) - memoryStore.reserveUnrollMemoryForThisThread(500) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) - memoryStore.reserveUnrollMemoryForThisThread(1000000) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted + memoryStore.reserveUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + memoryStore.reserveUnrollMemoryForThisTask(200) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + memoryStore.reserveUnrollMemoryForThisTask(500) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + memoryStore.reserveUnrollMemoryForThisTask(1000000) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 700) - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 600) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again - memoryStore.reserveUnrollMemoryForThisThread(4400) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) - memoryStore.reserveUnrollMemoryForThisThread(20000) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted + memoryStore.reserveUnrollMemoryForThisTask(4400) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + memoryStore.reserveUnrollMemoryForThisTask(20000) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again - memoryStore.releaseUnrollMemoryForThisThread(1000) - assert(memoryStore.currentUnrollMemoryForThisThread === 4000) - memoryStore.releaseUnrollMemoryForThisThread() // release all - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.releaseUnrollMemoryForThisTask(1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask() // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) } /** @@ -1060,24 +1060,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) val memoryStore = store.memoryStore val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with all the space in the world. This should succeed and return an array. var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) - memoryStore.releasePendingUnrollMemoryForThisThread() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll with not enough space. This should succeed after kicking out someBlock1. store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) droppedBlocks.clear() - memoryStore.releasePendingUnrollMemoryForThisThread() + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. @@ -1085,7 +1085,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) droppedBlocks.clear() @@ -1099,7 +1099,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with plenty of space. This should succeed and cache both blocks. val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) @@ -1110,7 +1110,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(result2.size > 0) assert(result1.data.isLeft) // unroll did not drop this block to disk assert(result2.data.isLeft) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Re-put these two blocks so block manager knows about them too. Otherwise, block manager // would not know how to drop them from memory later. @@ -1126,7 +1126,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b1")) assert(memoryStore.contains("b2")) assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.remove("b3") store.putIterator("b3", smallIterator, memOnly) @@ -1138,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } /** @@ -1153,7 +1153,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) store.putIterator("b1", smallIterator, memAndDisk) store.putIterator("b2", smallIterator, memAndDisk) @@ -1170,7 +1170,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!diskStore.contains("b3")) memoryStore.remove("b3") store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll huge block with not enough space. This should fail and drop the new block to disk // directly in addition to kicking out b2 in the process. Memory store should contain only @@ -1186,7 +1186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(diskStore.contains("b2")) assert(!diskStore.contains("b3")) assert(diskStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } test("multiple unrolls by the same thread") { @@ -1195,32 +1195,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // All unroll memory used is released because unrollSafely returned an array memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll memory is not released because unrollSafely returned an iterator // that still depends on the underlying vector used in the process memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB3 > 0) // The unroll memory owned by this thread builds on top of its value after the previous unrolls memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) // ... but only to a certain extent (until we run out of free space to grant new unroll memory) memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dc03e374b51db..26a2e96edaaa2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -17,22 +17,29 @@ package org.apache.spark.util.collection.unsafe.sort +import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks - import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { test("String prefix comparator") { def testPrefixComparison(s1: String, s2: String): Unit = { - val s1Prefix = PrefixComparators.STRING.computePrefix(s1) - val s2Prefix = PrefixComparators.STRING.computePrefix(s2) + val utf8string1 = UTF8String.fromString(s1) + val utf8string2 = UTF8String.fromString(s2) + val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + + val cmp = UnsignedBytes.lexicographicalComparator().compare( + utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) + assert( - (prefixComparisonResult == 0) || - (prefixComparisonResult < 0 && s1 < s2) || - (prefixComparisonResult > 0 && s1 > s2)) + (prefixComparisonResult == 0 && cmp == 0) || + (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || + (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) } // scalastyle:off @@ -48,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } - test("float prefix comparator handles NaN properly") { - val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) - val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) - assert(nan1.isNaN) - assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) - val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) - assert(nan1Prefix === nan2Prefix) - val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) - assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) - } - test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) - val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) - val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } diff --git a/dev/change-scala-version.sh b/dev/change-scala-version.sh index b81c00c9d6d9d..d7975dfb6475c 100755 --- a/dev/change-scala-version.sh +++ b/dev/change-scala-version.sh @@ -19,19 +19,23 @@ set -e +VALID_VERSIONS=( 2.10 2.11 ) + usage() { - echo "Usage: $(basename $0) " 1>&2 + echo "Usage: $(basename $0) [-h|--help] +where : + -h| --help Display this help text + valid version values : ${VALID_VERSIONS[*]} +" 1>&2 exit 1 } -if [ $# -ne 1 ]; then +if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then usage fi TO_VERSION=$1 -VALID_VERSIONS=( 2.10 2.11 ) - check_scala_version() { for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh new file mode 100755 index 0000000000000..0962d34c52f28 --- /dev/null +++ b/dev/change-version-to-2.10.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script exists for backwards compability. Use change-scala-version.sh instead. +echo "This script is deprecated. Please instead run: change-scala-version.sh 2.10" + +$(dirname $0)/change-scala-version.sh 2.10 diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh new file mode 100755 index 0000000000000..4ccfeef09fd04 --- /dev/null +++ b/dev/change-version-to-2.11.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script exists for backwards compability. Use change-scala-version.sh instead. +echo "This script is deprecated. Please instead run: change-scala-version.sh 2.11" + +$(dirname $0)/change-scala-version.sh 2.11 diff --git a/dev/lint-python b/dev/lint-python index e02dff220eb87..575dbb0ae321b 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -58,21 +58,21 @@ export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" export "PYLINT_HOME=$PYTHONPATH" export "PATH=$PYTHONPATH:$PATH" -if [ ! -d "$PYLINT_HOME" ]; then - mkdir "$PYLINT_HOME" - # Redirect the annoying pylint installation output. - easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" - easy_install_status="$?" - - if [ "$easy_install_status" -ne 0 ]; then - echo "Unable to install pylint locally in \"$PYTHONPATH\"." - cat "$PYLINT_INSTALL_INFO" - exit "$easy_install_status" - fi - - rm "$PYLINT_INSTALL_INFO" - -fi +# if [ ! -d "$PYLINT_HOME" ]; then +# mkdir "$PYLINT_HOME" +# # Redirect the annoying pylint installation output. +# easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" +# easy_install_status="$?" +# +# if [ "$easy_install_status" -ne 0 ]; then +# echo "Unable to install pylint locally in \"$PYTHONPATH\"." +# cat "$PYLINT_INSTALL_INFO" +# exit "$easy_install_status" +# fi +# +# rm "$PYLINT_INSTALL_INFO" +# +# fi # There is no need to write this output to a file #+ first, but we do so so that the check status can @@ -96,19 +96,19 @@ fi rm "$PEP8_REPORT_PATH" -for to_be_checked in "$PATHS_TO_CHECK" -do - pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" -done +# for to_be_checked in "$PATHS_TO_CHECK" +# do +# pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" +# done -if [ "${PIPESTATUS[0]}" -ne 0 ]; then - lint_status=1 - echo "Pylint checks failed." - cat "$PYLINT_REPORT_PATH" -else - echo "Pylint checks passed." -fi +# if [ "${PIPESTATUS[0]}" -ne 0 ]; then +# lint_status=1 +# echo "Pylint checks failed." +# cat "$PYLINT_REPORT_PATH" +# else +# echo "Pylint checks passed." +# fi -rm "$PYLINT_REPORT_PATH" +# rm "$PYLINT_REPORT_PATH" exit "$lint_status" diff --git a/dev/run-tests.py b/dev/run-tests.py index 1f0d218514f92..29420da9aa956 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -85,6 +85,13 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe return [f for f in raw_output.split('\n') if f] +def setup_test_environ(environ): + print("[info] Setup the following environment variables for tests: ") + for (k, v) in environ.items(): + print("%s=%s" % (k, v)) + os.environ[k] = v + + def determine_modules_to_test(changed_modules): """ Given a set of modules that have changed, compute the transitive closure of those modules' @@ -455,6 +462,15 @@ def main(): print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) + # setup environment variables + # note - the 'root' module doesn't collect environment variables for all modules. Because the + # environment variables should not be set if a module is not changed, even if running the 'root' + # module. So here we should use changed_modules rather than test_modules. + test_environ = {} + for m in changed_modules: + test_environ.update(m.environ) + setup_test_environ(test_environ) + test_modules = determine_modules_to_test(changed_modules) # license checks diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 3073d489bad4a..44600cb9523c1 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -29,7 +29,7 @@ class Module(object): changed. """ - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), should_run_r_tests=False): """ @@ -43,6 +43,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= filename strings. :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in order to build and test this module (e.g. '-PprofileName'). + :param environ: A dict of environment variables that should be set when files in this + module are changed. :param sbt_test_goals: A set of SBT test goals for testing this module. :param python_test_goals: A set of Python test goals for testing this module. :param blacklisted_python_implementations: A set of Python implementations that are not @@ -55,6 +57,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.source_file_prefixes = source_file_regexes self.sbt_test_goals = sbt_test_goals self.build_profile_flags = build_profile_flags + self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations self.should_run_r_tests = should_run_r_tests @@ -126,15 +129,22 @@ def contains_file(self, filename): ) +# Don't set the dependencies because changes in other modules should not trigger Kinesis tests. +# Kinesis tests depends on external Amazon kinesis service. We should run these tests only when +# files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't +# fail other PRs. streaming_kinesis_asl = Module( name="kinesis-asl", - dependencies=[streaming], + dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", ], build_profile_flags=[ "-Pkinesis-asl", ], + environ={ + "ENABLE_KINESIS_TESTS": "1" + }, sbt_test_goals=[ "kinesis-asl/test", ] @@ -313,7 +323,7 @@ def contains_file(self, filename): "pyspark.mllib.evaluation", "pyspark.mllib.feature", "pyspark.mllib.fpm", - "pyspark.mllib.linalg", + "pyspark.mllib.linalg.__init__", "pyspark.mllib.random", "pyspark.mllib.recommendation", "pyspark.mllib.regression", diff --git a/docs/configuration.md b/docs/configuration.md index 200f3cd212e46..fd236137cb96e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -203,7 +203,7 @@ Apart from these, the following properties are also available, and may be useful spark.driver.extraClassPath (none) - Extra classpath entries to append to the classpath of the driver. + Extra classpath entries to prepend to the classpath of the driver.
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -250,7 +250,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to append to the classpath of executors. This exists primarily for + Extra classpath entries to prepend to the classpath of executors. This exists primarily for backwards-compatibility with older versions of Spark. Users typically should not need to set this option. diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 3f10cb2dc3d2a..99f8c827f767f 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -800,7 +800,7 @@ import org.apache.spark.graphx._ // Import random graph generation library import org.apache.spark.graphx.util.GraphGenerators // A graph with edge attributes containing distances -val graph: Graph[Int, Double] = +val graph: Graph[Long, Double] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble) val sourceId: VertexId = 42 // The ultimate source // Initialize the graph such that all vertices except the root have distance infinity. diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md new file mode 100644 index 0000000000000..4ca0bb06b26a6 --- /dev/null +++ b/docs/mllib-evaluation-metrics.md @@ -0,0 +1,1497 @@ +--- +layout: global +title: Evaluation Metrics - MLlib +displayTitle: MLlib - Evaluation Metrics +--- + +* Table of contents +{:toc} + +Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance +of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +suite of metrics for the purpose of evaluating the performance of machine learning models. + +Specific machine learning algorithms fall under broader types of machine learning applications like classification, +regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +metrics that are currently available in Spark's MLlib are detailed in this section. + +## Classification model evaluation + +While there are many different types of classification algorithms, the evaluation of classification models all share +similar principles. In a [supervised classification problem](https://en.wikipedia.org/wiki/Statistical_classification), +there exists a true output and a model-generated predicted output for each data point. For this reason, the results for +each data point can be assigned to one of four categories: + +* True Positive (TP) - label is positive and prediction is also positive +* True Negative (TN) - label is negative and prediction is also negative +* False Positive (FP) - label is negative but prediction is positive +* False Negative (FN) - label is positive but prediction is negative + +These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering +classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The +reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from +a dataset where 95% of the data points are _not fraud_ and 5% of the data points are _fraud_, then a naive classifier +that predicts _not fraud_, regardless of input, will be 95% accurate. For this reason, metrics like +[precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) are typically used because they take into +account the *type* of error. In most applications there is some desired balance between precision and recall, which can +be captured by combining the two into a single metric, called the [F-measure](https://en.wikipedia.org/wiki/F1_score). + +### Binary classification + +[Binary classifiers](https://en.wikipedia.org/wiki/Binary_classification) are used to separate the elements of a given +dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification. +Most binary classification metrics can be generalized to multiclass classification metrics. + +#### Threshold tuning + +It is import to understand that many classification models actually output a "score" (often times a probability) for +each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for +each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where +the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a +credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction *threshold* +which determines what the predicted class will be based on the probabilities that the model outputs. + +Tuning the prediction threshold will change the precision and recall of the model and is an important part of model +optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is +common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision, +recall) points for different threshold values, while a +[receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), or ROC, curve +plots (recall, false positive rate) points. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Precision (Postive Predictive Value)$PPV=\frac{TP}{TP + FP}$
Recall (True Positive Rate)$TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$
F-measure$F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR} + {\beta^2 \cdot PPV + TPR}\right)$
Receiver Operating Characteristic (ROC)$FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$
Area Under ROC Curve$AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$
Area Under Precision-Recall Curve$AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$
+ + +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the +data, and evaluate the performance of the algorithm by several binary evaluation metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + +// Clear the prediction threshold so the model will return probabilities +model.clearThreshold + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new BinaryClassificationMetrics(predictionAndLabels) + +// Precision by threshold +val precision = metrics.precisionByThreshold +precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") +} + +// Recall by threshold +val recall = metrics.precisionByThreshold +recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") +} + +// Precision-Recall Curve +val PRC = metrics.pr + +// F-measure +val f1Score = metrics.fMeasureByThreshold +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") +} + +val beta = 0.5 +val fScore = metrics.fMeasureByThreshold(beta) +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") +} + +// AUPRC +val auPRC = metrics.areaUnderPR +println("Area under precision-recall curve = " + auPRC) + +// Compute thresholds used in ROC and PR curves +val thresholds = precision.map(_._1) + +// ROC Curve +val roc = metrics.roc + +// AUROC +val auROC = metrics.areaUnderROC +println("Area under ROC = " + auROC) + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class BinaryClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call (Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Several of the methods available in scala are currently missing from pyspark + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = BinaryClassificationMetrics(predictionAndLabels) + +# Area under precision-recall curve +print "Area under PR = %s" % metrics.areaUnderPR + +# Area under ROC curve +print "Area under ROC = %s" % metrics.areaUnderROC + +{% endhighlight %} + +
+
+ + +### Multiclass classification + +A [multiclass classification](https://en.wikipedia.org/wiki/Multiclass_classification) describes a classification +problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary +classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes. + +For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still +be positive or negative, but they must be considered under the context of a particular class. Each label and prediction +take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative +for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative +occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be +multiple true negatives for a given data sample. The extension of false negatives and false positives from the former +definitions of positive and negative labels is straightforward. + +#### Label based metrics + +Opposed to binary classification where there are only two possible labels, multiclass classification problems have many +possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +labels - the number of times any class was predicted correctly (true positives) normalized by the number of data +points. Precision by label considers only one class, and measures the number of time a specific label was predicted +correctly normalized by the number of times that label appears in the output. + +**Available metrics** + +Define the class, or label, set as + +$$L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \} $$ + +The true output vector $\mathbf{y}$ consists of $N$ elements + +$$\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L $$ + +A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements + +$$\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L $$ + +For this section, a modified delta function $\hat{\delta}(x)$ will prove useful + +$$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Confusion Matrix + $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\ + \left( \begin{array}{ccc} + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\ + \vdots & \ddots & \vdots \\ + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) + \end{array} \right)$ +
Overall Precision$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
Overall Recall$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
Overall F1-measure$F1 = 2 \cdot \left(\frac{PPV \cdot TPR} + {PPV + TPR}\right)$
Precision by label$PPV(\ell) = \frac{TP}{TP + FP} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$
Recall by label$TPR(\ell)=\frac{TP}{P} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$
F-measure by label$F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$
Weighted precision$PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
Weighted recall$TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
Weighted F-measure$F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
+ +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on +the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new MulticlassMetrics(predictionAndLabels) + +// Confusion matrix +println("Confusion matrix:") +println(metrics.confusionMatrix) + +// Overall Statistics +val precision = metrics.precision +val recall = metrics.recall // same as true positive rate +val f1Score = metrics.fMeasure +println("Summary Statistics") +println(s"Precision = $precision") +println(s"Recall = $recall") +println(s"F1 Score = $f1Score") + +// Precision by label +val labels = metrics.labels +labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) +} + +// Recall by label +labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) +} + +// False positive rate by label +labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) +} + +// F-measure by label +labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) +} + +// Weighted stats +println(s"Weighted precision: ${metrics.weightedPrecision}") +println(s"Weighted recall: ${metrics.weightedRecall}") +println(s"Weighted F1 score: ${metrics.weightedFMeasure}") +println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MulticlassClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = MulticlassMetrics(predictionAndLabels) + +# Overall statistics +precision = metrics.precision() +recall = metrics.recall() +f1Score = metrics.fMeasure() +print "Summary Stats" +print "Precision = %s" % precision +print "Recall = %s" % recall +print "F1 Score = %s" % f1Score + +# Statistics by class +labels = data.map(lambda lp: lp.label).distinct().collect() +for label in sorted(labels): + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + +# Weighted stats +print "Weighted recall = %s" % metrics.weightedRecall +print "Weighted precision = %s" % metrics.weightedPrecision +print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() +print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) +print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +{% endhighlight %} + +
+
+ +### Multilabel classification + +A [multilabel classification](https://en.wikipedia.org/wiki/Multi-label_classification) problem involves mapping +each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not +mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both +science and politics. + +Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label *sets*, rather +than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to +operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted +set and it exists in the true label set, for a specific data point. + +**Available metrics** + +Here we define a set $D$ of $N$ documents + +$$D = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +Define $L_0, L_1, ..., L_{N-1}$ to be a family of label sets and $P_0, P_1, ..., P_{N-1}$ +to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that +correspond to document $d_i$. + +The set of all unique labels is given by + +$$L = \bigcup_{k=0}^{N-1} L_k$$ + +The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary + +$$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Precision$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$
Recall$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$
Accuracy + $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|} + {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$ +
Precision by label$PPV(\ell)=\frac{TP}{TP + FP}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$
Recall by label$TPR(\ell)=\frac{TP}{P}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$
F1-measure by label$F1(\ell) = 2 + \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {PPV(\ell) + TPR(\ell)}\right)$
Hamming Loss + $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i + \cap P_i\right|$ +
Subset Accuracy$\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$
F1 Measure$\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$
Micro precision$\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$
Micro recall$\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$
Micro F1 Measure + $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot + \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1} + \left|P_i - L_i\right|}$ +
+ +**Examples** + +The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +use the fake prediction and label data for multilabel classification that is shown below. + +Document predictions: + +* doc 0 - predict 0, 1 - class 0, 2 +* doc 1 - predict 0, 2 - class 0, 1 +* doc 2 - predict none - class 0 +* doc 3 - predict 2 - class 2 +* doc 4 - predict 2, 0 - class 2, 0 +* doc 5 - predict 0, 1, 2 - class 0, 1 +* doc 6 - predict 1 - class 1, 2 + +Predicted classes: + +* class 0 - doc 0, 1, 4, 5 (total 4) +* class 1 - doc 0, 5, 6 (total 3) +* class 2 - doc 1, 3, 4, 5 (total 4) + +True classes: + +* class 0 - doc 0, 1, 2, 4, 5 (total 5) +* class 1 - doc 1, 5, 6 (total 3) +* class 2 - doc 0, 3, 4, 6 (total 4) + +
+ +
+ +{% highlight scala %} +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD; + +val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + +// Instantiate metrics object +val metrics = new MultilabelMetrics(scoreAndLabels) + +// Summary stats +println(s"Recall = ${metrics.recall}") +println(s"Precision = ${metrics.precision}") +println(s"F1 measure = ${metrics.f1Measure}") +println(s"Accuracy = ${metrics.accuracy}") + +// Individual label stats +metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) +metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) +metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + +// Micro stats +println(s"Micro recall = ${metrics.microRecall}") +println(s"Micro precision = ${metrics.microPrecision}") +println(s"Micro F1 measure = ${metrics.microF1Measure}") + +// Hamming loss +println(s"Hamming loss = ${metrics.hammingLoss}") + +// Subset accuracy +println(s"Subset accuracy = ${metrics.subsetAccuracy}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.SparkConf; +import java.util.Arrays; +import java.util.List; + +public class MultilabelClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.evaluation import MultilabelMetrics + +scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + +# Instantiate metrics object +metrics = MultilabelMetrics(scoreAndLabels) + +# Summary stats +print "Recall = %s" % metrics.recall() +print "Precision = %s" % metrics.precision() +print "F1 measure = %s" % metrics.f1Measure() +print "Accuracy = %s" % metrics.accuracy + +# Individual label stats +labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() +for label in labels: + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + +# Micro stats +print "Micro precision = %s" % metrics.microPrecision +print "Micro recall = %s" % metrics.microRecall +print "Micro F1 measure = %s" % metrics.microF1Measure + +# Hamming loss +print "Hamming loss = %s" % metrics.hammingLoss + +# Subset accuracy +print "Subset accuracy = %s" % metrics.subsetAccuracy + +{% endhighlight %} + +
+
+ +### Ranking systems + +The role of a ranking algorithm (often thought of as a [recommender system](https://en.wikipedia.org/wiki/Recommender_system)) +is to return to the user a set of relevant items or documents based on some training data. The definition of relevance +may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these +rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth +set of relevant documents, while other metrics may incorporate numerical ratings explicitly. + +**Available metrics** + +A ranking system usually deals with a set of $M$ users + +$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$ + +Each user ($u_i$) having a set of $N$ ground truth relevant documents + +$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +And a list of $Q$ recommended documents, in order of decreasing relevance + +$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$ + +The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the +sets and the effectiveness of the algorithms can be measured using the metrics listed below. + +It is necessary to define a function which, provided a recommended document and a set of ground truth relevant +documents, returns a relevance score for the recommended document. + +$$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinitionNotes
+ Precision at k + + $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$ + + Precision at k is a measure of + how many of the first k recommended documents are in the set of true relevant documents averaged across all + users. In this metric, the order of the recommendations is not taken into account. +
Mean Average Precision + $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{\left|D_i\right|} \sum_{j=0}^{Q-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$ + + MAP is a measure of how + many of the recommended documents are in the set of true relevant documents, where the + order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher). +
Normalized Discounted Cumulative Gain + $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \text{Where} \\ + \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + + NDCG at k is a + measure of how many of the first k recommended documents are in the set of true relevant documents averaged + across all users. In contrast to precision at k, this metric takes into account the order of the recommendations + (documents are assumed to be in order of decreasing relevance). +
+ +**Examples** + +The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation +model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the +methodology is provided below. + +MovieLens ratings are on a scale of 1-5: + + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + +So we should not recommend a movie if the predicted rating is less than 3. +To map ratings to confidence scores, we use: + + * 5 -> 2.5 + * 4 -> 1.5 + * 3 -> 0.5 + * 2 -> -0.5 + * 1 -> -1.5. + +This mappings means unobserved entries are generally between It's okay and Fairly bad. The semantics of 0 in this +expanded world of non-positive weights are "the same as never having interacted at all." + +
+ +
+ +{% highlight scala %} +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} + +// Read in the ratings data +val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) +}.cache() + +// Map ratings to 1 or 0, 1 indicating a movie that should be recommended +val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() + +// Summarize ratings +val numRatings = ratings.count() +val numUsers = ratings.map(_.user).distinct().count() +val numMovies = ratings.map(_.product).distinct().count() +println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + +// Build the model +val numIterations = 10 +val rank = 10 +val lambda = 0.01 +val model = ALS.train(ratings, rank, numIterations, lambda) + +// Define a function to scale ratings from 0 to 1 +def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) +} + +// Get sorted top ten predictions for each user and then scale from [0, 1] +val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => + (user, recs.map(scaledRating)) +} + +// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document +// Compare with top ten most relevant documents +val userMovies = binarizedRatings.groupBy(_.user) +val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) +} + +// Instantiate metrics object +val metrics = new RankingMetrics(relevantDocuments) + +// Precision at K +Array(1, 3, 5).foreach{ k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") +} + +// Mean average precision +println(s"Mean average precision = ${metrics.meanAveragePrecision}") + +// Normalized discounted cumulative gain +Array(1, 3, 5).foreach{ k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") +} + +// Get predictions for each data point +val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) +val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) +val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => + (predicted, actual) +} + +// Get the RMSE using regression metrics +val regressionMetrics = new RegressionMetrics(predictionsAndLabels) +println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${regressionMetrics.r2}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function; +import java.util.*; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.Rating; + +// Read in the ratings data +public class Ranking { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } + else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics + +# Read in the ratings data +lines = sc.textFile("data/mllib/sample_movielens_data.txt") + +def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) +ratings = lines.map(lambda r: parseLine(r)) + +# Train a model on to predict user-product ratings +model = ALS.train(ratings, 10, 10, 0.01) + +# Get predicted ratings on all existing user-product pairs +testData = ratings.map(lambda p: (p.user, p.product)) +predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + +ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) +scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + +# Instantiate regression metrics to compare predicted and actual ratings +metrics = RegressionMetrics(scoreAndLabels) + +# Root mean sqaured error +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +{% endhighlight %} + +
+
+ +## Regression model evaluation + +[Regression analysis](https://en.wikipedia.org/wiki/Regression_analysis) is used when predicting a continuous output +variable from a number of independent variables. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Mean Squared Error (MSE)$MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$
Root Mean Squared Error (RMSE)$RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
Mean Absoloute Error (MAE)$MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
Coefficient of Determination $(R^2)$$R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1} + (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$
Explained Variance$1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$
+ +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data, +and evaluate the performance of the algorithm by several regression metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils + +// Load the data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + +// Build the model +val numIterations = 100 +val model = LinearRegressionWithSGD.train(data, numIterations) + +// Get predictions +val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) +} + +// Instantiate metrics object +val metrics = new RegressionMetrics(valuesAndPreds) + +// Squared error +println(s"MSE = ${metrics.meanSquaredError}") +println(s"RMSE = ${metrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${metrics.r2}") + +// Mean absolute error +println(s"MAE = ${metrics.meanAbsoluteError}") + +// Explained variance +println(s"Explained variance = ${metrics.explainedVariance}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; + +public class LinearRegression { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector + +# Load and parse the data +def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) + +data = sc.textFile("data/mllib/sample_linear_regression_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LinearRegressionWithSGD.train(parsedData) + +# Get predictions +valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + +# Instantiate metrics object +metrics = RegressionMetrics(valuesAndPreds) + +# Squared Error +print "MSE = %s" % metrics.meanSquaredError +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +# Mean absolute error +print "MAE = %s" % metrics.meanAbsoluteError + +# Explained variance +print "Explained variance = %s" % metrics.explainedVariance + +{% endhighlight %} + +
+
\ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index d2d1cc93fe006..eea864eacf7c4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -48,6 +48,7 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [Feature extraction and transformation](mllib-feature-extraction.html) * [Frequent pattern mining](mllib-frequent-pattern-mining.html) * FP-growth +* [Evaluation Metrics](mllib-evaluation-metrics.html) * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index de22ab557cacf..cac08a91b97d9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -68,9 +68,9 @@ In YARN terminology, executors and application masters run inside "containers". yarn logs -applicationId -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5838bc172fe86..95945eb7fc8a0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1332,13 +1332,8 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.filterPushdown - false - - Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Parquet 1.6.0rc3 (PARQUET-136). - However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn - this feature on. - + true + Enables Parquet filter push-down optimization when set to true. spark.sql.hive.convertMetastoreParquet diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 7c83d68e7993e..ccf922d9371fb 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -242,7 +242,7 @@ def parse_args(): help="Number of EBS volumes to attach to each node as /vol[x]. " + "The volumes will be deleted when the instances terminate. " + "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0." + + "EBS volumes are only attached if --ebs-vol-size > 0. " + "Only support up to 8 EBS volumes.") parser.add_option( "--placement-group", type="string", default=None, diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala new file mode 100644 index 0000000000000..8f144a4d974a8 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConversions._ +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark._ +import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.NextIterator + + +/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ +private[kinesis] +case class SequenceNumberRange( + streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + +/** Class representing an array of Kinesis sequence number ranges */ +private[kinesis] +case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) { + def isEmpty(): Boolean = ranges.isEmpty + def nonEmpty(): Boolean = ranges.nonEmpty + override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")") +} + +private[kinesis] +object SequenceNumberRanges { + def apply(range: SequenceNumberRange): SequenceNumberRanges = { + new SequenceNumberRanges(Array(range)) + } +} + + +/** Partition storing the information of the ranges of Kinesis sequence numbers to read */ +private[kinesis] +class KinesisBackedBlockRDDPartition( + idx: Int, + blockId: BlockId, + val isBlockIdValid: Boolean, + val seqNumberRanges: SequenceNumberRanges + ) extends BlockRDDPartition(blockId, idx) + +/** + * A BlockRDD where the block data is backed by Kinesis, which can accessed using the + * sequence numbers of the corresponding blocks. + */ +private[kinesis] +class KinesisBackedBlockRDD( + sc: SparkContext, + regionId: String, + endpointUrl: String, + @transient blockIds: Array[BlockId], + @transient arrayOfseqNumberRanges: Array[SequenceNumberRanges], + @transient isBlockIdValid: Array[Boolean] = Array.empty, + retryTimeoutMs: Int = 10000, + awsCredentialsOption: Option[SerializableAWSCredentials] = None + ) extends BlockRDD[Array[Byte]](sc, blockIds) { + + require(blockIds.length == arrayOfseqNumberRanges.length, + "Number of blockIds is not equal to the number of sequence number ranges") + + override def isValid(): Boolean = true + + override def getPartitions: Array[Partition] = { + Array.tabulate(blockIds.length) { i => + val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) + new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val blockManager = SparkEnv.get.blockManager + val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] + val blockId = partition.blockId + + def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + logDebug(s"Read partition data of $this from block manager, block $blockId") + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + } + + def getBlockFromKinesis(): Iterator[Array[Byte]] = { + val credenentials = awsCredentialsOption.getOrElse { + new DefaultAWSCredentialsProviderChain().getCredentials() + } + partition.seqNumberRanges.ranges.iterator.flatMap { range => + new KinesisSequenceRangeIterator( + credenentials, endpointUrl, regionId, range, retryTimeoutMs) + } + } + if (partition.isBlockIdValid) { + getBlockFromBlockManager().getOrElse { getBlockFromKinesis() } + } else { + getBlockFromKinesis() + } + } +} + + +/** + * An iterator that return the Kinesis data based on the given range of sequence numbers. + * Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber, + * until the endSequenceNumber is reached. + */ +private[kinesis] +class KinesisSequenceRangeIterator( + credentials: AWSCredentials, + endpointUrl: String, + regionId: String, + range: SequenceNumberRange, + retryTimeoutMs: Int + ) extends NextIterator[Array[Byte]] with Logging { + + private val client = new AmazonKinesisClient(credentials) + private val streamName = range.streamName + private val shardId = range.shardId + + private var toSeqNumberReceived = false + private var lastSeqNumber: String = null + private var internalIterator: Iterator[Record] = null + + client.setEndpoint(endpointUrl, "kinesis", regionId) + + override protected def getNext(): Array[Byte] = { + var nextBytes: Array[Byte] = null + if (toSeqNumberReceived) { + finished = true + } else { + + if (internalIterator == null) { + + // If the internal iterator has not been initialized, + // then fetch records from starting sequence number + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + } else if (!internalIterator.hasNext) { + + // If the internal iterator does not have any more records, + // then fetch more records after the last consumed sequence number + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + } + + if (!internalIterator.hasNext) { + + // If the internal iterator still does not have any data, then throw exception + // and terminate this iterator + finished = true + throw new SparkException( + s"Could not read until the end sequence number of the range: $range") + } else { + + // Get the record, copy the data into a byte array and remember its sequence number + val nextRecord: Record = internalIterator.next() + val byteBuffer = nextRecord.getData() + nextBytes = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(nextBytes) + lastSeqNumber = nextRecord.getSequenceNumber() + + // If the this record's sequence number matches the stopping sequence number, then make sure + // the iterator is marked finished next time getNext() is called + if (nextRecord.getSequenceNumber == range.toSeqNumber) { + toSeqNumberReceived = true + } + } + + } + nextBytes + } + + override protected def close(): Unit = { + client.shutdown() + } + + /** + * Get records starting from or after the given sequence number. + */ + private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + val shardIterator = getKinesisIterator(iteratorType, seqNum) + val result = getRecordsAndNextKinesisIterator(shardIterator) + result._1 + } + + /** + * Get the records starting from using a Kinesis shard iterator (which is a progress handle + * to get records from Kinesis), and get the next shard iterator for next consumption. + */ + private def getRecordsAndNextKinesisIterator( + shardIterator: String): (Iterator[Record], String) = { + val getRecordsRequest = new GetRecordsRequest + getRecordsRequest.setRequestCredentials(credentials) + getRecordsRequest.setShardIterator(shardIterator) + val getRecordsResult = retryOrTimeout[GetRecordsResult]( + s"getting records using shard iterator") { + client.getRecords(getRecordsRequest) + } + (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator) + } + + /** + * Get the Kinesis shard iterator for getting records starting from or after the given + * sequence number. + */ + private def getKinesisIterator( + iteratorType: ShardIteratorType, + sequenceNumber: String): String = { + val getShardIteratorRequest = new GetShardIteratorRequest + getShardIteratorRequest.setRequestCredentials(credentials) + getShardIteratorRequest.setStreamName(streamName) + getShardIteratorRequest.setShardId(shardId) + getShardIteratorRequest.setShardIteratorType(iteratorType.toString) + getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber) + val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult]( + s"getting shard iterator from sequence number $sequenceNumber") { + client.getShardIterator(getShardIteratorRequest) + } + getShardIteratorResult.getShardIterator + } + + /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ + private def retryOrTimeout[T](message: String)(body: => T): T = { + import KinesisSequenceRangeIterator._ + + var startTimeMs = System.currentTimeMillis() + var retryCount = 0 + var waitTimeMs = MIN_RETRY_WAIT_TIME_MS + var result: Option[T] = None + var lastError: Throwable = null + + def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs + def isMaxRetryDone = retryCount >= MAX_RETRIES + + while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { + if (retryCount > 0) { // wait only if this is a retry + Thread.sleep(waitTimeMs) + waitTimeMs *= 2 // if you have waited, then double wait time for next round + } + try { + result = Some(body) + } catch { + case NonFatal(t) => + lastError = t + t match { + case ptee: ProvisionedThroughputExceededException => + logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee) + case e: Throwable => + throw new SparkException(s"Error while $message", e) + } + } + retryCount += 1 + } + result.getOrElse { + if (isTimedOut) { + throw new SparkException( + s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) + } else { + throw new SparkException( + s"Gave up after $retryCount retries while $message, last exception: ", lastError) + } + } + } +} + +private[streaming] +object KinesisSequenceRangeIterator { + val MAX_RETRIES = 3 + val MIN_RETRY_WAIT_TIME_MS = 100 +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index f6bf552e6bb8e..ca39358b75cb6 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -53,6 +53,8 @@ private class KinesisTestUtils( @volatile private var streamCreated = false + + @volatile private var _streamName: String = _ private lazy val kinesisClient = { @@ -115,21 +117,9 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } - def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = { - try { - val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) - val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() - Some(desc) - } catch { - case rnfe: ResourceNotFoundException => - None - } - } - def deleteStream(): Unit = { try { - if (describeStream().nonEmpty) { - val deleteStreamRequest = new DeleteStreamRequest() + if (streamCreated) { kinesisClient.deleteStream(streamName) } } catch { @@ -149,6 +139,17 @@ private class KinesisTestUtils( } } + private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + private def findNonExistentStreamName(): String = { var testStreamName: String = null do { @@ -177,7 +178,7 @@ private class KinesisTestUtils( private[kinesis] object KinesisTestUtils { - val envVarName = "RUN_KINESIS_TESTS" + val envVarName = "ENABLE_KINESIS_TESTS" val shouldRunTests = sys.env.get(envVarName) == Some("1") diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala new file mode 100644 index 0000000000000..e81fb11e5959f --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.{SparkConf, SparkContext, SparkException} + +class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { + + private val regionId = "us-east-1" + private val endpointUrl = "https://kinesis.us-east-1.amazonaws.com" + private val testData = 1 to 8 + + private var testUtils: KinesisTestUtils = null + private var shardIds: Seq[String] = null + private var shardIdToData: Map[String, Seq[Int]] = null + private var shardIdToSeqNumbers: Map[String, Seq[String]] = null + private var shardIdToDataAndSeqNumbers: Map[String, Seq[(Int, String)]] = null + private var shardIdToRange: Map[String, SequenceNumberRange] = null + private var allRanges: Seq[SequenceNumberRange] = null + + private var sc: SparkContext = null + private var blockManager: BlockManager = null + + + override def beforeAll(): Unit = { + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KinesisTestUtils(endpointUrl) + testUtils.createStream() + + shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") + + shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq + shardIdToData = shardIdToDataAndSeqNumbers.mapValues { _.map { _._1 }} + shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} + shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => + val seqNumRange = SequenceNumberRange( + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + (shardId, seqNumRange) + } + allRanges = shardIdToRange.values.toSeq + + val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite") + sc = new SparkContext(conf) + blockManager = sc.env.blockManager + } + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.deleteStream() + } + if (sc != null) { + sc.stop() + } + } + + testIfEnabled("Basic reading from Kinesis") { + // Verify all data using multiple ranges in a single RDD partition + val receivedData1 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(1), + Array(SequenceNumberRanges(allRanges.toArray)) + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData1.toSet === testData.toSet) + + // Verify all data using one range in each of the multiple RDD partitions + val receivedData2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData2.toSet === testData.toSet) + + // Verify ordering within each partition + val receivedData3 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collectPartitions() + assert(receivedData3.length === allRanges.size) + for (i <- 0 until allRanges.size) { + assert(receivedData3(i).toSeq === shardIdToData(allRanges(i).shardId)) + } + } + + testIfEnabled("Read data available in both block manager and Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available only in block manager, not in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0) + } + + testIfEnabled("Read data available only in Kinesis, not in block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 0, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available partially in block manager, rest in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 1, numPartitionsInKinesis = 1) + } + + testIfEnabled("Test isBlockValid skips block fetching from block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0, + testIsBlockValid = true) + } + + testIfEnabled("Test whether RDD is valid after removing blocks from block anager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2, + testBlockRemove = true) + } + + /** + * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager + * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * It can also test if the partitions that were read from the log were again stored in + * block manager. + * + * + * + * @param numPartitions Number of partitions in RDD + * @param numPartitionsInBM Number of partitions to write to the BlockManager. + * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager + * @param numPartitionsInKinesis Number of partitions to write to the Kinesis. + * Partitions (numPartitions - 1 - numPartitionsInKinesis) to + * (numPartitions - 1) will be written to Kinesis + * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching + * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with + * reads falling back to the WAL + * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4 + * + * numPartitionsInBM = 3 + * |------------------| + * | | + * 0 1 2 3 4 + * | | + * |-------------------------| + * numPartitionsInKinesis = 4 + */ + private def testRDD( + numPartitions: Int, + numPartitionsInBM: Int, + numPartitionsInKinesis: Int, + testIsBlockValid: Boolean = false, + testBlockRemove: Boolean = false + ): Unit = { + require(shardIds.size > 1, "Need at least 2 shards to test") + require(numPartitionsInBM <= shardIds.size , + "Number of partitions in BlockManager cannot be more than the Kinesis test shards available") + require(numPartitionsInKinesis <= shardIds.size , + "Number of partitions in Kinesis cannot be more than the Kinesis test shards available") + require(numPartitionsInBM <= numPartitions, + "Number of partitions in BlockManager cannot be more than that in RDD") + require(numPartitionsInKinesis <= numPartitions, + "Number of partitions in Kinesis cannot be more than that in RDD") + + // Put necessary blocks in the block manager + val blockIds = fakeBlockIds(numPartitions) + blockIds.foreach(blockManager.removeBlock(_)) + (0 until numPartitionsInBM).foreach { i => + val blockData = shardIdToData(shardIds(i)).iterator.map { _.toString.getBytes() } + blockManager.putIterator(blockIds(i), blockData, StorageLevel.MEMORY_ONLY) + } + + // Create the necessary ranges to use in the RDD + val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + val realRanges = Array.tabulate(numPartitionsInKinesis) { i => + val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) + SequenceNumberRanges(Array(range)) + } + val ranges = (fakeRanges ++ realRanges) + + + // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not + require( + blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty), + "Expected blocks not in BlockManager" + ) + + require( + blockIds.drop(numPartitionsInBM).forall(blockManager.get(_).isEmpty), + "Unexpected blocks in BlockManager" + ) + + // Make sure that the right sequence `numPartitionsInKinesis` are configured, and others are not + require( + ranges.takeRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName == testUtils.streamName } + }, "Incorrect configuration of RDD, expected ranges not set: " + ) + + require( + ranges.dropRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName != testUtils.streamName } + }, "Incorrect configuration of RDD, unexpected ranges set" + ) + + val rdd = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds, ranges) + val collectedData = rdd.map { bytes => + new String(bytes).toInt + }.collect() + assert(collectedData.toSet === testData.toSet) + + // Verify that the block fetching is skipped when isBlockValid is set to false. + // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // Using that RDD will throw exception, as it skips block fetching even if the blocks are in + // in BlockManager. + if (testIsBlockValid) { + require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") + require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") + val rdd2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds.toArray, + ranges, isBlockIdValid = Array.fill(blockIds.length)(false)) + intercept[SparkException] { + rdd2.collect() + } + } + + // Verify that the RDD is not invalid after the blocks are removed and can still read data + // from write ahead log + if (testBlockRemove) { + require(numPartitions === numPartitionsInKinesis, + "All partitions must be in WAL for this test") + require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test") + rdd.removeBlocks() + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSet === testData.toSet) + } + } + + /** Generate fake block ids */ + private def fakeBlockIds(num: Int): Array[BlockId] = { + Array.tabulate(num) { i => new StreamBlockId(0, i) } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala index 6d011f295e7f7..8373138785a89 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -23,15 +23,24 @@ import org.apache.spark.SparkFunSuite * Helper class that runs Kinesis real data transfer tests or * ignores them based on env variable is set or not. */ -trait KinesisSuiteHelper { self: SparkFunSuite => +trait KinesisFunSuite extends SparkFunSuite { import KinesisTestUtils._ /** Run the test if environment variable is set or ignore the test */ - def testOrIgnore(testName: String)(testBody: => Unit) { + def testIfEnabled(testName: String)(testBody: => Unit) { if (shouldRunTests) { test(testName)(testBody) } else { ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody) } } + + /** Run the give body of code only if Kinesis tests are enabled */ + def runIfTestsEnabled(message: String)(body: => Unit): Unit = { + if (shouldRunTests) { + body + } else { + ignore(s"$message [enable by setting env var $envVarName=1]")() + } + } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 50f71413abf37..b88c9c6478d56 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper +class KinesisStreamSuite extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL uses to save metadata to DynamoDB @@ -83,16 +83,16 @@ class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper * you must have AWS credentials available through the default AWS provider chain, * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ - testOrIgnore("basic operation") { + testIfEnabled("basic operation") { val kinesisTestUtils = new KinesisTestUtils() try { kinesisTestUtils.createStream() ssc = new StreamingContext(sc, Seconds(1)) - val aWSCredentials = KinesisTestUtils.getAWSCredentials() + val awsCredentials = KinesisTestUtils.getAWSCredentials() val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, - aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey) + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index cfcf7244eaed5..2ca60d51f8331 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -127,28 +127,25 @@ object Pregel extends Logging { var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages. Vertices that didn't get any messages do not appear in newVerts. - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Update the graph with the new vertices. + // Receive the messages and update the vertices. prevG = g - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } - g.cache() + g = g.joinVertices(messages)(vprog).cache() val oldMessages = messages - // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't - // get to send messages. We must cache messages so it can be materialized on the next line, - // allowing us to uncache the previous iteration. - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache() - // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This - // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the - // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages + // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages + // and the vertices of g). activeMessages = messages.count() logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs oldMessages.unpersist(blocking = false) - newVerts.unpersist(blocking = false) prevG.unpersistVertices(blocking = false) prevG.edges.unpersist(blocking = false) // count the iteration diff --git a/make-distribution.sh b/make-distribution.sh index cac7032bb2e87..4789b0e09cc8a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.6.4" +TACHYON_VERSION="0.7.0" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 85c097bc64a4f..581d8fa7749be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -156,5 +156,5 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ - protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax + protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index ea757c5e40c76..1741f19dc911c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams { /** * param for the base binary classifier that we reduce multiclass classification into. + * The base classifier input and output columns are ignored in favor of + * the ones specified in [[OneVsRest]]. * @group param */ val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier") @@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String) set(classifier, value.asInstanceOf[ClassifierType]) } + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } @@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String) val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier - classifier.fit(trainingDataset, classifier.labelCol -> labelColName) + val paramMap = new ParamMap() + paramMap.put(classifier.labelCol -> labelColName) + paramMap.put(classifier.featuresCol -> getFeaturesCol) + paramMap.put(classifier.predictionCol -> getPredictionCol) + classifier.fit(trainingDataset, paramMap) }.toArray[ClassificationModel[_, _]] if (handlePersistence) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 38e832372698c..dad451108626d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -173,5 +173,5 @@ private[spark] abstract class ProbabilisticClassificationModel[ * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ - protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax + protected def probability2prediction(probability: Vector): Double = probability.argmax } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index fc0693f67cc2e..bc19bd6df894f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - new RandomForestClassificationModel(trees) + new RandomForestClassificationModel(trees, numClasses) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -125,8 +125,9 @@ object RandomForestClassifier { @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, - private val _trees: Array[DecisionTreeClassificationModel]) - extends PredictionModel[Vector, RandomForestClassificationModel] + private val _trees: Array[DecisionTreeClassificationModel], + override val numClasses: Int) + extends ClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel]) = - this(Identifiable.randomUID("rfc"), trees) + def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numClasses) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] @@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. - val votes = mutable.Map.empty[Int, Double] + val votes = new Array[Double](numClasses) _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt - votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight + votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } - votes.maxBy(_._2)._1 + Vectors.dense(votes) } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } override def toString: String = { @@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestClassifier, - categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { + categoricalFeatures: Map[Int, Int], + numClasses: Int): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees) + new RandomForestClassificationModel(uid, newTrees, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 3825942795645..9c60d4084ec46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => if (nominal.values.isDefined) { - nominal.values.map(_.map(v => inputColName + is + v)) + nominal.values } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) } else { None } case binary: BinaryAttribute => if (binary.values.isDefined) { - binary.values.map(_.map(v => inputColName + is + v)) + binary.values } else { - Some(Array.tabulate(2)(i => inputColName + is + i)) + Some(Array.tabulate(2)(_.toString)) } case _: NumericAttribute => throw new RuntimeException( @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer override def transform(dataset: DataFrame): DataFrame = { // schema transformation - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) val shouldDropLast = $(dropLast) @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer math.max(m0, m1) } ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames val outputAttrs: Array[Attribute] = filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f7b46efa10e90..d1726917e4517 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -17,17 +17,35 @@ package org.apache.spark.ml.feature +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.Transformer +import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +/** + * Base trait for [[RFormula]] and [[RFormulaModel]]. + */ +private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { + /** @group getParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group getParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + protected def hasLabelCol(schema: StructType): Boolean = { + schema.map(_.name).contains($(labelCol)) + } +} + /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently @@ -35,8 +53,7 @@ import org.apache.spark.sql.types._ * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental -class RFormula(override val uid: String) - extends Transformer with HasFeaturesCol with HasLabelCol { +class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { def this() = this(Identifiable.randomUID("rFormula")) @@ -62,19 +79,90 @@ class RFormula(override val uid: String) /** @group getParam */ def getFormula: String = $(formula) - /** @group getParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + /** Whether the formula specifies fitting an intercept. */ + private[ml] def hasIntercept: Boolean = { + require(parsedFormula.isDefined, "Must call setFormula() first.") + parsedFormula.get.hasIntercept + } - /** @group getParam */ - def setLabelCol(value: String): this.type = set(labelCol, value) + override def fit(dataset: DataFrame): RFormulaModel = { + require(parsedFormula.isDefined, "Must call setFormula() first.") + val resolvedFormula = parsedFormula.get.resolve(dataset.schema) + // StringType terms and terms representing interactions need to be encoded before assembly. + // TODO(ekl) add support for feature interactions + val encoderStages = ArrayBuffer[PipelineStage]() + val tempColumns = ArrayBuffer[String]() + val takenNames = mutable.Set(dataset.columns: _*) + val encodedTerms = resolvedFormula.terms.map { term => + dataset.schema(term) match { + case column if column.dataType == StringType => + val indexCol = term + "_idx_" + uid + val encodedCol = { + var tmp = term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(indexCol) + takenNames.add(encodedCol) + encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) + encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) + tempColumns += indexCol + tempColumns += encodedCol + encodedCol + case _ => + term + } + } + encoderStages += new VectorAssembler(uid) + .setInputCols(encodedTerms.toArray) + .setOutputCol($(featuresCol)) + encoderStages += new ColumnPruner(tempColumns.toSet) + val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) + copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) + } + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + if (hasLabelCol(schema)) { + StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) + } else { + StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+ + StructField($(labelCol), DoubleType, true)) + } + } + + override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + + override def toString: String = s"RFormula(${get(formula)})" +} + +/** + * :: Experimental :: + * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. + * @param resolvedFormula the fitted R formula. + * @param pipelineModel the fitted feature model, including factor to index mappings. + */ +@Experimental +class RFormulaModel private[feature]( + override val uid: String, + resolvedFormula: ResolvedRFormula, + pipelineModel: PipelineModel) + extends Model[RFormulaModel] with RFormulaBase { + + override def transform(dataset: DataFrame): DataFrame = { + checkCanTransform(dataset.schema) + transformLabel(pipelineModel.transform(dataset)) + } override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) - val withFeatures = transformFeatures.transformSchema(schema) + val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else if (schema.exists(_.name == parsedFormula.get.label)) { - val nullable = schema(parsedFormula.get.label).dataType match { + } else if (schema.exists(_.name == resolvedFormula.label)) { + val nullable = schema(resolvedFormula.label).dataType match { case _: NumericType | BooleanType => false case _ => true } @@ -86,24 +174,19 @@ class RFormula(override val uid: String) } } - override def transform(dataset: DataFrame): DataFrame = { - checkCanTransform(dataset.schema) - transformLabel(transformFeatures.transform(dataset)) - } - - override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + override def copy(extra: ParamMap): RFormulaModel = copyValues( + new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormula(${get(formula)})" + override def toString: String = s"RFormulaModel(${resolvedFormula})" private def transformLabel(dataset: DataFrame): DataFrame = { - val labelName = parsedFormula.get.label + val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { dataset } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) - // TODO(ekl) add support for string-type labels case other => throw new IllegalArgumentException("Unsupported type for label: " + other) } @@ -114,46 +197,30 @@ class RFormula(override val uid: String) } } - private def transformFeatures: Transformer = { - // TODO(ekl) add support for non-numeric features and feature interactions - new VectorAssembler(uid) - .setInputCols(parsedFormula.get.terms.toArray) - .setOutputCol($(featuresCol)) - } - private def checkCanTransform(schema: StructType) { - require(parsedFormula.isDefined, "Must call setFormula() first.") val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, "Label column already exists and is not of type DoubleType.") } - - private def hasLabelCol(schema: StructType): Boolean = { - schema.map(_.name).contains($(labelCol)) - } } /** - * Represents a parsed R formula. + * Utility transformer for removing temporary columns from a DataFrame. + * TODO(ekl) make this a public transformer */ -private[ml] case class ParsedRFormula(label: String, terms: Seq[String]) +private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { + override val uid = Identifiable.randomUID("columnPruner") -/** - * Limited implementation of R formula parsing. Currently supports: '~', '+'. - */ -private[ml] object RFormulaParser extends RegexParsers { - def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r - - def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } - - def formula: Parser[ParsedRFormula] = - (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + override def transform(dataset: DataFrame): DataFrame = { + val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) + dataset.select(columnsToKeep.map(dataset.col) : _*) + } - def parse(value: String): ParsedRFormula = parseAll(formula, value) match { - case Success(result, _) => result - case failure: NoSuccess => throw new IllegalArgumentException( - "Could not parse formula: " + value) + override def transformSchema(schema: StructType): StructType = { + StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) } + + override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala new file mode 100644 index 0000000000000..1ca3b92a7d92a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types._ + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { + /** + * Resolves formula terms into column names. A schema is necessary for inferring the meaning + * of the special '.' term. Duplicate terms will be removed during resolution. + */ + def resolve(schema: StructType): ResolvedRFormula = { + var includedTerms = Seq[String]() + terms.foreach { + case Dot => + includedTerms ++= simpleTypes(schema).filter(_ != label.value) + case ColumnRef(value) => + includedTerms :+= value + case Deletion(term: Term) => + term match { + case ColumnRef(value) => + includedTerms = includedTerms.filter(_ != value) + case Dot => + // e.g. "- .", which removes all first-order terms + val fromSchema = simpleTypes(schema) + includedTerms = includedTerms.filter(fromSchema.contains(_)) + case _: Deletion => + assert(false, "Deletion terms cannot be nested") + case _: Intercept => + } + case _: Intercept => + } + ResolvedRFormula(label.value, includedTerms.distinct) + } + + /** Whether this formula specifies fitting with an intercept term. */ + def hasIntercept: Boolean = { + var intercept = true + terms.foreach { + case Intercept(enabled) => + intercept = enabled + case Deletion(Intercept(enabled)) => + intercept = !enabled + case _ => + } + intercept + } + + // the dot operator excludes complex column types + private def simpleTypes(schema: StructType): Seq[String] = { + schema.fields.filter(_.dataType match { + case _: NumericType | StringType | BooleanType | _: VectorUDT => true + case _ => false + }).map(_.name) + } +} + +/** + * Represents a fully evaluated and simplified R formula. + */ +private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) + +/** + * R formula terms. See the R formula docs here for more information: + * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +private[ml] sealed trait Term + +/* R formula reference to all available columns, e.g. "." in a formula */ +private[ml] case object Dot extends Term + +/* R formula reference to a column, e.g. "+ Species" in a formula */ +private[ml] case class ColumnRef(value: String) extends Term + +/* R formula intercept toggle, e.g. "+ 0" in a formula */ +private[ml] case class Intercept(enabled: Boolean) extends Term + +/* R formula deletion of a variable, e.g. "- Species" in a formula */ +private[ml] case class Deletion(term: Term) extends Term + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def intercept: Parser[Intercept] = + "([01])".r ^^ { case a => Intercept(a == "1") } + + def columnRef: Parser[ColumnRef] = + "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } + + def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } + + def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { + case op ~ list => list.foldLeft(List(op)) { + case (left, "+" ~ right) => left ++ Seq(right) + case (left, "-" ~ right) => left ++ Seq(Deletion(right)) + } + } + + def formula: Parser[ParsedRFormula] = + (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0b3af4747e693..248288ca73e99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -50,7 +50,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S /** * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split - * the text (default) or repeatedly matching the regex (if `gaps` is true). + * the text (default) or repeatedly matching the regex (if `gaps` is false). * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 824efa5ed4b28..954aa17e26a02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -476,11 +476,14 @@ trait Params extends Identifiable with Serializable { /** * Sets default values for a list of params. * + * Note: Java developers should use the single-parameter [[setDefault()]]. + * Annotating this with varargs can cause compilation failures due to a Scala compiler bug. + * See SPARK-9268. + * * @param paramPairs a list of param pairs that specify params and their default values to set * respectively. Make sure that the params are initialized before this method * gets called. */ - @varargs protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => setDefault(p.param.asInstanceOf[Param[Any]], p.value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 1ee080641e3e3..f5a022c31ed90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame @@ -32,10 +33,38 @@ private[r] object SparkRWrappers { alpha: Double): PipelineModel = { val formula = new RFormula().setFormula(value) val estimator = family match { - case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha) - case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha) + case "gaussian" => new LinearRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) + case "binomial" => new LogisticRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) } val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) } + + def getModelWeights(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => + Array(m.intercept) ++ m.weights.toArray + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No weights available for LogisticRegressionModel") // SPARK-9492 + } + } + + def getModelFeatures(model: PipelineModel): Array[String] = { + model.stages.last match { + case m: LinearRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No features names available for LogisticRegressionModel") // SPARK-9492 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala new file mode 100644 index 0000000000000..4ece8cf8cf0b6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DoubleType, DataType} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.storage.StorageLevel + +/** + * Params for isotonic regression. + */ +private[regression] trait IsotonicRegressionParams extends PredictorParams { + + /** + * Param for weight column name. + * TODO: Move weightCol to sharedParams. + * + * @group param + */ + final val weightCol: Param[String] = + new Param[String](this, "weightCol", "weight column name") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) + + /** + * Param for isotonic parameter. + * Isotonic (increasing) or antitonic (decreasing) sequence. + * @group param + */ + final val isotonic: BooleanParam = + new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence") + + /** @group getParam */ + final def getIsotonicParam: Boolean = $(isotonic) +} + +/** + * :: Experimental :: + * Isotonic regression. + * + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +@Experimental +class IsotonicRegression(override val uid: String) + extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] + with IsotonicRegressionParams { + + def this() = this(Identifiable.randomUID("isoReg")) + + /** + * Set the isotonic parameter. + * Default is true. + * @group setParam + */ + def setIsotonicParam(value: Boolean): this.type = set(isotonic, value) + setDefault(isotonic -> true) + + /** + * Set weight column param. + * Default is weight. + * @group setParam + */ + def setWeightParam(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "weight") + + override private[ml] def featuresDataType: DataType = DoubleType + + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + + private[this] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + + dataset.select($(labelCol), $(featuresCol), $(weightCol)) + .map { case Row(label: Double, features: Double, weights: Double) => + (label, features, weights) + } + } + + override protected def train(dataset: DataFrame): IsotonicRegressionModel = { + SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType) + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val instances = extractWeightedLabeledPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) + val parentModel = isotonicRegression.run(instances) + + new IsotonicRegressionModel(uid, parentModel) + } +} + +/** + * :: Experimental :: + * Model fitted by IsotonicRegression. + * Predicts using a piecewise linear function. + * + * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. + * + * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +class IsotonicRegressionModel private[ml] ( + override val uid: String, + private[ml] val parentModel: MLlibIsotonicRegressionModel) + extends RegressionModel[Double, IsotonicRegressionModel] + with IsotonicRegressionParams { + + override def featuresDataType: DataType = DoubleType + + override protected def predict(features: Double): Double = { + parentModel.predict(features) + } + + override def copy(extra: ParamMap): IsotonicRegressionModel = { + copyValues(new IsotonicRegressionModel(uid, parentModel), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 89718e0f3e15a..3b85ba001b128 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructField import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -146,9 +147,10 @@ class LinearRegression(override val uid: String) val model = new LinearRegressionModel(uid, weights, intercept) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } @@ -221,9 +223,10 @@ class LinearRegression(override val uid: String) val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + val featuresCol: String, val objectiveHistory: Array[Double]) extends LinearRegressionSummary(predictions, predictionCol, labelCol) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala new file mode 100644 index 0000000000000..0ec88ef77d695 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{List => JList} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} +import org.apache.spark.mllib.clustering.GaussianMixtureModel + +/** + * Wrapper around GaussianMixtureModel to provide helper methods in Python + */ +private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { + val weights: Vector = Vectors.dense(model.weights) + val k: Int = weights.size + + /** + * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian + */ + val gaussians: JList[Object] = { + val modelGaussians = model.gaussians + var i = 0 + var mu = ArrayBuffer.empty[Vector] + var sigma = ArrayBuffer.empty[Matrix] + while (i < k) { + mu += modelGaussians(i).mu + sigma += modelGaussians(i).sigma + i += 1 + } + List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fda8d5a0b048f..6f080d32bbf4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable { seed: java.lang.Long, initialModelWeights: java.util.ArrayList[Double], initialModelMu: java.util.ArrayList[Vector], - initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = { + initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = { val gmmAlg = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) @@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) gmmAlg.setSeed(seed) try { - val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) - var wt = ArrayBuffer.empty[Double] - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - for (i <- 0 until model.k) { - wt += model.weights(i) - mu += model.gaussians(i).mu - sigma += model.gaussians(i).sigma - } - List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))) } finally { data.rdd.unpersist(blocking = false) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 920b57756b625..6cfad3fbbdb87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV} - +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -28,14 +27,13 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector} -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.BoundedPriorityQueue - /** * :: Experimental :: * @@ -53,6 +51,31 @@ abstract class LDAModel private[clustering] extends Saveable { /** Vocabulary size (number of terms or terms in the vocabulary) */ def vocabSize: Int + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a Dirichlet distribution. + */ + def docConcentration: Vector + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + */ + def topicConcentration: Double + + /** + * Shape parameter for random initialization of variational parameter gamma. + * Used for variational inference for perplexity and other test-time computations. + */ + protected def gammaShape: Double + /** * Inferred topics, where each topic is represented by a distribution over terms. * This is a matrix of size vocabSize x k, where each column is a topic. @@ -163,12 +186,14 @@ abstract class LDAModel private[clustering] extends Saveable { * This model stores only the inferred topics. * It may be used for computing topics for new documents, but it may give less accurate answers * than the [[DistributedLDAModel]]. - * * @param topics Inferred topics (vocabSize x k matrix). */ @Experimental class LocalLDAModel private[clustering] ( - private val topics: Matrix) extends LDAModel with Serializable { + val topics: Matrix, + override val docConcentration: Vector, + override val topicConcentration: Double, + override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable { override def k: Int = topics.numCols @@ -189,16 +214,122 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" override def save(sc: SparkContext, path: String): Unit = { - LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix) + LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, + gammaShape) } // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? - // TODO: - // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + /** + * Calculate the log variational bound on perplexity. See Equation (16) in original Online + * LDA paper. + * @param documents test corpus to use for calculating perplexity + * @return the log perplexity per word + */ + def logPerplexity(documents: RDD[(Long, Vector)]): Double = { + val corpusWords = documents + .map { case (_, termCounts) => termCounts.toArray.sum } + .sum() + val batchVariationalBound = bound(documents, docConcentration, + topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) + val perWordBound = batchVariationalBound / corpusWords + + perWordBound + } + + /** + * Estimate the variational likelihood bound of from `documents`: + * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] + * This bound is derived by decomposing the LDA model to: + * log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p) + * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper. + * @param documents a subset of the test corpus + * @param alpha document-topic Dirichlet prior parameters + * @param eta topic-word Dirichlet prior parameters + * @param lambda parameters for variational q(beta | lambda) topic-word distributions + * @param gammaShape shape parameter for random initialization of variational q(theta | gamma) + * topic mixture distributions + * @param k number of topics + * @param vocabSize number of unique terms in the entire test corpus + */ + private def bound( + documents: RDD[(Long, Vector)], + alpha: Vector, + eta: Double, + lambda: BDM[Double], + gammaShape: Double, + k: Int, + vocabSize: Long): Double = { + val brzAlpha = alpha.toBreeze.toDenseVector + // transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t + + var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) => + var docScore = 0.0D + val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, exp(Elogbeta), brzAlpha, gammaShape, k) + val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) + + // E[log p(doc | theta, beta)] + termCounts.foreachActive { case (idx, count) => + docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t) + } + // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector + docScore += sum((brzAlpha - gammad) :* Elogthetad) + docScore += sum(lgamma(gammad) - lgamma(brzAlpha)) + docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) + + docScore + }.sum() + + // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar + score += sum((eta - lambda) :* Elogbeta) + score += sum(lgamma(lambda) - lgamma(eta)) + + val sumEta = eta * vocabSize + score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) + + score + } + + /** + * Predicts the topic mixture distribution for each document (often called "theta" in the + * literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * @param documents documents to predict topic mixture distributions for + * @return An RDD of (document ID, topic mixture distribution for document) + */ + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { + // Double transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + + documents.map { case (id: Long, termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + (id, Vectors.zeros(k)) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbeta, + docConcentrationBrz, + gammaShape, + k) + (id, Vectors.dense(normalize(gamma, 1.0).toArray)) + } + } + } } + @Experimental object LocalLDAModel extends Loader[LocalLDAModel] { @@ -212,14 +343,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] { // as a Row in data. case class Data(topic: Vector, index: Int) - def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { + def save( + sc: SparkContext, + path: String, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix @@ -229,7 +369,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] { sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) } - def load(sc: SparkContext, path: String): LocalLDAModel = { + def load( + sc: SparkContext, + path: String, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) @@ -243,7 +388,10 @@ object LocalLDAModel extends Loader[LocalLDAModel] { topics.foreach { case Row(vec: Vector, ind: Int) => brzTopics(::, ind) := vec.toBreeze } - new LocalLDAModel(Matrices.fromBreeze(brzTopics)) + val topicsMat = Matrices.fromBreeze(brzTopics) + + // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 + new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) } } @@ -252,15 +400,19 @@ object LocalLDAModel extends Loader[LocalLDAModel] { implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] val expectedVocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val gammaShape = (metadata \ "gammaShape").extract[Double] val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => - SaveLoadV1_0.load(sc, path) + SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape) case _ => throw new Exception( s"LocalLDAModel.load did not recognize model with (className, format version):" + - s"($loadedClassName, $loadedVersion). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") } val topicsMatrix = model.topicsMatrix @@ -268,7 +420,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics") require(expectedVocabSize == topicsMatrix.numRows, s"LocalLDAModel requires $expectedVocabSize terms for each topic, " + - s"but got ${topicsMatrix.numRows}") + s"but got ${topicsMatrix.numRows}") model } } @@ -282,28 +434,25 @@ object LocalLDAModel extends Loader[LocalLDAModel] { * than the [[LocalLDAModel]]. */ @Experimental -class DistributedLDAModel private ( - private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], - private val globalTopicTotals: LDA.TopicCounts, +class DistributedLDAModel private[clustering] ( + private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private[clustering] val globalTopicTotals: LDA.TopicCounts, val k: Int, val vocabSize: Int, - private val docConcentration: Double, - private val topicConcentration: Double, + override val docConcentration: Vector, + override val topicConcentration: Double, + override protected[clustering] val gammaShape: Double, private[spark] val iterationTimes: Array[Double]) extends LDAModel { import LDA._ - private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = { - this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, - state.topicConcentration, iterationTimes) - } - /** * Convert model to a local model. * The local model stores the inferred topics but not the topic distributions for training * documents. */ - def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix) + def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) /** * Inferred topics, where each topic is represented by a distribution over terms. @@ -375,8 +524,9 @@ class DistributedLDAModel private ( * hyperparameters. */ lazy val logLikelihood: Double = { - val eta = topicConcentration - val alpha = docConcentration + // TODO: generalize this for asymmetric (non-scalar) alpha + val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object + val eta = this.topicConcentration assert(eta > 1.0) assert(alpha > 1.0) val N_k = globalTopicTotals @@ -400,8 +550,9 @@ class DistributedLDAModel private ( * log P(topics, topic distributions for docs | alpha, eta) */ lazy val logPrior: Double = { - val eta = topicConcentration - val alpha = docConcentration + // TODO: generalize this for asymmetric (non-scalar) alpha + val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object + val eta = this.topicConcentration // Term vertices: Compute phi_{wk}. Use to compute prior log probability. // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. val N_k = globalTopicTotals @@ -412,12 +563,12 @@ class DistributedLDAModel private ( val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k - (eta - 1.0) * brzSum(phi_wk.map(math.log)) + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) - (alpha - 1.0) * brzSum(theta_kj.map(math.log)) + (alpha - 1.0) * sum(theta_kj.map(math.log)) } } graph.vertices.aggregate(0.0)(seqOp, _ + _) @@ -448,7 +599,7 @@ class DistributedLDAModel private ( override def save(sc: SparkContext, path: String): Unit = { DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, - iterationTimes) + iterationTimes, gammaShape) } } @@ -460,7 +611,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val thisFormatVersion = "1.0" - val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" + val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel" // Store globalTopicTotals as a Vector. case class Data(globalTopicTotals: Vector) @@ -478,17 +629,20 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { globalTopicTotals: LDA.TopicCounts, k: Int, vocabSize: Int, - docConcentration: Double, + docConcentration: Vector, topicConcentration: Double, - iterationTimes: Array[Double]): Unit = { + iterationTimes: Array[Double], + gammaShape: Double): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~ - ("topicConcentration" -> topicConcentration) ~ - ("iterationTimes" -> iterationTimes.toSeq))) + (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("iterationTimes" -> iterationTimes.toSeq) ~ + ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString @@ -510,9 +664,10 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { sc: SparkContext, path: String, vocabSize: Int, - docConcentration: Double, + docConcentration: Vector, topicConcentration: Double, - iterationTimes: Array[Double]): DistributedLDAModel = { + iterationTimes: Array[Double], + gammaShape: Double): DistributedLDAModel = { val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString @@ -536,7 +691,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, - docConcentration, topicConcentration, iterationTimes) + docConcentration, topicConcentration, gammaShape, iterationTimes) } } @@ -546,32 +701,35 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] val vocabSize = (metadata \ "vocabSize").extract[Int] - val docConcentration = (metadata \ "docConcentration").extract[Double] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) val topicConcentration = (metadata \ "topicConcentration").extract[Double] val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] - val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + val gammaShape = (metadata \ "gammaShape").extract[Double] + val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => { - DistributedLDAModel.SaveLoadV1_0.load( - sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray) + DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration, + topicConcentration, iterationTimes.toArray, gammaShape) } case _ => throw new Exception( s"DistributedLDAModel.load did not recognize model with (className, format version):" + - s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") } require(model.vocabSize == vocabSize, s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize") require(model.docConcentration == docConcentration, s"DistributedLDAModel requires $docConcentration docConcentration, " + - s"got ${model.docConcentration} docConcentration") + s"got ${model.docConcentration} docConcentration") require(model.topicConcentration == topicConcentration, s"DistributedLDAModel requires $topicConcentration docConcentration, " + - s"got ${model.topicConcentration} docConcentration") + s"got ${model.topicConcentration} docConcentration") require(expectedK == model.k, s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics") model } } + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index f4170a3d98dd8..d6f8b29a43dfd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import java.util.Random import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} -import breeze.numerics.{abs, digamma, exp} +import breeze.numerics.{abs, exp} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.DeveloperApi @@ -142,8 +142,9 @@ final class EMLDAOptimizer extends LDAOptimizer { this.k = k this.vocabSize = docs.take(1).head._2.size this.checkpointInterval = lda.getCheckpointInterval - this.graphCheckpointer = new - PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + checkpointInterval, graph.vertices.sparkContext) + this.graphCheckpointer.update(this.graph) this.globalTopicTotals = computeGlobalTopicTotals() this } @@ -188,7 +189,7 @@ final class EMLDAOptimizer extends LDAOptimizer { // Update the vertex descriptors with the new counts. val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) graph = newGraph - graphCheckpointer.updateGraph(newGraph) + graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() this } @@ -208,7 +209,11 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(this, iterationTimes) + // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal + // conversion + new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, + Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, + 100, iterationTimes) } } @@ -385,71 +390,52 @@ final class OnlineLDAOptimizer extends LDAOptimizer { iteration += 1 val k = this.k val vocabSize = this.vocabSize - val Elogbeta = dirichletExpectation(lambda).t - val expElogbeta = exp(Elogbeta) + val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t val alpha = this.alpha.toBreeze val gammaShape = this.gammaShape - val stats: RDD[BDM[Double]] = batch.mapPartitions { docs => - val stat = BDM.zeros[Double](k, vocabSize) - docs.foreach { doc => - val termCounts = doc._2 - val (ids: List[Int], cts: Array[Double]) = termCounts match { - case v: DenseVector => ((0 until v.size).toList, v.values) - case v: SparseVector => (v.indices.toList, v.values) - case v => throw new IllegalArgumentException("Online LDA does not support vector type " - + v.getClass) - } - if (!ids.isEmpty) { - - // Initialize the variational distribution q(theta|gamma) for the mini-batch - val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K - val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K - val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K - - val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids - var meanchange = 1D - val ctsVector = new BDV[Double](cts) // ids - - // Iterate between gamma and phi until convergence - while (meanchange > 1e-3) { - val lastgamma = gammad.copy - // K K * ids ids - gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha - expElogthetad := exp(digamma(gammad) - digamma(sum(gammad))) - phinorm := expElogbetad * expElogthetad :+ 1e-100 - meanchange = sum(abs(gammad - lastgamma)) / k - } + val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix + val stat = BDM.zeros[Double](k, vocabSize) + var gammaPart = List[BDV[Double]]() + nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) => + val ids: List[Int] = termCounts match { + case v: DenseVector => (0 until v.size).toList + case v: SparseVector => v.indices.toList } + val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbeta, alpha, gammaShape, k) + stat(::, ids) := stat(::, ids).toDenseMatrix + sstats + gammaPart = gammad :: gammaPart } - Iterator(stat) + Iterator((stat, gammaPart)) } - - val statsSum: BDM[Double] = stats.reduce(_ += _) + val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( + stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count - update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt) + updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) this } - override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose) - } - /** * Update lambda based on the batch submitted. batchSize can be different for each iteration. */ - private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = { + private def updateLambda(stat: BDM[Double], batchSize: Int): Unit = { // weight of the mini-batch. - val weight = math.pow(getTau0 + iter, -getKappa) + val weight = rho() // Update lambda based on documents. - lambda = lambda * (1 - weight) + - (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight + lambda := (1 - weight) * lambda + + weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) + } + + /** Calculates learning rate rho, which decays as a function of [[iteration]] */ + private def rho(): Double = { + math.pow(getTau0 + this.iteration, -getKappa) } /** @@ -463,15 +449,57 @@ final class OnlineLDAOptimizer extends LDAOptimizer { new BDM[Double](col, row, temp).t } + override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + } + +} + +/** + * Serializable companion object containing helper methods and shared code for + * [[OnlineLDAOptimizer]] and [[LocalLDAModel]]. + */ +private[clustering] object OnlineLDAOptimizer { /** - * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation - * uses digamma which is accurate but expensive. + * Uses variational inference to infer the topic distribution `gammad` given the term counts + * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will + * throw a BLAS error. + * + * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001) + * avoids explicit computation of variational parameter `phi`. + * @see [[http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.7566]] */ - private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { - val rowSum = sum(alpha(breeze.linalg.*, ::)) - val digAlpha = digamma(alpha) - val digRowSum = digamma(rowSum) - val result = digAlpha(::, breeze.linalg.*) - digRowSum - result + private[clustering] def variationalTopicInference( + termCounts: Vector, + expElogbeta: BDM[Double], + alpha: breeze.linalg.Vector[Double], + gammaShape: Double, + k: Int): (BDV[Double], BDM[Double]) = { + val (ids: List[Int], cts: Array[Double]) = termCounts match { + case v: DenseVector => ((0 until v.size).toList, v.values) + case v: SparseVector => (v.indices.toList, v.values) + } + // Initialize the variational distribution q(theta|gamma) for the mini-batch + val gammad: BDV[Double] = + new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K + val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K + + val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + var meanchange = 1D + val ctsVector = new BDV[Double](cts) // ids + + // Iterate between gamma and phi until convergence + while (meanchange > 1e-3) { + val lastgamma = gammad.copy + // K K * ids ids + gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha + expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) + phinorm := expElogbetad * expElogthetad :+ 1e-100 + meanchange = sum(abs(gammad - lastgamma)) / k + } + + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix + (gammad, sstatsd) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala new file mode 100644 index 0000000000000..f7e5ce1665fe6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum} +import breeze.numerics._ + +/** + * Utility methods for LDA. + */ +object LDAUtils { + /** + * Log Sum Exp with overflow protection using the identity: + * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} + */ + private[clustering] def logSumExp(x: BDV[Double]): Double = { + val a = max(x) + a + log(sum(exp(x :- a))) + } + + /** + * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation + * uses [[breeze.numerics.digamma]] which is accurate but expensive. + */ + private[clustering] def dirichletExpectation(alpha: BDV[Double]): BDV[Double] = { + digamma(alpha) - digamma(sum(alpha)) + } + + /** + * Computes [[dirichletExpectation()]] row-wise, assuming each row of alpha are + * Dirichlet parameters. + */ + private[clustering] def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { + val rowSum = sum(alpha(breeze.linalg.*, ::)) + val digAlpha = digamma(alpha) + val digRowSum = digamma(rowSum) + val result = digAlpha(::, breeze.linalg.*) - digRowSum + result + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index e7a243f854e33..407e43a024a2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -153,6 +153,27 @@ class PowerIterationClustering private[clustering] ( this } + /** + * Run the PIC algorithm on Graph. + * + * @param graph an affinity matrix represented as graph, which is the matrix A in the PIC paper. + * The similarity s,,ij,, represented as the edge between vertices (i, j) must + * be nonnegative. This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For + * any (i, j) with nonzero similarity, there should be either (i, j, s,,ij,,) + * or (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we + * assume s,,ij,, = 0.0. + * + * @return a [[PowerIterationClusteringModel]] that contains the clustering result + */ + def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = { + val w = normalize(graph) + val w0 = initMode match { + case "random" => randomInit(w) + case "degree" => initDegreeVector(w) + } + pic(w0) + } + /** * Run the PIC algorithm. * @@ -212,6 +233,31 @@ object PowerIterationClustering extends Logging { @Experimental case class Assignment(id: Long, cluster: Int) + /** + * Normalizes the affinity graph (A) and returns the normalized affinity matrix (W). + */ + private[clustering] + def normalize(graph: Graph[Double, Double]): Graph[Double, Double] = { + val vD = graph.aggregateMessages[Double]( + sendMsg = ctx => { + val i = ctx.srcId + val j = ctx.dstId + val s = ctx.attr + if (s < 0.0) { + throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + } + if (s > 0.0) { + ctx.sendToSrc(s) + } + }, + mergeMsg = _ + _, + TripletFields.EdgeOnly) + GraphImpl.fromExistingRDDs(vD, graph.edges) + .mapTriplets( + e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), + TripletFields.Src) + } + /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f087d06d2a46a..cbbd2b0c8d060 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -403,17 +403,8 @@ class Word2Vec extends Serializable with Logging { } newSentences.unpersist() - val word2VecMap = mutable.HashMap.empty[String, Array[Float]] - var i = 0 - while (i < vocabSize) { - val word = bcVocab.value(i).word - val vector = new Array[Float](vectorSize) - Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) - word2VecMap += word -> vector - i += 1 - } - - new Word2VecModel(word2VecMap.toMap) + val wordArray = vocab.map(_.word) + new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) } /** @@ -429,38 +420,42 @@ class Word2Vec extends Serializable with Logging { /** * :: Experimental :: * Word2Vec model + * @param wordIndex maps each word to an index, which can retrieve the corresponding + * vector from wordVectors + * @param wordVectors array of length numWords * vectorSize, vector corresponding + * to the word mapped with index i can be retrieved by the slice + * (i * vectorSize, i * vectorSize + vectorSize) */ @Experimental -class Word2VecModel private[spark] ( - model: Map[String, Array[Float]]) extends Serializable with Saveable { - - // wordList: Ordered list of words obtained from model. - private val wordList: Array[String] = model.keys.toArray - - // wordIndex: Maps each word to an index, which can retrieve the corresponding - // vector from wordVectors (see below). - private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap +class Word2VecModel private[mllib] ( + private val wordIndex: Map[String, Int], + private val wordVectors: Array[Float]) extends Serializable with Saveable { - // vectorSize: Dimension of each word's vector. - private val vectorSize = model.head._2.size private val numWords = wordIndex.size + // vectorSize: Dimension of each word's vector. + private val vectorSize = wordVectors.length / numWords + + // wordList: Ordered list of words obtained from wordIndex. + private val wordList: Array[String] = { + val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip + wl.toArray + } - // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word - // mapped with index i can be retrieved by the slice - // (ind * vectorSize, ind * vectorSize + vectorSize) // wordVecNorms: Array of length numWords, each value being the Euclidean norm // of the wordVector. - private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = { - val wordVectors = new Array[Float](vectorSize * numWords) + private val wordVecNorms: Array[Double] = { val wordVecNorms = new Array[Double](numWords) var i = 0 while (i < numWords) { - val vec = model.get(wordList(i)).get - Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize) + val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) i += 1 } - (wordVectors, wordVecNorms) + wordVecNorms + } + + def this(model: Map[String, Array[Float]]) = { + this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { @@ -484,8 +479,9 @@ class Word2VecModel private[spark] ( * @return vector representation of word */ def transform(word: String): Vector = { - model.get(word) match { - case Some(vec) => + wordIndex.get(word) match { + case Some(ind) => + val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) Vectors.dense(vec.map(_.toDouble)) case None => throw new IllegalStateException(s"$word not in vocabulary") @@ -511,7 +507,7 @@ class Word2VecModel private[spark] ( */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - + // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 @@ -521,13 +517,13 @@ class Word2VecModel private[spark] ( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) // Need not divide with the norm of the given vector since it is constant. - val updatedCosines = new Array[Double](numWords) + val cosVec = cosineVec.map(_.toDouble) var ind = 0 while (ind < numWords) { - updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind) + cosVec(ind) /= wordVecNorms(ind) ind += 1 } - wordList.zip(updatedCosines) + wordList.zip(cosVec) .toSeq .sortBy(- _._2) .take(num + 1) @@ -548,6 +544,23 @@ class Word2VecModel private[spark] ( @Experimental object Word2VecModel extends Loader[Word2VecModel] { + private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { + model.keys.zipWithIndex.toMap + } + + private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = { + require(model.nonEmpty, "Word2VecMap should be non-empty") + val (vectorSize, numWords) = (model.head._2.size, model.size) + val wordList = model.keys.toArray + val wordVectors = new Array[Float](vectorSize * numWords) + var i = 0 + while (i < numWords) { + Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize) + i += 1 + } + wordVectors + } + private object SaveLoadV1_0 { val formatVersionV1_0 = "1.0" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 7ead6327486cc..0ea792081086d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefixes: List[Int], - database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { database .map(getSuffix(prefix, _)) .filter(_.nonEmpty) @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): mutable.Map[Int, Long] = { + database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { // TODO: use PrimitiveKeyOpenHashMap val counts = mutable.Map[Int, Long]().withDefaultValue(0L) database.foreach { sequence => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 6f52db7b073ae..e6752332cdeeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -43,28 +45,45 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { + /** + * The maximum number of items allowed in a projected database before local processing. If a + * projected database exceeds this size, another iteration of distributed PrefixSpan is run. + */ + // TODO: make configurable with a better default value, 10000 may be too small + private val maxLocalProjDBSize: Long = 10000 + /** * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`}. */ def this() = this(0.1, 10) + /** + * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered + * frequent). + */ + def getMinSupport: Double = this.minSupport + /** * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, - "The minimum support value must be between 0 and 1, including 0 and 1.") + require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") this.minSupport = minSupport this } + /** + * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. + */ + def getMaxPatternLength: Double = this.maxPatternLength + /** * Sets maximal pattern length (default: `10`). */ def setMaxPatternLength(maxPatternLength: Int): this.type = { - require(maxPatternLength >= 1, - "The maximum pattern length value must be greater than 0.") + // TODO: support unbounded pattern length when maxPatternLength = 0 + require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") this.maxPatternLength = maxPatternLength this } @@ -78,81 +97,153 @@ class PrefixSpan private ( * the value of pair is the pattern's count. */ def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + val sc = sequences.sparkContext + if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getMinCount(sequences) - val lengthOnePatternsAndCounts = - getFreqItemAndCounts(minCount, sequences).collect() - val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( - lengthOnePatternsAndCounts.map(_._1), sequences) - val groupedProjectedDatabase = prefixAndProjectedDatabase - .map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) - val lengthOnePatternsAndCountsRdd = - sequences.sparkContext.parallelize( - lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) - val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns - allPatterns + + // Convert min support to a min number of transactions for this dataset + val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + + // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold + val freqItemCounts = sequences + .flatMap(seq => seq.distinct.map(item => (item, 1L))) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + .collect() + + // Pairs of (length 1 prefix, suffix consisting of frequent items) + val itemSuffixPairs = { + val freqItems = freqItemCounts.map(_._1).toSet + sequences.flatMap { seq => + val filteredSeq = seq.filter(freqItems.contains(_)) + freqItems.flatMap { item => + val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) + candidateSuffix match { + case suffix if !suffix.isEmpty => Some((List(item), suffix)) + case _ => None + } + } + } + } + + // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. + // frequent length-one prefixes) + var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) + + // Remaining work to be locally and distributively processed respectfully + var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) + + // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have + // projected database sizes <= `maxLocalProjDBSize`) + while (pairsForDistributed.count() != 0) { + val (nextPatternAndCounts, nextPrefixSuffixPairs) = + extendPrefixes(minCount, pairsForDistributed) + pairsForDistributed.unpersist() + val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) + pairsForDistributed = largerPairsPart + pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) + pairsForLocal ++= smallerPairsPart + resultsAccumulator ++= nextPatternAndCounts.collect() + } + + // Process the small projected databases locally + val remainingResults = getPatternsInLocal( + minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) + + (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) + .map { case (pattern, count) => (pattern.toArray, count) } } + /** - * Get the minimum count (sequences count * minSupport). - * @param sequences input data set, contains a set of sequences, - * @return minimum count, + * Partitions the prefix-suffix pairs by projected database size. + * @param prefixSuffixPairs prefix (length n) and suffix pairs, + * @return prefix-suffix pairs partitioned by whether their projected database size is <= or + * greater than [[maxLocalProjDBSize]] */ - private def getMinCount(sequences: RDD[Array[Int]]): Long = { - if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { + val prefixToSuffixSize = prefixSuffixPairs + .aggregateByKey(0)( + seqOp = { case (count, suffix) => count + suffix.length }, + combOp = { _ + _ }) + val smallPrefixes = prefixToSuffixSize + .filter(_._2 <= maxLocalProjDBSize) + .keys + .collect() + .toSet + val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } + val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } + (small.collect(), large) } /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences original sequences data - * @return array of item and count pair + * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes + * and remaining work. + * @param minCount minimum count + * @param prefixSuffixPairs prefix (length N) and suffix pairs, + * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended + * prefix, corresponding suffix) pairs. */ - private def getFreqItemAndCounts( + private def extendPrefixes( minCount: Long, - sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { - sequences.flatMap(_.distinct.map((_, 1L))) + prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { + + // (length N prefix, item from suffix) pairs and their corresponding number of occurrences + // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` + val prefixItemPairAndCounts = prefixSuffixPairs + .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) .filter(_._2 >= minCount) - } - /** - * Get the frequent prefixes' projected database. - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPrefixAndProjectedDatabase( - frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { - val filteredSequences = sequences.map { p => - p.filter (frequentPrefixes.contains(_) ) - } - filteredSequences.flatMap { x => - frequentPrefixes.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (Array(y), sub) - }.filter(_._2.nonEmpty) - } + // Map from prefix to set of possible next items from suffix + val prefixToNextItems = prefixItemPairAndCounts + .keys + .groupByKey() + .mapValues(_.toSet) + .collect() + .toMap + + + // Frequent patterns with length N+1 and their corresponding counts + val extendedPrefixAndCounts = prefixItemPairAndCounts + .map { case ((prefix, item), count) => (item :: prefix, count) } + + // Remaining work, all prefixes will have length N+1 + val extendedPrefixAndSuffix = prefixSuffixPairs + .filter(x => prefixToNextItems.contains(x._1)) + .flatMap { case (prefix, suffix) => + val frequentNextItems = prefixToNextItems(prefix) + val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) + frequentNextItems.flatMap { item => + LocalPrefixSpan.getSuffix(item, filteredSuffix) match { + case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) + case _ => None + } + } + } + + (extendedPrefixAndCounts, extendedPrefixAndSuffix) } /** - * calculate the patterns in local. + * Calculate the patterns in local. * @param minCount the absolute minimum count - * @param data patterns and projected sequences data data + * @param data prefixes and projected sequences data data * @return patterns */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) - .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } + data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { + data.flatMap { + case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) + .map { case (pattern: List[Int], count: Long) => + (pattern.reverse, count) + } } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala new file mode 100644 index 0000000000000..72d3aabc9b1f4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.storage.StorageLevel + + +/** + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs + * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to + * the distributed data type (RDD, Graph, etc.). + * + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, + * as well as unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created, + * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later Datasets have been checkpointed. + * However, references to the older Datasets will still return isCheckpointed = true. + * + * @param checkpointInterval Datasets will be checkpointed at this interval + * @param sc SparkContext for the Datasets given to this checkpointer + * @tparam T Dataset type, such as RDD[Double] + */ +private[mllib] abstract class PeriodicCheckpointer[T]( + val checkpointInterval: Int, + val sc: SparkContext) extends Logging { + + /** FIFO queue of past checkpointed Datasets */ + private val checkpointQueue = mutable.Queue[T]() + + /** FIFO queue of past persisted Datasets */ + private val persistedQueue = mutable.Queue[T]() + + /** Number of times [[update()]] has been called */ + private var updateCount = 0 + + /** + * Update with a new Dataset. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the Dataset + * has been materialized. + * + * @param newData New Dataset created from previous Datasets in the lineage. + */ + def update(newData: T): Unit = { + persist(newData) + persistedQueue.enqueue(newData) + // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: + // Users should call [[update()]] when a new Dataset has been created, + // before the Dataset has been materialized. + while (persistedQueue.size > 3) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + checkpoint(newData) + checkpointQueue.enqueue(newData) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (isCheckpointed(checkpointQueue.head)) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** Checkpoint the Dataset */ + protected def checkpoint(data: T): Unit + + /** Return true iff the Dataset is checkpointed */ + protected def isCheckpointed(data: T): Boolean + + /** + * Persist the Dataset. + * Note: This should handle checking the current [[StorageLevel]] of the Dataset. + */ + protected def persist(data: T): Unit + + /** Unpersist the Dataset */ + protected def unpersist(data: T): Unit + + /** Get list of checkpoint files for this given Dataset */ + protected def getCheckpointFiles(data: T): Iterable[String] + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + getCheckpointFiles(old).foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 6e5dd119dd653..11a059536c50c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -17,11 +17,7 @@ package org.apache.spark.mllib.impl -import scala.collection.mutable - -import org.apache.hadoop.fs.{Path, FileSystem} - -import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel @@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as * unpersisting and removing checkpoint files. * - * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * Users should call update() when a new graph has been created, * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are * responsible for materializing the graph to ensure that persisting and checkpointing actually * occur. * - * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * When update() is called, this does the following: * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. * - Unpersist graphs from queue until there are at most 3 persisted graphs. * - If using checkpointing and the checkpoint interval has been reached, @@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel * Example usage: * {{{ * val (graph1, graph2, graph3, ...) = ... - * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * val cp = new PeriodicGraphCheckpointer(2, sc) * graph1.vertices.count(); graph1.edges.count() * // persisted: graph1 * cp.updateGraph(graph2) @@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param currentGraph Initial graph * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + * TODO: Move this out of MLlib? */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( - var currentGraph: Graph[VD, ED], - val checkpointInterval: Int) extends Logging { - - /** FIFO queue of past checkpointed RDDs */ - private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() - - /** FIFO queue of past persisted RDDs */ - private val persistedQueue = mutable.Queue[Graph[VD, ED]]() - - /** Number of times [[updateGraph()]] has been called */ - private var updateCount = 0 - - /** - * Spark Context for the Graphs given to this checkpointer. - * NOTE: This code assumes that only one SparkContext is used for the given graphs. - */ - private val sc = currentGraph.vertices.sparkContext + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { - updateGraph(currentGraph) + override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - /** - * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. - * Since this handles persistence and checkpointing, this should be called before the graph - * has been materialized. - * - * @param newGraph New graph created from previous graphs in the lineage. - */ - def updateGraph(newGraph: Graph[VD, ED]): Unit = { - if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { - newGraph.persist() - } - persistedQueue.enqueue(newGraph) - // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: - // Users should call [[updateGraph()]] when a new graph has been created, - // before the graph has been materialized. - while (persistedQueue.size > 3) { - val graphToUnpersist = persistedQueue.dequeue() - graphToUnpersist.unpersist(blocking = false) - } - updateCount += 1 + override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { - // Add new checkpoint before removing old checkpoints. - newGraph.checkpoint() - checkpointQueue.enqueue(newGraph) - // Remove checkpoints before the latest one. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // Delete the oldest checkpoint only if the next checkpoint exists. - if (checkpointQueue.get(1).get.isCheckpointed) { - removeCheckpointFile() - } else { - canDelete = false - } - } + override protected def persist(data: Graph[VD, ED]): Unit = { + if (data.vertices.getStorageLevel == StorageLevel.NONE) { + data.persist() } } - /** - * Call this at the end to delete any remaining checkpoint files. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.size > 0) { - removeCheckpointFile() - } - } + override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) - /** - * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. - * This prints a warning but does not fail if the files cannot be removed. - */ - private def removeCheckpointFile(): Unit = { - val old = checkpointQueue.dequeue() - // Since the old checkpoint is not deleted by Spark, we manually delete it. - val fs = FileSystem.get(sc.hadoopConfiguration) - old.getCheckpointFiles.foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { + data.getCheckpointFiles } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala new file mode 100644 index 0000000000000..f31ed2aa90a64 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing RDDs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new RDD has been created, + * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are + * responsible for materializing the RDD to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. + * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which RDDs should be + * checkpointed). + * - This class removes checkpoint files once later RDDs have been checkpointed. + * However, references to the older RDDs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (rdd1, rdd2, rdd3, ...) = ... + * val cp = new PeriodicRDDCheckpointer(2, sc) + * rdd1.count(); + * // persisted: rdd1 + * cp.update(rdd2) + * rdd2.count(); + * // persisted: rdd1, rdd2 + * // checkpointed: rdd2 + * cp.update(rdd3) + * rdd3.count(); + * // persisted: rdd1, rdd2, rdd3 + * // checkpointed: rdd2 + * cp.update(rdd4) + * rdd4.count(); + * // persisted: rdd2, rdd3, rdd4 + * // checkpointed: rdd4 + * cp.update(rdd5) + * rdd5.count(); + * // persisted: rdd3, rdd4, rdd5 + * // checkpointed: rdd4 + * }}} + * + * @param checkpointInterval RDDs will be checkpointed at this interval + * @tparam T RDD element type + * + * TODO: Move this out of MLlib? + */ +private[mllib] class PeriodicRDDCheckpointer[T]( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { + + override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() + + override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + + override protected def persist(data: RDD[T]): Unit = { + if (data.getStorageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + + override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + data.getCheckpointFile.map(x => x) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 55da0e094d132..88914fa875990 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, sm.colPtrs.toSeq) - row.update(4, sm.rowIndices.toSeq) - row.update(5, sm.values.toSeq) + row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) + row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) + row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, dm.values.toSeq) + row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, dm.isTransposed) } row @@ -174,17 +174,17 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def deserialize(datum: Any): Matrix = { datum match { case row: InternalRow => - require(row.length == 7, - s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") + require(row.numFields == 7, + s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7") val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getAs[Iterable[Double]](5).toArray + val values = row.getArray(5).toArray.map(_.asInstanceOf[Double]) val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = row.getAs[Iterable[Int]](3).toArray - val rowIndices = row.getAs[Iterable[Int]](4).toArray + val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int]) + val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int]) new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index 9669c364bad8f..b416d50a5631e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -25,3 +25,11 @@ import org.apache.spark.annotation.Experimental */ @Experimental case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) + +/** + * :: Experimental :: + * Represents QR factors. + */ +@Experimental +case class QRDecomposition[UType, VType](Q: UType, R: VType) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 9067b3ba9a7bb..89a1818db0d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, indices.toSeq) - row.update(3, values.toSeq) + row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, values.toSeq) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row } } @@ -203,17 +203,17 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def deserialize(datum: Any): Vector = { datum match { case row: InternalRow => - require(row.length == 4, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + require(row.numFields == 4, + s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4") val tpe = row.getByte(0) tpe match { case 0 => val size = row.getInt(1) - val indices = row.getAs[Iterable[Int]](2).toArray - val values = row.getAs[Iterable[Double]](3).toArray + val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int]) + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new SparseVector(size, indices, values) case 1 => - val values = row.getAs[Iterable[Double]](3).toArray + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new DenseVector(values) } } @@ -634,6 +634,8 @@ class SparseVector( require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") + require(indices.length <= size, s"You provided ${indices.length} indices and values, " + + s"which exceeds the specified vector size ${size}.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 1626da9c3d2ee..bfc90c9ef8527 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -22,7 +22,7 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd} + svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -497,6 +497,50 @@ class RowMatrix( columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) } + /** + * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR + * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape. + * Reference: + * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce + * architectures" ([[http://dx.doi.org/10.1145/1996092.1996103]]) + * + * @param computeQ whether to computeQ + * @return QRDecomposition(Q, R), Q = null if computeQ = false. + */ + def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { + val col = numCols().toInt + // split rows horizontally into smaller matrices, and compute QR for each of them + val blockQRs = rows.glom().map { partRows => + val bdm = BDM.zeros[Double](partRows.length, col) + var i = 0 + partRows.foreach { row => + bdm(i, ::) := row.toBreeze.t + i += 1 + } + breeze.linalg.qr.reduced(bdm).r + } + + // combine the R part from previous results vertically into a tall matrix + val combinedR = blockQRs.treeReduce{ (r1, r2) => + val stackedR = BDM.vertcat(r1, r2) + breeze.linalg.qr.reduced(stackedR).r + } + val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) + val finalQ = if (computeQ) { + try { + val invR = inv(combinedR) + this.multiply(Matrices.fromBreeze(invR)) + } catch { + case err: MatrixSingularException => + logWarning("R is not invertible and return Q as null") + null + } + } else { + null + } + QRDecomposition(finalQ, finalR) + } + /** * Find all similar columns using the DIMSUM sampling algorithm, described in two papers * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 93290e6508529..56c549ef99cb7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -26,6 +26,7 @@ import org.apache.spark.storage.StorageLevel /** * A more compact class to represent a rating than Tuple3[Int, Int, Double]. + * @since 0.8.0 */ case class Rating(user: Int, product: Int, rating: Double) @@ -254,6 +255,7 @@ class ALS private ( /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. + * @since 0.8.0 */ object ALS { /** @@ -269,6 +271,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param seed random seed + * @since 0.9.1 */ def train( ratings: RDD[Rating], @@ -293,6 +296,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into + * @since 0.8.0 */ def train( ratings: RDD[Rating], @@ -315,6 +319,7 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) + * @since 0.8.0 */ def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) : MatrixFactorizationModel = { @@ -331,6 +336,7 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) + * @since 0.8.0 */ def train(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { @@ -351,6 +357,7 @@ object ALS { * @param blocks level of parallelism to split computation into * @param alpha confidence parameter * @param seed random seed + * @since 0.8.1 */ def trainImplicit( ratings: RDD[Rating], @@ -377,6 +384,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param alpha confidence parameter + * @since 0.8.1 */ def trainImplicit( ratings: RDD[Rating], @@ -401,6 +409,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param alpha confidence parameter + * @since 0.8.1 */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { @@ -418,6 +427,7 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) + * @since 0.8.1 */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 43d219a49cf4e..261ca9cef0c5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. + * @since 0.8.0 */ class MatrixFactorizationModel( val rank: Int, @@ -73,7 +74,9 @@ class MatrixFactorizationModel( } } - /** Predict the rating of one user for one product. */ + /** Predict the rating of one user for one product. + * @since 0.8.0 + */ def predict(user: Int, product: Int): Double = { val userVector = userFeatures.lookup(user).head val productVector = productFeatures.lookup(product).head @@ -111,6 +114,7 @@ class MatrixFactorizationModel( * * @param usersProducts RDD of (user, product) pairs. * @return RDD of Ratings. + * @since 0.9.0 */ def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { // Previously the partitions of ratings are only based on the given products. @@ -142,6 +146,7 @@ class MatrixFactorizationModel( /** * Java-friendly version of [[MatrixFactorizationModel.predict]]. + * @since 1.2.0 */ def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() @@ -157,6 +162,7 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the user. The score is an opaque value that indicates how strongly * recommended the product is. + * @since 1.1.0 */ def recommendProducts(user: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) @@ -173,6 +179,7 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the product. The score is an opaque value that indicates how strongly * recommended the user is. + * @since 1.1.0 */ def recommendUsers(product: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) @@ -180,6 +187,20 @@ class MatrixFactorizationModel( protected override val formatVersion: String = "1.0" + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + * @since 1.3.0 + */ override def save(sc: SparkContext, path: String): Unit = { MatrixFactorizationModel.SaveLoadV1_0.save(this, path) } @@ -191,6 +212,7 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of * rating objects which contains the same userId, recommended productID and a "score" in the * rating field. Semantics of score is same as recommendProducts API + * @since 1.4.0 */ def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map { @@ -208,6 +230,7 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array * of rating objects which contains the recommended userId, same productID and a "score" in the * rating field. Semantics of score is same as recommendUsers API + * @since 1.4.0 */ def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map { @@ -218,6 +241,9 @@ class MatrixFactorizationModel( } } +/** + * @since 1.3.0 + */ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { import org.apache.spark.mllib.util.Loader._ @@ -292,6 +318,16 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { } } + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + * @since 1.3.0 + */ override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 58a50f9c19f14..93a6753efd4d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD * .setBandwidth(3.0) * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) * }}} + * @since 1.4.0 */ @Experimental class KernelDensity extends Serializable { @@ -51,6 +52,7 @@ class KernelDensity extends Serializable { /** * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). + * @since 1.4.0 */ def setBandwidth(bandwidth: Double): this.type = { require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") @@ -60,6 +62,7 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation. + * @since 1.4.0 */ def setSample(sample: RDD[Double]): this.type = { this.sample = sample @@ -68,6 +71,7 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation (for Java users). + * @since 1.4.0 */ def setSample(sample: JavaRDD[java.lang.Double]): this.type = { this.sample = sample.rdd.asInstanceOf[RDD[Double]] @@ -76,6 +80,7 @@ class KernelDensity extends Serializable { /** * Estimates probability density function at the given array of points. + * @since 1.4.0 */ def estimate(points: Array[Double]): Array[Double] = { val sample = this.sample diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index d321cc554c1cc..62da9f2ef22a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -33,6 +33,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. + * @since 1.1.0 */ @DeveloperApi class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { @@ -52,6 +53,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param sample The sample in dense/sparse vector format to be added into this summarizer. * @return This MultivariateOnlineSummarizer object. + * @since 1.1.0 */ def add(sample: Vector): this.type = { if (n == 0) { @@ -107,6 +109,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param other The other MultivariateOnlineSummarizer to be merged. * @return This MultivariateOnlineSummarizer object. + * @since 1.1.0 */ def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalCnt != 0 && other.totalCnt != 0) { @@ -149,6 +152,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this } + /** + * @since 1.1.0 + */ override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -161,6 +167,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMean) } + /** + * @since 1.1.0 + */ override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -183,14 +192,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realVariance) } + /** + * @since 1.1.0 + */ override def count: Long = totalCnt + /** + * @since 1.1.0 + */ override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } + /** + * @since 1.1.0 + */ override def max: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -202,6 +220,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMax) } + /** + * @since 1.1.0 + */ override def min: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -213,6 +234,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMin) } + /** + * @since 1.2.0 + */ override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -227,6 +251,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMagnitude) } + /** + * @since 1.2.0 + */ override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 6a364c93284af..3bb49f12289e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -21,46 +21,55 @@ import org.apache.spark.mllib.linalg.Vector /** * Trait for multivariate statistical summary of a data matrix. + * @since 1.0.0 */ trait MultivariateStatisticalSummary { /** * Sample mean vector. + * @since 1.0.0 */ def mean: Vector /** * Sample variance vector. Should return a zero vector if the sample size is 1. + * @since 1.0.0 */ def variance: Vector /** * Sample size. + * @since 1.0.0 */ def count: Long /** * Number of nonzero elements (including explicitly presented zero values) in each column. + * @since 1.0.0 */ def numNonzeros: Vector /** * Maximum value of each column. + * @since 1.0.0 */ def max: Vector /** * Minimum value of each column. + * @since 1.0.0 */ def min: Vector /** * Euclidean magnitude of each column + * @since 1.2.0 */ def normL2: Vector /** * L1 norm of each column + * @since 1.2.0 */ def normL1: Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 90332028cfb3a..f84502919e381 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD /** * :: Experimental :: * API for statistical functions in MLlib. + * @since 1.1.0 */ @Experimental object Statistics { @@ -41,6 +42,7 @@ object Statistics { * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. + * @since 1.1.0 */ def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() @@ -52,6 +54,7 @@ object Statistics { * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. + * @since 1.1.0 */ def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) @@ -68,6 +71,7 @@ object Statistics { * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. + * @since 1.1.0 */ def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) @@ -81,10 +85,14 @@ object Statistics { * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s + * @since 1.1.0 */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) - /** Java-friendly version of [[corr()]] */ + /** + * Java-friendly version of [[corr()]] + * @since 1.4.1 + */ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) @@ -101,10 +109,14 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. + * @since 1.1.0 */ def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) - /** Java-friendly version of [[corr()]] */ + /** + * Java-friendly version of [[corr()]] + * @since 1.4.1 + */ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) @@ -121,6 +133,7 @@ object Statistics { * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) @@ -135,6 +148,7 @@ object Statistics { * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) @@ -145,6 +159,7 @@ object Statistics { * @param observed The contingency matrix (containing either counts or relative frequencies). * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) @@ -157,6 +172,7 @@ object Statistics { * Real-valued features will be treated as categorical for each distinct value. * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. + * @since 1.1.0 */ def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cf51b24ff777f..9aa7763d7890d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution + * @since 1.3.0 */ @DeveloperApi class MultivariateGaussian ( @@ -60,12 +61,16 @@ class MultivariateGaussian ( */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - /** Returns density of this multivariate Gaussian at given point, x */ + /** Returns density of this multivariate Gaussian at given point, x + * @since 1.3.0 + */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - /** Returns the log-density of this multivariate Gaussian at given point, x */ + /** Returns the log-density of this multivariate Gaussian at given point, x + * @since 1.3.0 + */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index a835f96d5d0e3..9ce6faa137c41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging { false } + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + timer.stop("init") logDebug("##########") logDebug("Building tree 0") logDebug("##########") - var data = input // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeModel = new DecisionTree(treeStrategy).run(input) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction @@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging { var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 var bestM = 1 - // pseudo-residual for second iteration - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } - var m = 1 - while (m < numIterations) { + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") val model = new DecisionTree(treeStrategy).run(data) timer.stop(s"building tree $m") - // Create partial model + // Update partial model baseLearners(m) = model // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Technically, the weight should be optimized for the particular loss. // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate - // Note: A model of type regression is used since we require raw prediction - val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1)) predError = GradientBoostedTreesModel.updatePredictionError( input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) if (validate) { @@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging { validatePredError = GradientBoostedTreesModel.updatePredictionError( validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) val currentValidateError = validatePredError.values.mean() if (bestValidateError - currentValidateError < validationTol) { - return new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, bestM), - baseLearnerWeights.slice(0, bestM)) + doneLearning = true } else if (currentValidateError < bestValidateError) { - bestValidateError = currentValidateError - bestM = m + 1 + bestValidateError = currentValidateError + bestM = m + 1 } } - // Update data with pseudo-residuals - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } m += 1 } @@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() if (persistedInput) input.unpersist() if (validate) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 2d6b01524ff3d..9fd30c9b56319 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * learning rate should be between in the interval (0, 1] * @param validationTol Useful when runWithValidation is used. If the error rate on the * validation input between two iterations is less than the validationTol - * then stop. Ignored when [[run]] is used. + * then stop. Ignored when + * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ @Experimental case class BoostingStrategy( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 380291ac22bd3..9fe264656ede7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -128,9 +128,13 @@ private[spark] object DecisionTreeMetadata extends Logging { // based on the number of training examples. if (strategy.categoricalFeaturesInfo.nonEmpty) { val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + - s"in categorical features (= $maxCategoriesPerFeature)") + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") } val unorderedFeatures = new mutable.HashSet[Int]() diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 3ae09d39ef500..dc6ce8061f62b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -96,11 +96,8 @@ private void init() { new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); setDefault(myIntParam(), 1); - setDefault(myIntParam().w(1)); setDefault(myDoubleParam(), 0.5); - setDefault(myIntParam().w(1), myDoubleParam().w(0.5)); setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); - setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); } @Override diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index b48f190f599a2..d272a42c8576f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; import scala.Tuple2; @@ -59,7 +60,10 @@ public void tearDown() { @Test public void localLDAModel() { - LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + Matrix topics = LDASuite$.MODULE$.tinyTopics(); + double[] topicConcentration = new double[topics.numRows()]; + Arrays.fill(topicConcentration, 1.0D / topics.numRows()); + LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); // Check: basic parameters assertEquals(model.k(), tinyK); diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index c5fd2f9d5a22a..6355e0f179496 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -218,7 +218,7 @@ class AttributeSuite extends SparkFunSuite { // Attribute.fromStructField should accept any NumericType, not just DoubleType val longFldWithMeta = new StructField("x", LongType, false, metadata) assert(Attribute.fromStructField(longFldWithMeta).isNumeric) - val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata) + val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata) assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 82c345491bb3c..a7bc77965fefd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setLossType("logistic") + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 75cf5bd4ead4f..3775292f6dca7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { ova.fit(datasetWithLabelMetadata) } + test("SPARK-8092: ensure label features and prediction cols are configurable") { + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexed") + + val indexedDataset = labelIndexer + .fit(dataset) + .transform(dataset) + .drop("label") + .withColumnRenamed("features", "f") + + val ova = new OneVsRest() + ova.setClassifier(new LogisticRegression()) + .setLabelCol(labelIndexer.getOutputCol) + .setFeaturesCol("f") + .setPredictionCol("p") + + val ovaModel = ova.fit(indexedDataset) + val transformedDataset = ovaModel.transform(indexedDataset) + val outputFields = transformedDataset.schema.fieldNames.toSet + assert(outputFields.contains("p")) + } + test("SPARK-8049: OneVsRest shouldn't output temp columns") { val logReg = new LogisticRegression() .setMaxIter(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 1b6b69c7dc71e..ab711c8e4b215 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) ParamsSuite.checkParams(model) } @@ -167,9 +167,19 @@ private object RandomForestClassifierSuite { val newModel = rf.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures, + numClasses) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) + assert(newModel.numClasses == numClasses) + val results = newModel.transform(newData) + results.select("rawPrediction", "prediction").collect().foreach { + case Row(raw: Vector, prediction: Double) => { + assert(raw.size == numClasses) + val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2 + assert(predFromRaw == prediction) + } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 65846a846b7b4..321eeb843941c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } test("input column without ML attribute") { @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index c8d065f37a605..436e66bab09b0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -18,17 +18,65 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ class RFormulaParserSuite extends SparkFunSuite { - private def checkParse(formula: String, label: String, terms: Seq[String]) { - val parsed = RFormulaParser.parse(formula) - assert(parsed.label == label) - assert(parsed.terms == terms) + private def checkParse( + formula: String, + label: String, + terms: Seq[String], + schema: StructType = null) { + val resolved = RFormulaParser.parse(formula).resolve(schema) + assert(resolved.label == label) + assert(resolved.terms == terms) } test("parse simple formulas") { checkParse("y ~ x", "y", Seq("x")) + checkParse("y ~ x + x", "y", Seq("x")) checkParse("y ~ ._foo ", "y", Seq("._foo")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } + + test("parse dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ .", "a", Seq("b", "c"), schema) + } + + test("parse deletion") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ c - b", "a", Seq("c"), schema) + } + + test("parse additions and deletions in order") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ . - b + . - c", "a", Seq("b"), schema) + } + + test("dot ignores complex column types") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "tinyint", false) + .add("c", "map", true) + checkParse("a ~ .", "a", Seq("b"), schema) + } + + test("parse intercept") { + assert(RFormulaParser.parse("a ~ b").hasIntercept) + assert(RFormulaParser.parse("a ~ b + 1").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 0").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept) + assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 79c4ccf02d4e0..6aed3243afce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -31,72 +32,95 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { val formula = new RFormula().setFormula("id ~ v1 + v2") val original = sqlContext.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") - val result = formula.transform(original) - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) val expected = sqlContext.createDataFrame( Seq( - (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0), - (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0)) + (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), + (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) ).toDF("id", "v1", "v2", "features", "label") // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) - assert(result.collect().toSeq == expected.collect().toSeq) + assert(result.collect() === expected.collect()) } test("features column already exists") { val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") intercept[IllegalArgumentException] { - formula.transformSchema(original.schema) + formula.fit(original) } intercept[IllegalArgumentException] { - formula.transform(original) + formula.fit(original) } } test("label column already exists") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) - assert(resultSchema.toString == formula.transform(original).schema.toString) + assert(resultSchema.toString == model.transform(original).schema.toString) } test("label column already exists but is not double type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val model = formula.fit(original) intercept[IllegalArgumentException] { - formula.transformSchema(original.schema) + model.transformSchema(original.schema) } intercept[IllegalArgumentException] { - formula.transform(original) + model.transform(original) } } test("allow missing label column for test datasets") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(!resultSchema.exists(_.name == "label")) - assert(resultSchema.toString == formula.transform(original).schema.toString) + assert(resultSchema.toString == model.transform(original).schema.toString) } -// TODO(ekl) enable after we implement string label support -// test("transform string label") { -// val formula = new RFormula().setFormula("name ~ id") -// val original = sqlContext.createDataFrame( -// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name") -// val result = formula.transform(original) -// val resultSchema = formula.transformSchema(original.schema) -// val expected = sqlContext.createDataFrame( -// Seq( -// (1, "foo", Vectors.dense(Array(1.0)), 1.0), -// (2, "bar", Vectors.dense(Array(2.0)), 0.0), -// (3, "bar", Vectors.dense(Array(3.0)), 0.0)) -// ).toDF("id", "name", "features", "label") -// assert(result.schema.toString == resultSchema.toString) -// assert(result.collect().toSeq == expected.collect().toSeq) -// } + test("encodes string terms") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } + + test("attribute generation") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array( + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), + new NumericAttribute(Some("b"), Some(3)))) + assert(attrs === expectedAttrs) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 9682edcd9ba84..dbdce0c9dea54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions.min() < -1) } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val df = sqlContext.createDataFrame(data) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala new file mode 100644 index 0000000000000..66e4b170bae80 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row} + +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + private val schema = StructType( + Array( + StructField("label", DoubleType), + StructField("features", DoubleType), + StructField("weight", DoubleType))) + + private val predictionSchema = StructType(Array(StructField("features", DoubleType))) + + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { + val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d)) + val parallelData = sc.parallelize(data) + + sqlContext.createDataFrame(parallelData, schema) + } + + private def generatePredictionInput(features: Seq[Double]): DataFrame = { + val data = Seq.tabulate(features.size)(i => Row(features(i))) + + val parallelData = sc.parallelize(data) + sqlContext.createDataFrame(parallelData, predictionSchema) + } + + test("isotonic regression predictions") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + val trainer = new IsotonicRegression().setIsotonicParam(true) + + val model = trainer.fit(dataset) + + val predictions = model + .transform(dataset) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.parentModel.isotonic) + } + + test("antitonic regression predictions") { + val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) + val trainer = new IsotonicRegression().setIsotonicParam(false) + + val model = trainer.fit(dataset) + val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1)) + } + + test("params validation") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression + ParamsSuite.checkParams(ir) + val model = ir.fit(dataset) + ParamsSuite.checkParams(model) + } + + test("default params") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression() + assert(ir.getLabelCol === "label") + assert(ir.getFeaturesCol === "features") + assert(ir.getWeightCol === "weight") + assert(ir.getPredictionCol === "prediction") + assert(ir.getIsotonicParam === true) + + val model = ir.fit(dataset) + model.transform(dataset) + .select("label", "features", "prediction", "weight") + .collect() + + assert(model.getLabelCol === "label") + assert(model.getFeaturesCol === "features") + assert(model.getWeightCol === "weight") + assert(model.getPredictionCol === "prediction") + assert(model.getIsotonicParam === true) + assert(model.hasParent) + } + + test("set parameters") { + val isotonicRegression = new IsotonicRegression() + .setIsotonicParam(false) + .setWeightParam("w") + .setFeaturesCol("f") + .setLabelCol("l") + .setPredictionCol("p") + + assert(isotonicRegression.getIsotonicParam === false) + assert(isotonicRegression.getWeightCol === "w") + assert(isotonicRegression.getFeaturesCol === "f") + assert(isotonicRegression.getLabelCol === "l") + assert(isotonicRegression.getPredictionCol === "p") + } + + test("missing column") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + + intercept[IllegalArgumentException] { + new IsotonicRegression().setWeightParam("w").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setFeaturesCol("f").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setLabelCol("l").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index fd653296c9d97..d7b291d5a6330 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { @@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) inputDStream.count() @@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase .setNumIterations(10) val numBatches = 10 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index da70d9bd7c790..c43e1e575c09c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseMatrix => BDM, max, argmax} import org.apache.spark.SparkFunSuite +import org.apache.spark.graphx.Edge import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -30,7 +31,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ test("LocalLDAModel") { - val model = new LocalLDAModel(tinyTopics) + val model = new LocalLDAModel(tinyTopics, + Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) // Check: basic parameters assert(model.k === tinyK) @@ -81,21 +83,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.topicsMatrix === localModel.topicsMatrix) // Check: topic summaries - // The odd decimal formatting and sorting is a hack to do a robust comparison. - val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => - assert(t1 === t2) + val topicSummary = model.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) => + assert(topics ~== topicsLocal absTol 0.01) } // Check: per-doc topic distributions @@ -195,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // verify the result, Note this generate the identical result as // [[https://github.com/Blei-Lab/onlineldavb]] - val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1) - assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2) + val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t) + val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t) + val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950) + val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050) + assert(topic1 ~== expectedTopic1 absTol 0.01) + assert(topic2 ~== expectedTopic2 absTol 0.01) } test("OnlineLDAOptimizer with toy data") { @@ -234,6 +231,114 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("LocalLDAModel logPerplexity") { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + // obtained from LDA model trained in gensim, see below + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val docs = sc.parallelize(toydata) + + + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(lda.log_perplexity(corpus)) + > -3.69051285096 + */ + + assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) + } + + test("LocalLDAModel predict") { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + // obtained from LDA model trained in gensim, see below + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val docs = sc.parallelize(toydata) + + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(list(lda.get_document_topics(corpus))) + > [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)], + > [(0, 0.99504950495049516)], [(1, 0.99504950495049516)], + > [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]] + */ + + val expectedPredictions = List( + (0, 0.99504), (0, 0.99504), + (0, 0.99504), (1, 0.99504), + (1, 0.99504), (1, 0.99504)) + + val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) => + // convert results to expectedPredictions format, which only has highest probability topic + val topicsBz = topics.toBreeze.toDenseVector + (id, (argmax(topicsBz), max(topicsBz))) + }.sortByKey() + .values + .collect() + + expectedPredictions.zip(actualPredictions).forall { case (expected, actual) => + expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D) + } + } + test("OnlineLDAOptimizer with asymmetric prior") { def toydata: Array[(Long, Vector)] = Array( Vectors.sparse(6, Array(0, 1), Array(1, 1)), @@ -286,7 +391,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("model save/load") { // Test for LocalLDAModel. - val localModel = new LocalLDAModel(tinyTopics) + val localModel = new LocalLDAModel(tinyTopics, + Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D) val tempDir1 = Utils.createTempDir() val path1 = tempDir1.toURI.toString @@ -312,18 +418,76 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) assert(samelocalModel.k === localModel.k) assert(samelocalModel.vocabSize === localModel.vocabSize) + assert(samelocalModel.docConcentration === localModel.docConcentration) + assert(samelocalModel.topicConcentration === localModel.topicConcentration) + assert(samelocalModel.gammaShape === localModel.gammaShape) val sameDistributedModel = DistributedLDAModel.load(sc, path2) assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) assert(distributedModel.k === sameDistributedModel.k) assert(distributedModel.vocabSize === sameDistributedModel.vocabSize) assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) + assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) + assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.gammaShape === sameDistributedModel.gammaShape) + assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) + + val graph = distributedModel.graph + val sameGraph = sameDistributedModel.graph + assert(graph.vertices.sortByKey().collect() === sameGraph.vertices.sortByKey().collect()) + val edge = graph.edges.map { + case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos) + }.sortBy(x => (x._1, x._2)).collect() + val sameEdge = sameGraph.edges.map { + case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos) + }.sortBy(x => (x._1, x._2)).collect() + assert(edge === sameEdge) } finally { Utils.deleteRecursively(tempDir1) Utils.deleteRecursively(tempDir2) } } + test("EMLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new EMLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + + test("OnlineLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new OnlineLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + } private[clustering] object LDASuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 19e65f1b53ab5..189000512155f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -68,6 +68,54 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) } + test("power iteration clustering on graph") { + /* + We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for + edge (3, 4). + + 15-14 -13 -12 + | | + 4 . 3 - 2 11 + | | x | | + 5 0 - 1 10 + | | + 6 - 7 - 8 - 9 + */ + + val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), + (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge + (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), + (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + + val edges = similarities.flatMap { case (i, j, s) => + if (i != j) { + Seq(Edge(i, j, s), Edge(j, i, s)) + } else { + None + } + } + val graph = Graph.fromEdges(sc.parallelize(edges, 2), 0.0) + + val model = new PowerIterationClustering() + .setK(2) + .run(graph) + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + model.assignments.collect().foreach { a => + predictions(a.cluster) += a.id + } + assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + + val model2 = new PowerIterationClustering() + .setK(2) + .setInitializationMode("degree") + .run(sc.parallelize(similarities, 2)) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + model2.assignments.collect().foreach { a => + predictions2(a.cluster) += a.id + } + assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + } + test("normalize and powerIter") { /* Test normalize() with the following graph: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index ac01622b8a089..3645d29dccdb2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom @@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + test("accuracy for single center and equivalence to grand average") { // set parameters val numBatches = 10 @@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) @@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index b6818369208d7..a864eec460f2b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -37,6 +37,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(syms.length == 2) assert(syms(0)._1 == "b") assert(syms(1)._1 == "c") + + // Test that model built using Word2Vec, i.e wordVectors and wordIndec + // and a Word2VecMap give the same values. + val word2VecMap = model.getVectors + val newModel = new Word2VecModel(word2VecMap) + assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq)) + } + + test("Word2Vec throws exception when vocabulary is empty") { + intercept[IllegalArgumentException] { + val sentence = "a b c" + val localDoc = Seq(sentence, sentence) + val doc = sc.parallelize(localDoc) + .map(line => line.split(" ").toSeq) + new Word2Vec().setMinCount(10).fit(doc) + } } test("Word2VecModel") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 9f107c89f6d80..6dd2dc926acc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(sequences, 2).cache() - def compareResult( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet - } - val prefixspan = new PrefixSpan() .setMinSupport(0.33) .setMaxPatternLength(50) @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue1, result1.collect())) + assert(compareResults(expectedValue1, result1.collect())) prefixspan.setMinSupport(0.5).setMaxPatternLength(50) val result2 = prefixspan.run(rdd) @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4), 4L), (Array(5), 3L) ) - assert(compareResult(expectedValue2, result2.collect())) + assert(compareResults(expectedValue2, result2.collect())) prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue3, result3.collect())) + assert(compareResults(expectedValue3, result3.collect())) + } + + private def compareResults( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index d34888af2d73b..e331c75989187 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo import PeriodicGraphCheckpointerSuite._ - // TODO: Do I need to call count() on the graphs' RDDs? - test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) checkPersistence(graphsToCheck, iteration) iteration += 1 @@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var graphsToCheck = Seq.empty[GraphToCheck] sc.setCheckpointDir(path) val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) @@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graph.vertices.count() graph.edges.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) @@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite { } else { // Graph should never be checkpointed assert(!graph.isCheckpointed, "Graph should never have been checkpointed") - assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") } } catch { case e: AssertionError => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala new file mode 100644 index 0000000000000..b2a459a68b5fa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PeriodicRDDCheckpointerSuite._ + + test("Persisting") { + var rddsToCheck = Seq.empty[RDDToCheck] + + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) + checkpointer.update(rdd1) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkPersistence(rddsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkPersistence(rddsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var rddsToCheck = Seq.empty[RDDToCheck] + sc.setCheckpointDir(path) + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) + rdd1.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkCheckpoint(rddsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rdd.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkCheckpoint(rddsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + rddsToCheck.foreach { rdd => + confirmCheckpointRemoved(rdd.rdd) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicRDDCheckpointerSuite { + + case class RDDToCheck(rdd: RDD[Double], gIndex: Int) + + def createRDD(sc: SparkContext): RDD[Double] = { + sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) + } + + def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { + rdds.foreach { g => + checkPersistence(g.rdd, g.gIndex, iteration) + } + } + + /** + * Check storage level of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(rdd.getStorageLevel == StorageLevel.NONE) + } else { + assert(rdd.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") + } + } + + def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { + rdds.reverse.foreach { g => + checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { + // Note: We cannot check rdd.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this rdd.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration) + rdd.getCheckpointFile.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkCheckpoint( + rdd: RDD[_], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) + // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(rdd.isCheckpointed, "RDD should be checkpointed") + assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(rdd) + } + } else { + // RDD should never be checkpointed + assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") + assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + + s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 03be4119bdaca..1c37ea5123e82 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -57,6 +57,21 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(vec.values === values) } + test("sparse vector construction with mismatched indices/values array") { + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0, 7.0, 9.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0)) + } + } + + test("sparse vector construction with too many indices vs size") { + intercept[IllegalArgumentException] { + Vectors.sparse(3, Array(1, 2, 3, 4), Array(3.0, 5.0, 7.0, 9.0)) + } + } + test("dense to array") { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.toArray.eq(arr)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index b6cb53d0c743e..283ffec1d49d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random +import breeze.numerics.abs import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} import org.apache.spark.SparkFunSuite @@ -238,6 +239,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("QR Decomposition") { + for (mat <- Seq(denseMat, sparseMat)) { + val result = mat.tallSkinnyQR(true) + val expected = breeze.linalg.qr.reduced(mat.toBreeze()) + val calcQ = result.Q + val calcR = result.R + assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze()))) + assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze())) + // Decomposition without computing Q + val rOnly = mat.tallSkinnyQR(computeQ = false) + assert(rOnly.Q == null) + assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]]))) + } + } } class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index a2a4c5f6b8b70..34c07ed170816 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { def errorMessage = v1.toString + " did not equal " + v2.toString @@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) inputDStream.count() @@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) // collect the output as (true, estimated) tuples @@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val numBatches = 10 val nPoints = 100 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 2521b3342181a..6fc9e8df621df 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) - (algos zip losses) map { - case (algo, loss) => { - val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) - .runWithValidation(trainRdd, validateRdd) - val numTrees = gbtValidate.numTrees - assert(numTrees !== numIterations) - - // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) - val (errorWithoutValidation, errorWithValidation) = { - if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) - } else { - (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) - } - } - assert(errorWithValidation <= errorWithoutValidation) - - // Test that results from evaluateEachIteration comply with runWithValidation. - // Note that convergenceTol is set to 0.0 - val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) - assert(evaluationArray.length === numIterations) - assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) - var i = 1 - while (i < numTrees) { - assert(evaluationArray(i) <= evaluationArray(i - 1)) - i += 1 + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) } } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, checkpointInterval = 2) + val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + } private object GradientBoostedTreesSuite { diff --git a/pom.xml b/pom.xml index 1f44dc8abe1d4..35fc8c44bc1b0 100644 --- a/pom.xml +++ b/pom.xml @@ -573,7 +573,7 @@ io.netty netty-all - 4.0.28.Final + 4.0.29.Final org.apache.derby diff --git a/pylintrc b/pylintrc index 061775960393b..6a675770da69a 100644 --- a/pylintrc +++ b/pylintrc @@ -84,7 +84,7 @@ enable= # If you would like to improve the code quality of pyspark, remove any of these disabled errors # run ./dev/lint-python and see if the errors raised by pylint can be fixed. -disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable [REPORTS] diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 9ef93071d2e77..3b647985801b7 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,7 +350,26 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ - self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj) + self.save(_load_class) + self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) + d.pop('__doc__', None) + # handle property and staticmethod + dd = {} + for k, v in d.items(): + if isinstance(v, property): + k = ('property', k) + v = (v.fget, v.fset, v.fdel, v.__doc__) + elif isinstance(v, staticmethod) and hasattr(v, '__func__'): + k = ('staticmethod', k) + v = v.__func__ + elif isinstance(v, classmethod) and hasattr(v, '__func__'): + k = ('classmethod', k) + v = v.__func__ + dd[k] = v + self.save(dd) + self.write(pickle.TUPLE2) + self.write(pickle.REDUCE) + else: raise pickle.PicklingError("Can't pickle %r" % obj) @@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None): None, None, closure) +def _load_class(cls, d): + """ + Loads additional properties into class `cls`. + """ + for k, v in d.items(): + if isinstance(k, tuple): + typ, k = k + if typ == 'property': + v = property(*v) + elif typ == 'staticmethod': + v = staticmethod(v) + elif typ == 'classmethod': + v = classmethod(v) + setattr(cls, k, v) + return cls + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 90cd342a6cf7f..60be85e53e2aa 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -52,7 +52,11 @@ def launch_gateway(): script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): - submit_args = "--conf spark.ui.enabled=false " + submit_args + submit_args = ' '.join([ + "--conf spark.ui.enabled=false", + "--conf spark.buffer.pageSize=4mb", + submit_args + ]) command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 89117e492846b..5a82bc286d1e8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) - >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) + >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) - >>> allclose(model.treeWeights, [1.0, 1.0]) + >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 86e654dd0779f..015e7a9d4900a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -525,7 +525,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text - (default) or repeatedly matching the regex (if gaps is true). + (default) or repeatedly matching the regex (if gaps is false). Optional parameters also allow filtering tokens using a minimal length. It returns an array of strings that can be empty. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 58ad99d46e23b..900ade248c386 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -152,11 +152,19 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" return KMeansModel([c.toArray() for c in centers]) -class GaussianMixtureModel(object): +@inherit_doc +class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental - """A clustering model derived from the Gaussian Mixture Model method. + A clustering model derived from the Gaussian Mixture Model method. >>> from pyspark.mllib.linalg import Vectors, DenseMatrix + >>> from numpy.testing import assert_equal + >>> from shutil import rmtree + >>> import os, tempfile + >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) @@ -169,6 +177,25 @@ class GaussianMixtureModel(object): True >>> labels[4]==labels[5] True + + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = GaussianMixtureModel.load(sc, path) + >>> assert_equal(model.weights, sameModel.weights) + >>> mus, sigmas = list( + ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) + >>> sameMus, sameSigmas = list( + ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) + >>> mus == sameMus + True + >>> sigmas == sameSigmas + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + >>> data = array([-5.1971, -2.5359, -3.8220, ... -5.2211, -5.0602, 4.7118, ... 6.8989, 3.4592, 4.6322, @@ -182,25 +209,15 @@ class GaussianMixtureModel(object): True >>> labels[3]==labels[4] True - >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1)) - >>> im = GaussianMixtureModel([0.5, 0.5], - ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])), - ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))]) - >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im) """ - def __init__(self, weights, gaussians): - self._weights = weights - self._gaussians = gaussians - self._k = len(self._weights) - @property def weights(self): """ Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1. """ - return self._weights + return array(self.call("weights")) @property def gaussians(self): @@ -208,12 +225,14 @@ def gaussians(self): Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i. """ - return self._gaussians + return [ + MultivariateGaussian(gaussian[0], gaussian[1]) + for gaussian in zip(*self.call("gaussians"))] @property def k(self): """Number of gaussians in mixture.""" - return self._k + return len(self.weights) def predict(self, x): """ @@ -238,17 +257,30 @@ def predictSoft(self, x): :return: membership_matrix. RDD of array of double values. """ if isinstance(x, RDD): - means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians]) + means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - _convert_to_vector(self._weights), means, sigmas) + _convert_to_vector(self.weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) else: raise TypeError("x should be represented by an RDD, " "but got %s." % type(x)) + @classmethod + def load(cls, sc, path): + """Load the GaussianMixtureModel from disk. + + :param sc: SparkContext + :param path: str, path to where the model is stored. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.GaussianMixtureModelWrapper(model) + return cls(wrapper) + class GaussianMixture(object): """ + .. note:: Experimental + Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. :param data: RDD of data points @@ -271,11 +303,10 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] - weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), - k, convergenceTol, maxIterations, seed, - initialModelWeights, initialModelMu, initialModelSigma) - mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] - return GaussianMixtureModel(weight, mvg_obj) + java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) + return GaussianMixtureModel(java_model) class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg/__init__.py similarity index 100% rename from python/pyspark/mllib/linalg.py rename to python/pyspark/mllib/linalg/__init__.py diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 8e90adee5f4c2..5b7afc15ddfba 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -97,9 +97,11 @@ class LinearRegressionModelBase(LinearModel): def predict(self, x): """ - Predict the value of the dependent variable given a vector x - containing values for the independent variables. + Predict the value of the dependent variable given a vector or + an RDD of vectors containing values for the independent variables. """ + if isinstance(x, RDD): + return x.map(self.predict) x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept @@ -124,6 +126,8 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) @@ -267,6 +271,8 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) @@ -382,6 +388,8 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 875d3b2d642c6..916de2d6fcdbd 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -21,7 +21,9 @@ if sys.version > '3': xrange = range + basestring = str +from pyspark import SparkContext from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -223,6 +225,10 @@ class JavaSaveable(Saveable): """ def save(self, sc, path): + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) self._java_model.save(sc._jsc.sc(), path) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 144cdf0b0cdd5..99331297c19f0 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -40,7 +40,7 @@ if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -sc = SparkContext(appName="PySparkShell", pyFiles=add_files) +sc = SparkContext(pyFiles=add_files) atexit.register(lambda: sc.stop()) try: diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 8fb71bac64a5e..b8118bdb7ca76 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -606,7 +606,7 @@ def _open_file(self): if not os.path.exists(d): os.makedirs(d) p = os.path.join(d, str(id(self))) - self._file = open(p, "wb+", 65536) + self._file = open(p, "w+b", 65536) self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) os.unlink(p) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index abb6522dde7b0..917de24f3536b 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -277,6 +277,66 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + def _createFromRDD(self, rdd, schema, samplingRatio): + """ + Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. + """ + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(struct) + rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + rdd = rdd.map(schema.toInternal) + return rdd, schema + + def _createFromLocal(self, data, schema): + """ + Create an RDD for DataFrame from an list or pandas.DataFrame, returns + the RDD and schema. + """ + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + # make sure data could consumed multiple times + if not isinstance(data, list): + data = list(data) + + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + for row in data: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + data = [schema.toInternal(row) for row in data] + return self._sc.parallelize(data), schema + @since(1.3) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): @@ -340,49 +400,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - - if not isinstance(data, RDD): - if not isinstance(data, list): - data = list(data) - try: - # data could be list, tuple, generator ... - rdd = self._sc.parallelize(data) - except Exception: - raise TypeError("cannot create an RDD from type: %s" % type(data)) + if isinstance(data, RDD): + rdd, schema = self._createFromRDD(data, schema, samplingRatio) else: - rdd = data - - if schema is None or isinstance(schema, (list, tuple)): - if isinstance(data, RDD): - struct = self._inferSchema(rdd, samplingRatio) - else: - struct = self._inferSchemaFromList(data) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - schema = struct - converter = _create_converter(schema) - rdd = rdd.map(converter) - - elif isinstance(schema, StructType): - # take the first few rows to verify schema - rows = rdd.take(10) - for row in rows: - _verify_type(row, schema) - - else: - raise TypeError("schema should be StructType or list or None") - - # convert python objects to sql data - rdd = rdd.map(schema.toInternal) - + rdd, schema = self._createFromLocal(data, schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + df = DataFrame(jdf, self) + df._schema = schema + return df @since(1.3) def registerDataFrameAsTable(self, df, tableName): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 83e02b85f06f1..0f3480c239187 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -441,6 +441,42 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 3| + | 1| 8| + +---+-----+ + + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1130,7 +1166,7 @@ def crosstab(self, col1, col2): non-zero pair frequencies will be returned. The first column of each row will be the distinct values of `col1` and the column names will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. - Pairs that have no occurrences will have `null` as their counts. + Pairs that have no occurrences will have zero as their counts. :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. :param col1: The name of the first column. Distinct items will make the first item of @@ -1314,6 +1350,11 @@ def freqItems(self, cols, support=None): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b1de85acc1c21..2bb768c7baa01 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -59,7 +59,7 @@ __all__ += ['lag', 'lead', 'ntile'] __all__ += [ - 'date_format', + 'date_format', 'date_add', 'date_sub', 'add_months', 'months_between', 'year', 'quarter', 'month', 'hour', 'minute', 'second', 'dayofmonth', 'dayofyear', 'weekofyear'] @@ -543,6 +543,16 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +def expr(str): + """Parses the expression string into the column that it represents + + >>> df.select(expr("length(name)")).collect() + [Row('length(name)=5), Row('length(name)=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.expr(str)) + + @ignore_unicode_prefix @since(1.5) def length(col): @@ -708,7 +718,7 @@ def date_format(dateCol, format): [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.date_format(dateCol, format)) + return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format)) @since(1.5) @@ -721,7 +731,7 @@ def year(col): [Row(year=2015)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.year(col)) + return Column(sc._jvm.functions.year(_to_java_column(col))) @since(1.5) @@ -734,7 +744,7 @@ def quarter(col): [Row(quarter=2)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.quarter(col)) + return Column(sc._jvm.functions.quarter(_to_java_column(col))) @since(1.5) @@ -747,7 +757,7 @@ def month(col): [Row(month=4)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.month(col)) + return Column(sc._jvm.functions.month(_to_java_column(col))) @since(1.5) @@ -760,7 +770,7 @@ def dayofmonth(col): [Row(day=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.dayofmonth(col)) + return Column(sc._jvm.functions.dayofmonth(_to_java_column(col))) @since(1.5) @@ -773,7 +783,7 @@ def dayofyear(col): [Row(day=98)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.dayofyear(col)) + return Column(sc._jvm.functions.dayofyear(_to_java_column(col))) @since(1.5) @@ -786,7 +796,7 @@ def hour(col): [Row(hour=13)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.hour(col)) + return Column(sc._jvm.functions.hour(_to_java_column(col))) @since(1.5) @@ -799,7 +809,7 @@ def minute(col): [Row(minute=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.minute(col)) + return Column(sc._jvm.functions.minute(_to_java_column(col))) @since(1.5) @@ -812,7 +822,7 @@ def second(col): [Row(second=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.second(col)) + return Column(sc._jvm.functions.second(_to_java_column(col))) @since(1.5) @@ -821,11 +831,63 @@ def weekofyear(col): Extract the week number of a given date as integer. >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(weekofyear('a').alias('week')).collect() + >>> df.select(weekofyear(df.a).alias('week')).collect() [Row(week=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.weekofyear(col)) + return Column(sc._jvm.functions.weekofyear(_to_java_column(col))) + + +@since(1.5) +def date_add(start, days): + """ + Returns the date that is `days` days after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_add(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 9))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_add(_to_java_column(start), days)) + + +@since(1.5) +def date_sub(start, days): + """ + Returns the date that is `days` days before `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_sub(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 7))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) + + +@since(1.5) +def add_months(start, months): + """ + Returns the date that is `months` months after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(add_months(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 5, 8))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.add_months(_to_java_column(start), months)) + + +@since(1.5) +def months_between(date1, date2): + """ + Returns the number of months between date1 and date2. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) + >>> df.select(months_between(df.t, df.d).alias('months')).collect() + [Row(months=3.9495967...)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) @since(1.5) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ea821f486f13a..ebd3ea8db6a43 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -75,7 +75,7 @@ def sqlType(self): @classmethod def module(cls): - return 'pyspark.tests' + return 'pyspark.sql.tests' @classmethod def scalaUDT(cls): @@ -106,10 +106,45 @@ def __str__(self): return "(%s,%s)" % (self.x, self.y) def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ + return isinstance(other, self.__class__) and \ other.x == self.x and other.y == self.y +class PythonOnlyUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return PythonOnlyPoint(datum[0], datum[1]) + + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + + +class PythonOnlyPoint(ExamplePoint): + """ + An example class to demonstrate UDT in only Python + """ + __UDT__ = PythonOnlyUDT() + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -395,10 +430,39 @@ def test_convert_row_to_dict(self): self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) + def test_udt(self): + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + + def check_datatype(datatype): + pickled = pickle.loads(pickle.dumps(datatype)) + assert datatype == pickled + scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert datatype == python_datatype + + check_datatype(ExamplePointUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + check_datatype(structtype_with_udt) + p = ExamplePoint(1.0, 2.0) + self.assertEqual(_infer_type(p), ExamplePointUDT()) + _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + + check_datatype(PythonOnlyUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + check_datatype(structtype_with_udt) + p = PythonOnlyPoint(1.0, 2.0) + self.assertEqual(_infer_type(p), PythonOnlyUDT()) + _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) @@ -406,36 +470,66 @@ def test_infer_schema_with_udt(self): point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_apply_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = rdd.toDF(schema) + df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + def test_parquet_with_udt(self): - from pyspark.sql.tests import ExamplePoint + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.sc.parallelize([row]).toDF() + df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) + df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_column_operators(self): ci = self.df.key cs = self.df.value @@ -846,6 +940,13 @@ def test_bitwise_operations(self): result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() self.assertEqual(~75, result['~b']) + def test_expr(self): + from pyspark.sql import functions + row = Row(a="length string", b=75) + df = self.sqlCtx.createDataFrame([row]) + result = df.select(functions.expr("length(a)")).collect()[0].asDict() + self.assertEqual(13, result["'length(a)"]) + def test_replace(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 10ad89ea14a8d..6f74b7162f7cc 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -22,6 +22,7 @@ import calendar import json import re +import base64 from array import array if sys.version >= "3": @@ -31,6 +32,8 @@ from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass +from pyspark.serializers import CloudPickleSerializer + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -194,30 +197,33 @@ def fromInternal(self, ts): class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. + + The DecimalType must have fixed precision (the maximum total number of digits) + and scale (the number of digits on the right of dot). For example, (5, 2) can + support the value from [-999.99 to 999.99]. + + The precision can be up to 38, the scale must less or equal to precision. + + When create a DecimalType, the default precision and scale is (10, 0). When infer + schema from decimal.Decimal objects, it will be DecimalType(38, 18). + + :param precision: the maximum total number of digits (default: 10) + :param scale: the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision=None, scale=None): + def __init__(self, precision=10, scale=0): self.precision = precision self.scale = scale - self.hasPrecisionInfo = precision is not None + self.hasPrecisionInfo = True # this is public API def simpleString(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal(10,0)" + return "decimal(%d,%d)" % (self.precision, self.scale) def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" + return "decimal(%d,%d)" % (self.precision, self.scale) def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" + return "DecimalType(%d,%d)" % (self.precision, self.scale) class DoubleType(FractionalType): @@ -455,7 +461,7 @@ def __init__(self, fields=None): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeFields = None + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -498,6 +504,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) return self def simpleString(self): @@ -523,12 +530,9 @@ def toInternal(self, obj): if obj is None: return - if self._needSerializeFields is None: - self._needSerializeFields = any(f.needConversion() for f in self.fields) - - if self._needSerializeFields: + if self._needSerializeAnyField: if isinstance(obj, dict): - return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) else: @@ -547,7 +551,10 @@ def fromInternal(self, obj): if isinstance(obj, Row): # it's already converted by pickler return obj - values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + if self._needSerializeAnyField: + values = [f.fromInternal(v) for f, v in zip(self.fields, obj)] + else: + values = obj return _create_row(self.names, values) @@ -578,9 +585,10 @@ def module(cls): @classmethod def scalaUDT(cls): """ - The class name of the paired Scala UDT. + The class name of the paired Scala UDT (could be '', if there + is no corresponding one). """ - raise NotImplementedError("UDT must have a paired Scala UDT.") + return '' def needConversion(self): return True @@ -619,22 +627,37 @@ def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } + if self.scalaUDT(): + assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT' + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + else: + ser = CloudPickleSerializer() + b = ser.dumps(type(self)) + schema = { + "type": "udt", + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "serializedClass": base64.b64encode(b).decode('utf8'), + "sqlType": self.sqlType().jsonValue() + } return schema @classmethod def fromJson(cls, json): - pyUDT = json["pyClass"] + pyUDT = str(json["pyClass"]) # convert unicode to str split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] m = __import__(pyModule, globals(), locals(), [pyClass]) - UDT = getattr(m, pyClass) + if not hasattr(m, pyClass): + s = base64.b64decode(json['serializedClass'].encode('utf-8')) + UDT = CloudPickleSerializer().loads(s) + else: + UDT = getattr(m, pyClass) return UDT() def __eq__(self, other): @@ -693,11 +716,6 @@ def _parse_datatype_json_string(json_string): >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) - - >>> check_datatype(ExamplePointUDT()) - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) """ return _parse_datatype_json_value(json.loads(json_string)) @@ -749,10 +767,6 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT """ if obj is None: return NullType() @@ -761,7 +775,10 @@ def _infer_type(obj): return obj.__UDT__ dataType = _type_mappings.get(type(obj)) - if dataType is not None: + if dataType is DecimalType: + # the precision and scale of `obj` may be different from row to row. + return DecimalType(38, 18) + elif dataType is not None: return dataType() if isinstance(obj, dict): @@ -1084,11 +1101,6 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... """ # all objects are nullable if obj is None: @@ -1253,18 +1265,12 @@ def convert(self, obj, gateway_client): def _test(): import doctest from pyspark.context import SparkContext - # let doctest run in pyspark.sql.types, so DataTypes can be picklable - import pyspark.sql.types - from pyspark.sql import Row, SQLContext - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.types.__dict__.copy() + from pyspark.sql import SQLContext + globs = globals() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - (failure_count, test_count) = doctest.testmod( - pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionDescription.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionDescription.java new file mode 100644 index 0000000000000..9e10f27d59d55 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionDescription.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.annotation.DeveloperApi; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +/** + * ::DeveloperApi:: + + * A function description type which can be recognized by FunctionRegistry, and will be used to + * show the usage of the function in human language. + * + * `usage()` will be used for the function usage in brief way. + * `extended()` will be used for the function usage in verbose way, suppose + * an example will be provided. + * + * And we can refer the function name by `_FUNC_`, in `usage` and `extended`, as it's + * registered in `FunctionRegistry`. + */ +@DeveloperApi +@Retention(RetentionPolicy.RUNTIME) +public @interface ExpressionDescription { + String usage() default "_FUNC_ is undocumented"; + String extended() default "No example for _FUNC_."; +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java new file mode 100644 index 0000000000000..ba8e9cb4be28b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +/** + * Expression information, will be used to describe a expression. + */ +public class ExpressionInfo { + private String className; + private String usage; + private String name; + private String extended; + + public String getClassName() { + return className; + } + + public String getUsage() { + return usage; + } + + public String getName() { + return name; + } + + public String getExtended() { + return extended; + } + + public ExpressionInfo(String className, String name, String usage, String extended) { + this.className = className; + this.name = name; + this.usage = usage; + this.extended = extended; + } + + public ExpressionInfo(String className, String name) { + this(className, name, null, null); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java new file mode 100644 index 0000000000000..e3d3ba7a9ccc0 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +public interface SpecializedGetters { + + boolean isNullAt(int ordinal); + + boolean getBoolean(int ordinal); + + byte getByte(int ordinal); + + short getShort(int ordinal); + + int getInt(int ordinal); + + long getLong(int ordinal); + + float getFloat(int ordinal); + + double getDouble(int ordinal); + + Decimal getDecimal(int ordinal, int precision, int scale); + + UTF8String getUTF8String(int ordinal); + + byte[] getBinary(int ordinal); + + CalendarInterval getInterval(int ordinal); + + InternalRow getStruct(int ordinal, int numFields); + + ArrayData getArray(int ordinal); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 79d55b36dab01..f3b462778dc10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -19,11 +19,11 @@ import java.util.Iterator; -import scala.Function1; - import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.catalyst.util.UniqueObjectPool; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -40,94 +40,78 @@ public final class UnsafeFixedWidthAggregationMap { * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the * map, we copy this buffer and use it as the value. */ - private final byte[] emptyBuffer; + private final byte[] emptyAggregationBuffer; - /** - * An empty row used by `initProjection` - */ - private static final InternalRow emptyRow = new GenericInternalRow(); + private final StructType aggregationBufferSchema; - /** - * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. - */ - private final boolean reuseEmptyBuffer; + private final StructType groupingKeySchema; /** - * The projection used to initialize the emptyBuffer + * Encodes grouping keys as UnsafeRows. */ - private final Function1 initProjection; - - /** - * Encodes grouping keys or buffers as UnsafeRows. - */ - private final UnsafeRowConverter keyConverter; - private final UnsafeRowConverter bufferConverter; + private final UnsafeProjection groupingKeyProjection; /** * A hashmap which maps from opaque bytearray keys to bytearray values. */ private final BytesToBytesMap map; - /** - * An object pool for objects that are used in grouping keys. - */ - private final UniqueObjectPool keyPool; - - /** - * An object pool for objects that are used in aggregation buffers. - */ - private final ObjectPool bufferPool; - /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentBuffer = new UnsafeRow(); + private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + + private final boolean enablePerfMetrics; /** - * Scratch space that is used when encoding grouping keys into UnsafeRow format. - * - * By default, this is a 8 kb array, but it will grow as necessary in case larger keys are - * encountered. + * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given + * schema, false otherwise. */ - private byte[] groupingKeyConversionScratchSpace = new byte[1024 * 8]; - - private final boolean enablePerfMetrics; + public static boolean supportsAggregationBufferSchema(StructType schema) { + for (StructField field: schema.fields()) { + if (field.dataType() instanceof DecimalType) { + DecimalType dt = (DecimalType) field.dataType(); + if (dt.precision() > Decimal.MAX_LONG_DIGITS()) { + return false; + } + } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } /** * Create a new UnsafeFixedWidthAggregationMap. * - * @param initProjection the default value for new keys (a "zero" of the agg. function) - * @param keyConverter the converter of the grouping key, used for row conversion. - * @param bufferConverter the converter of the aggregation buffer, used for row conversion. + * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) + * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. + * @param groupingKeySchema the schema of the grouping key, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( - Function1 initProjection, - UnsafeRowConverter keyConverter, - UnsafeRowConverter bufferConverter, + InternalRow emptyAggregationBuffer, + StructType aggregationBufferSchema, + StructType groupingKeySchema, TaskMemoryManager memoryManager, int initialCapacity, + long pageSizeBytes, boolean enablePerfMetrics) { - this.initProjection = initProjection; - this.keyConverter = keyConverter; - this.bufferConverter = bufferConverter; + this.aggregationBufferSchema = aggregationBufferSchema; + this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); + this.groupingKeySchema = groupingKeySchema; + this.map = + new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); this.enablePerfMetrics = enablePerfMetrics; - this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); - this.keyPool = new UniqueObjectPool(100); - this.bufferPool = new ObjectPool(initialCapacity); - - InternalRow initRow = initProjection.apply(emptyRow); - int emptyBufferSize = bufferConverter.getSizeRequirement(initRow); - this.emptyBuffer = new byte[emptyBufferSize]; - int writtenLength = bufferConverter.writeRow( - initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize, - bufferPool); - assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; - // re-use the empty buffer only when there is no object saved in pool. - reuseEmptyBuffer = bufferPool.size() == 0; + // Initialize the buffer for aggregation value + final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); + this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + + UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); } /** @@ -135,53 +119,35 @@ public UnsafeFixedWidthAggregationMap( * return the same object. */ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { - final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey); - // Make sure that the buffer is large enough to hold the key. If it's not, grow it: - if (groupingKeySize > groupingKeyConversionScratchSpace.length) { - groupingKeyConversionScratchSpace = new byte[groupingKeySize]; - } - final int actualGroupingKeySize = keyConverter.writeRow( - groupingKey, - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize, - keyPool); - assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; + final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize); + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes()); if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: - if (!reuseEmptyBuffer) { - // There is some objects referenced by emptyBuffer, so generate a new one - InternalRow initRow = initProjection.apply(emptyRow); - bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize, bufferPool); - } loc.putNewKey( - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize, - emptyBuffer, + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes(), + emptyAggregationBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - emptyBuffer.length + emptyAggregationBuffer.length ); } // Reset the pointer to point to the value that we just stored or looked up: final MemoryLocation address = loc.getValueAddress(); - currentBuffer.pointTo( + currentAggregationBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - bufferConverter.numFields(), - loc.getValueLength(), - bufferPool + aggregationBufferSchema.length(), + loc.getValueLength() ); - return currentBuffer; + return currentAggregationBuffer; } /** @@ -217,16 +183,14 @@ public MapEntry next() { entry.key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - keyConverter.numFields(), - loc.getKeyLength(), - keyPool + groupingKeySchema.length(), + loc.getKeyLength() ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - bufferConverter.numFields(), - loc.getValueLength(), - bufferPool + aggregationBufferSchema.length(), + loc.getValueLength() ); return entry; } @@ -254,8 +218,6 @@ public void printPerfMetrics() { System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); - System.out.println("Number of unique objects in keys: " + keyPool.size()); - System.out.println("Number of objects in buffers: " + bufferPool.size()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 7f08bf7b742dc..e7088edced1a1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -19,14 +19,22 @@ import java.io.IOException; import java.io.OutputStream; - -import org.apache.spark.sql.catalyst.util.ObjectPool; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import static org.apache.spark.sql.types.DataTypes.*; /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -40,35 +48,49 @@ * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the * base address of the row) that points to the beginning of the variable-length field, and length - * (they are combined into a long). For other objects, they are stored in a pool, the indexes of - * them are hold in the the word. - * - * In order to support fast hashing and equality checks for UnsafeRows that contain objects - * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make - * sure all the key have the same index for same object, then we can hash/compare the objects by - * hash/compare the index. - * - * For non-primitive types, the word of a field could be: - * UNION { - * [1] [offset: 31bits] [length: 31bits] // StringType - * [0] [offset: 31bits] [length: 31bits] // BinaryType - * - [index: 63bits] // StringType, Binary, index to object in pool - * } + * (they are combined into a long). * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ public final class UnsafeRow extends MutableRow { - private Object baseObject; - private long baseOffset; + ////////////////////////////////////////////////////////////////////////////// + // Static methods + ////////////////////////////////////////////////////////////////////////////// - /** A pool to hold non-primitive objects */ - private ObjectPool pool; + public static int calculateBitSetWidthInBytes(int numFields) { + return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + } - public Object getBaseObject() { return baseObject; } - public long getBaseOffset() { return baseOffset; } - public int getSizeInBytes() { return sizeInBytes; } - public ObjectPool getPool() { return pool; } + /** + * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) + */ + public static final Set settableFieldTypes; + + // DecimalType(precision <= 18) is settable + static { + settableFieldTypes = Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList(new DataType[] { + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DateType, + TimestampType + }))); + } + + ////////////////////////////////////////////////////////////////////////////// + // Private fields and methods + ////////////////////////////////////////////////////////////////////////////// + + private Object baseObject; + private long baseOffset; /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; @@ -76,20 +98,21 @@ public final class UnsafeRow extends MutableRow { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - public int length() { return numFields; } + private void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; private long getFieldOffset(int ordinal) { - return baseOffset + bitSetWidthInBytes + ordinal * 8L; - } - - public static int calculateBitSetWidthInBytes(int numFields) { - return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + return baseOffset + bitSetWidthInBytes + ordinal * 8L; } - public static final long OFFSET_BITS = 31L; + ////////////////////////////////////////////////////////////////////////////// + // Public methods + ////////////////////////////////////////////////////////////////////////////// /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, @@ -97,6 +120,13 @@ public static int calculateBitSetWidthInBytes(int numFields) { */ public UnsafeRow() { } + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + + @Override + public int numFields() { return numFields; } + /** * Update this UnsafeRow to point to different backing data. * @@ -104,17 +134,14 @@ public UnsafeRow() { } * @param baseOffset the offset within the base object * @param numFields the number of fields in this row * @param sizeInBytes the size of this row's backing data, in bytes - * @param pool the object pool to hold arbitrary objects */ - public void pointTo( - Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) { - assert numFields >= 0 : "numFields should >= 0"; + public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) { + assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; this.sizeInBytes = sizeInBytes; - this.pool = pool; } private void assertIndexIsValid(int index) { @@ -132,73 +159,9 @@ public void setNullAt(int i) { PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); } - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - - /** - * Updates the column `i` as Object `value`, which cannot be primitive types. - */ @Override - public void update(int i, Object value) { - if (value == null) { - if (!isNullAt(i)) { - // remove the old value from pool - long idx = getLong(i); - if (idx <= 0) { - // this is the index of old value in pool, remove it - pool.replace((int)-idx, null); - } else { - // there will be some garbage left (UTF8String or byte[]) - } - setNullAt(i); - } - return; - } - - if (isNullAt(i)) { - // there is not an old value, put the new value into pool - int idx = pool.put(value); - setLong(i, (long)-idx); - } else { - // there is an old value, check the type, then replace it or update it - long v = getLong(i); - if (v <= 0) { - // it's the index in the pool, replace old value with new one - int idx = (int)-v; - pool.replace(idx, value); - } else { - // old value is UTF8String or byte[], try to reuse the space - boolean isString; - byte[] newBytes; - if (value instanceof UTF8String) { - newBytes = ((UTF8String) value).getBytes(); - isString = true; - } else { - newBytes = (byte[]) value; - isString = false; - } - int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); - int oldLength = (int) (v & Integer.MAX_VALUE); - if (newBytes.length <= oldLength) { - // the new value can fit in the old buffer, re-use it - PlatformDependent.copyMemory( - newBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + offset, - newBytes.length); - long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L; - setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length); - } else { - // Cannot fit in the buffer - int idx = pool.put(value); - setLong(i, (long) -idx); - } - } - } - setNotNullAt(i); + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); } @Override @@ -256,95 +219,181 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - /** - * Returns the object for column `i`, which should not be primitive type. - */ @Override - public Object get(int i) { - assertIndexIsValid(i); - if (isNullAt(i)) { - return null; - } - long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); - if (v <= 0) { - // It's an index to object in the pool. - int idx = (int)-v; - return pool.get(idx); + public void setDecimal(int ordinal, Decimal value, int precision) { + assertIndexIsValid(ordinal); + if (value == null) { + setNullAt(ordinal); } else { - // The column could be StingType or BinaryType - boolean isString = (v >> (OFFSET_BITS * 2)) > 0; - int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); - int size = (int) (v & Integer.MAX_VALUE); - final byte[] bytes = new byte[size]; - // TODO(davies): Avoid the copy once we can manage the life cycle of Row well. - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size - ); - if (isString) { - return UTF8String.fromBytes(bytes); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + setLong(ordinal, value.toUnscaledLong()); } else { - return bytes; + // TODO(davies): support update decimal (hold a bounded space even it's null) + throw new UnsupportedOperationException(); } } } @Override - public boolean isNullAt(int i) { - assertIndexIsValid(i); - return BitSetMethods.isSet(baseObject, baseOffset, i); + public Object get(int ordinal) { + throw new UnsupportedOperationException(); } @Override - public boolean getBoolean(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i)); + public Object get(int ordinal, DataType dataType) { + if (isNullAt(ordinal) || dataType instanceof NullType) { + return null; + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType dt = (DecimalType) dataType; + return getDecimal(ordinal, dt.precision(), dt.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType) dataType).size()); + } else { + throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); + } } @Override - public byte getByte(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i)); + public boolean isNullAt(int ordinal) { + assertIndexIsValid(ordinal); + return BitSetMethods.isSet(baseObject, baseOffset, ordinal); } @Override - public short getShort(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i)); + public boolean getBoolean(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(ordinal)); } @Override - public int getInt(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i)); + public byte getByte(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(ordinal)); } @Override - public long getLong(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + public short getShort(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(ordinal)); } @Override - public float getFloat(int i) { - assertIndexIsValid(i); - if (isNullAt(i)) { - return Float.NaN; + public int getInt(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(ordinal)); + } + + @Override + public long getLong(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(ordinal)); + } + + @Override + public float getFloat(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); + } + + @Override + public double getDouble(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + assertIndexIsValid(ordinal); + if (isNullAt(ordinal)) { + return null; + } + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); } else { - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); } } @Override - public double getDouble(int i) { - assertIndexIsValid(i); - if (isNullAt(i)) { - return Float.NaN; + public UTF8String getUTF8String(int ordinal) { + assertIndexIsValid(ordinal); + return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal)); + } + + @Override + public byte[] getBinary(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + assertIndexIsValid(ordinal); + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size + ); + return bytes; + } + } + + @Override + public CalendarInterval getInterval(int ordinal) { + if (isNullAt(ordinal)) { + return null; } else { - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long microseconds = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + return new CalendarInterval(months, microseconds); + } + } + + @Override + public UnsafeRow getStruct(int ordinal, int numFields) { + if (isNullAt(ordinal)) { + return null; + } else { + assertIndexIsValid(ordinal); + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final UnsafeRow row = new UnsafeRow(); + row.pointTo(baseObject, baseOffset + offset, numFields, size); + return row; } } @@ -356,23 +405,17 @@ public double getDouble(int i) { */ @Override public UnsafeRow copy() { - if (pool != null) { - throw new UnsupportedOperationException( - "Copy is not supported for UnsafeRows that use object pools"); - } else { - UnsafeRow rowCopy = new UnsafeRow(); - final byte[] rowDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeInBytes - ); - rowCopy.pointTo( - rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null); - return rowCopy; - } + UnsafeRow rowCopy = new UnsafeRow(); + final byte[] rowDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + return rowCopy; } /** @@ -426,7 +469,7 @@ public boolean equals(Object other) { */ public byte[] getBytes() { if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET - && (((byte[]) baseObject).length == sizeInBytes)) { + && (((byte[]) baseObject).length == sizeInBytes)) { return (byte[]) baseObject; } else { byte[] bytes = new byte[sizeInBytes]; @@ -452,4 +495,19 @@ public String toString() { public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } + + /** + * Writes the content of this row into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. + */ + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + baseObject, + baseOffset, + target, + targetOffset, + sizeInBytes + ); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java new file mode 100644 index 0000000000000..f43a285cd6cad --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.ByteArray; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A set of helper methods to write data into {@link UnsafeRow}s, + * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + */ +public class UnsafeRowWriters { + + /** Writer for Decimal with precision under 18. */ + public static class CompactDecimalWriter { + + public static int getSize(Decimal input) { + return 0; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + target.setLong(ordinal, input.toUnscaledLong()); + return 0; + } + } + + /** Writer for Decimal with precision larger than 18. */ + public static class DecimalWriter { + + public static int getSize(Decimal input) { + // bounded size + return 16; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + final long offset = target.getBaseOffset() + cursor; + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + final int numBytes = bytes.length; + assert(numBytes <= 16); + + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, + target.getBaseObject(), offset, numBytes); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return 16; + } + } + + /** Writer for UTF8String. */ + public static class UTF8StringWriter { + + public static int getSize(UTF8String input) { + return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.numBytes()); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String input) { + final long offset = target.getBaseOffset() + cursor; + final int numBytes = input.numBytes(); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the bytes to the variable length portion. + input.writeToMemory(target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + + /** Writer for binary (byte array) type. */ + public static class BinaryWriter { + + public static int getSize(byte[] input) { + return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) { + final long offset = target.getBaseOffset() + cursor; + final int numBytes = input.length; + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the bytes to the variable length portion. + ByteArray.writeToMemory(input, target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + + /** + * Writer for struct type where the struct field is backed by an {@link UnsafeRow}. + * + * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}. + * Non-UnsafeRow struct fields are handled directly in + * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection} + * by generating the Java code needed to convert them into UnsafeRow. + */ + public static class StructWriter { + public static int getSize(InternalRow input) { + int numBytes = 0; + if (input instanceof UnsafeRow) { + numBytes = ((UnsafeRow) input).getSizeInBytes(); + } else { + // This is handled directly in GenerateUnsafeProjection. + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) { + int numBytes = 0; + final long offset = target.getBaseOffset() + cursor; + if (input instanceof UnsafeRow) { + final UnsafeRow row = (UnsafeRow) input; + numBytes = row.getSizeInBytes(); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the bytes to the variable length portion. + row.writeToMemory(target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + } else { + // This is handled directly in GenerateUnsafeProjection. + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + + /** Writer for interval type. */ + public static class IntervalWriter { + + public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInterval input) { + final long offset = target.getBaseOffset() + cursor; + + // Write the months and microseconds fields of Interval to the variable length portion. + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months); + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds); + + // Set the fixed length portion. + target.setLong(ordinal, ((long) cursor) << 32); + return 16; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java deleted file mode 100644 index 97f89a7d0b758..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util; - -/** - * A object pool stores a collection of objects in array, then they can be referenced by the - * pool plus an index. - */ -public class ObjectPool { - - /** - * An array to hold objects, which will grow as needed. - */ - private Object[] objects; - - /** - * How many objects in the pool. - */ - private int numObj; - - public ObjectPool(int capacity) { - objects = new Object[capacity]; - numObj = 0; - } - - /** - * Returns how many objects in the pool. - */ - public int size() { - return numObj; - } - - /** - * Returns the object at position `idx` in the array. - */ - public Object get(int idx) { - assert (idx < numObj); - return objects[idx]; - } - - /** - * Puts an object `obj` at the end of array, returns the index of it. - *

- * The array will grow as needed. - */ - public int put(Object obj) { - if (numObj >= objects.length) { - Object[] tmp = new Object[objects.length * 2]; - System.arraycopy(objects, 0, tmp, 0, objects.length); - objects = tmp; - } - objects[numObj++] = obj; - return numObj - 1; - } - - /** - * Replaces the object at `idx` with new one `obj`. - */ - public void replace(int idx, Object obj) { - assert (idx < numObj); - objects[idx] = obj; - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java deleted file mode 100644 index d512392dcaacc..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util; - -import java.util.HashMap; - -/** - * An unique object pool stores a collection of unique objects in it. - */ -public class UniqueObjectPool extends ObjectPool { - - /** - * A hash map from objects to their indexes in the array. - */ - private HashMap objIndex; - - public UniqueObjectPool(int capacity) { - super(capacity); - objIndex = new HashMap(); - } - - /** - * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will - * return the index of the existing one. - */ - @Override - public int put(Object obj) { - if (objIndex.containsKey(obj)) { - return objIndex.get(obj); - } else { - int idx = super.put(obj); - objIndex.put(obj, idx); - return idx; - } - } - - /** - * The objects can not be replaced. - */ - @Override - public void replace(int idx, Object obj) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 39fd6e1bc6d13..68c49feae938e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -49,7 +48,6 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; @@ -63,7 +61,6 @@ public UnsafeExternalRowSorter( PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -72,7 +69,7 @@ public UnsafeExternalRowSorter( sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, - new RowComparator(ordering, schema.length(), null), + new RowComparator(ordering, schema.length()), prefixComparator, 4096, sparkEnv.conf() @@ -89,13 +86,12 @@ void setTestSpillFrequency(int frequency) { } @VisibleForTesting - void insertRow(InternalRow row) throws IOException { - UnsafeRow unsafeRow = unsafeProjection.apply(row); + void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - unsafeRow.getBaseObject(), - unsafeRow.getBaseOffset(), - unsafeRow.getSizeInBytes(), + row.getBaseObject(), + row.getBaseOffset(), + row.getSizeInBytes(), prefix ); numRowsInserted++; @@ -114,7 +110,7 @@ private void cleanupResources() { } @VisibleForTesting - Iterator sort() throws IOException { + Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -122,10 +118,10 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); - private final UnsafeRow row = new UnsafeRow(); + private UnsafeRow row = new UnsafeRow(); @Override public boolean hasNext() { @@ -133,20 +129,22 @@ public boolean hasNext() { } @Override - public InternalRow next() { + public UnsafeRow next() { try { sortedIterator.loadNext(); row.pointTo( sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, - sortedIterator.getRecordLength(), - null); + sortedIterator.getRecordLength()); if (!hasNext()) { - row.copy(); // so that we don't have dangling pointers to freed page + UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page + row = null; // so that we don't keep references to the base object cleanupResources(); + return copy; + } else { + return row; } - return row; } catch (IOException e) { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack @@ -163,38 +161,36 @@ public InternalRow next() { } - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); } /** * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. */ public static boolean supportsSchema(StructType schema) { - // TODO: add spilling note to explain why we do this for now: return UnsafeProjection.canSupport(schema); } private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; - private final ObjectPool objPool; private final UnsafeRow row1 = new UnsafeRow(); private final UnsafeRow row2 = new UnsafeRow(); - public RowComparator(Ordering ordering, int numFields, ObjectPool objPool) { + public RowComparator(Ordering ordering, int numFields) { this.numFields = numFields; this.ordering = ordering; - this.objPool = objPool; } @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool); - row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool); + // TODO: Why are the sizes -1? + row1.pointTo(baseObj1, baseOff1, numFields, -1); + row2.pointTo(baseObj2, baseOff2, numFields, -1); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index d22ad6794d608..17659d7d960b0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -50,9 +50,9 @@ public class DataTypes { public static final DataType TimestampType = TimestampType$.MODULE$; /** - * Gets the IntervalType object. + * Gets the CalendarIntervalType object. */ - public static final DataType IntervalType = IntervalType$.MODULE$; + public static final DataType CalendarIntervalType = CalendarIntervalType$.MODULE$; /** * Gets the DoubleType object. @@ -111,12 +111,18 @@ public static ArrayType createArrayType(DataType elementType, boolean containsNu return new ArrayType(elementType, containsNull); } + /** + * Creates a DecimalType by specifying the precision and scale. + */ public static DecimalType createDecimalType(int precision, int scale) { return DecimalType$.MODULE$.apply(precision, scale); } + /** + * Creates a DecimalType with default precision and scale, which are 10 and 0. + */ public static DecimalType createDecimalType() { - return DecimalType$.MODULE$.Unlimited(); + return DecimalType$.MODULE$.USER_DEFAULT(); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index ae0ab2f4c63f5..7ca20fe97fbef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -55,7 +55,6 @@ object CatalystTypeConverters { private def isWholePrimitive(dt: DataType): Boolean = dt match { case dt if isPrimitive(dt) => true - case ArrayType(elementType, _) => isWholePrimitive(elementType) case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) case _ => false } @@ -69,7 +68,7 @@ object CatalystTypeConverters { case StringType => StringConverter case DateType => DateConverter case TimestampType => TimestampConverter - case dt: DecimalType => BigDecimalConverter + case dt: DecimalType => new DecimalConverter(dt) case BooleanType => BooleanConverter case ByteType => ByteConverter case ShortType => ShortConverter @@ -77,7 +76,7 @@ object CatalystTypeConverters { case LongType => LongConverter case FloatType => FloatConverter case DoubleType => DoubleConverter - case _ => IdentityConverter + case dataType: DataType => IdentityConverter(dataType) } converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] } @@ -137,54 +136,58 @@ object CatalystTypeConverters { protected def toScalaImpl(row: InternalRow, column: Int): ScalaOutputType } - private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { + private case class IdentityConverter(dataType: DataType) + extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = scalaValue override def toScala(catalystValue: Any): Any = catalystValue - override def toScalaImpl(row: InternalRow, column: Int): Any = row(column) + override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType) } private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) - override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row(column)) + override def toScalaImpl(row: InternalRow, column: Int): Any = + toScala(row.get(column, udt.sqlType)) } /** Converter for arrays, sequences, and Java iterables. */ private case class ArrayConverter( - elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] { private[this] val elementConverter = getConverterForType(elementType) private[this] val isNoChange = isWholePrimitive(elementType) - override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + override def toCatalystImpl(scalaValue: Any): ArrayData = { scalaValue match { - case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) - case s: Seq[_] => s.map(elementConverter.toCatalyst) + case a: Array[_] => + new GenericArrayData(a.map(elementConverter.toCatalyst)) + case s: Seq[_] => + new GenericArrayData(s.map(elementConverter.toCatalyst).toArray) case i: JavaIterable[_] => val iter = i.iterator - var convertedIterable: List[Any] = List() + val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any] while (iter.hasNext) { val item = iter.next() - convertedIterable :+= elementConverter.toCatalyst(item) + convertedIterable += elementConverter.toCatalyst(item) } - convertedIterable + new GenericArrayData(convertedIterable.toArray) } } - override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + override def toScala(catalystValue: ArrayData): Seq[Any] = { if (catalystValue == null) { null } else if (isNoChange) { - catalystValue + catalystValue.toArray() } else { - catalystValue.map(elementConverter.toScala) + catalystValue.toArray().map(elementConverter.toScala) } } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row(column).asInstanceOf[Seq[Any]]) + toScala(row.getArray(column)) } private case class MapConverter( @@ -227,7 +230,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = - toScala(row(column).asInstanceOf[Map[Any, Any]]) + toScala(row.get(column, MapType(keyType, valueType)).asInstanceOf[Map[Any, Any]]) } private case class StructConverter( @@ -260,9 +263,9 @@ object CatalystTypeConverters { if (row == null) { null } else { - val ar = new Array[Any](row.size) + val ar = new Array[Any](row.numFields) var idx = 0 - while (idx < row.size) { + while (idx < row.numFields) { ar(idx) = converters(idx).toScala(row, idx) idx += 1 } @@ -271,7 +274,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Row = - toScala(row(column).asInstanceOf[InternalRow]) + toScala(row.getStruct(column, structType.size)) } private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { @@ -281,7 +284,8 @@ object CatalystTypeConverters { } override def toScala(catalystValue: UTF8String): String = if (catalystValue == null) null else catalystValue.toString - override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString + override def toScalaImpl(row: InternalRow, column: Int): String = + row.getUTF8String(column).toString } private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { @@ -302,7 +306,8 @@ object CatalystTypeConverters { DateTimeUtils.toJavaTimestamp(row.getLong(column)) } - private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + private class DecimalConverter(dataType: DecimalType) + extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) @@ -310,9 +315,11 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.get(column).asInstanceOf[Decimal].toJavaBigDecimal + row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal } + private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT) + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { final override def toScala(catalystValue: Any): Any = catalystValue final override def toCatalystImpl(scalaValue: T): Any = scalaValue @@ -399,9 +406,9 @@ object CatalystTypeConverters { case t: Timestamp => TimestampConverter.toCatalyst(t) case d: BigDecimal => BigDecimalConverter.toCatalyst(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) - case seq: Seq[Any] => seq.map(convertToCatalyst) + case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) case m: Map[_, _] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 024973a6b9fcd..b19bf4386b0ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,53 +19,164 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Row { +abstract class InternalRow extends Serializable with SpecializedGetters { - // This is only use for test - override def getString(i: Int): String = { - val str = getAs[UTF8String](i) - if (str != null) str.toString else null + def numFields: Int + + def get(ordinal: Int): Any = get(ordinal, null) + + def genericGet(ordinal: Int): Any = get(ordinal, null) + + def get(ordinal: Int, dataType: DataType): Any + + def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] + + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + + override def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) + + override def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) + + override def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) + + override def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) + + override def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) + + override def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) + + override def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) + + override def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) + + override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + getAs[Decimal](ordinal, DecimalType(precision, scale)) + + override def getInterval(ordinal: Int): CalendarInterval = + getAs[CalendarInterval](ordinal, CalendarIntervalType) + + // This is only use for test and will throw a null pointer exception if the position is null. + def getString(ordinal: Int): String = getUTF8String(ordinal).toString + + /** + * Returns a struct from ordinal position. + * + * @param ordinal position to get the struct from. + * @param numFields number of fields the struct type has + */ + override def getStruct(ordinal: Int, numFields: Int): InternalRow = + getAs[InternalRow](ordinal, null) + + override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null) + + override def toString: String = s"[${this.mkString(",")}]" + + /** + * Make a copy of the current [[InternalRow]] object. + */ + def copy(): InternalRow = this + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false } - // These expensive API should not be used internally. - final override def getDecimal(i: Int): java.math.BigDecimal = - throw new UnsupportedOperationException - final override def getDate(i: Int): java.sql.Date = - throw new UnsupportedOperationException - final override def getTimestamp(i: Int): java.sql.Timestamp = - throw new UnsupportedOperationException - final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException - final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException - final override def getMap[K, V](i: Int): scala.collection.Map[K, V] = - throw new UnsupportedOperationException - final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] = - throw new UnsupportedOperationException - final override def getStruct(i: Int): Row = throw new UnsupportedOperationException - final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException - final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = - throw new UnsupportedOperationException - - // A default implementation to change the return type - override def copy(): InternalRow = this + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[InternalRow]) { + return false + } + + val other = o.asInstanceOf[InternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + /* ---------------------- utility methods for Scala ---------------------- */ + + /** + * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. + */ + def toSeq: Seq[Any] = { + val n = numFields + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values.update(i, get(i)) + i += 1 + } + values.toSeq + } + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) /** - * Returns true if we can check equality for these 2 rows. - * Equality check between external row and internal row is not allowed. - * Here we do this check to prevent call `equals` on internal row with external row. + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. */ - protected override def canEqual(other: Row) = other.isInstanceOf[InternalRow] + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 var i = 0 - while (i < length) { + val len = numFields + while (i < len) { val update: Int = if (isNullAt(i)) { 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 9a3f9694e4c48..88a457f87ce4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -75,7 +75,7 @@ private [sql] object JavaTypeInference { case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 21b1de1ab9cb1..2442341da106d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -131,10 +131,10 @@ trait ScalaReflection { case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) - case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigDecimal] => - Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true) @@ -167,8 +167,8 @@ trait ScalaReflection { case obj: Float => FloatType case obj: Double => DoubleType case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.Unlimited - case obj: Decimal => DecimalType.Unlimited + case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT + case obj: Decimal => DecimalType.SYSTEM_DEFAULT case obj: java.sql.Timestamp => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 29cfc064da89a..f2498861c9573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.Interval +import org.apache.spark.unsafe.types.CalendarInterval /** * A very simple SQL parser. Based loosely on: @@ -48,6 +48,15 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } + def parseTableIdentifier(input: String): TableIdentifier = { + // Initialize the Keywords. + initLexical + phrase(tableIdentifier)(new lexical.Scanner(input)) match { + case Success(ident, _) => ident + case failureOrError => sys.error(failureOrError.toString) + } + } + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val ALL = Keyword("ALL") @@ -322,7 +331,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ { case s ~ f => Literal((s.getOrElse("") + f).toDouble) } + | sign.? ~ unsignedFloat ^^ { + case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) + } ) protected lazy val unsignedFloat: Parser[String] = @@ -354,32 +365,32 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val millisecond: Parser[Long] = integral <~ intervalUnit("millisecond") ^^ { - case num => num.toLong * Interval.MICROS_PER_MILLI + case num => num.toLong * CalendarInterval.MICROS_PER_MILLI } protected lazy val second: Parser[Long] = integral <~ intervalUnit("second") ^^ { - case num => num.toLong * Interval.MICROS_PER_SECOND + case num => num.toLong * CalendarInterval.MICROS_PER_SECOND } protected lazy val minute: Parser[Long] = integral <~ intervalUnit("minute") ^^ { - case num => num.toLong * Interval.MICROS_PER_MINUTE + case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE } protected lazy val hour: Parser[Long] = integral <~ intervalUnit("hour") ^^ { - case num => num.toLong * Interval.MICROS_PER_HOUR + case num => num.toLong * CalendarInterval.MICROS_PER_HOUR } protected lazy val day: Parser[Long] = integral <~ intervalUnit("day") ^^ { - case num => num.toLong * Interval.MICROS_PER_DAY + case num => num.toLong * CalendarInterval.MICROS_PER_DAY } protected lazy val week: Parser[Long] = integral <~ intervalUnit("week") ^^ { - case num => num.toLong * Interval.MICROS_PER_WEEK + case num => num.toLong * CalendarInterval.MICROS_PER_WEEK } protected lazy val intervalLiteral: Parser[Literal] = @@ -395,7 +406,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { val months = Seq(year, month).map(_.getOrElse(0)).sum val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) .map(_.getOrElse(0L)).sum - Literal.create(new Interval(months, microseconds), IntervalType) + Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType) } private def toNarrowestIntegerType(value: String): Any = { @@ -408,6 +419,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } + private def toDecimalOrDouble(value: String): Any = { + val decimal = BigDecimal(value) + // follow the behavior in MS SQL Server + // https://msdn.microsoft.com/en-us/library/ms179899.aspx + if (value.contains('E') || value.contains('e')) { + decimal.doubleValue() + } else { + decimal.underlying() + } + } + protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } @@ -441,4 +463,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) } + + protected lazy val tableIdentifier: Parser[TableIdentifier] = + (ident <~ ".").? ~ ident ^^ { + case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala similarity index 61% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala index 568b7ac2c5987..aebcdeb9d070f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala @@ -15,9 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.catalyst /** - * Package containing expressions that are specific to Spark runtime. + * Identifies a `table` in `database`. If `database` is not defined, the current database is used. */ -package object expressions +private[sql] case class TableIdentifier(table: String, database: Option[String] = None) { + def withDatabase(database: String): TableIdentifier = this.copy(database = Some(database)) + + def toSeq: Seq[String] = database.toSeq :+ table + + override def toString: String = toSeq.map("`" + _ + "`").mkString(".") + + def unquotedString: String = toSeq.mkString(".") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8cadbc57e87e1..265f3d1e41765 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ -import scala.collection.mutable.ArrayBuffer /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -77,8 +79,11 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: + RemoveEvaluationFromSort :: HiveTypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*) + extendedResolutionRules : _*), + Batch("Nondeterministic", Once, + PullOutNondeterministic) ) /** @@ -533,7 +538,7 @@ class Analyzer( case min: Min if isDistinct => min // For other aggregate functions, DISTINCT keyword is not supported for now. // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other if isDistinct => + case other: AggregateExpression1 if isDistinct => failAnalysis(s"$name does not support DISTINCT keyword.") // If it does not have DISTINCT keyword, we will return it as is. case other => other @@ -910,6 +915,96 @@ class Analyzer( Project(finalProjectList, withWindow) } } + + /** + * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, + * put them into an inner Project and finally project them away at the outer Project. + */ + object PullOutNondeterministic extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Project => p + case f: Filter => f + + // todo: It's hard to write a general rule to pull out nondeterministic expressions + // from LogicalPlan, currently we only do it for UnaryNode which has same output + // schema with its child. + case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => + val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { + case n: Nondeterministic => n + } + leafNondeterministic.map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + new TreeNodeRef(e) -> ne + } + }.toMap + val newPlan = p.transformExpressions { case e => + nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + } + val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) + Project(p.output, newPlan.withNewChildren(newChild :: Nil)) + } + } + + /** + * Removes all still-need-evaluate ordering expressions from sort and use an inner project to + * materialize them, finally use a outer project to project them away to keep the result same. + * Then we can make sure we only sort by [[AttributeReference]]s. + * + * As an example, + * {{{ + * Sort('a, 'b + 1, + * Relation('a, 'b)) + * }}} + * will be turned into: + * {{{ + * Project('a, 'b, + * Sort('a, '_sortCondition, + * Project('a, 'b, ('b + 1).as("_sortCondition"), + * Relation('a, 'b)))) + * }}} + */ + object RemoveEvaluationFromSort extends Rule[LogicalPlan] { + private def hasAlias(expr: Expression) = { + expr.find { + case a: Alias => true + case _ => false + }.isDefined + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // The ordering expressions have no effect to the output schema of `Sort`, + // so `Alias`s in ordering expressions are unnecessary and we should remove them. + case s @ Sort(ordering, _, _) if ordering.exists(hasAlias) => + val newOrdering = ordering.map(_.transformUp { + case Alias(child, _) => child + }.asInstanceOf[SortOrder]) + s.copy(order = newOrdering) + + case s @ Sort(ordering, global, child) + if s.expressions.forall(_.resolved) && s.childrenResolved && !s.hasNoEvaluation => + + val (ref, needEval) = ordering.partition(_.child.isInstanceOf[AttributeReference]) + + val namedExpr = needEval.map(_.child match { + case n: NamedExpression => n + case e => Alias(e, "_sortCondition")() + }) + + val newOrdering = ref ++ needEval.zip(namedExpr).map { case (order, ne) => + order.copy(child = ne.toAttribute) + } + + // Add still-need-evaluate ordering expressions into inner project and then project + // them away after the sort. + Project(child.output, + Sort(newOrdering, global, + Project(child.output ++ namedExpr, child))) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 1541491608b24..5766e6a2dd51a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -23,8 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.EmptyConf +import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} /** @@ -54,7 +53,7 @@ trait Catalog { */ def getTables(databaseName: Option[String]): Seq[(String, Boolean)] - def refreshTable(databaseName: String, tableName: String): Unit + def refreshTable(tableIdent: TableIdentifier): Unit def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit @@ -132,7 +131,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { result } - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } } @@ -241,7 +240,7 @@ object EmptyCatalog extends Catalog { override def unregisterAllTables(): Unit = {} - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c203fcecf20fb..0ebc3d180a780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -38,10 +37,10 @@ trait CheckAnalysis { throw new AnalysisException(msg) } - def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { - case e: Generator => true - }).nonEmpty + case e: Generator => e + }).length > 1 } def checkAnalysis(plan: LogicalPlan): Unit = { @@ -83,6 +82,23 @@ trait CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + failAnalysis( + s"join condition '${condition.prettyString}' " + + s"of type ${condition.dataType.simpleString} is not a boolean.") + + case j @ Join(_, _, _, Some(condition)) => + def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { + case p: Predicate => + p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) + case e if e.dataType.isInstanceOf[BinaryType] => + failAnalysis(s"expression ${e.prettyString} in join condition " + + s"'${condition.prettyString}' can't be binary type.") + case _ => // OK + } + + checkValidJoinConditionExprs(condition) + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK @@ -96,7 +112,25 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } + def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { + case BinaryType => + failAnalysis(s"grouping expression '${expr.prettyString}' in aggregate can " + + s"not be binary type.") + case _ => // OK + } + aggregateExprs.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidGroupingExprs) + + case Sort(orders, _, _) => + orders.foreach { order => + order.dataType match { + case t: AtomicType => // OK + case NullType => // OK + case t => + failAnalysis(s"Sorting is not supported for columns of type ${t.simpleString}") + } + } case _ => // Fallbacks to the following checks } @@ -122,13 +156,21 @@ trait CheckAnalysis { s""" |Failure when resolving conflicting references in Join: |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") + case o if o.expressions.exists(!_.deterministic) && + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + failAnalysis( + s"""nondeterministic expressions are only allowed in Project or Filter, found: + | ${o.expressions.map(_.prettyString).mkString(",")} + |in operator ${operator.simpleString} + """.stripMargin) + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index eaa643faa0e82..2af7aa6da721a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -30,26 +30,44 @@ import org.apache.spark.sql.catalyst.util.StringKeyHashMap /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { - def registerFunction(name: String, builder: FunctionBuilder): Unit + final def registerFunction(name: String, builder: FunctionBuilder): Unit = { + registerFunction(name, new ExpressionInfo(builder.getClass.getCanonicalName, name), builder) + } + + def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit @throws[AnalysisException]("If function does not exist") def lookupFunction(name: String, children: Seq[Expression]): Expression + + /* List all of the registered function names. */ + def listFunction(): Seq[String] + + /* Get the class of the registered function by specified name. */ + def lookupFunction(name: String): Option[ExpressionInfo] } class SimpleFunctionRegistry extends FunctionRegistry { - private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) + private val functionBuilders = + StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) - override def registerFunction(name: String, builder: FunctionBuilder): Unit = { - functionBuilders.put(name, builder) + override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) + : Unit = { + functionBuilders.put(name, (info, builder)) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - val func = functionBuilders.get(name).getOrElse { + val func = functionBuilders.get(name).map(_._2).getOrElse { throw new AnalysisException(s"undefined function $name") } func(children) } + + override def listFunction(): Seq[String] = functionBuilders.iterator.map(_._1).toList.sorted + + override def lookupFunction(name: String): Option[ExpressionInfo] = { + functionBuilders.get(name).map(_._1) + } } /** @@ -57,13 +75,22 @@ class SimpleFunctionRegistry extends FunctionRegistry { * functions are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - override def registerFunction(name: String, builder: FunctionBuilder): Unit = { + override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) + : Unit = { throw new UnsupportedOperationException } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } + + override def listFunction(): Seq[String] = { + throw new UnsupportedOperationException + } + + override def lookupFunction(name: String): Option[ExpressionInfo] = { + throw new UnsupportedOperationException + } } @@ -71,7 +98,7 @@ object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression - val expressions: Map[String, FunctionBuilder] = Map( + val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), expression[CreateArray]("array"), @@ -134,13 +161,6 @@ object FunctionRegistry { expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), - // misc functions - expression[Md5]("md5"), - expression[Sha2]("sha2"), - expression[Sha1]("sha1"), - expression[Sha1]("sha"), - expression[Crc32]("crc32"), - // aggregate functions expression[Average]("avg"), expression[Count]("count"), @@ -186,33 +206,50 @@ object FunctionRegistry { expression[Upper]("upper"), // datetime functions + expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), + expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), + expression[DateSub]("date_sub"), expression[DayOfMonth]("day"), expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), + expression[FromUnixTime]("from_unixtime"), expression[Hour]("hour"), - expression[Month]("month"), + expression[LastDay]("last_day"), expression[Minute]("minute"), + expression[Month]("month"), + expression[MonthsBetween]("months_between"), + expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), + expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), // collection functions - expression[Size]("size") + expression[Size]("size"), + + // misc functions + expression[Crc32]("crc32"), + expression[Md5]("md5"), + expression[Sha1]("sha"), + expression[Sha1]("sha1"), + expression[Sha2]("sha2"), + expression[SparkPartitionID]("spark_partition_id"), + expression[InputFileName]("input_file_name") ) val builtin: FunctionRegistry = { val fr = new SimpleFunctionRegistry - expressions.foreach { case (name, builder) => fr.registerFunction(name, builder) } + expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) } fr } /** See usage above. */ - private def expression[T <: Expression](name: String) - (implicit tag: ClassTag[T]): (String, FunctionBuilder) = { + def expression[T <: Expression](name: String) + (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // See if we can find a constructor that accepts Seq[Expression] val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption @@ -238,6 +275,15 @@ object FunctionRegistry { } } } - (name, builder) + + val clazz = tag.runtimeClass + val df = clazz.getAnnotation(classOf[ExpressionDescription]) + if (df != null) { + (name, + (new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()), + builder)) + } else { + (name, (new ExpressionInfo(clazz.getCanonicalName, name), builder)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e214545726249..603afc4032a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -36,7 +36,7 @@ object HiveTypeCoercion { val typeCoercionRules = PropagateTypes :: InConversion :: - WidenTypes :: + WidenSetOperationTypes :: PromoteStrings :: DecimalPrecision :: BooleanEquality :: @@ -47,6 +47,7 @@ object HiveTypeCoercion { Division :: PropagateTypes :: ImplicitTypeCasts :: + DateTimeOperations :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -58,8 +59,7 @@ object HiveTypeCoercion { IntegerType, LongType, FloatType, - DoubleType, - DecimalType.Unlimited) + DoubleType) /** * Find the tightest common type of two types that might be used in a binary expression. @@ -72,15 +72,16 @@ object HiveTypeCoercion { case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) - // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => + Some(t2) + case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => + Some(t1) + + // Promote numeric types to the highest of the two case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) Some(numericPrecedence(index)) - // Fixed-precision decimals can up-cast into unlimited - case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited) - case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited) - case _ => None } @@ -101,7 +102,7 @@ object HiveTypeCoercion { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case None => None case Some(d) => - findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c)) + findTightestCommonTypeToString(d, c) }) } @@ -109,13 +110,35 @@ object HiveTypeCoercion { * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. */ - private def findTightestCommonType(types: Seq[DataType]) = { + private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case None => None case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } + private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (t: FractionalType, d: DecimalType) => + Some(DoubleType) + case (d: DecimalType, t: FractionalType) => + Some(DoubleType) + case _ => + findTightestCommonTypeToString(t1, t2) + } + + private def findWiderCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findWiderTypeForTwo(d, c) + case None => None + }) + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -158,6 +181,9 @@ object HiveTypeCoercion { * converted to DOUBLE. * - TINYINT, SMALLINT, and INT can all be converted to FLOAT. * - BOOLEAN types cannot be converted to any other type. + * - Any integral numeric type can be implicitly converted to decimal type. + * - two different decimal types will be converted into a wider decimal type for both of them. + * - decimal type will be converted into double if there float or double together with it. * * Additionally, all types when UNION-ed with strings will be promoted to strings. * Other string conversions are handled by PromoteStrings. @@ -166,55 +192,37 @@ object HiveTypeCoercion { * - IntegerType to FloatType * - LongType to FloatType * - LongType to DoubleType + * - DecimalType to Double + * + * This rule is only applied to Union/Except/Intersect */ - object WidenTypes extends Rule[LogicalPlan] { - - private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan): - (LogicalPlan, LogicalPlan) = { + object WidenSetOperationTypes extends Rule[LogicalPlan] { - // TODO: with fixed-precision decimals - val castedInput = left.output.zip(right.output).map { - // When a string is found on one side, make the other side a string too. - case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => - (lhs, Alias(Cast(rhs, StringType), rhs.name)()) - case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => - (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + private[this] def widenOutputTypes( + planName: String, + left: LogicalPlan, + right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => - logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}") - findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => - val newLeft = - if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() - val newRight = - if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() - - (newLeft, newRight) - }.getOrElse { - // If there is no applicable conversion, leave expression unchanged. - (lhs, rhs) - } - - case other => other + findWiderTypeForTwo(lhs.dataType, rhs.dataType) + case other => None } - val (castedLeft, castedRight) = castedInput.unzip - - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}") - Project(castedLeft, left) - } else { - left + def castOutput(plan: LogicalPlan): LogicalPlan = { + val casted = plan.output.zip(castedTypes).map { + case (e, Some(dt)) if e.dataType != dt => + Alias(Cast(e, dt), e.name)() + case (e, _) => e } + Project(casted, plan) + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logDebug(s"Widening numeric types in $planName $castedRight ${right.output}") - Project(castedRight, right) - } else { - right - } - (newLeft, newRight) + if (castedTypes.exists(_.isDefined)) { + (castOutput(left), castOutput(right)) + } else { + (left, right) + } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -238,8 +246,13 @@ object HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), r) => - a.makeCopy(Array(Cast(left, DoubleType), r)) + case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) => + a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right)) + case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) => + a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT))) + + case a @ BinaryArithmetic(left @ StringType(), right) => + a.makeCopy(Array(Cast(left, DoubleType), right)) case a @ BinaryArithmetic(left, right @ StringType()) => a.makeCopy(Array(left, Cast(right, DoubleType))) @@ -334,144 +347,82 @@ object HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE - * 1. Union, Intersect and Except operations: - * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the - * same as Hive) - * 2. Other operation: - * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, - * but note that unlimited decimals are considered bigger than doubles in WidenTypes) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * + * Note: Union/Except/Interact is handled by WidenTypes */ // scalastyle:on object DecimalPrecision extends Rule[LogicalPlan] { import scala.math.{max, min} - // Conversion rules for integer types into fixed-precision decimals - private val intTypeToFixed: Map[DataType, DecimalType] = Map( - ByteType -> DecimalType(3, 0), - ShortType -> DecimalType(5, 0), - IntegerType -> DecimalType(10, 0), - LongType -> DecimalType(20, 0) - ) - private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - // Conversion rules for float and double into fixed-precision decimals - private val floatTypeToFixed: Map[DataType, DecimalType] = Map( - FloatType -> DecimalType(7, 7), - DoubleType -> DecimalType(15, 15) - ) - - private def castDecimalPrecision( - left: LogicalPlan, - right: LogicalPlan): (LogicalPlan, LogicalPlan) = { - val castedInput = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { - case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => - // Decimals with precision/scale p1/s2 and p2/s2 will be promoted to - // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) - val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) - case _ => (lhs, rhs) - } - case other => other - } - - val (castedLeft, castedRight) = castedInput.unzip - - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - Project(castedLeft, left) - } else { - left - } + // Returns the wider decimal type that's wider than both of them + def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + DecimalType.bounded(range + scale, scale) + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - Project(castedRight, right) - } else { - right - } - (newLeft, newRight) + private def changePrecision(e: Expression, dataType: DataType): Expression = { + ChangeDecimalPrecision(Cast(e, dataType)) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // fix decimal precision for union, intersect and except - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Union(newLeft, newRight) - case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Intersect(newLeft, newRight) - case e @ Except(left, right) if e.childrenResolved && !e.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Except(newLeft, newRight) - // fix decimal precision for expressions case q => q.transformExpressions { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e + // Skip nodes who is already promoted + case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + Add(changePrecision(e1, dt), changePrecision(e2, dt)) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + Subtract(changePrecision(e1, dt), changePrecision(e2, dt)) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) + val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + Multiply(changePrecision(e1, dt), changePrecision(e2, dt)) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) + val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + Divide(changePrecision(e1, dt), changePrecision(e2, dt)) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)), + resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType) - // When we compare 2 decimal types with different precisions, cast them to the smallest - // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - val resultType = DecimalType(max(p1, p2), max(s1, s2)) + val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(left, Cast(right, intTypeToFixed(t)))) + case (t: IntegralType, DecimalType.Fixed(p, s)) => + b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) + case (DecimalType.Fixed(p, s), t: IntegralType) => + b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) case (t, DecimalType.Fixed(p, s)) if isFloat(t) => b.makeCopy(Array(left, Cast(right, DoubleType))) case (DecimalType.Fixed(p, s), t) if isFloat(t) => @@ -485,7 +436,6 @@ object HiveTypeCoercion { // SUM and AVERAGE are handled by the implementations of those expressions } } - } /** @@ -563,7 +513,7 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType.Unlimited), t) + Cast(Cast(e, DecimalType.forType(LongType)), t) } } @@ -608,7 +558,7 @@ object HiveTypeCoercion { // compatible with every child column. case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -689,6 +639,27 @@ object HiveTypeCoercion { } } + /** + * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType + * to TimeAdd/TimeSub + */ + object DateTimeOperations extends Rule[LogicalPlan] { + + private val acceptedTypes = Seq(DateType, TimestampType, StringType) + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) => + Cast(TimeAdd(r, l), r.dataType) + case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => + Cast(TimeAdd(l, r), l.dataType) + case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => + Cast(TimeSub(l, r), l.dataType) + } + } + /** * Casts types according to the expected input types for [[Expression]]s. */ @@ -756,8 +727,8 @@ object HiveTypeCoercion { // Implicit cast among numeric types. When we reach here, input type is not acceptable. // If input is a numeric type but not decimal, and we expect a decimal type, - // cast the input to unlimited precision decimal. - case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) + // cast the input to decimal. + case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d)) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long case (_: NumericType, target: NumericType) => Cast(e, target) @@ -766,12 +737,13 @@ object HiveTypeCoercion { case (TimestampType, DateType) => Cast(e, DateType) // Implicit cast from/to string - case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited) + case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT) case (StringType, target: NumericType) => Cast(e, target) case (StringType, DateType) => Cast(e, DateType) case (StringType, TimestampType) => Cast(e, TimestampType) case (StringType, BinaryType) => Cast(e, BinaryType) - case (any, StringType) if any != StringType => Cast(e, StringType) + // Cast any atomic type to string. + case (any: AtomicType, StringType) if any != StringType => Cast(e, StringType) // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 51821757967d2..a7e3a49327655 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -201,7 +201,7 @@ package object dsl { /** Creates a new AttributeReference of type decimal */ def decimal: AttributeReference = - AttributeReference(s, DecimalType.Unlimited, nullable = true)() + AttributeReference(s, DecimalType.SYSTEM_DEFAULT, nullable = true)() /** Creates a new AttributeReference of type decimal */ def decimal(precision: Int, scale: Int): AttributeReference = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4a13b687bf4ce..45709c1c8f554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -46,7 +46,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case LongType | TimestampType => input.getLong(ordinal) case FloatType => input.getFloat(ordinal) case DoubleType => input.getDouble(ordinal) - case _ => input.get(ordinal) + case StringType => input.getUTF8String(ordinal) + case BinaryType => input.getBinary(ordinal) + case CalendarIntervalType => input.getInterval(ordinal) + case t: StructType => input.getStruct(ordinal, t.size) + case _ => input.get(ordinal, dataType) } } } @@ -60,10 +64,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val value = ctx.getValue("i", dataType, ordinal.toString) s""" - boolean ${ev.isNull} = i.isNullAt($ordinal); - ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); + boolean ${ev.isNull} = i.isNullAt($ordinal); + $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 3346d3c9f9e61..43be11c48ae7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{Interval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import scala.collection.mutable object Cast { @@ -53,7 +55,7 @@ object Cast { case (_, DateType) => true - case (StringType, IntervalType) => true + case (StringType, CalendarIntervalType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true @@ -223,7 +225,7 @@ case class Cast(child: Expression, dataType: DataType) // IntervalConverter private[this] def castToInterval(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => Interval.fromString(s.toString)) + buildCast[UTF8String](_, s => CalendarInterval.fromString(s.toString)) case _ => _ => null } @@ -298,12 +300,7 @@ case class Cast(child: Expression, dataType: DataType) * NOTE: this modifies `value` in-place, so don't call it on external data. */ private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { - decimalType match { - case DecimalType.Unlimited => - value - case DecimalType.Fixed(precision, scale) => - if (value.changePrecision(precision, scale)) value else null - } + if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null } private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { @@ -366,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = { val elementCast = cast(from.elementType, to.elementType) - buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v))) + // TODO: Could be faster? + buildCast[ArrayData](_, array => { + val length = array.numElements() + val values = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + values(i) = null + } else { + values(i) = elementCast(array.get(i)) + } + i += 1 + } + new GenericArrayData(values) + }) } private[this] def castMap(from: MapType, to: MapType): Any => Any = { @@ -378,15 +389,16 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def castStruct(from: StructType, to: StructType): Any => Any = { - val casts = from.fields.zip(to.fields).map { + val castFuncs: Array[(Any) => Any] = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? val newRow = new GenericMutableRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 - while (i < row.length) { - newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row(i))) + while (i < row.numFields) { + newRow.update(i, + if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType))) i += 1 } newRow.copy() @@ -400,7 +412,7 @@ case class Cast(child: Expression, dataType: DataType) case DateType => castToDate(from) case decimal: DecimalType => castToDecimal(from, decimal) case TimestampType => castToTimestamp(from) - case IntervalType => castToInterval(from) + case CalendarIntervalType => castToInterval(from) case BooleanType => castToBoolean(from) case ByteType => castToByte(from) case ShortType => castToShort(from) @@ -418,51 +430,499 @@ case class Cast(child: Expression, dataType: DataType) protected override def nullSafeEval(input: Any): Any = cast(input) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - // TODO: Add support for more data types. - (child.dataType, dataType) match { + val eval = child.gen(ctx) + val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) + eval.code + + castCode(ctx, eval.primitive, eval.isNull, ev.primitive, ev.isNull, dataType, nullSafeCast) + } + + // three function arguments are: child.primitive, result.primitive and result.isNull + // it returns the code snippets to be put in null safe evaluation region + private[this] type CastFunction = (String, String, String) => String + + private[this] def nullSafeCastFunction( + from: DataType, + to: DataType, + ctx: CodeGenContext): CastFunction = to match { + + case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case StringType => castToStringCode(from, ctx) + case BinaryType => castToBinaryCode(from) + case DateType => castToDateCode(from, ctx) + case decimal: DecimalType => castToDecimalCode(from, decimal) + case TimestampType => castToTimestampCode(from, ctx) + case CalendarIntervalType => castToIntervalCode(from) + case BooleanType => castToBooleanCode(from) + case ByteType => castToByteCode(from) + case ShortType => castToShortCode(from) + case IntegerType => castToIntCode(from) + case FloatType => castToFloatCode(from) + case LongType => castToLongCode(from) + case DoubleType => castToDoubleCode(from) + + case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx) + case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) + case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + } + + // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's + // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. + private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String, + resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { + s""" + boolean $resultNull = $childNull; + ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; + if (!${childNull}) { + ${cast(childPrim, resultPrim, resultNull)} + } + """ + } + + private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = { + from match { + case BinaryType => + (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + case DateType => + (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" + case TimestampType => + (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));""" + case _ => + (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + } + } - case (BinaryType, StringType) => - defineCodeGen (ctx, ev, c => - s"UTF8String.fromBytes($c)") + private[this] def castToBinaryCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + } - case (DateType, StringType) => - defineCodeGen(ctx, ev, c => - s"""UTF8String.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") + private[this] def castToDateCode( + from: DataType, + ctx: CodeGenContext): CastFunction = from match { + case StringType => + val intOpt = ctx.freshName("intOpt") + (c, evPrim, evNull) => s""" + scala.Option $intOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); + if ($intOpt.isDefined()) { + $evPrim = ((Integer) $intOpt.get()).intValue(); + } else { + $evNull = true; + } + """ + case TimestampType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);"; + case _ => + (c, evPrim, evNull) => s"$evNull = true;" + } - case (TimestampType, StringType) => - defineCodeGen(ctx, ev, c => - s"""UTF8String.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") + private[this] def changePrecision(d: String, decimalType: DecimalType, + evPrim: String, evNull: String): String = + s""" + if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { + $evPrim = $d; + } else { + $evNull = true; + } + """ + + private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = { + from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + new scala.math.BigDecimal( + new java.math.BigDecimal($c.toString()))); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = null; + if ($c) { + tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); + } else { + tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); + } + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case DateType => + // date can't cast to decimal in Hive + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + // Note that we lose precision here. + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case DecimalType() => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case LongType => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set($c); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case x: NumericType => + // All other numeric types can be represented precisely as Doubles + (c, evPrim, evNull) => + s""" + try { + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + } + } + + private[this] def castToTimestampCode( + from: DataType, + ctx: CodeGenContext): CastFunction = from match { + case StringType => + val longOpt = ctx.freshName("longOpt") + (c, evPrim, evNull) => + s""" + scala.Option $longOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c); + if ($longOpt.isDefined()) { + $evPrim = ((Long) $longOpt.get()).longValue(); + } else { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + case _: IntegralType => + (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + case DateType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000;" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + case DoubleType => + (c, evPrim, evNull) => + s""" + if (Double.isNaN($c) || Double.isInfinite($c)) { + $evNull = true; + } else { + $evPrim = (long)($c * 1000000L); + } + """ + case FloatType => + (c, evPrim, evNull) => + s""" + if (Float.isNaN($c) || Float.isInfinite($c)) { + $evNull = true; + } else { + $evPrim = (long)($c * 1000000L); + } + """ + } + + private[this] def castToIntervalCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s"$evPrim = CalendarInterval.fromString($c.toString());" + } + + private[this] def decimalToTimestampCode(d: String): String = + s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def timestampToIntegerCode(ts: String): String = + s"java.lang.Math.floor((double) $ts / 1000000L)" + private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + + private[this] def castToBooleanCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = $c != 0;" + case DateType => + // Hive would return null when cast from date to boolean + (c, evPrim, evNull) => s"$evNull = true;" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + case n: NumericType => + (c, evPrim, evNull) => s"$evPrim = $c != 0;" + } + + private[this] def castToByteCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Byte.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + } - case (_, StringType) => - defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))") + private[this] def castToShortCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Short.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (short) $c;" + } - case (StringType, IntervalType) => - defineCodeGen(ctx, ev, c => - s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())") + private[this] def castToIntCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Integer.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (int) $c;" + } - // fallback for DecimalType, this must be before other numeric types - case (_, dt: DecimalType) => - super.genCode(ctx, ev) + private[this] def castToLongCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Long.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (long) $c;" + } - case (BooleanType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + private[this] def castToFloatCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Float.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (float) $c;" + } - case (dt: DecimalType, BooleanType) => - defineCodeGen(ctx, ev, c => s"!$c.isZero()") + private[this] def castToDoubleCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Double.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (double) $c;" + } - case (dt: NumericType, BooleanType) => - defineCodeGen(ctx, ev, c => s"$c != 0") + private[this] def castArrayCode( + from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { + val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) + val arrayClass = classOf[GenericArrayData].getName + val fromElementNull = ctx.freshName("feNull") + val fromElementPrim = ctx.freshName("fePrim") + val toElementNull = ctx.freshName("teNull") + val toElementPrim = ctx.freshName("tePrim") + val size = ctx.freshName("n") + val j = ctx.freshName("j") + val values = ctx.freshName("values") + + (c, evPrim, evNull) => + s""" + final int $size = $c.numElements(); + final Object[] $values = new Object[$size]; + for (int $j = 0; $j < $size; $j ++) { + if ($c.isNullAt($j)) { + $values[$j] = null; + } else { + boolean $fromElementNull = false; + ${ctx.javaType(from.elementType)} $fromElementPrim = + ${ctx.getValue(c, from.elementType, j)}; + ${castCode(ctx, fromElementPrim, + fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} + if ($toElementNull) { + $values[$j] = null; + } else { + $values[$j] = $toElementPrim; + } + } + } + $evPrim = new $arrayClass($values); + """ + } - case (_: DecimalType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = { + val keyCast = nullSafeCastFunction(from.keyType, to.keyType, ctx) + val valueCast = nullSafeCastFunction(from.valueType, to.valueType, ctx) + + val hashMapClass = classOf[mutable.HashMap[Any, Any]].getName + val fromKeyPrim = ctx.freshName("fkp") + val fromKeyNull = ctx.freshName("fkn") + val fromValuePrim = ctx.freshName("fvp") + val fromValueNull = ctx.freshName("fvn") + val toKeyPrim = ctx.freshName("tkp") + val toKeyNull = ctx.freshName("tkn") + val toValuePrim = ctx.freshName("tvp") + val toValueNull = ctx.freshName("tvn") + val result = ctx.freshName("result") + + (c, evPrim, evNull) => + s""" + final $hashMapClass $result = new $hashMapClass(); + scala.collection.Iterator iter = $c.iterator(); + while (iter.hasNext()) { + scala.Tuple2 kv = (scala.Tuple2) iter.next(); + boolean $fromKeyNull = false; + ${ctx.javaType(from.keyType)} $fromKeyPrim = + (${ctx.boxedType(from.keyType)}) kv._1(); + ${castCode(ctx, fromKeyPrim, + fromKeyNull, toKeyPrim, toKeyNull, to.keyType, keyCast)} + + boolean $fromValueNull = kv._2() == null; + if ($fromValueNull) { + $result.put($toKeyPrim, null); + } else { + ${ctx.javaType(from.valueType)} $fromValuePrim = + (${ctx.boxedType(from.valueType)}) kv._2(); + ${castCode(ctx, fromValuePrim, + fromValueNull, toValuePrim, toValueNull, to.valueType, valueCast)} + if ($toValueNull) { + $result.put($toKeyPrim, null); + } else { + $result.put($toKeyPrim, $toValuePrim); + } + } + } + $evPrim = $result; + """ + } - case (_: NumericType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") + private[this] def castStructCode( + from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = { - case other => - super.genCode(ctx, ev) + val fieldsCasts = from.fields.zip(to.fields).map { + case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } + val rowClass = classOf[GenericMutableRow].getName + val result = ctx.freshName("result") + val tmpRow = ctx.freshName("tmpRow") + + val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => { + val fromFieldPrim = ctx.freshName("ffp") + val fromFieldNull = ctx.freshName("ffn") + val toFieldPrim = ctx.freshName("tfp") + val toFieldNull = ctx.freshName("tfn") + val fromType = ctx.javaType(from.fields(i).dataType) + s""" + boolean $fromFieldNull = $tmpRow.isNullAt($i); + if ($fromFieldNull) { + $result.setNullAt($i); + } else { + $fromType $fromFieldPrim = + ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)}; + ${castCode(ctx, fromFieldPrim, + fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} + if ($toFieldNull) { + $result.setNullAt($i); + } else { + ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)}; + } + } + """ + } + }.mkString("\n") + + (c, evPrim, evNull) => + s""" + final $rowClass $result = new $rowClass(${fieldsCasts.size}); + final InternalRow $tmpRow = $c; + $fieldsEvalCode + $evPrim = $result.copy(); + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 29ae47e842ddb..8fc182607ce68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -184,10 +184,10 @@ abstract class Expression extends TreeNode[Expression] { */ trait Unevaluable extends Expression { - override def eval(input: InternalRow = null): Any = + final override def eval(input: InternalRow = null): Any = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + final override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } @@ -196,7 +196,24 @@ trait Unevaluable extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - override def deterministic: Boolean = false + final override def deterministic: Boolean = false + final override def foldable: Boolean = false + + private[this] var initialized = false + + final def setInitialValues(): Unit = { + initInternal() + initialized = true + } + + protected def initInternal(): Unit + + final override def eval(input: InternalRow = null): Any = { + require(initialized, "nondeterministic expression should be initialized before evaluate") + evalInternal(input) + } + + protected def evalInternal(input: InternalRow): Any } @@ -336,9 +353,9 @@ abstract class BinaryExpression extends Expression { * @param f accepts two variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s"${ev.primitive} = ${f(eval1, eval2)};" }) @@ -353,9 +370,9 @@ abstract class BinaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val resultCode = f(eval1.primitive, eval2.primitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala new file mode 100644 index 0000000000000..1e74f716955e3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + */ +case class InputFileName() extends LeafExpression with Nondeterministic { + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override val prettyName = "INPUT_FILE_NAME" + + override protected def initInternal(): Unit = {} + + override protected def evalInternal(input: InternalRow): UTF8String = { + SqlNewHadoopRDD.getInputFileName() + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = " + + "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala similarity index 86% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 2645eb1854bce..291b7a5bc3af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expressions +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{LongType, DataType} @@ -37,17 +36,22 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with /** * Record ID within each partition. By being transient, count's value is reset to 0 every time - * we serialize and deserialize it. + * we serialize and deserialize and initialize it. */ - @transient private[this] var count: Long = 0L + @transient private[this] var count: Long = _ - @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + @transient private[this] var partitionMask: Long = _ + + override protected def initInternal(): Unit = { + count = 0L + partitionMask = TaskContext.getPartitionId().toLong << 33 + } override def nullable: Boolean = false override def dataType: DataType = LongType - override def eval(input: InternalRow): Long = { + override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 69758e653eba0..7c7664e4c1a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types.{Decimal, StructType, DataType} +import org.apache.spark.unsafe.types.UTF8String /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -30,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + }) + // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -43,7 +49,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { new GenericInternalRow(outputArray) } - override def toString: String = s"Row => [${exprArray.mkString(",")}]" + override def toString(): String = s"Row => [${exprArray.mkString(",")}]" } /** @@ -56,8 +62,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + }) + private[this] val exprArray = expressions.toArray - private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) + private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) def currentValue: InternalRow = mutableRow override def target(row: MutableRow): MutableProjection = { @@ -89,8 +100,10 @@ object UnsafeProjection { * Seq[Expression]. */ def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) - def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) + private def canSupport(types: Array[DataType]): Boolean = { + types.forall(GenerateUnsafeProjection.canSupport) + } /** * Returns an UnsafeProjection for given StructType. @@ -175,487 +188,62 @@ class JoinedRow extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length - - override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) - - override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - - override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - - override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - - override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - - override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - - override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - - override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - - override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - - override def copy(): InternalRow = { - val totalSize = row1.length + row2.length - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = apply(i) - i += 1 - } - new GenericInternalRow(copiedValues) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} - -/** - * JIT HACK: Replace with macros - * The `JoinedRow` class is used in many performance critical situation. Unfortunately, since there - * are multiple different types of `Rows` that could be stored as `row1` and `row2` most of the - * calls in the critical path are polymorphic. By creating special versions of this class that are - * used in only a single location of the code, we increase the chance that only a single type of - * Row will be referenced, increasing the opportunity for the JIT to play tricks. This sounds - * crazy but in benchmarks it had noticeable effects. - */ -class JoinedRow2 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def length: Int = row1.length + row2.length - - override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) - - override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - - override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - - override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - - override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - - override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - - override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - - override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - - override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - - override def copy(): InternalRow = { - val totalSize = row1.length + row2.length - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = apply(i) - i += 1 - } - new GenericInternalRow(copiedValues) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} - -/** - * JIT HACK: Replace with macros - */ -class JoinedRow3 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def length: Int = row1.length + row2.length - - override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) - - override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - - override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - - override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - - override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - - override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - - override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - - override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - - override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - - override def copy(): InternalRow = { - val totalSize = row1.length + row2.length - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = apply(i) - i += 1 - } - new GenericInternalRow(copiedValues) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} - -/** - * JIT HACK: Replace with macros - */ -class JoinedRow4 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def length: Int = row1.length + row2.length - - override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) - - override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - - override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - - override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - - override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - - override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + override def numFields: Int = row1.numFields + row2.numFields - override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - - override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - - override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - - override def copy(): InternalRow = { - val totalSize = row1.length + row2.length - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = apply(i) - i += 1 - } - new GenericInternalRow(copiedValues) + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } -} -/** - * JIT HACK: Replace with macros - */ -class JoinedRow5 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def length: Int = row1.length + row2.length - - override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + override def get(i: Int, dataType: DataType): Any = + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - override def copy(): InternalRow = { - val totalSize = row1.length + row2.length - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = apply(i) - i += 1 - } - new GenericInternalRow(copiedValues) + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) row1.getDecimal(i, precision, scale) + else row2.getDecimal(i - row1.numFields, precision, scale) } - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") + override def getStruct(i: Int, numFields: Int): InternalRow = { + if (i < row1.numFields) { + row1.getStruct(i, numFields) } else { - mkString("[", ",", "]") + row2.getStruct(i - row1.numFields, numFields) } } -} - -/** - * JIT HACK: Replace with macros - */ -class JoinedRow6 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def length: Int = row1.length + row2.length - - override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) - - override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) - - override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) - - override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) - - override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) - - override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) - - override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) - - override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) - - override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3f436c0eb893c..9fe877f10fa08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator abstract sealed class SortDirection case object Ascending extends SortDirection @@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection) override def nullable: Boolean = child.nullable override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + + def isAscending: Boolean = direction == Ascending +} + +/** + * An expression to generate a 64-bit long prefix used in sorting. + */ +case class SortPrefix(child: SortOrder) extends UnaryExpression { + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childCode = child.child.gen(ctx) + val input = childCode.primitive + val DoublePrefixCmp = classOf[DoublePrefixComparator].getName + + val (nullValue: Long, prefixCode: String) = child.child.dataType match { + case BooleanType => + (Long.MinValue, s"$input ? 1L : 0L") + case _: IntegralType => + (Long.MinValue, s"(long) $input") + case FloatType | DoubleType => + (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), + s"$DoublePrefixCmp.computePrefix((double)$input)") + case StringType => (0L, s"$input.getPrefix()") + case _ => (0L, "0L") + } + + childCode.code + + s""" + |long ${ev.primitive} = ${nullValue}L; + |boolean ${ev.isNull} = false; + |if (!${childCode.isNull}) { + | ${ev.primitive} = $prefixCode; + |} + """.stripMargin + } + + override def dataType: DataType = LongType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala similarity index 78% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 53ddd47e3e0c1..4b1772a2deed5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expressions +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{IntegerType, DataType} @@ -27,15 +26,21 @@ import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic { +private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = IntegerType - @transient private lazy val partitionId = TaskContext.getPartitionId() + @transient private[this] var partitionId: Int = _ - override def eval(input: InternalRow): Int = partitionId + override val prettyName = "SPARK_PARTITION_ID" + + override protected def initInternal(): Unit = { + partitionId = TaskContext.getPartitionId() + } + + override protected def evalInternal(input: InternalRow): Int = partitionId override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val idTerm = ctx.freshName("partitionId") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 6f291d2c86c1e..b877ce47c083f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -211,7 +211,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this() = this(Seq.empty) - override def length: Int = values.length + override def numFields: Int = values.length override def toSeq: Seq[Any] = values.map(_.boxed).toSeq @@ -219,7 +219,11 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).isNull = true } - override def get(i: Int): Any = values(i).boxed + override def get(i: Int, dataType: DataType): Any = values(i).boxed + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).boxed.asInstanceOf[InternalRow] + } override def isNullAt(i: Int): Boolean = values(i).isNull @@ -245,8 +249,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String.fromString(value)) - override def getString(ordinal: Int): String = apply(ordinal).toString - override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] currentValue.isNull = false @@ -316,8 +318,4 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } - - override def getAs[T](i: Int): T = { - values(i).boxed.asInstanceOf[T] - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala deleted file mode 100644 index 885ab091fcdf5..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ /dev/null @@ -1,290 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.ObjectPool -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent -import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.unsafe.types.UTF8String - -/** - * Converts Rows into UnsafeRow format. This class is NOT thread-safe. - * - * @param fieldTypes the data types of the row's columns. - */ -class UnsafeRowConverter(fieldTypes: Array[DataType]) { - - def this(schema: StructType) { - this(schema.fields.map(_.dataType)) - } - - def numFields: Int = fieldTypes.length - - /** Re-used pointer to the unsafe row being written */ - private[this] val unsafeRow = new UnsafeRow() - - /** Functions for encoding each column */ - private[this] val writers: Array[UnsafeColumnWriter] = { - fieldTypes.map(t => UnsafeColumnWriter.forType(t)) - } - - /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ - private[this] val fixedLengthSize: Int = - (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) - - /** - * Compute the amount of space, in bytes, required to encode the given row. - */ - def getSizeRequirement(row: InternalRow): Int = { - var fieldNumber = 0 - var variableLengthFieldSize: Int = 0 - while (fieldNumber < writers.length) { - if (!row.isNullAt(fieldNumber)) { - variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber) - } - fieldNumber += 1 - } - fixedLengthSize + variableLengthFieldSize - } - - /** - * Convert the given row into UnsafeRow format. - * - * @param row the row to convert - * @param baseObject the base object of the destination address - * @param baseOffset the base offset of the destination address - * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)` - * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. - */ - def writeRow( - row: InternalRow, - baseObject: Object, - baseOffset: Long, - rowLengthInBytes: Int, - pool: ObjectPool): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool) - - if (writers.length > 0) { - // zero-out the bitset - var n = writers.length / 64 - while (n >= 0) { - PlatformDependent.UNSAFE.putLong( - unsafeRow.getBaseObject, - unsafeRow.getBaseOffset + n * 8, - 0L) - n -= 1 - } - } - - var fieldNumber = 0 - var cursor: Int = fixedLengthSize - while (fieldNumber < writers.length) { - if (row.isNullAt(fieldNumber)) { - unsafeRow.setNullAt(fieldNumber) - } else { - cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor) - } - fieldNumber += 1 - } - cursor - } - -} - -/** - * Function for writing a column into an UnsafeRow. - */ -private abstract class UnsafeColumnWriter { - /** - * Write a value into an UnsafeRow. - * - * @param source the row being converted - * @param target a pointer to the converted unsafe row - * @param column the column to write - * @param cursor the offset from the start of the unsafe row to the end of the row; - * used for calculating where variable-length data should be written - * @return the number of variable-length bytes written - */ - def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int - - /** - * Return the number of bytes that are needed to write this variable-length value. - */ - def getSize(source: InternalRow, column: Int): Int -} - -private object UnsafeColumnWriter { - - def forType(dataType: DataType): UnsafeColumnWriter = { - dataType match { - case NullType => NullUnsafeColumnWriter - case BooleanType => BooleanUnsafeColumnWriter - case ByteType => ByteUnsafeColumnWriter - case ShortType => ShortUnsafeColumnWriter - case IntegerType | DateType => IntUnsafeColumnWriter - case LongType | TimestampType => LongUnsafeColumnWriter - case FloatType => FloatUnsafeColumnWriter - case DoubleType => DoubleUnsafeColumnWriter - case StringType => StringUnsafeColumnWriter - case BinaryType => BinaryUnsafeColumnWriter - case t => ObjectUnsafeColumnWriter - } - } - - /** - * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool). - */ - def canEmbed(dataType: DataType): Boolean = { - forType(dataType) != ObjectUnsafeColumnWriter - } -} - -// ------------------------------------------------------------------------------------------------ - - -private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { - // Primitives don't write to the variable-length region: - def getSize(sourceRow: InternalRow, column: Int): Int = 0 -} - -private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setNullAt(column) - 0 - } -} - -private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setBoolean(column, source.getBoolean(column)) - 0 - } -} - -private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setByte(column, source.getByte(column)) - 0 - } -} - -private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setShort(column, source.getShort(column)) - 0 - } -} - -private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setInt(column, source.getInt(column)) - 0 - } -} - -private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setLong(column, source.getLong(column)) - 0 - } -} - -private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setFloat(column, source.getFloat(column)) - 0 - } -} - -private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setDouble(column, source.getDouble(column)) - 0 - } -} - -private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { - - protected[this] def isString: Boolean - protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte] - - override def getSize(source: InternalRow, column: Int): Int = { - val numBytes = getBytes(source, column).length - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - } - - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val bytes = getBytes(source, column) - write(target, bytes, column, cursor) - } - - def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = { - val offset = target.getBaseOffset + cursor - val numBytes = bytes.length - if ((numBytes & 0x07) > 0) { - // zero-out the padding bytes - PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L) - } - PlatformDependent.copyMemory( - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - target.getBaseObject, - offset, - numBytes - ) - val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0 - target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong) - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - } -} - -private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter { - protected[this] def isString: Boolean = true - def getBytes(source: InternalRow, column: Int): Array[Byte] = { - source.getAs[UTF8String](column).getBytes - } - // TODO(davies): refactor this - // specialized for codegen - def getSize(value: UTF8String): Int = - ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes()) - def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = { - write(target, value.getBytes, column, cursor) - } -} - -private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter { - protected[this] override def isString: Boolean = false - override def getBytes(source: InternalRow, column: Int): Array[Byte] = { - source.getAs[Array[Byte]](column) - } - // specialized for codegen - def getSize(value: Array[Byte]): Int = - ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length) -} - -private object ObjectUnsafeColumnWriter extends UnsafeColumnWriter { - override def getSize(sourceRow: InternalRow, column: Int): Int = 0 - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val obj = source.get(column) - val idx = target.getPool.put(obj) - target.setLong(column, - idx) - 0 - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index b924af4cc84d8..88fb516e64aaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -36,14 +36,13 @@ case class Average(child: Expression) extends AlgebraicAggregate { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) private val resultType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) - case DecimalType.Unlimited => DecimalType.Unlimited + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 4, s + 4) case _ => DoubleType } private val sumDataType = child.dataType match { - case _ @ DecimalType() => DecimalType.Unlimited + case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _ => DoubleType } @@ -71,7 +70,14 @@ case class Average(child: Expression) extends AlgebraicAggregate { ) // If all input are nulls, currentCount will be 0 and we will get null after the division. - override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType) + override val evaluateExpression = child.dataType match { + case DecimalType.Fixed(p, s) => + // increase the precision and scale to prevent precision loss + val dt = DecimalType.bounded(p + 14, s + 4) + Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType) + case _ => + Cast(currentSum, resultType) / Cast(currentCount, resultType) + } } case class Count(child: Expression) extends AlgebraicAggregate { @@ -255,15 +261,11 @@ case class Sum(child: Expression) extends AlgebraicAggregate { private val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) - case DecimalType.Unlimited => DecimalType.Unlimited + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } - private val sumDataType = child.dataType match { - case _ @ DecimalType() => DecimalType.Unlimited - case _ => child.dataType - } + private val sumDataType = resultType private val currentSum = AttributeReference("currentSum", sumDataType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 577ede73cb01f..d08f553cefe8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCod import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** The mode of an [[AggregateFunction1]]. */ +/** The mode of an [[AggregateFunction2]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -42,15 +42,15 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers - * containing intermediate results for this function and the generate final result. + * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers + * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. */ private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly + * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. @@ -63,10 +63,6 @@ private[sql] case object Complete extends AggregateMode */ private[sql] case object NoOp extends Expression with Unevaluable { override def nullable: Boolean = true - override def eval(input: InternalRow): Any = { - throw new TreeNodeException( - this, s"No function to evaluate expression. type: ${this.nodeName}") - } override def dataType: DataType = NullType override def children: Seq[Expression] = Nil } @@ -89,12 +85,12 @@ private[sql] case class AggregateExpression2( override def nullable: Boolean = aggregateFunction.nullable override def references: AttributeSet = { - val childReferemces = mode match { + val childReferences = mode match { case Partial | Complete => aggregateFunction.references.toSeq case PartialMerge | Final => aggregateFunction.bufferAttributes } - AttributeSet(childReferemces) + AttributeSet(childReferences) } override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" @@ -103,15 +99,34 @@ private[sql] case class AggregateExpression2( abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { - self: Product => - /** An aggregate function is not foldable. */ - override def foldable: Boolean = false + final override def foldable: Boolean = false + + /** + * The offset of this function's start buffer value in the + * underlying shared mutable aggregation buffer. + * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share + * the same aggregation buffer. In this shared buffer, the position of the first + * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)` + * will be 2. + */ + var mutableBufferOffset: Int = 0 /** - * The offset of this function's buffer in the underlying buffer shared with other functions. + * The offset of this function's start buffer value in the + * underlying shared input aggregation buffer. An input aggregation buffer is used + * when we merge two aggregation buffers and it is basically the immutable one + * (we merge an input aggregation buffer and a mutable aggregation buffer and + * then store the new buffer values to the mutable aggregation buffer). + * Usually, an input aggregation buffer also contain extra elements like grouping + * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often + * different. + * For example, we have a grouping expression `key``, and two aggregate functions + * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first + * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` + * will be 3 (position 0 is used for the value of key`). */ - var bufferOffset: Int = 0 + var inputBufferOffset: Int = 0 /** The schema of the aggregation buffer. */ def bufferSchema: StructType @@ -151,8 +166,7 @@ abstract class AggregateFunction2 /** * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. */ -abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { - self: Product => +abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable with Unevaluable { val initialValues: Seq[Expression] val updateExpressions: Seq[Expression] @@ -183,24 +197,20 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { override def initialize(buffer: MutableRow): Unit = { var i = 0 while (i < bufferAttributes.size) { - buffer(i + bufferOffset) = initialValues(i).eval() + buffer(i + mutableBufferOffset) = initialValues(i).eval() i += 1 } } - override def update(buffer: MutableRow, input: InternalRow): Unit = { + override final def update(buffer: MutableRow, input: InternalRow): Unit = { throw new UnsupportedOperationException( "AlgebraicAggregate's update should not be called directly") } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override final def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { throw new UnsupportedOperationException( "AlgebraicAggregate's merge should not be called directly") } - override def eval(buffer: InternalRow): Any = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's eval should not be called directly") - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala new file mode 100644 index 0000000000000..4a43318a95490 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.types.{StructType, MapType, ArrayType} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object Utils { + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { + val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { + case array: ArrayType => true + case map: MapType => true + case struct: StructType => true + case _ => false + } + + !hasComplexTypes + } + + private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = p.transformExpressionsDown { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } + // Check if there is any expressions.AggregateExpression1 left. + // If so, we cannot convert this plan. + val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => + // For every expressions, check if it contains AggregateExpression1. + expr.find { + case agg: expressions.AggregateExpression1 => true + case other => false + }.isDefined + } + + // Check if there are multiple distinct columns. + val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val hasMultipleDistinctColumnSets = + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + true + } else { + false + } + + if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None + + case other => None + } + + def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { + // If the plan cannot be converted, we will do a final round check to see if the original + // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, + // we need to throw an exception. + val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg.aggregateFunction + } + }.distinct + if (aggregateFunction2s.nonEmpty) { + // For functions implemented based on the new interface, prepare a list of function names. + val invalidFunctions = { + if (aggregateFunction2s.length > 1) { + s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + + s"and ${aggregateFunction2s.head.nodeName} are" + } else { + s"${aggregateFunction2s.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} implemented based on the new Aggregate Function " + + s"interface and it cannot be used with functions implemented based on " + + s"the old Aggregate Function interface." + throw new AnalysisException(errorMessage) + } + } + + def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate => + val converted = doConvert(p) + if (converted.isDefined) { + converted + } else { + checkInvalidAggregateFunction2(p) + None + } + case other => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index e07c920a41d0a..5d4b349b1597a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -71,8 +71,7 @@ trait PartialAggregate1 extends AggregateExpression1 { * A specific implementation of an aggregate function. Used to wrap a generic * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. */ -abstract class AggregateFunction1 - extends LeafExpression with AggregateExpression1 with Serializable { +abstract class AggregateFunction1 extends LeafExpression with Serializable { /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression1 @@ -82,9 +81,9 @@ abstract class AggregateFunction1 def update(input: InternalRow): Unit - // Do we really need this? - override def newInstance(): AggregateFunction1 = { - makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + throw new UnsupportedOperationException( + "AggregateFunction1 should not be used for generated aggregates") } } @@ -391,22 +390,21 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 4 digits after decimal point, like Hive + DecimalType.bounded(precision + 4, scale + 4) case _ => DoubleType } override def asPartial: SplitEvaluation = { child.dataType match { - case DecimalType.Fixed(_, _) | DecimalType.Unlimited => - // Turn the child to unlimited decimals for calculation, before going back to fixed - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + case DecimalType.Fixed(precision, scale) => + val partialSum = Alias(Sum(child), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) - val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) + // partialSum already increase the precision by 10 + val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) SplitEvaluation( Cast(Divide(castedSum, castedCount), dataType), partialCount :: partialSum :: Nil) @@ -436,8 +434,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1) private val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) case _ => expr.dataType } @@ -455,10 +453,9 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1) null } else { expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(Divide( - Cast(sum, DecimalType.Unlimited), - Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null) + case DecimalType.Fixed(precision, scale) => + val dt = DecimalType.bounded(precision + 14, scale + 4) + Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null) case _ => Divide( Cast(sum, dataType), @@ -482,9 +479,8 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 10 digits left of decimal point, like Hive + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } @@ -492,15 +488,15 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), + Cast(Sum(partialSum.toAttribute), dataType), partialSum :: Nil) case _ => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - CombineSum(partialSum.toAttribute), + Sum(partialSum.toAttribute), partialSum :: Nil) } } @@ -516,8 +512,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg private val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) case _ => expr.dataType } @@ -526,8 +522,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg private val sum = MutableLiteral(null, calcType) - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) override def update(input: InternalRow): Unit = { sum.update(addFunction, input) @@ -542,76 +537,14 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg } } -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class CombineSumFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - - override def update(input: InternalRow): Unit = { - val result = expr.eval(input) - // partial sum result can be null only when no input rows present - if(result != null) { - sum.update(addFunction, input) - } - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) override def nullable: Boolean = true override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 10 digits left of decimal point, like Hive + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } @@ -680,7 +613,7 @@ case class CombineSetsAndSumFunction( val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] val inputIterator = inputSetEval.iterator while (inputIterator.hasNext) { - seen.add(inputIterator.next) + seen.add(inputIterator.next()) } } @@ -690,7 +623,7 @@ case class CombineSetsAndSumFunction( null } else { Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( + casted.iterator.map(f => f.genericGet(0)).reduceLeft( base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), base.dataType).eval(null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 05b5ad88fee8f..6f8f4dd230f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.Interval +import org.apache.spark.unsafe.types.CalendarInterval case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -37,12 +37,12 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") - case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input.asInstanceOf[Interval].negate() + if (dataType.isInstanceOf[CalendarIntervalType]) { + input.asInstanceOf[CalendarInterval].negate() } else { numeric.negate(input) } @@ -65,8 +65,10 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects /** * A function that get the absolute value of the numeric value. */ -case class Abs(child: Expression) - extends UnaryExpression with ExpectsInputTypes with CodegenFallback { +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", + extended = "> SELECT _FUNC_('-1');\n1") +case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -88,6 +90,8 @@ abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType + override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess + /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") @@ -114,14 +118,11 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval]) + if (dataType.isInstanceOf[CalendarIntervalType]) { + input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval]) } else { numeric.plus(input1, input2) } @@ -133,7 +134,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case IntervalType => + case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") @@ -146,14 +147,11 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "-" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval]) + if (dataType.isInstanceOf[CalendarIntervalType]) { + input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval]) } else { numeric.minus(input1, input2) } @@ -165,7 +163,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case IntervalType => + case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") @@ -179,9 +177,6 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "*" override def decimalMethod: String = "$times" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) @@ -195,9 +190,6 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def decimalMethod: String = "$div" override def nullable: Boolean = true - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot @@ -260,9 +252,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def decimalMethod: String = "remainder" override def nullable: Boolean = true - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala new file mode 100644 index 0000000000000..c98182c96b165 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +/** + * An utility class that indents a block of code based on the curly braces and parentheses. + * This is used to prettify generated code when in debug mode (or exceptions). + * + * Written by Matei Zaharia. + */ +object CodeFormatter { + def format(code: String): String = new CodeFormatter().addLines(code).result() +} + +private class CodeFormatter { + private val code = new StringBuilder + private var indentLevel = 0 + private val indentSize = 2 + private var indentString = "" + + private def addLine(line: String): Unit = { + val indentChange = + line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) + val newIndentLevel = math.max(0, indentLevel + indentChange) + // Lines starting with '}' should be de-indented even if they contain '{' after; + // in addition, lines ending with ':' are typically labels + val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) { + " " * (indentSize * (indentLevel - 1)) + } else { + indentString + } + code.append(thisLineIndent) + code.append(line) + code.append("\n") + indentLevel = newIndentLevel + indentString = " " * (indentSize * newIndentLevel) + } + + private def addLines(code: String): CodeFormatter = { + code.split('\n').foreach(s => addLine(s.trim())) + this + } + + private def result(): String = code.result() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 319dcd1c04316..60e2863f7bbb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -79,7 +79,6 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initCode)) } - final val intervalType: String = classOf[Interval].getName final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -101,14 +100,19 @@ class CodeGenContext { } /** - * Returns the code to access a column in Row for a given DataType. + * Returns the code to access a value in `SpecializedGetters` for a given DataType. */ - def getColumn(row: String, dataType: DataType, ordinal: Int): String = { + def getValue(getter: String, dataType: DataType, ordinal: String): String = { val jt = javaType(dataType) - if (isPrimitiveType(jt)) { - s"$row.get${primitiveTypeName(jt)}($ordinal)" - } else { - s"($jt)$row.apply($ordinal)" + dataType match { + case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$getter.getUTF8String($ordinal)" + case BinaryType => s"$getter.getBinary($ordinal)" + case CalendarIntervalType => s"$getter.getInterval($ordinal)" + case t: StructType => s"$getter.getStruct($ordinal, ${t.size})" + case a: ArrayType => s"$getter.getArray($ordinal)" + case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter. } } @@ -117,10 +121,10 @@ class CodeGenContext { */ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) - if (isPrimitiveType(jt)) { - s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - } else { - s"$row.update($ordinal, $value)" + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case _ => s"$row.update($ordinal, $value)" } } @@ -148,10 +152,10 @@ class CodeGenContext { case dt: DecimalType => "Decimal" case BinaryType => "byte[]" case StringType => "UTF8String" - case IntervalType => intervalType + case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" - case _: ArrayType => s"scala.collection.Seq" - case _: MapType => s"scala.collection.Map" + case _: ArrayType => "ArrayData" + case _: MapType => "scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" @@ -212,7 +216,9 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" - case other => s"$c1.compare($c2)" + case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case _ => throw new IllegalArgumentException( + "cannot generate compare code for un-comparable type") } /** @@ -248,13 +254,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def declareMutableStates(ctx: CodeGenContext) = { + protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" }.mkString("\n ") } - protected def initMutableStates(ctx: CodeGenContext) = { + protected def initMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map(_._3).mkString("\n ") } @@ -290,14 +296,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, - classOf[Decimal].getName + classOf[Decimal].getName, + classOf[CalendarInterval].getName, + classOf[ArrayData].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { evaluator.cook(code) } catch { case e: Exception => - val msg = s"failed to compile:\n $code" + val msg = "failed to compile:\n " + CodeFormatter.format(code) logError(msg, e) throw new Exception(msg, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6b187f05604fd..3492d2c6189ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Expression trait CodegenFallback extends Expression { protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + } + ctx.references += this val objectTerm = ctx.freshName("obj") s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d838268f46956..825031a4faf5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import scala.collection.mutable.ArrayBuffer - // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -45,10 +45,11 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val evaluationCode = e.gen(ctx) evaluationCode.code + s""" - if(${evaluationCode.isNull}) + if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); - else + } else { ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + } """ } // collect projections into blocks as function has 64kb codesize limit in JVM @@ -119,7 +120,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) () => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 2e6f9e204d813..dbd4616d281c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -107,7 +107,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } }""" - logDebug(s"Generated Ordering: $code") + logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}") compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 1dda5992c3654..dfd593fb7c064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -60,7 +60,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool } }""" - logDebug(s"Generated predicate '$predicate':\n$code") + logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] (r: InternalRow) => p.eval(r) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 405d6b0e3bc76..35920147105ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -178,12 +178,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { $initColumns } - public int length() { return ${expressions.length};} + public int numFields() { return ${expressions.length};} protected boolean[] nullBits = new boolean[${expressions.length}]; public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { + public Object get(int i, ${classOf[DataType].getName} dataType) { if (isNullAt(i)) return null; switch (i) { $getCases @@ -230,7 +230,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } """ - logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") + logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" + + CodeFormatter.format(code)) compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3a8e8302b24fd..1d223986d9441 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{NullType, BinaryType, StringType} - +import org.apache.spark.sql.types._ /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -32,89 +31,268 @@ import org.apache.spark.sql.types.{NullType, BinaryType, StringType} */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer.execute) + private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName + private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName + private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName + private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName + private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName + private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + /** Returns true iff we support this data type. */ + def canSupport(dataType: DataType): Boolean = dataType match { + case t: AtomicType if !t.isInstanceOf[DecimalType] => true + case _: CalendarIntervalType => true + case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) + case NullType => true + case t: DecimalType => true + case _ => false + } - protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() - val exprs = expressions.map(_.gen(ctx)) + def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" + case StringType => + s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + case BinaryType => + s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + case CalendarIntervalType => + s" + (${ev.isNull} ? 0 : 16)" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _ => "" + } + + def genFieldWriter( + ctx: CodeGenContext, + fieldType: DataType, + ev: GeneratedExpressionCode, + primitive: String, + index: Int, + cursor: String): String = fieldType match { + case _ if ctx.isPrimitiveType(fieldType) => + s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case StringType => + s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case BinaryType => + s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case CalendarIntervalType => + s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") + } + + /** + * Generates the code to create an [[UnsafeRow]] object based on the input expressions. + * @param ctx context for code generation + * @param ev specifies the name of the variable for the output [[UnsafeRow]] object + * @param expressions input expressions + * @return generated code to put the expression output into an [[UnsafeRow]] + */ + def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) + : String = { + + val ret = ev.primitive + ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + val numBytes = ctx.freshName("numBytes") + + val exprs = expressions.map { e => e.dataType match { + case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st) + case _ => e.gen(ctx) + }} val allExprs = exprs.map(_.code).mkString("\n") + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter" - val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter" - val additionalSize = expressions.zipWithIndex.map { case (e, i) => - e.dataType match { - case StringType => - s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))" - case BinaryType => - s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))" - case _ => "" - } + val additionalSize = expressions.zipWithIndex.map { + case (e, i) => genAdditionalSize(e.dataType, exprs(i)) }.mkString("") val writers = expressions.zipWithIndex.map { case (e, i) => - val update = e.dataType match { - case dt if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}" - case StringType => - s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)" - case BinaryType => - s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") - } + val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) s"""if (${exprs(i).isNull}) { - target.setNullAt($i); + $ret.setNullAt($i); } else { $update; }""" }.mkString("\n ") - val code = s""" - private $exprType[] expressions; + s""" + $allExprs + int $numBytes = $fixedSize $additionalSize; + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; + } + + $ret.pointTo( + $buffer, + org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + ${expressions.size}, + $numBytes); + int $cursor = $fixedSize; - public Object generate($exprType[] expr) { - this.expressions = expr; - return new SpecificProjection(); + $writers + boolean ${ev.isNull} = false; + """ + } + + /** + * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. + * + * This function also handles nested structs by recursively generating the code to do conversion. + * + * @param ctx code generation context + * @param input the input struct, identified by a [[GeneratedExpressionCode]] + * @param schema schema of the struct field + */ + // TODO: refactor createCode and this function to reduce code duplication. + private def createCodeForStruct( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + schema: StructType): GeneratedExpressionCode = { + + val isNull = input.isNull + val primitive = ctx.freshName("structConvert") + ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + + val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { + case (dt, i) => dt match { + case st: StructType => + val nestedStructEv = GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" + ) + createCodeForStruct(ctx, nestedStructEv, st) + case _ => + GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" + ) + } } + val allExprs = exprs.map(_.code).mkString("\n") - class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => + genAdditionalSize(dt, ev) + }.mkString("") - private UnsafeRow target = new UnsafeRow(); - private byte[] buffer = new byte[64]; - ${declareMutableStates(ctx)} + val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => + val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) + s""" + if (${exprs(i).isNull}) { + $primitive.setNullAt($i); + } else { + $update; + } + """ + }.mkString("\n ") - public SpecificProjection() { - ${initMutableStates(ctx)} - } + // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, + // just copy the bytes directly into our buffer space without running any conversion. + // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from + // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. + val tmp = ctx.freshName("tmp") + val numBytes = ctx.freshName("numBytes") + val code = s""" + |${input.code} + |if (!${input.isNull}) { + | Object $tmp = (Object) ${input.primitive}; + | if ($tmp instanceof UnsafeRow) { + | $primitive = (UnsafeRow) $tmp; + | } else { + | $allExprs + | + | int $numBytes = $fixedSize $additionalSize; + | if ($numBytes > $buffer.length) { + | $buffer = new byte[$numBytes]; + | } + | + | $primitive.pointTo( + | $buffer, + | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + | ${exprs.size}, + | $numBytes); + | int $cursor = $fixedSize; + | + | $writers + | } + |} + """.stripMargin - // Scala.Function1 need this - public Object apply(Object row) { - return apply((InternalRow) row); + GeneratedExpressionCode(code, isNull, primitive) + } + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer.execute) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): UnsafeProjection = { + val ctx = newCodeGenContext() + + val isNull = ctx.freshName("retIsNull") + val primitive = ctx.freshName("retValue") + val eval = GeneratedExpressionCode("", isNull, primitive) + eval.code = createCode(ctx, eval, expressions) + + val code = s""" + public Object generate($exprType[] exprs) { + return new SpecificProjection(exprs); } - public UnsafeRow apply(InternalRow i) { - ${allExprs} + class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + + private $exprType[] expressions; - // additionalSize had '+' in the beginning - int numBytes = $fixedSize $additionalSize; - if (numBytes > buffer.length) { - buffer = new byte[numBytes]; + ${declareMutableStates(ctx)} + + public SpecificProjection($exprType[] expressions) { + this.expressions = expressions; + ${initMutableStates(ctx)} + } + + // Scala.Function1 need this + public Object apply(Object row) { + return apply((InternalRow) row); + } + + public UnsafeRow apply(InternalRow i) { + ${eval.code} + return ${eval.primitive}; } - target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, - ${expressions.size}, numBytes, null); - int cursor = $fixedSize; - $writers - return target; } - } - """ + """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2d92dcf23a86e..1a00dbc254de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) override def nullSafeEval(value: Any): Int = child.dataType match { - case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size - case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + case _: ArrayType => value.asInstanceOf[ArrayData].numElements() + case _: MapType => value.asInstanceOf[Map[Any, Any]].size } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + val sizeCall = child.dataType match { + case _: ArrayType => "numElements()" + case _: MapType => "size()" + } + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 20b1eaab8e303..a145dfb4bbf08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable - +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -44,25 +43,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = { - children.map(_.eval(input)) + new GenericArrayData(children.map(_.eval(input)).toArray) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName s""" - boolean ${ev.isNull} = false; - $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + final boolean ${ev.isNull} = false; + final Object[] values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" } override def prettyName: String = "array" @@ -104,18 +104,19 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ }.mkString("\n") } override def prettyName: String = "struct" } + /** * Creates a struct with the given field names and values * @@ -126,11 +127,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } @@ -143,14 +145,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { - val invalidNames = - nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Odd position only allow foldable and not-null StringType expressions, got :" + + s"Only foldable StringType expressions are allowed to appear at odd position , got :" + s" ${invalidNames.mkString(",")}") - } else { + } else if (names.forall(_ != null)){ TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Field name should not be null") } } } @@ -168,14 +171,83 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { valExprs.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ }.mkString("\n") } override def prettyName: String = "named_struct" } + +/** + * Returns a Row containing the evaluation of all children expressions. This is a variant that + * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + */ +case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, children) + } + + override def prettyName: String = "struct_unsafe" +} + + +/** + * Creates a struct with the given field names and values. This is a variant that returns + * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, valExprs) + } + + override def prettyName: String = "named_struct_unsafe" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5504781edca1b..99393c9c76ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -57,7 +57,8 @@ object ExtractValue { case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), + ordinal, fields.length, containsNull) case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) @@ -110,7 +111,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow](ordinal) + input.asInstanceOf[InternalRow].get(ordinal, field.dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { @@ -118,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)}; + ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)}; } """ }) @@ -134,6 +135,7 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, + numFields: Int, containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) @@ -141,26 +143,45 @@ case class GetArrayStructFields( override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { - input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row(ordinal) + val array = input.asInstanceOf[ArrayData] + val length = array.numElements() + val result = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + result(i) = null + } else { + val row = array.getStruct(i, numFields) + if (row.isNullAt(ordinal)) { + result(i) = null + } else { + result(i) = row.get(ordinal, field.dataType) + } + } + i += 1 } + new GenericArrayData(result) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = "scala.collection.mutable.ArraySeq" - // TODO: consider using Array[_] for ArrayType child to avoid - // boxing of primitives + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { s""" - final int n = $eval.size(); - final $arraySeqClass values = new $arraySeqClass(n); + final int n = $eval.numElements(); + final Object[] values = new Object[n]; for (int j = 0; j < n; j++) { - InternalRow row = (InternalRow) $eval.apply(j); - if (row != null && !row.isNullAt($ordinal)) { - values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + if ($eval.isNullAt(j)) { + values[j] = null; + } else { + final InternalRow row = $eval.getStruct(j, $numFields); + if (row.isNullAt($ordinal)) { + values[j] = null; + } else { + values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + } } } - ${ev.primitive} = (${ctx.javaType(dataType)}) values; + ${ev.primitive} = new $arrayClass(values); """ }) } @@ -186,23 +207,23 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx protected override def nullSafeEval(value: Any, ordinal: Any): Any = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives - val baseValue = value.asInstanceOf[Seq[_]] + val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.size || index < 0) { + if (index >= baseValue.numElements() || index < 0) { null } else { - baseValue(index) + baseValue.get(index) } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - final int index = (int)$eval2; - if (index >= $eval1.size() || index < 0) { + final int index = (int) $eval2; + if (index >= $eval1.numElements() || index < 0) { ${ev.isNull} = true; } else { - ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index); + ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 15b33da884dcb..961b1d8616801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -315,7 +315,6 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW * It takes at least 2 parameters, and returns null iff all parameters are null. */ case class Least(children: Seq[Expression]) extends Expression { - require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -323,7 +322,9 @@ case class Least(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST (${children.map(_.dataType)}).") @@ -369,7 +370,6 @@ case class Least(children: Seq[Expression]) extends Expression { * It takes at least 2 parameters, and returns null iff all parameters are null. */ case class Greatest(children: Seq[Expression]) extends Expression { - require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -377,7 +377,9 @@ case class Greatest(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST (${children.map(_.dataType)}).") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 9e55f0546e123..9795673ee0664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Date import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} @@ -26,7 +25,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import scala.util.Try /** * Returns the current date at the start of query evaluation. @@ -62,6 +63,53 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { } } +/** + * Adds a number of days to startdate. + */ +case class DateAdd(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] + d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd + $d;""" + }) + } +} + +/** + * Subtracts a number of days to startdate. + */ +case class DateSub(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] - d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd - $d;""" + }) + } +} + case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -74,9 +122,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getHours($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") } } @@ -92,9 +138,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getMinutes($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") } } @@ -110,9 +154,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getSeconds($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") } } @@ -128,9 +170,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getDayInYear($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") } } @@ -147,9 +187,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => - s"""$dtu.getYear($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") } } @@ -165,9 +203,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getQuarter($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") } } @@ -183,9 +219,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getMonth($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") } } @@ -201,9 +235,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getDayOfMonth($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") } } @@ -226,7 +258,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (time) => { + nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") ctx.addMutableState(cal, c, @@ -250,18 +282,417 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) - override def prettyName: String = "date_format" - override protected def nullSafeEval(timestamp: Any, format: Any): Any = { val sdf = new SimpleDateFormat(format.toString) - UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000))) + UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) - .format(new java.sql.Date($timestamp / 1000)))""" + .format(new java.util.Date($timestamp / 1000)))""" + }) + } + + override def prettyName: String = "date_format" +} + +/** + * Converts time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), returns null if fail. + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss". + * If no parameters provided, the first parameter will be current_timestamp. + * If the first parameter is a Date or Timestamp instead of String, we will ignore the + * second parameter. + */ +case class UnixTimestamp(timeExp: Expression, format: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = timeExp + override def right: Expression = format + + def this(time: Expression) = { + this(time, Literal("yyyy-MM-dd HH:mm:ss")) + } + + def this() = { + this(CurrentTimestamp()) + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, DateType, TimestampType), StringType) + + override def dataType: DataType = LongType + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val t = left.eval(input) + if (t == null) { + null + } else { + left.dataType match { + case DateType => + DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L + case TimestampType => + t.asInstanceOf[Long] / 1000000L + case StringType if right.foldable => + if (constFormat != null) { + Try(new SimpleDateFormat(constFormat.toString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } else { + null + } + case StringType => + val f = format.eval(input) + if (f == null) { + null + } else { + val formatString = f.asInstanceOf[UTF8String].toString + Try(new SimpleDateFormat(formatString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + left.dataType match { + case StringType if right.foldable => + val sdf = classOf[SimpleDateFormat].getName + val fString = if (constFormat == null) null else constFormat.toString + val formatter = ctx.freshName("formatter") + if (fString == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + $sdf $formatter = new $sdf("$fString"); + ${ev.primitive} = + $formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + case StringType => + val sdf = classOf[SimpleDateFormat].getName + nullSafeCodeGen(ctx, ev, (string, format) => { + s""" + try { + ${ev.primitive} = + (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + """ + }) + case TimestampType => + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${eval1.primitive} / 1000000L; + } + """ + case DateType => + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L; + } + """ + } + } +} + +/** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. If the format is missing, using format like "1970-01-01 00:00:00". + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + */ +case class FromUnixTime(sec: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = sec + override def right: Expression = format + + def this(unix: Expression) = { + this(unix, Literal("yyyy-MM-dd HH:mm:ss")) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val time = left.eval(input) + if (time == null) { + null + } else { + if (format.foldable) { + if (constFormat == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format( + new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } else { + val f = format.eval(input) + if (f == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat( + f.asInstanceOf[UTF8String].toString).format(new java.util.Date( + time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val sdf = classOf[SimpleDateFormat].getName + if (format.foldable) { + if (constFormat == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val t = left.gen(ctx) + s""" + ${t.code} + boolean ${ev.isNull} = ${t.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + ${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( + new java.util.Date(${t.primitive} * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (seconds, f) => { + s""" + try { + ${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format( + new java.util.Date($seconds * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + }""".stripMargin + }) + } + } + +} + +/** + * Returns the last day of the month which the date belongs to. + */ +case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def child: Expression = startDate + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def nullSafeEval(date: Any): Any = { + DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") + } + + override def prettyName: String = "last_day" +} + +/** + * Returns the first date which is later than startDate and named as dayOfWeek. + * For example, NextDay(2015-07-27, Sunday) would return 2015-08-02, which is the first + * Sunday later than 2015-07-27. + * + * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. + */ +case class NextDay(startDate: Expression, dayOfWeek: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = dayOfWeek + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, dayOfW: Any): Any = { + val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) + if (dow == -1) { + null + } else { + val sd = start.asInstanceOf[Int] + DateTimeUtils.getNextDateForDayOfWeek(sd, dow) + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, dowS) => { + val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") + val dayOfWeekTerm = ctx.freshName("dayOfWeek") + if (dayOfWeek.foldable) { + val input = dayOfWeek.eval().asInstanceOf[UTF8String] + if ((input eq null) || DateTimeUtils.getDayOfWeekFromString(input) == -1) { + s""" + |${ev.isNull} = true; + """.stripMargin + } else { + val dayOfWeekValue = DateTimeUtils.getDayOfWeekFromString(input) + s""" + |${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue); + """.stripMargin + } + } else { + s""" + |int $dayOfWeekTerm = $dateTimeUtilClass.getDayOfWeekFromString($dowS); + |if ($dayOfWeekTerm == -1) { + | ${ev.isNull} = true; + |} else { + | ${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekTerm); + |} + """.stripMargin + } + }) + } + + override def prettyName: String = "next_day" +} + +/** + * Adds an interval to timestamp. + */ +case class TimeAdd(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left + $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], itvl.months, itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" + }) + } +} + +/** + * Subtracts an interval from timestamp. + */ +case class TimeSub(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left - $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" + }) + } +} + +/** + * Returns the date that is num_months after start_date. + */ +case class AddMonths(startDate: Expression, numMonths: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = numMonths + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, months: Any): Any = { + DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, m) => { + s"""$dtu.dateAddMonths($sd, $m)""" + }) + } +} + +/** + * Returns number of months between dates date1 and date2. + */ +case class MonthsBetween(date1: Expression, date2: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = date1 + override def right: Expression = date2 + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + + override def dataType: DataType = DoubleType + + override def nullSafeEval(t1: Any, t2: Any): Any = { + DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (l, r) => { + s"""$dtu.monthsBetween($l, $r)""" }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index b9d4736a65e26..adb33e4c8d4a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ @@ -60,3 +61,15 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un }) } } + +/** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ +case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def prettyName: String = "change_decimal_precision" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 2dbcf2830f876..8064235c64ef9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -121,8 +121,8 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { case ArrayType(_, _) => - val inputArray = child.eval(input).asInstanceOf[Seq[Any]] - if (inputArray == null) Nil else inputArray.map(v => InternalRow(v)) + val inputArray = child.eval(input).asInstanceOf[ArrayData] + if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v)) case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] if (inputMap == null) Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index f25ac32679587..34bad23802ba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -36,13 +36,13 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) - case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) - case d: Decimal => Literal(d, DecimalType.Unlimited) + case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale)) + case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale())) + case d: Decimal => Literal(d, DecimalType(d.precision, d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) - case i: Interval => Literal(i, IntervalType) + case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) @@ -118,7 +118,7 @@ case class Literal protected (value: Any, dataType: DataType) super.genCode(ctx, ev) } else { ev.isNull = "false" - ev.primitive = s"${value}" + ev.primitive = s"${value}D" "" } case ByteType | ShortType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 68cca0ad3d067..e6d807f6d897b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -646,19 +646,19 @@ case class Logarithm(left: Expression, right: Expression) /** * Round the `child`'s result to `scale` decimal place when `scale` >= 0 * or round at integral part when `scale` < 0. - * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30. * - * Child of IntegralType would eval to itself when `scale` >= 0. - * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * Child of IntegralType would round to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always round to itself. * - * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], - * which leads to scale update in DecimalType's [[PrecisionInfo]] + * Round's dataType would always equal to `child`'s dataType except for DecimalType, + * which would lead scale decrease from the origin DecimalType. * * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime */ case class Round(child: Expression, scale: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { import BigDecimal.RoundingMode.HALF_UP @@ -838,6 +838,4 @@ case class Round(child: Expression, scale: Expression) """ } } - - override def prettyName: String = "round" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3f1bd2a925fe7..ab7d3afce8f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -30,6 +30,10 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { + expression.foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index aef24a5486466..62d3d204ca872 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, the Random Number Generator is - * reset every time we serialize and deserialize it. + * reset every time we serialize and deserialize and initialize it. */ - @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + @transient protected var rng: XORShiftRandom = _ + + override protected def initInternal(): Unit = { + rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + } override def nullable: Boolean = false @@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ case class Rand(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextDouble() + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() def this() = this(Utils.random.nextLong()) @@ -62,7 +66,7 @@ case class Rand(seed: Long) extends RDG { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") + s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); @@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG { /** Generate a random column with i.i.d. gaussian random distribution. */ case class Randn(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextGaussian() + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() def this() = this(Utils.random.nextLong()) @@ -85,7 +89,7 @@ case class Randn(seed: Long) extends RDG { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") + s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index d78be5a5958f9..df6ea586c87ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, StructType, AtomicType} +import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType} import org.apache.spark.unsafe.types.UTF8String /** @@ -39,14 +39,16 @@ abstract class MutableRow extends InternalRow { def setShort(i: Int, value: Short): Unit = { update(i, value) } def setByte(i: Int, value: Byte): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } def setString(i: Int, value: String): Unit = { update(i, UTF8String.fromString(value)) } override def copy(): InternalRow = { - val arr = new Array[Any](length) + val n = numFields + val arr = new Array[Any](n) var i = 0 - while (i < length) { + while (i < n) { arr(i) = get(i) i += 1 } @@ -54,36 +56,23 @@ abstract class MutableRow extends InternalRow { } } -/** - * A row implementation that uses an array of objects as the underlying storage. - */ -trait ArrayBackedRow { - self: Row => - - protected val values: Array[Any] - - override def toSeq: Seq[Any] = values.toSeq - - def length: Int = values.length - - override def get(i: Int): Any = values(i) - - def setNullAt(i: Int): Unit = { values(i) = null} - - def update(i: Int, value: Any): Unit = { values(i) = value } -} - /** * A row implementation that uses an array of objects as the underlying storage. Note that, while * the array is not copied, and thus could technically be mutated after creation, this is not * allowed. */ -class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow { +class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def length: Int = values.length + + override def get(i: Int): Any = values(i) + + override def toSeq: Seq[Any] = values.toSeq + override def copy(): Row = this } @@ -101,34 +90,57 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(protected[sql] val values: Array[Any]) - extends InternalRow with ArrayBackedRow { +class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def toSeq: Seq[Any] = values.toSeq + + override def numFields: Int = values.length + + override def get(i: Int, dataType: DataType): Any = values(i) + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).asInstanceOf[InternalRow] + } + override def copy(): InternalRow = this } /** * This is used for serialization of Python DataFrame */ -class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType) +class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) extends GenericInternalRow(values) { /** No-arg constructor for serialization. */ protected def this() = this(null, null) - override def fieldIndex(name: String): Int = schema.fieldIndex(name) + def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { +class GenericMutableRow(val values: Array[Any]) extends MutableRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def toSeq: Seq[Any] = values.toSeq + + override def numFields: Int = values.length + + override def get(i: Int, dataType: DataType): Any = values(i) + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).asInstanceOf[InternalRow] + } + + override def setNullAt(i: Int): Unit = { values(i) = null} + + override def update(i: Int, value: Any): Unit = { values(i) = value } + override def copy(): InternalRow = new GenericInternalRow(values.clone()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index ee0b5ab481158..8611debcc6a92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -51,7 +51,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; @@ -92,7 +92,7 @@ case class ConcatWs(children: Seq[Expression]) val flatInputs = children.flatMap { child => child.eval(input) match { case s: UTF8String => Iterator(s) - case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]] + case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String]) case null => Iterator(null.asInstanceOf[UTF8String]) } } @@ -105,7 +105,7 @@ case class ConcatWs(children: Seq[Expression]) val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" @@ -213,6 +213,9 @@ trait String2StringExpression extends ImplicitCastInputTypes { /** * A function that converts the characters of a string to uppercase. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns str with all characters changed to uppercase", + extended = "> SELECT _FUNC_('SparkSql');\n 'SPARKSQL'") case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { @@ -226,6 +229,9 @@ case class Upper(child: Expression) /** * A function that converts the characters of a string to lowercase. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns str with all characters changed to lowercase", + extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'") case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -659,13 +665,15 @@ case class StringSplit(str: Expression, pattern: Expression) override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def nullSafeEval(string: Any, regex: Any): Any = { - string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq + val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => - s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer( - java.util.Arrays.asList($str.split($pattern, -1)));""") + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") } override def prettyName: String = "split" @@ -770,7 +778,6 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) override def dataType: DataType = IntegerType - protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) @@ -1018,7 +1025,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio s""" ${evalSubject.code} - boolean ${ev.isNull} = ${evalSubject.isNull}; + boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${evalSubject.isNull}) { ${evalRegexp.code} @@ -1113,9 +1120,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val evalIdx = idx.gen(ctx) s""" - ${ctx.javaType(dataType)} ${ev.primitive} = null; - boolean ${ev.isNull} = true; ${evalSubject.code} + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + boolean ${ev.isNull} = true; if (!${evalSubject.isNull}) { ${evalRegexp.code} if (!${evalRegexp.isNull}) { @@ -1127,7 +1134,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); + ${termPattern}.matcher(${evalSubject.primitive}.toString()); if (m.find()) { ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); @@ -1149,7 +1156,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio * fractional part. */ case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + extends BinaryExpression with ExpectsInputTypes { override def left: Expression = x override def right: Expression = d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d2db3dd3d078e..29d706dcb39a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,8 +36,9 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: - Batch("Distinct", FixedPoint(100), - ReplaceDistinctWithAggregate) :: + Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -311,7 +312,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) + case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => + Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) @@ -553,33 +555,27 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe // Split the condition into small conditions by `And`, so that we can push down part of this // condition without nondeterministic expressions. val andConditions = splitConjunctivePredicates(condition) - val nondeterministicConditions = andConditions.filter(hasNondeterministic(_, aliasMap)) + + val (deterministic, nondeterministic) = andConditions.partition(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a) + }.forall(_.deterministic)) // If there is no nondeterministic conditions, push down the whole condition. - if (nondeterministicConditions.isEmpty) { + if (nondeterministic.isEmpty) { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) } else { // If they are all nondeterministic conditions, leave it un-changed. - if (nondeterministicConditions.length == andConditions.length) { + if (deterministic.isEmpty) { filter } else { - val deterministicConditions = andConditions.filterNot(hasNondeterministic(_, aliasMap)) // Push down the small conditions without nondeterministic expressions. - val pushedCondition = deterministicConditions.map(replaceAlias(_, aliasMap)).reduce(And) - Filter(nondeterministicConditions.reduce(And), + val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) + Filter(nondeterministic.reduce(And), project.copy(child = Filter(pushedCondition, grandChild))) } } } - private def hasNondeterministic( - condition: Expression, - sourceAliases: AttributeMap[Expression]) = { - condition.collect { - case a: Attribute if sourceAliases.contains(a) => sourceAliases(a) - }.exists(!_.deterministic) - } - // Substitute any attributes that are produced by the child projection, so that we safely // eliminate it. private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { @@ -805,3 +801,15 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { case Distinct(child) => Aggregate(child.output, child.output, child) } } + +/** + * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result + * but only makes the grouping key bigger. + */ +object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, _, _) => + val newGrouping = grouping.filter(!_.foldable) + a.copy(groupingExpressions = newGrouping) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b8e3b0d53a505..b9ca712c1ee1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -144,14 +144,14 @@ object PartialAggregation { // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + groupingExpressions.map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation @@ -184,7 +184,7 @@ object PartialAggregation { * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { - /** (joinType, rightKeys, leftKeys, condition, leftChild, rightChild) */ + /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ type ReturnType = (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d06a7a2add754..c610f70d38437 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { self: PlanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 1868f119f0e97..e3e7a11dba973 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} import org.apache.spark.sql.types.{StructField, StructType} @@ -28,6 +29,12 @@ object LocalRelation { new LocalRelation(StructType(output1 +: output).toAttributes) } + def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { + val schema = StructType.fromAttributes(output) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) + } + def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 6aefa9f67556a..a67f8de6b733a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -34,7 +34,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend }.nonEmpty ) - !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions + expressions.forall(_.resolved) && childrenResolved && !hasSpecialExpressions } } @@ -68,7 +68,7 @@ case class Generate( generator.resolved && childrenResolved && generator.elementTypes.length == generatorOutput.length && - !generatorOutput.exists(!_.resolved) + generatorOutput.forall(_.resolved) } // we don't want the gOutput to be taken as part of the expressions @@ -128,7 +128,10 @@ case class Join( // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { - childrenResolved && expressions.forall(_.resolved) && selfJoinResolved + childrenResolved && + expressions.forall(_.resolved) && + selfJoinResolved && + condition.forall(_.dataType == BooleanType) } } @@ -184,14 +187,8 @@ case class WithWindowDefinition( override def output: Seq[Attribute] = child.output } -case class WriteToFile( - path: String, - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output -} - /** - * @param order The ordering expressions + * @param order The ordering expressions, should all be [[AttributeReference]] * @param global True means global sorting apply for entire data set, * False means sorting only apply within the partition. * @param child Child logical plan @@ -201,6 +198,11 @@ case class Sort( global: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + def hasNoEvaluation: Boolean = order.forall(_.child.isInstanceOf[AttributeReference]) + + override lazy val resolved: Boolean = + expressions.forall(_.resolved) && childrenResolved && hasNoEvaluation } case class Aggregate( @@ -215,9 +217,11 @@ case class Aggregate( }.nonEmpty ) - !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions + expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions } + lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } @@ -376,7 +380,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { - val limit = limitExpr.eval(null).asInstanceOf[Int] + val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum Statistics(sizeInBytes = sizeInBytes) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 246f4d7e34d3d..e6621e0f50a9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.types.StringType /** * A logical node that represents a non-query command to be executed by the system. For example, @@ -25,3 +26,28 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * eagerly executed. */ trait Command + +/** + * Returned for the "DESCRIBE [EXTENDED] FUNCTION functionName" command. + * @param functionName The function to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + */ +private[sql] case class DescribeFunction( + functionName: String, + isExtended: Boolean) extends LogicalPlan with Command { + + override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( + AttributeReference("function_desc", StringType, nullable = false)()) +} + +/** + * Returned for the "SHOW FUNCTIONS" command, which will list all of the + * registered function list. + */ +private[sql] case class ShowFunctions( + db: Option[String], pattern: Option[String]) extends LogicalPlan with Command { + override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( + AttributeReference("function", StringType, nullable = false)()) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 07412e73b6a5b..53abdf6618eac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -45,6 +45,7 @@ object DateTimeUtils { final val to2001 = -11323 // this is year -17999, calculation: 50 * daysIn400Year + final val YearZero = -17999 final val toYearZero = to2001 + 7304850 @transient lazy val defaultTimeZone = TimeZone.getDefault @@ -573,4 +574,209 @@ object DateTimeUtils { dayInYear - 334 } } + + /** + * The number of days for each month (not leap year) + */ + private val monthDays = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31) + + /** + * Returns the date value for the first day of the given month. + * The month is expressed in months since year zero (17999 BC), starting from 0. + */ + private def firstDayOfMonth(absoluteMonth: Int): Int = { + val absoluteYear = absoluteMonth / 12 + var monthInYear = absoluteMonth - absoluteYear * 12 + var date = getDateFromYear(absoluteYear) + if (monthInYear >= 2 && isLeapYear(absoluteYear + YearZero)) { + date += 1 + } + while (monthInYear > 0) { + date += monthDays(monthInYear - 1) + monthInYear -= 1 + } + date + } + + /** + * Returns the date value for January 1 of the given year. + * The year is expressed in years since year zero (17999 BC), starting from 0. + */ + private def getDateFromYear(absoluteYear: Int): Int = { + val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100 + + absoluteYear / 4) + absoluteDays - toYearZero + } + + /** + * Add date and year-month interval. + * Returns a date value, expressed in days since 1.1.1970. + */ + def dateAddMonths(days: Int, months: Int): Int = { + val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months + val currentMonthInYear = absoluteMonth % 12 + val currentYear = absoluteMonth / 12 + val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 + val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay + + val dayOfMonth = getDayOfMonth(days) + val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) { + // last day of the month + lastDayOfMonth + } else { + dayOfMonth + } + firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1 + } + + /** + * Add timestamp and full interval. + * Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00. + */ + def timestampAddInterval(start: Long, months: Int, microseconds: Long): Long = { + val days = millisToDays(start / 1000L) + val newDays = dateAddMonths(days, months) + daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds + } + + /** + * Returns the last dayInMonth in the month it belongs to. The date is expressed + * in days since 1.1.1970. the return value starts from 1. + */ + private def getLastDayInMonthOfMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear > 31 && dayInYear <= 60) { + return 29 + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + if (dayInYear <= 31) { + 31 + } else if (dayInYear <= 59) { + 28 + } else if (dayInYear <= 90) { + 31 + } else if (dayInYear <= 120) { + 30 + } else if (dayInYear <= 151) { + 31 + } else if (dayInYear <= 181) { + 30 + } else if (dayInYear <= 212) { + 31 + } else if (dayInYear <= 243) { + 31 + } else if (dayInYear <= 273) { + 30 + } else if (dayInYear <= 304) { + 31 + } else if (dayInYear <= 334) { + 30 + } else { + 31 + } + } + + /** + * Returns number of months between time1 and time2. time1 and time2 are expressed in + * microseconds since 1.1.1970. + * + * If time1 and time2 having the same day of month, or both are the last day of month, + * it returns an integer (time under a day will be ignored). + * + * Otherwise, the difference is calculated based on 31 days per month, and rounding to + * 8 digits. + */ + def monthsBetween(time1: Long, time2: Long): Double = { + val millis1 = time1 / 1000L + val millis2 = time2 / 1000L + val date1 = millisToDays(millis1) + val date2 = millisToDays(millis2) + // TODO(davies): get year, month, dayOfMonth from single function + val dayInMonth1 = getDayOfMonth(date1) + val dayInMonth2 = getDayOfMonth(date2) + val months1 = getYear(date1) * 12 + getMonth(date1) + val months2 = getYear(date2) * 12 + getMonth(date2) + + if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1) + && dayInMonth2 == getLastDayInMonthOfMonth(date2))) { + return (months1 - months2).toDouble + } + // milliseconds is enough for 8 digits precision on the right side + val timeInDay1 = millis1 - daysToMillis(date1) + val timeInDay2 = millis2 - daysToMillis(date2) + val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY + val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } + + /* + * Returns day of week from String. Starting from Thursday, marked as 0. + * (Because 1970-01-01 is Thursday). + */ + def getDayOfWeekFromString(string: UTF8String): Int = { + val dowString = string.toString.toUpperCase + dowString match { + case "SU" | "SUN" | "SUNDAY" => 3 + case "MO" | "MON" | "MONDAY" => 4 + case "TU" | "TUE" | "TUESDAY" => 5 + case "WE" | "WED" | "WEDNESDAY" => 6 + case "TH" | "THU" | "THURSDAY" => 0 + case "FR" | "FRI" | "FRIDAY" => 1 + case "SA" | "SAT" | "SATURDAY" => 2 + case _ => -1 + } + } + + /** + * Returns the first date which is later than startDate and is of the given dayOfWeek. + * dayOfWeek is an integer ranges in [0, 6], and 0 is Thu, 1 is Fri, etc,. + */ + def getNextDateForDayOfWeek(startDate: Int, dayOfWeek: Int): Int = { + startDate + 1 + ((dayOfWeek - 1 - startDate) % 7 + 7) % 7 + } + + /** + * Returns last day of the month for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getLastDayOfMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear > 31 && dayInYear <= 60) { + return date + (60 - dayInYear) + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + val lastDayOfMonthInYear = if (dayInYear <= 31) { + 31 + } else if (dayInYear <= 59) { + 59 + } else if (dayInYear <= 90) { + 90 + } else if (dayInYear <= 120) { + 120 + } else if (dayInYear <= 151) { + 151 + } else if (dayInYear <= 181) { + 181 + } else if (dayInYear <= 212) { + 212 + } else if (dayInYear <= 243) { + 243 + } else if (dayInYear <= 273) { + 273 + } else if (dayInYear <= 304) { + 304 + } else if (dayInYear <= 334) { + 334 + } else { + 365 + } + date + (lastDayOfMonthInYear - dayInYear) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 40bf4b299c990..e0667c629486d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -95,7 +95,7 @@ private[sql] object TypeCollection { * Types that include numeric types and interval type. They are only used in unary_minus, * unary_positive, add and subtract operations. */ - val NumericAndInterval = TypeCollection(NumericType, IntervalType) + val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType) def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala new file mode 100644 index 0000000000000..14a7285877622 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters + +abstract class ArrayData extends SpecializedGetters with Serializable { + // todo: remove this after we handle all types.(map type need special getter) + def get(ordinal: Int): Any + + def numElements(): Int + + // todo: need a more efficient way to iterate array type. + def toArray(): Array[Any] = { + val n = numElements() + val values = new Array[Any](n) + var i = 0 + while (i < n) { + if (isNullAt(i)) { + values(i) = null + } else { + values(i) = get(i) + } + i += 1 + } + values + } + + override def toString(): String = toArray.mkString("[", ",", "]") + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[ArrayData]) { + return false + } + + val other = o.asInstanceOf[ArrayData] + if (other eq null) { + return false + } + + val len = numElements() + if (len != other.numElements()) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numElements() + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + get(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala similarity index 64% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index 87c6e9e6e5e2c..3565f52c21f69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -22,16 +22,19 @@ import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * The data type representing time intervals. + * The data type representing calendar time intervals. The calendar time interval is stored + * internally in two components: number of months the number of microseconds. * - * Please use the singleton [[DataTypes.IntervalType]]. + * Note that calendar intervals are not comparable. + * + * Please use the singleton [[DataTypes.CalendarIntervalType]]. */ @DeveloperApi -class IntervalType private() extends DataType { +class CalendarIntervalType private() extends DataType { - override def defaultSize: Int = 4096 + override def defaultSize: Int = 16 - private[spark] override def asNullable: IntervalType = this + private[spark] override def asNullable: CalendarIntervalType = this } -case object IntervalType extends IntervalType +case object CalendarIntervalType extends CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e98fd2583b931..f4428c2e8b202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -106,7 +106,7 @@ object DataType { private def nameToType(name: String): DataType = { val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r name match { - case "decimal" => DecimalType.Unlimited + case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) case other => nonDecimalNameToType(other) } @@ -142,12 +142,21 @@ object DataType { ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + // Scala/Java UDT case JSortedObject( ("class", JString(udtClass)), ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + + // Python UDT + case JSortedObject( + ("pyClass", JString(pyClass)), + ("serializedClass", JString(serialized)), + ("sqlType", v: JValue), + ("type", JString("udt"))) => + new PythonUserDefinedType(parseDataType(v), pyClass, serialized) } private def parseStructField(json: JValue): StructField = json match { @@ -177,7 +186,7 @@ object DataType { | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.Unlimited + | "DecimalType()" ^^^ DecimalType.USER_DEFAULT | fixedDecimalType | "TimestampType" ^^^ TimestampType ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 6b43224feb1f2..6e081ea9237bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -48,7 +48,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)binary".r ^^^ BinaryType | "(?i)boolean".r ^^^ BooleanType | fixedDecimalType | - "(?i)decimal".r ^^^ DecimalType.Unlimited | + "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | varchar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bc689810bc292..c0155eeb450a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -188,6 +188,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { + // fast path for UnsafeProjection + if (precision == this.precision && scale == this.scale) { + return true + } // First, update our longVal if we can, or transfer over to using a BigDecimal if (decimalVal.eq(null)) { if (scale < _scale) { @@ -224,7 +228,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { decimalVal = newVal } else { // We're still using Longs, but we should check whether we match the new precision - val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) if (longVal <= -p || longVal >= p) { // Note that we shouldn't have been able to fix this by switching to BigDecimal return false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 377c75f6e85a5..0cd352d0fa928 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -26,25 +26,46 @@ import org.apache.spark.sql.catalyst.expressions.Expression /** Precision parameters for a Decimal */ +@deprecated("Use DecimalType(precision, scale) directly", "1.5") case class PrecisionInfo(precision: Int, scale: Int) { if (scale > precision) { throw new AnalysisException( s"Decimal scale ($scale) cannot be greater than precision ($precision).") } + if (precision > DecimalType.MAX_PRECISION) { + throw new AnalysisException( + s"DecimalType can only support precision up to 38" + ) + } } /** * :: DeveloperApi :: * The data type representing `java.math.BigDecimal` values. - * A Decimal that might have fixed precision and scale, or unlimited values for these. + * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number + * of digits on right side of dot). + * + * The precision can be up to 38, scale can also be up to 38 (less or equal to precision). + * + * The default precision and scale is (10, 0). * * Please use [[DataTypes.createDecimalType()]] to create a specific instance. */ @DeveloperApi -case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { +case class DecimalType(precision: Int, scale: Int) extends FractionalType { + + // default constructor for Java + def this(precision: Int) = this(precision, 0) + def this() = this(10) + + @deprecated("Use DecimalType(precision, scale) instead", "1.5") + def this(precisionInfo: Option[PrecisionInfo]) { + this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, + precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) + } - /** No-arg constructor for kryo. */ - protected def this() = this(null) + @deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5") + val precisionInfo = Some(PrecisionInfo(precision, scale)) private[sql] type InternalType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } @@ -53,18 +74,20 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT private[sql] val ordering = Decimal.DecimalIsFractional private[sql] val asIntegral = Decimal.DecimalAsIfIntegral - def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) + override def typeName: String = s"decimal($precision,$scale)" - def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) + override def toString: String = s"DecimalType($precision,$scale)" - override def typeName: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal" - } - - override def toString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" - case None => "DecimalType()" + /** + * Returns whether this DecimalType is wider than `other`. If yes, it means `other` + * can be casted into `this` safely without losing any precision or range. + */ + private[sql] def isWiderThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale + case dt: IntegralType => + isWiderThan(DecimalType.forType(dt)) + case _ => false } /** @@ -72,10 +95,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT */ override def defaultSize: Int = 4096 - override def simpleString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal(10,0)" - } + override def simpleString: String = s"decimal($precision,$scale)" private[spark] override def asNullable: DecimalType = this } @@ -83,8 +103,47 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ object DecimalType extends AbstractDataType { + import scala.math.min + + val MAX_PRECISION = 38 + val MAX_SCALE = 38 + val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) + val USER_DEFAULT: DecimalType = DecimalType(10, 0) + + @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5") + val Unlimited: DecimalType = SYSTEM_DEFAULT + + // The decimal types compatible with other numeric types + private[sql] val ByteDecimal = DecimalType(3, 0) + private[sql] val ShortDecimal = DecimalType(5, 0) + private[sql] val IntDecimal = DecimalType(10, 0) + private[sql] val LongDecimal = DecimalType(20, 0) + private[sql] val FloatDecimal = DecimalType(14, 7) + private[sql] val DoubleDecimal = DecimalType(30, 15) + + private[sql] def forType(dataType: DataType): DecimalType = dataType match { + case ByteType => ByteDecimal + case ShortType => ShortDecimal + case IntegerType => IntDecimal + case LongType => LongDecimal + case FloatType => FloatDecimal + case DoubleType => DoubleDecimal + } + + @deprecated("please specify precision and scale", "1.5") + def apply(): DecimalType = USER_DEFAULT + + @deprecated("Use DecimalType(precision, scale) instead", "1.5") + def apply(precisionInfo: Option[PrecisionInfo]) { + this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, + precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) + } + + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { + DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) + } - override private[sql] def defaultConcreteType: DataType = Unlimited + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[DecimalType] @@ -92,31 +151,18 @@ object DecimalType extends AbstractDataType { override private[sql] def simpleString: String = "decimal" - val Unlimited: DecimalType = DecimalType(None) - private[sql] object Fixed { - def unapply(t: DecimalType): Option[(Int, Int)] = - t.precisionInfo.map(p => (p.precision, p.scale)) + def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale)) } private[sql] object Expression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { - case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case t: DecimalType => Some((t.precision, t.scale)) case _ => None } } - def apply(): DecimalType = Unlimited - - def apply(precision: Int, scale: Int): DecimalType = - DecimalType(Some(PrecisionInfo(precision, scale))) - def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] - - def isFixed(dataType: DataType): Boolean = dataType match { - case DecimalType.Fixed(_, _) => true - case _ => false - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala new file mode 100644 index 0000000000000..35ace673fb3da --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval} + +class GenericArrayData(array: Array[Any]) extends ArrayData { + private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T] + + override def toArray(): Array[Any] = array + + override def get(ordinal: Int): Any = array(ordinal) + + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + + override def getByte(ordinal: Int): Byte = getAs(ordinal) + + override def getShort(ordinal: Int): Short = getAs(ordinal) + + override def getInt(ordinal: Int): Int = getAs(ordinal) + + override def getLong(ordinal: Int): Long = getAs(ordinal) + + override def getFloat(ordinal: Int): Float = getAs(ordinal) + + override def getDouble(ordinal: Int): Double = getAs(ordinal) + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + + override def numElements(): Int = array.length +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index e47cfb4833bd8..4305903616bd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -45,6 +45,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Paired Python UDT class, if exists. */ def pyUDT: String = null + /** Serialized Python UDT class, if exists. */ + def serializedPyClass: String = null + /** * Convert the user type to a SQL datum * @@ -82,3 +85,29 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass } + +/** + * ::DeveloperApi:: + * The user defined type in Python. + * + * Note: This can only be accessed via Python UDF, or accessed as serialized object. + */ +private[sql] class PythonUserDefinedType( + val sqlType: DataType, + override val pyUDT: String, + override val serializedPyClass: String) extends UserDefinedType[Any] { + + /* The serialization is handled by UDT class in Python */ + override def serialize(obj: Any): Any = obj + override def deserialize(datam: Any): Any = datam + + /* There is no Java class for Python UDT */ + override def userClass: java.lang.Class[Any] = null + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("pyClass" -> pyUDT) ~ + ("serializedClass" -> serializedPyClass) ~ + ("sqlType" -> sqlType.jsonValue) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 13aad467fa578..75ae29d690770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -69,8 +69,7 @@ object RandomDataGenerator { * Returns a function which generates random values for the given [[DataType]], or `None` if no * random data generator is defined for that data type. The generated values will use an external * representation of the data type; for example, the random generator for [[DateType]] will return - * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a - * [[org.apache.spark.Row]]. + * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. * * @param dataType the type to generate values for * @param nullable whether null values should be generated @@ -94,8 +93,8 @@ object RandomDataGenerator { case BooleanType => Some(() => rand.nextBoolean()) case DateType => Some(() => new java.sql.Date(rand.nextInt())) case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) - case DecimalType.Unlimited => Some( - () => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED)) + case DecimalType.Fixed(precision, scale) => Some( + () => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision))) case DoubleType => randomNumeric[Double]( rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index dbba93dba668e..cccac7efa09e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -32,7 +32,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite { */ def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) - val generator = RandomDataGenerator.forType(dataType, nullable).getOrElse { + val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse { fail(s"Random data generator was not defined for $dataType") } if (nullable) { @@ -50,9 +50,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite { for ( dataType <- DataTypeTestUtils.atomicTypes; nullable <- Seq(true, false) - if !dataType.isInstanceOf[DecimalType] || - dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty - ) { + if !dataType.isInstanceOf[DecimalType]) { test(s"$dataType (nullable=$nullable)") { testRandomDataGeneration(dataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 878a1bb9b7e6d..01ff84cb56054 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -83,15 +83,5 @@ class RowTest extends FunSpec with Matchers { it("equality check for internal rows") { internalRow shouldEqual internalRow2 } - - it("throws an exception when check equality between external and internal rows") { - def assertError(f: => Unit): Unit = { - val e = intercept[UnsupportedOperationException](f) - e.getMessage.contains("cannot check equality between external and internal rows") - } - - assertError(internalRow.equals(externalRow)) - assertError(externalRow.equals(internalRow)) - } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index b4b00f558463f..3b848cfdf737f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -102,7 +102,7 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("byteField", ByteType, nullable = true), StructField("booleanField", BooleanType, nullable = true), StructField("stringField", StringType, nullable = true), - StructField("decimalField", DecimalType.Unlimited, nullable = true), + StructField("decimalField", DecimalType.SYSTEM_DEFAULT, nullable = true), StructField("dateField", DateType, nullable = true), StructField("timestampField", TimestampType, nullable = true), StructField("binaryField", BinaryType, nullable = true))), @@ -216,7 +216,7 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(DoubleType === typeOfObject(1.7976931348623157E308)) // DecimalType - assert(DecimalType.Unlimited === + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) // DateType @@ -229,19 +229,19 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(NullType === typeOfObject(null)) def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.Unlimited - case value: java.math.BigDecimal => DecimalType.Unlimited + case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT + case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT case _ => StringType } - assert(DecimalType.Unlimited === typeOfObject1( + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( new BigInteger("92233720368547758070"))) - assert(DecimalType.Unlimited === typeOfObject1( + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( new java.math.BigDecimal("1.7976931348623157E318"))) assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.Unlimited + case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT } intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index dca8c881f21ab..2588df98246dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -113,11 +113,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { testRelation.select(Literal(1).cast(BinaryType).as('badCast)), "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + errorTest( + "sorting by unsupported column types", + listRelation.orderBy('list.asc), + "sorting" :: "type" :: "array" :: Nil) + errorTest( "non-boolean filters", testRelation.where(Literal(1)), "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + errorTest( + "non-boolean join conditions", + testRelation.join(testRelation, condition = Some(Literal(1))), + "condition" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + errorTest( "missing group by", testRelation2.groupBy('a)('b), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 58df1de983a09..a86cefe941e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -28,6 +24,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +// todo: remove this and use AnalysisTest instead. object AnalysisSuite { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -55,7 +52,7 @@ object AnalysisSuite { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.Unlimited)(), + AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) val nestedRelation = LocalRelation( @@ -81,8 +78,7 @@ object AnalysisSuite { } -class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { - import AnalysisSuite._ +class AnalysisSuite extends AnalysisTest { test("union project *") { val plan = (1 to 100) @@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyzer.execute(plan).resolved) + assertAnalysisSuccess(plan) } test("check project's resolved") { @@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { } test("analyze project") { - assert( - caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) === - Project(testRelation.output, testRelation)) - - assert( - caseSensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - val e = intercept[AnalysisException] { - caseSensitiveAnalyze( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) - } - assert(e.getMessage().toLowerCase.contains("cannot resolve")) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) + checkAnalysis( + Project(Seq(UnresolvedAttribute("a")), testRelation), + Project(testRelation.output, testRelation)) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation)) + + assertAnalysisError( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Seq("cannot resolve")) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) } test("resolve relations") { - val e = intercept[RuntimeException] { - caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) - } - assert(e.getMessage == "Table Not Found: tAbLe") + assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe")) - assert( - caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false) } - test("divide should be casted into fractional types") { - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.Unlimited)(), - AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyzer.execute( testRelation2.select( 'a / Literal(2) as 'div1, @@ -173,7 +148,56 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DecimalType.Unlimited) + // StringType will be promoted into Decimal(38, 18) + assert(pl(3).dataType == DecimalType(38, 29)) assert(pl(4).dataType == DoubleType) } + + test("pull out nondeterministic expressions from RepartitionByExpression") { + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val projected = Alias(Rand(33), "_nondeterministic")() + val expected = + Project(testRelation.output, + RepartitionByExpression(Seq(projected.toAttribute), + Project(testRelation.output :+ projected, testRelation))) + checkAnalysis(plan, expected) + } + + test("pull out nondeterministic expressions from Sort") { + val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation) + val analyzed = caseSensitiveAnalyzer.execute(plan) + analyzed.transform { + case s: Sort if s.expressions.exists(!_.deterministic) => + fail("nondeterministic expressions are not allowed in Sort") + } + } + + test("remove still-need-evaluate ordering expressions from sort") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + + def makeOrder(e: Expression): SortOrder = SortOrder(e, Ascending) + + val noEvalOrdering = makeOrder(a) + val noEvalOrderingWithAlias = makeOrder(Alias(Alias(b, "name1")(), "name2")()) + + val needEvalExpr = Coalesce(Seq(a, Literal("1"))) + val needEvalExpr2 = Coalesce(Seq(a, b)) + val needEvalOrdering = makeOrder(needEvalExpr) + val needEvalOrdering2 = makeOrder(needEvalExpr2) + + val plan = Sort( + Seq(noEvalOrdering, noEvalOrderingWithAlias, needEvalOrdering, needEvalOrdering2), + false, testRelation2) + + val evaluatedOrdering = makeOrder(AttributeReference("_sortCondition", StringType)()) + val materializedExprs = Seq(needEvalExpr, needEvalExpr2).map(e => Alias(e, "_sortCondition")()) + + val expected = + Project(testRelation2.output, + Sort(Seq(makeOrder(a), makeOrder(b), evaluatedOrdering, evaluatedOrdering), false, + Project(testRelation2.output ++ materializedExprs, testRelation2))) + + checkAnalysis(plan, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala new file mode 100644 index 0000000000000..fdb4f28950daf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.types._ + +trait AnalysisTest extends PlanTest { + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))(), + AttributeReference("e", ShortType)()) + + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + + val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { + val caseSensitiveConf = new SimpleCatalystConf(true) + val caseInsensitiveConf = new SimpleCatalystConf(false) + + val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) + val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) + + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } -> + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + } + + protected def getAnalyzer(caseSensitive: Boolean) = { + if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer + } + + protected def checkAnalysis( + inputPlan: LogicalPlan, + expectedPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + val actualPlan = analyzer.execute(inputPlan) + analyzer.checkAnalysis(actualPlan) + comparePlans(actualPlan, expectedPlan) + } + + protected def assertAnalysisSuccess( + inputPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + + protected def assertAnalysisError( + inputPlan: LogicalPlan, + expectedErrors: Seq[String], + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + // todo: make sure we throw AnalysisException during analysis + val e = intercept[Exception] { + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + expectedErrors.forall(e.getMessage.contains) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 7bac97b7894f5..fc11627da6fd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -34,7 +34,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("i", IntegerType)(), AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), - AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("f", FloatType)(), AttributeReference("b", DoubleType)() ) @@ -92,11 +92,11 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualTo(i, d1), DecimalType(11, 1)) checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) - checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThan(i, d1), DecimalType(11, 1)) checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) - checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) + checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } @@ -106,12 +106,12 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkUnion(i, d2, DecimalType(12, 2)) checkUnion(d1, d2, DecimalType(5, 2)) checkUnion(d2, d1, DecimalType(5, 2)) - checkUnion(d1, f, DecimalType(8, 7)) - checkUnion(f, d2, DecimalType(10, 7)) - checkUnion(d1, b, DecimalType(16, 15)) - checkUnion(b, d2, DecimalType(18, 15)) - checkUnion(d1, u, DecimalType.Unlimited) - checkUnion(u, d2, DecimalType.Unlimited) + checkUnion(d1, f, DoubleType) + checkUnion(f, d2, DoubleType) + checkUnion(d1, b, DoubleType) + checkUnion(b, d2, DoubleType) + checkUnion(d1, u, DecimalType.SYSTEM_DEFAULT) + checkUnion(u, d2, DecimalType.SYSTEM_DEFAULT) } test("bringing in primitive types") { @@ -125,13 +125,59 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Add(d1, Cast(i, DoubleType)), DoubleType) } - test("unlimited decimals make everything else cast up") { - for (expr <- Seq(d1, d2, i, f, u)) { - checkType(Add(expr, u), DecimalType.Unlimited) - checkType(Subtract(expr, u), DecimalType.Unlimited) - checkType(Multiply(expr, u), DecimalType.Unlimited) - checkType(Divide(expr, u), DecimalType.Unlimited) - checkType(Remainder(expr, u), DecimalType.Unlimited) + test("maximum decimals") { + for (expr <- Seq(d1, d2, i, u)) { + checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) } + + checkType(Multiply(d1, u), DecimalType(38, 19)) + checkType(Multiply(d2, u), DecimalType(38, 20)) + checkType(Multiply(i, u), DecimalType(38, 18)) + checkType(Multiply(u, u), DecimalType(38, 36)) + + checkType(Divide(u, d1), DecimalType(38, 21)) + checkType(Divide(u, d2), DecimalType(38, 24)) + checkType(Divide(u, i), DecimalType(38, 29)) + checkType(Divide(u, u), DecimalType(38, 38)) + + checkType(Remainder(d1, u), DecimalType(19, 18)) + checkType(Remainder(d2, u), DecimalType(21, 18)) + checkType(Remainder(i, u), DecimalType(28, 18)) + checkType(Remainder(u, u), DecimalType.SYSTEM_DEFAULT) + + for (expr <- Seq(f, b)) { + checkType(Add(expr, u), DoubleType) + checkType(Subtract(expr, u), DoubleType) + checkType(Multiply(expr, u), DoubleType) + checkType(Divide(expr, u), DoubleType) + checkType(Remainder(expr, u), DoubleType) + } + } + + test("DecimalType.isWiderThan") { + val d0 = DecimalType(2, 0) + val d1 = DecimalType(2, 1) + val d2 = DecimalType(5, 2) + val d3 = DecimalType(15, 3) + val d4 = DecimalType(25, 4) + + assert(d0.isWiderThan(d1) === false) + assert(d1.isWiderThan(d0) === false) + assert(d1.isWiderThan(d2) === false) + assert(d2.isWiderThan(d1) === true) + assert(d2.isWiderThan(d3) === false) + assert(d3.isWiderThan(d2) === true) + assert(d4.isWiderThan(d3) === true) + + assert(d1.isWiderThan(ByteType) === false) + assert(d2.isWiderThan(ByteType) === true) + assert(d2.isWiderThan(ShortType) === false) + assert(d3.isWiderThan(ShortType) === true) + assert(d3.isWiderThan(IntegerType) === true) + assert(d3.isWiderThan(LongType) === false) + assert(d4.isWiderThan(LongType) === true) + assert(d4.isWiderThan(FloatType) === false) + assert(d4.isWiderThan(DoubleType) === false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ad15136ee9a2f..a52e4cb4dfd9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "type (numeric or interval)") + assertError(UnaryMinus('stringField), "type (numeric or calendarinterval)") assertError(Abs('stringField), "expected to be of type numeric") assertError(BitwiseNot('stringField), "expected to be of type integral") } @@ -78,8 +78,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type") - assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type") + assertError(Add('booleanField, 'booleanField), "accepts (numeric or calendarinterval) type") + assertError(Subtract('booleanField, 'booleanField), + "accepts (numeric or calendarinterval) type") assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") assertError(Divide('booleanField, 'booleanField), "accepts numeric type") assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") @@ -166,10 +167,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Odd position only allow foldable and not-null StringType expressions") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Odd position only allow foldable and not-null StringType expressions") + "Only foldable StringType expressions are allowed to appear at odd position") + assertError( + CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), + "Field name should not be null") } test("check types for ROUND") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 835220c563f41..70608771dd110 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.analysis +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class HiveTypeCoercionSuite extends PlanTest { @@ -35,14 +38,14 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(NullType, NullType, NullType) shouldCast(NullType, IntegerType, IntegerType) - shouldCast(NullType, DecimalType, DecimalType.Unlimited) + shouldCast(NullType, DecimalType, DecimalType.SYSTEM_DEFAULT) shouldCast(ByteType, IntegerType, IntegerType) shouldCast(IntegerType, IntegerType, IntegerType) shouldCast(IntegerType, LongType, LongType) - shouldCast(IntegerType, DecimalType, DecimalType.Unlimited) + shouldCast(IntegerType, DecimalType, DecimalType(10, 0)) shouldCast(LongType, IntegerType, IntegerType) - shouldCast(LongType, DecimalType, DecimalType.Unlimited) + shouldCast(LongType, DecimalType, DecimalType(20, 0)) shouldCast(DateType, TimestampType, TimestampType) shouldCast(TimestampType, DateType, DateType) @@ -71,8 +74,8 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) - shouldCast( - DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited) + shouldCast(DecimalType.SYSTEM_DEFAULT, + TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT) shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) @@ -82,7 +85,7 @@ class HiveTypeCoercionSuite extends PlanTest { // NumericType should not be changed when function accepts any of them. Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => + DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)).foreach { tpe => shouldCast(tpe, NumericType, tpe) } @@ -107,14 +110,22 @@ class HiveTypeCoercionSuite extends PlanTest { shouldNotCast(IntegerType, TimestampType) shouldNotCast(LongType, DateType) shouldNotCast(LongType, TimestampType) - shouldNotCast(DecimalType.Unlimited, DateType) - shouldNotCast(DecimalType.Unlimited, TimestampType) + shouldNotCast(DecimalType.SYSTEM_DEFAULT, DateType) + shouldNotCast(DecimalType.SYSTEM_DEFAULT, TimestampType) shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) shouldNotCast(IntegerType, ArrayType) shouldNotCast(IntegerType, MapType) shouldNotCast(IntegerType, StructType) + + shouldNotCast(CalendarIntervalType, StringType) + + // Don't implicitly cast complex types to string. + shouldNotCast(ArrayType(StringType), StringType) + shouldNotCast(MapType(StringType, StringType), StringType) + shouldNotCast(new StructType().add("a1", StringType), StringType) + shouldNotCast(MapType(StringType, StringType), StringType) } test("tightest common bound for types") { @@ -160,14 +171,6 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(LongType, FloatType, Some(FloatType)) widenTest(LongType, DoubleType, Some(DoubleType)) - // Casting up to unlimited-precision decimal - widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited)) - // No up-casting for fixed-precision decimal (this is handled by arithmetic rules) widenTest(DecimalType(2, 1), DecimalType(3, 2), None) widenTest(DecimalType(2, 1), DoubleType, None) @@ -242,9 +245,9 @@ class HiveTypeCoercionSuite extends PlanTest { :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) :: Nil), - Coalesce(Cast(Literal(1L), DecimalType()) - :: Cast(Literal(1), DecimalType()) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) + Coalesce(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) } @@ -314,7 +317,7 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenTypes for union except and intersect") { + test("WidenSetOperationTypes for union except and intersect") { def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { logical.output.zip(expectTypes).foreach { case (attr, dt) => assert(attr.dataType === dt) @@ -323,7 +326,7 @@ class HiveTypeCoercionSuite extends PlanTest { val left = LocalRelation( AttributeReference("i", IntegerType)(), - AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("b", ByteType)(), AttributeReference("d", DoubleType)()) val right = LocalRelation( @@ -332,8 +335,8 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = HiveTypeCoercion.WidenTypes - val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType) + val wt = HiveTypeCoercion.WidenSetOperationTypes + val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) val r1 = wt(Union(left, right)).asInstanceOf[Union] val r2 = wt(Except(left, right)).asInstanceOf[Except] @@ -353,13 +356,13 @@ class HiveTypeCoercionSuite extends PlanTest { } } - val dp = HiveTypeCoercion.DecimalPrecision + val dp = HiveTypeCoercion.WidenSetOperationTypes val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) val right1 = LocalRelation( AttributeReference("r", DecimalType(5, 5))()) - val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5))) + val expectedType1 = Seq(DecimalType(10, 8)) val r1 = dp(Union(left1, right1)).asInstanceOf[Union] val r2 = dp(Except(left1, right1)).asInstanceOf[Except] @@ -372,12 +375,11 @@ class HiveTypeCoercionSuite extends PlanTest { checkOutput(r3.left, expectedType1) checkOutput(r3.right, expectedType1) - val plan1 = LocalRelation( - AttributeReference("l", DecimalType(10, 10))()) + val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))()) val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) - val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0), - DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15)) + val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5), + DecimalType(25, 5), DoubleType, DoubleType) rightTypes.zip(expectedTypes).map { case (rType, expectedType) => val plan2 = LocalRelation( @@ -401,6 +403,33 @@ class HiveTypeCoercionSuite extends PlanTest { } } + test("rule for date/timestamp operations") { + val dateTimeOperations = HiveTypeCoercion.DateTimeOperations + val date = Literal(new java.sql.Date(0L)) + val timestamp = Literal(new Timestamp(0L)) + val interval = Literal(new CalendarInterval(0, 0)) + val str = Literal("2015-01-01") + + ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType)) + ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType)) + ruleTest(dateTimeOperations, Add(timestamp, interval), + Cast(TimeAdd(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Add(interval, timestamp), + Cast(TimeAdd(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType)) + ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType)) + + ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType)) + ruleTest(dateTimeOperations, Subtract(timestamp, interval), + Cast(TimeSub(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType)) + + // interval operations should not be effected + ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval)) + ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval)) + } + + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e7e5231d32c9e..d03b0fbbfb2b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -116,9 +116,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("Abs") { testNumericDataTypes { convert => + val input = Literal(convert(1)) + val dataType = input.dataType checkEvaluation(Abs(Literal(convert(0))), convert(0)) checkEvaluation(Abs(Literal(convert(1))), convert(1)) checkEvaluation(Abs(Literal(convert(-1))), convert(1)) + checkEvaluation(Abs(Literal.create(null, dataType)), null) } } @@ -170,6 +173,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(-7, 3), 2) checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) - checkEvaluation(Pmod(2L, Long.MaxValue), 2) + checkEvaluation(Pmod(2L, Long.MaxValue), 2L) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 648fbf5a4c30b..fa30fbe528479 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -30,8 +30,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, ~1.toByte) - check(1000.toShort, ~1000.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, (~1.toByte).toByte) + check(1000.toShort, (~1000.toShort).toShort) check(1000000, ~1000000) check(123456789123L, ~123456789123L) @@ -45,8 +46,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte & 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort & 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort) check(1000000, 4, 1000000 & 4) check(123456789123L, 5L, 123456789123L & 5L) @@ -63,8 +65,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte | 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort | 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort) check(1000000, 4, 1000000 | 4) check(123456789123L, 5L, 123456789123L | 5L) @@ -81,8 +84,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte ^ 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort ^ 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort) check(1000000, 4, 1000000 ^ 4) check(123456789123L, 5L, 123456789123L ^ 5L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ccf448eee0688..1ad70733eae03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -21,9 +21,11 @@ import java.sql.{Timestamp, Date} import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for data type casting expression [[Cast]]. @@ -42,6 +44,42 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(v, Literal(expected).dataType), expected) } + private def checkNullCast(from: DataType, to: DataType): Unit = { + checkEvaluation(Cast(Literal.create(null, from), to), null) + } + + test("null cast") { + import DataTypeTestUtils._ + + // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic + // to ensure we test every possible cast situation here + atomicTypes.zip(atomicTypes).foreach { case (from, to) => + checkNullCast(from, to) + } + + atomicTypes.foreach(dt => checkNullCast(NullType, dt)) + atomicTypes.foreach(dt => checkNullCast(dt, StringType)) + checkNullCast(StringType, BinaryType) + checkNullCast(StringType, BooleanType) + checkNullCast(DateType, BooleanType) + checkNullCast(TimestampType, BooleanType) + numericTypes.foreach(dt => checkNullCast(dt, BooleanType)) + + checkNullCast(StringType, TimestampType) + checkNullCast(BooleanType, TimestampType) + checkNullCast(DateType, TimestampType) + numericTypes.foreach(dt => checkNullCast(dt, TimestampType)) + + atomicTypes.foreach(dt => checkNullCast(dt, DateType)) + + checkNullCast(StringType, CalendarIntervalType) + numericTypes.foreach(dt => checkNullCast(StringType, dt)) + numericTypes.foreach(dt => checkNullCast(BooleanType, dt)) + numericTypes.foreach(dt => checkNullCast(DateType, dt)) + numericTypes.foreach(dt => checkNullCast(TimestampType, dt)) + for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to) + } + test("cast string to date") { var c = Calendar.getInstance() c.set(2015, 0, 1, 0, 0, 0) @@ -68,8 +106,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - checkEvaluation(Cast(Literal("123"), TimestampType), - null) + checkEvaluation(Cast(Literal("123"), TimestampType), null) var c = Calendar.getInstance() c.set(2015, 0, 1, 0, 0, 0) @@ -185,7 +222,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1, 1.0) checkCast(123, "123") - checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) @@ -203,12 +240,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1L, 1.0) checkCast(123L, "123") - checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) - checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) + checkEvaluation(cast(123L, DecimalType(3, 1)), null) - // TODO: Fix the following bug and re-enable it. - // checkEvaluation(cast(123L, DecimalType(2, 0)), null) + checkEvaluation(cast(123L, DecimalType(2, 0)), null) } test("cast from boolean") { @@ -225,7 +261,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) - checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) @@ -267,7 +303,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast("abcdef", IntegerType).nullable === true) assert(cast("abcdef", ShortType).nullable === true) assert(cast("abcdef", ByteType).nullable === true) - assert(cast("abcdef", DecimalType.Unlimited).nullable === true) + assert(cast("abcdef", DecimalType.USER_DEFAULT).nullable === true) assert(cast("abcdef", DecimalType(4, 2)).nullable === true) assert(cast("abcdef", DoubleType).nullable === true) assert(cast("abcdef", FloatType).nullable === true) @@ -291,9 +327,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { c.getTimeInMillis * 1000) checkEvaluation(cast("abdef", StringType), "abdef") - checkEvaluation(cast("abdef", DecimalType.Unlimited), null) + checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) checkEvaluation(cast("abdef", TimestampType), null) - checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65)) + checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65)) checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) @@ -311,20 +347,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { 5.toLong) checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), - DecimalType.Unlimited), LongType), StringType), ShortType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), 0.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), - DecimalType.Unlimited), LongType), StringType), ShortType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) - checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited), + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), 0.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) checkEvaluation(cast("23", FloatType), 23f) - checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23)) + checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23)) checkEvaluation(cast("23", ByteType), 23.toByte) checkEvaluation(cast("23", ShortType), 23.toShort) checkEvaluation(cast("2012-12-11", DoubleType), null) @@ -338,7 +374,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d) checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24) checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f) - checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24)) + checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.USER_DEFAULT)), Decimal(24)) checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte) checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort) } @@ -362,10 +398,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { // - Values that would overflow the target precision should turn into null // - Because of this, casts to fixed-precision decimals should be nullable - assert(cast(123, DecimalType.Unlimited).nullable === false) - assert(cast(10.03f, DecimalType.Unlimited).nullable === true) - assert(cast(10.03, DecimalType.Unlimited).nullable === true) - assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false) + assert(cast(123, DecimalType.USER_DEFAULT).nullable === true) + assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true) assert(cast(123, DecimalType(2, 1)).nullable === true) assert(cast(10.03f, DecimalType(2, 1)).nullable === true) @@ -373,7 +409,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) - checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0)) checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10)) @@ -383,7 +419,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0)) checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null) - checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1)) checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10)) @@ -409,10 +445,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) - checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null) - checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null) - checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null) - checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null) + checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null) checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) @@ -427,7 +463,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(d, LongType), null) checkEvaluation(cast(d, FloatType), null) checkEvaluation(cast(d, DoubleType), null) - checkEvaluation(cast(d, DecimalType.Unlimited), null) + checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(d, DecimalType(10, 2)), null) checkEvaluation(cast(d, StringType), "1970-01-01") checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00") @@ -454,7 +490,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast(cast(millis.toDouble / 1000, TimestampType), DoubleType), millis.toDouble / 1000) checkEvaluation( - cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited), + cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), Decimal(1)) // A test for higher precision than millis @@ -472,6 +508,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val array_notNull = Literal.create(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) + { val ret = cast(array, ArrayType(IntegerType, containsNull = true)) assert(ret.resolved === true) @@ -525,6 +563,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) + checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) + { val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) assert(ret.resolved === true) @@ -579,15 +619,30 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from struct") { + checkNullCast( + StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))), + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType)))) + val struct = Literal.create( - InternalRow("123", "abc", "", null), + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("abc"), + UTF8String.fromString(""), + null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) val struct_notNull = Literal.create( - InternalRow("123", "abc", ""), + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("abc"), + UTF8String.fromString("")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -675,10 +730,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( - InternalRow( + Row( Seq("123", "abc", ""), - Map("a" -> "123", "b" -> "abc", "c" -> ""), - InternalRow(0)), + Map("a" ->"123", "b" -> "abc", "c" -> ""), + Row(0)), StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false), nullable = true), @@ -698,20 +753,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("l", LongType, nullable = true))))))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow( + checkEvaluation(ret, Row( Seq(123, null, null), Map("a" -> true, "b" -> true, "c" -> false), - InternalRow(0L))) + Row(0L))) } test("case between string and interval") { - import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.CalendarInterval - checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType), - new Interval(-3, 7 * Interval.MICROS_PER_HOUR)) + checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), + new CalendarInterval(-3, 7 * CalendarInterval.MICROS_PER_HOUR)) checkEvaluation(Cast(Literal.create( - new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType), + new CalendarInterval(15, -3 * CalendarInterval.MICROS_PER_DAY), CalendarIntervalType), + StringType), "interval 1 years 3 months -3 days") } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index a8aee8f634e03..3fa246b69d1f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -109,7 +110,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { expr.dataType match { case ArrayType(StructType(fields), containsNull) => val field = fields.find(_.name == fieldName).get - GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull) } } @@ -131,6 +132,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } test("CreateStruct") { @@ -138,24 +140,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val c1 = 'a.int.at(0) val c3 = 'c.int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null)) } test("CreateNamedStruct") { - val row = InternalRow(1, 2, 3) + val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) val c3 = 'c.int.at(2) - checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) - } - - test("CreateNamedStruct with literal field") { - val row = InternalRow(1, 2, 3) - val c1 = 'a.int.at(0) - checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) - } - - test("CreateNamedStruct from all literal fields") { - checkEvaluation( - CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), + create_row(1, UTF8String.fromString("y")), row) + checkEvaluation(CreateNamedStruct(Seq("a", "x", "b", 2.0)), + create_row(UTF8String.fromString("x"), 2.0)) + checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))), + create_row(null)) } test("test dsl for complex type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index afa143bd5f331..d26bcdb2902ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -60,7 +60,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toFloat, FloatType) testIf(_.toDouble, DoubleType) - testIf(Decimal(_), DecimalType.Unlimited) + testIf(Decimal(_), DecimalType.USER_DEFAULT) testIf(identity, DateType) testIf(_.toLong, TimestampType) @@ -149,6 +149,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row) checkEvaluation(Least(Seq(c4, c5, c3, c3, Literal("a"))), "a", row) + val nullLiteral = Literal.create(null, IntegerType) + checkEvaluation(Least(Seq(nullLiteral, nullLiteral)), null) checkEvaluation(Least(Seq(Literal(null), Literal(null))), null, InternalRow.empty) checkEvaluation(Least(Seq(Literal(-1.0), Literal(2.5))), -1.0, InternalRow.empty) checkEvaluation(Least(Seq(Literal(-1), Literal(2))), -1, InternalRow.empty) @@ -188,6 +190,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row) checkEvaluation(Greatest(Seq(c4, c5, c3, Literal("ccc"))), "ccc", row) + val nullLiteral = Literal.create(null, IntegerType) + checkEvaluation(Greatest(Seq(nullLiteral, nullLiteral)), null) checkEvaluation(Greatest(Seq(Literal(null), Literal(null))), null, InternalRow.empty) checkEvaluation(Greatest(Seq(Literal(-1.0), Literal(2.5))), 2.5, InternalRow.empty) checkEvaluation(Greatest(Seq(Literal(-1), Literal(2))), 2, InternalRow.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f724bab4d8839..887e43621a941 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Calendar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{StringType, TimestampType, DateType} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.sql.types._ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -31,82 +33,48 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + test("datetime function current_date") { + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] + val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) + } + + test("datetime function current_timestamp") { + val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long]) + val t1 = System.currentTimeMillis() + assert(math.abs(t1 - ct.getTime) < 5000) + } + test("DayOfYear") { val sdfDay = new SimpleDateFormat("D") - (2002 to 2004).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), - sdfDay.format(c.getTime).toInt) - } - } - } - (1998 to 2002).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), - sdfDay.format(c.getTime).toInt) - } - } - } - - (1969 to 1970).foreach { y => - (0 to 11).foreach { m => + (0 to 3).foreach { m => (0 to 5).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), - sdfDay.format(c.getTime).toInt) - } - } - } - - (2402 to 2404).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), - sdfDay.format(c.getTime).toInt) - } - } - } - - (2398 to 2402).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } } + checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) } test("Year") { checkEvaluation(Year(Literal.create(null, DateType)), null) - checkEvaluation(Year(Cast(Literal(d), DateType)), 2015) + checkEvaluation(Year(Literal(d)), 2015) checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType)), 2015) checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) val c = Calendar.getInstance() - (2000 to 2010).foreach { y => + (2000 to 2002).foreach { y => (0 to 11 by 11).foreach { m => c.set(y, m, 28) (0 to 5 * 24).foreach { i => c.add(Calendar.HOUR_OF_DAY, 1) - checkEvaluation(Year(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Year(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.YEAR)) } } @@ -115,7 +83,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Quarter") { checkEvaluation(Quarter(Literal.create(null, DateType)), null) - checkEvaluation(Quarter(Cast(Literal(d), DateType)), 2) + checkEvaluation(Quarter(Literal(d)), 2) checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType)), 2) checkEvaluation(Quarter(Cast(Literal(ts), DateType)), 4) @@ -125,7 +93,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(y, m, 28, 0, 0, 0) (0 to 5 * 24).foreach { i => c.add(Calendar.HOUR_OF_DAY, 1) - checkEvaluation(Quarter(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Quarter(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.MONTH) / 3 + 1) } } @@ -134,29 +102,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Month") { checkEvaluation(Month(Literal.create(null, DateType)), null) - checkEvaluation(Month(Cast(Literal(d), DateType)), 4) + checkEvaluation(Month(Literal(d)), 4) checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType)), 4) checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) (2003 to 2004).foreach { y => - (0 to 11).foreach { m => - (0 to 5 * 24).foreach { i => + (0 to 3).foreach { m => + (0 to 2 * 24).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), - c.get(Calendar.MONTH) + 1) - } - } - } - - (1999 to 2000).foreach { y => - (0 to 11).foreach { m => - (0 to 5 * 24).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.MONTH) + 1) } } @@ -166,7 +122,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Day / DayOfMonth") { checkEvaluation(DayOfMonth(Cast(Literal("2000-02-29"), DateType)), 29) checkEvaluation(DayOfMonth(Literal.create(null, DateType)), null) - checkEvaluation(DayOfMonth(Cast(Literal(d), DateType)), 8) + checkEvaluation(DayOfMonth(Literal(d)), 8) checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType)), 8) checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType)), 8) @@ -175,7 +131,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(y, 0, 1, 0, 0, 0) (0 to 365).foreach { d => c.add(Calendar.DATE, 1) - checkEvaluation(DayOfMonth(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfMonth(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.DAY_OF_MONTH)) } } @@ -190,14 +146,14 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() (0 to 60 by 5).foreach { s => c.set(2015, 18, 3, 3, 5, s) - checkEvaluation(Second(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.SECOND)) } } test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) - checkEvaluation(WeekOfYear(Cast(Literal(d), DateType)), 15) + checkEvaluation(WeekOfYear(Literal(d)), 15) checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) @@ -223,7 +179,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 60 by 15).foreach { m => (0 to 60 by 15).foreach { s => c.set(2015, 18, 3, h, m, s) - checkEvaluation(Hour(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Hour(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.HOUR_OF_DAY)) } } @@ -240,10 +196,214 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 60 by 5).foreach { m => (0 to 60 by 15).foreach { s => c.set(2015, 18, 3, 3, m, s) - checkEvaluation(Minute(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Minute(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.MINUTE)) } } } + test("date_add") { + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29"))) + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(-365)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28"))) + checkEvaluation(DateAdd(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(DateAdd(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)), + null) + checkEvaluation(DateAdd(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("date_sub") { + checkEvaluation( + DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2014-12-31"))) + checkEvaluation( + DateSub(Literal(Date.valueOf("2015-01-01")), Literal(-1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-02"))) + checkEvaluation(DateSub(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(DateSub(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)), + null) + checkEvaluation(DateSub(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("time_add") { + checkEvaluation( + TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal(new CalendarInterval(1, 123000L))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00.123"))) + + checkEvaluation( + TimeAdd(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), + null) + checkEvaluation( + TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal.create(null, CalendarIntervalType)), + null) + checkEvaluation( + TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), + null) + } + + test("time_sub") { + checkEvaluation( + TimeSub(Literal(Timestamp.valueOf("2016-03-31 10:00:00")), + Literal(new CalendarInterval(1, 0))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00"))) + checkEvaluation( + TimeSub( + Literal(Timestamp.valueOf("2016-03-30 00:00:01")), + Literal(new CalendarInterval(1, 2000000.toLong))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-28 23:59:59"))) + + checkEvaluation( + TimeSub(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), + null) + checkEvaluation( + TimeSub(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal.create(null, CalendarIntervalType)), + null) + checkEvaluation( + TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), + null) + } + + test("add_months") { + checkEvaluation(AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28"))) + checkEvaluation(AddMonths(Literal(Date.valueOf("2016-03-30")), Literal(-1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29"))) + checkEvaluation( + AddMonths(Literal(Date.valueOf("2015-01-30")), Literal.create(null, IntegerType)), + null) + checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("months_between") { + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("1997-02-28 10:30:00")), + Literal(Timestamp.valueOf("1996-10-30 00:00:00"))), + 3.94959677) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-01-30 11:52:00")), + Literal(Timestamp.valueOf("2015-01-30 11:50:00"))), + 0.0) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-01-31 00:00:00")), + Literal(Timestamp.valueOf("2015-03-31 22:00:00"))), + -2.0) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-03-31 22:00:00")), + Literal(Timestamp.valueOf("2015-02-28 00:00:00"))), + 1.0) + val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) + val tnull = Literal.create(null, TimestampType) + checkEvaluation(MonthsBetween(t, tnull), null) + checkEvaluation(MonthsBetween(tnull, t), null) + checkEvaluation(MonthsBetween(tnull, tnull), null) + } + + test("last_day") { + checkEvaluation(LastDay(Literal(Date.valueOf("2015-02-28"))), Date.valueOf("2015-02-28")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-03-27"))), Date.valueOf("2015-03-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-04-26"))), Date.valueOf("2015-04-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-05-25"))), Date.valueOf("2015-05-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-06-24"))), Date.valueOf("2015-06-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-07-23"))), Date.valueOf("2015-07-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-08-01"))), Date.valueOf("2015-08-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-09-02"))), Date.valueOf("2015-09-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-10-03"))), Date.valueOf("2015-10-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-11-04"))), Date.valueOf("2015-11-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-12-05"))), Date.valueOf("2015-12-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) + checkEvaluation(LastDay(Literal.create(null, DateType)), null) + } + + test("next_day") { + def testNextDay(input: String, dayOfWeek: String, output: String): Unit = { + checkEvaluation( + NextDay(Literal(Date.valueOf(input)), NonFoldableLiteral(dayOfWeek)), + DateTimeUtils.fromJavaDate(Date.valueOf(output))) + checkEvaluation( + NextDay(Literal(Date.valueOf(input)), Literal(dayOfWeek)), + DateTimeUtils.fromJavaDate(Date.valueOf(output))) + } + testNextDay("2015-07-23", "Mon", "2015-07-27") + testNextDay("2015-07-23", "mo", "2015-07-27") + testNextDay("2015-07-23", "Tue", "2015-07-28") + testNextDay("2015-07-23", "tu", "2015-07-28") + testNextDay("2015-07-23", "we", "2015-07-29") + testNextDay("2015-07-23", "wed", "2015-07-29") + testNextDay("2015-07-23", "Thu", "2015-07-30") + testNextDay("2015-07-23", "TH", "2015-07-30") + testNextDay("2015-07-23", "Fri", "2015-07-24") + testNextDay("2015-07-23", "fr", "2015-07-24") + + checkEvaluation(NextDay(Literal(Date.valueOf("2015-07-23")), Literal("xx")), null) + checkEvaluation(NextDay(Literal.create(null, DateType), Literal("xx")), null) + checkEvaluation( + NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) + } + + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0))) + checkEvaluation(FromUnixTime( + Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(1000000))) + checkEvaluation( + FromUnixTime(Literal(-1000L), Literal(fmt2)), sdf2.format(new Timestamp(-1000000))) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType)), null) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(FromUnixTime(Literal(1000L), Literal.create(null, StringType)), null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("not a valid format")), null) + } + + test("unix_timestamp") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3) + val date1 = Date.valueOf("2015-07-24") + checkEvaluation( + UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + UnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) + checkEvaluation( + UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) + checkEvaluation(UnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) + val t1 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(UnixTimestamp( + Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) + checkEvaluation( + UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 6e17ffcda9dc4..3c05e5c3b833c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -43,7 +43,7 @@ trait ExpressionEvalHelper { checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) - if (UnsafeColumnWriter.canEmbed(expression.dataType)) { + if (GenerateUnsafeProjection.canSupport(expression.dataType)) { checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) } checkEvaluationWithOptimization(expression, catalystValue, inputRow) @@ -64,6 +64,10 @@ trait ExpressionEvalHelper { } protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + expression.foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + } expression.eval(inputRow) } @@ -74,12 +78,11 @@ trait ExpressionEvalHelper { generator } catch { case e: Throwable => - val ctx = new CodeGenContext - val evaluated = expression.gen(ctx) fail( s""" |Code generation of $expression failed: |$e + |${e.getStackTraceString} """.stripMargin) } } @@ -109,10 +112,10 @@ trait ExpressionEvalHelper { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) - val actual = plan(inputRow).get(0) + val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } } @@ -144,7 +147,8 @@ trait ExpressionEvalHelper { if (actual != expectedRow) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + fail("Incorrect Evaluation in codegen mode: " + + s"$expression, actual: $actual, expected: $expectedRow$input") } if (actual.copy() != expectedRow) { fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") @@ -161,12 +165,21 @@ trait ExpressionEvalHelper { expression) val unsafeRow = plan(inputRow) - // UnsafeRow cannot be compared with GenericInternalRow directly - val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) - val expectedRow = InternalRow(expected) - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected) + val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } } } @@ -190,13 +203,14 @@ trait ExpressionEvalHelper { var plan = generateProject( GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - var actual = plan(inputRow).get(0) + var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - actual = FromUnsafeProjection(expression.dataType :: Nil)(plan(inputRow)).get(0) + actual = FromUnsafeProjection(expression.dataType :: Nil)( + plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index d924ff7a102f6..f6404d21611e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -33,7 +33,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, LongType), null) checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, BinaryType), null) - checkEvaluation(Literal.create(null, DecimalType()), null) + checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index a2b0fad7b7a04..9fcb548af6bbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -110,35 +110,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } - test("conv") { - checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") - checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") - checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) - checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) - checkEvaluation( - Conv(Literal("1234"), Literal(10), Literal(37)), null) - checkEvaluation( - Conv(Literal(""), Literal(10), Literal(16)), null) - checkEvaluation( - Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") - // If there is an invalid digit in the number, the longest valid prefix should be converted. - checkEvaluation( - Conv(Literal("11abc"), Literal(10), Literal(16)), "B") - } - private def checkNaN( - expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { checkNaNWithoutCodegen(expression, inputRow) checkNaNWithGeneratedProjection(expression, inputRow) checkNaNWithOptimization(expression, inputRow) } private def checkNaNWithoutCodegen( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -149,7 +131,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - private def checkNaNWithGeneratedProjection( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { @@ -158,7 +139,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) - val actual = plan(inputRow).apply(0) + val actual = plan(inputRow).get(0, expression.dataType) if (!actual.asInstanceOf[Double].isNaN) { fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") } @@ -172,6 +153,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } + test("conv") { + checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") + checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") + checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null) + checkEvaluation( + Conv(Literal("1234"), Literal(10), Literal(37)), null) + checkEvaluation( + Conv(Literal(""), Literal(10), Literal(16)), null) + checkEvaluation( + Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") + // If there is an invalid digit in the number, the longest valid prefix should be converted. + checkEvaluation( + Conv(Literal("11abc"), Literal(10), Literal(16)), "B") + } + test("e") { testLeaf(EulerNumber, math.E) } @@ -417,7 +417,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("round") { - val domain = -6 to 6 + val scales = -6 to 6 val doublePi: Double = math.Pi val shortPi: Short = 31415 val intPi: Int = 314159265 @@ -437,17 +437,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ Seq.fill(7)(31415926535897932L) - val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), - BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), - BigDecimal(3.141593), BigDecimal(3.1415927)) - - domain.zipWithIndex.foreach { case (scale, i) => + scales.zipWithIndex.foreach { case (scale, i) => checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) } + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null (0 to 7).foreach { i => @@ -456,5 +455,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) } + + DataTypeTestUtils.numericTypes.foreach { dataType => + checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null) + checkEvaluation(Round(Literal.create(null, dataType), + Literal.create(null, IntegerType)), null) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala new file mode 100644 index 0000000000000..0559fb80e7fce --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types._ + + +/** + * A literal value that is not foldable. Used in expression codegen testing to test code path + * that behave differently based on foldable values. + */ +case class NonFoldableLiteral(value: Any, dataType: DataType) + extends LeafExpression with CodegenFallback { + + override def foldable: Boolean = false + override def nullable: Boolean = true + + override def toString: String = if (value != null) value.toString else "null" + + override def eval(input: InternalRow): Any = value + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + Literal.create(value, dataType).genCode(ctx, ev) + } +} + + +object NonFoldableLiteral { + def apply(value: Any): NonFoldableLiteral = { + val lit = Literal(value) + NonFoldableLiteral(lit.value, lit.dataType) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala similarity index 76% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala index 99e11fd64b2b9..bf1c930c0bd0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala @@ -15,18 +15,20 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expression +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions. ExpressionEvalHelper -import org.apache.spark.sql.execution.expressions.{SparkPartitionID, MonotonicallyIncreasingID} class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("MonotonicallyIncreasingID") { - checkEvaluation(MonotonicallyIncreasingID(), 0) + checkEvaluation(MonotonicallyIncreasingID(), 0L) } test("SparkPartitionID") { - checkEvaluation(SparkPartitionID, 0) + checkEvaluation(SparkPartitionID(), 0) + } + + test("InputFileName") { + checkEvaluation(InputFileName(), "") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index 0728f6695c39d..ace6c15dc8418 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -30,7 +30,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testFunc(1L, LongType) testFunc(1.0F, FloatType) testFunc(1.0, DoubleType) - testFunc(Decimal(1.5), DecimalType.Unlimited) + testFunc(Decimal(1.5), DecimalType(2, 1)) testFunc(new java.sql.Date(10), DateType) testFunc(new java.sql.Timestamp(10), TimestampType) testFunc("abcd", StringType) @@ -92,7 +92,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val nullOnly = Seq(Literal("x"), Literal.create(null, DoubleType), - Literal.create(null, DecimalType.Unlimited), + Literal.create(null, DecimalType.USER_DEFAULT), Literal(Float.MaxValue), Literal(false)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 698c81ba24482..4a644d136f09c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -20,9 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.DoubleType - class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -30,4 +27,9 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) } + + test("SPARK-9127 codegen with long seed") { + checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001) + checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 0c70a0bf34420..248b0d02fb47d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -332,32 +332,31 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("soundex unit test") { - checkEvaluation(SoundEx(Literal("ZIN")), "Z500", create_row("s1")) - checkEvaluation(SoundEx(Literal("SU")), "S000", create_row("s2")) - checkEvaluation(SoundEx(Literal("")), "", create_row("s3")) - checkEvaluation(SoundEx(Literal.create(null, StringType)), null, create_row("s4")) + checkEvaluation(SoundEx(Literal("ZIN")), "Z500") + checkEvaluation(SoundEx(Literal("SU")), "S000") + checkEvaluation(SoundEx(Literal("")), "") + checkEvaluation(SoundEx(Literal.create(null, StringType)), null) // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here. - checkEvaluation(SoundEx(Literal("测试")), "测试", create_row("s5")) - checkEvaluation(SoundEx(Literal("z測試")), "z測試", create_row("s6")) - checkEvaluation(SoundEx(Literal("Tschüss")), "Tschüss", create_row("s7")) + checkEvaluation(SoundEx(Literal("测试")), "测试") + checkEvaluation(SoundEx(Literal("Tschüss")), "T220") // scalastyle:on - checkEvaluation(SoundEx(Literal("zZ")), "z000", create_row("s8")) - checkEvaluation(SoundEx(Literal("RAGSSEEESSSVEEWE")), "R221", create_row("s9")) - checkEvaluation(SoundEx(Literal("Ashcraft")), "A261", create_row("s10")) - checkEvaluation(SoundEx(Literal("Aswcraft")), "A261", create_row("s11")) - checkEvaluation(SoundEx(Literal("Tymczak")), "T522", create_row("s12")) - checkEvaluation(SoundEx(Literal("Pfister")), "P236", create_row("s13")) - checkEvaluation(SoundEx(Literal("Miller")), "M460", create_row("s14")) - checkEvaluation(SoundEx(Literal("Peterson")), "P362", create_row("s15")) - checkEvaluation(SoundEx(Literal("Peters")), "P362", create_row("s16")) - checkEvaluation(SoundEx(Literal("Auerbach")), "A612", create_row("s17")) - checkEvaluation(SoundEx(Literal("Uhrbach")), "U612", create_row("s18")) - checkEvaluation(SoundEx(Literal("Moskowitz")), "M232", create_row("s19")) - checkEvaluation(SoundEx(Literal("Moskovitz")), "M213", create_row("s20")) - checkEvaluation(SoundEx(Literal("relyheewsgeessg")), "r422", create_row("s21")) - checkEvaluation(SoundEx(Literal("!!")), "!!", create_row("s22")) + checkEvaluation(SoundEx(Literal("zZ")), "Z000", create_row("s8")) + checkEvaluation(SoundEx(Literal("RAGSSEEESSSVEEWE")), "R221") + checkEvaluation(SoundEx(Literal("Ashcraft")), "A261") + checkEvaluation(SoundEx(Literal("Aswcraft")), "A261") + checkEvaluation(SoundEx(Literal("Tymczak")), "T522") + checkEvaluation(SoundEx(Literal("Pfister")), "P236") + checkEvaluation(SoundEx(Literal("Miller")), "M460") + checkEvaluation(SoundEx(Literal("Peterson")), "P362") + checkEvaluation(SoundEx(Literal("Peters")), "P362") + checkEvaluation(SoundEx(Literal("Auerbach")), "A612") + checkEvaluation(SoundEx(Literal("Uhrbach")), "U612") + checkEvaluation(SoundEx(Literal("Moskowitz")), "M232") + checkEvaluation(SoundEx(Literal("Moskovitz")), "M213") + checkEvaluation(SoundEx(Literal("relyheewsgeessg")), "R422") + checkEvaluation(SoundEx(Literal("!!")), "!!") } test("TRIM/LTRIM/RTRIM") { @@ -377,6 +376,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) // scalastyle:on + checkEvaluation(StringTrim(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null) } test("FORMAT") { @@ -420,6 +422,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s3 = 'c.string.at(2) val s4 = 'd.int.at(3) val row1 = create_row("aaads", "aa", "zz", 1) + val row2 = create_row(null, "aa", "zz", 0) + val row3 = create_row("aaads", null, "zz", 0) + val row4 = create_row(null, null, null, 0) checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) @@ -431,6 +436,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringLocate(s2, s1, s4), 2, row1) checkEvaluation(new StringLocate(s3, s1), 0, row1) checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1) + checkEvaluation(new StringLocate(s2, s1), null, row2) + checkEvaluation(new StringLocate(s2, s1), null, row3) + checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4) } test("LPAD/RPAD") { @@ -477,6 +485,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("abccc") checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) checkEvaluation(StringReverse(s), "cccba", row1) + checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { @@ -495,6 +504,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("100-200", "(\\d+)", "num") val row2 = create_row("100-200", "(\\d+)", "###") val row3 = create_row("100-200", "(-)", "###") + val row4 = create_row(null, "(\\d+)", "###") + val row5 = create_row("100-200", null, "###") + val row6 = create_row("100-200", "(-)", null) val s = 's.string.at(0) val p = 'p.string.at(1) @@ -504,6 +516,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, "num-num", row1) checkEvaluation(expr, "###-###", row2) checkEvaluation(expr, "100###200", row3) + checkEvaluation(expr, null, row4) + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) } test("RegexExtract") { @@ -511,6 +526,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) val row3 = create_row("100-200", "(\\d+).*", 1) val row4 = create_row("100-200", "([a-z])", 1) + val row5 = create_row(null, "([a-z])", 1) + val row6 = create_row("100-200", null, 1) + val row7 = create_row("100-200", "([a-z])", null) val s = 's.string.at(0) val p = 'p.string.at(1) @@ -521,6 +539,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, "200", row2) checkEvaluation(expr, "100", row3) checkEvaluation(expr, "", row4) // will not match anything, empty string get + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) + checkEvaluation(expr, null, row7) val expr1 = new RegExpExtract(s, p) checkEvaluation(expr1, "100", row1) @@ -530,11 +551,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) val row1 = create_row("aa2bb3cc", "[1-9]+") + val row2 = create_row(null, "[1-9]+") + val row3 = create_row("aa2bb3cc", null) checkEvaluation( StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + checkEvaluation(StringSplit(s1, s2), null, row2) + checkEvaluation(StringSplit(s1, s2), null, row3) } test("length for string / binary") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index c9667e90a0aaa..c6b4c729de2f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -24,9 +24,8 @@ import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} import org.apache.spark.unsafe.types.UTF8String @@ -35,11 +34,12 @@ class UnsafeFixedWidthAggregationMapSuite with Matchers with BeforeAndAfterEach { + import UnsafeFixedWidthAggregationMap._ + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) - private def emptyProjection: Projection = - GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)())) private def emptyAggregationBuffer: InternalRow = InternalRow(0) + private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes private var memoryManager: TaskMemoryManager = null @@ -54,13 +54,24 @@ class UnsafeFixedWidthAggregationMapSuite } } + test("supported schemas") { + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + test("empty map") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, - 1024, // initial capacity + 1024, // initial capacity, + PAGE_SIZE_BYTES, false // disable perf metrics ) assert(!map.iterator().hasNext) @@ -69,11 +80,12 @@ class UnsafeFixedWidthAggregationMapSuite test("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, 1024, // initial capacity + PAGE_SIZE_BYTES, false // disable perf metrics ) val groupKey = InternalRow(UTF8String.fromString("cats")) @@ -95,11 +107,12 @@ class UnsafeFixedWidthAggregationMapSuite test("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, 128, // initial capacity + PAGE_SIZE_BYTES, false // disable perf metrics ) val rand = new Random(42) @@ -116,32 +129,4 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } - test("with decimal in the key and values") { - val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil) - val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil) - val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))), - Seq(AttributeReference("price", DecimalType.Unlimited)())) - val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), - memoryManager, - 1, // initial capacity - false // disable perf metrics - ) - - (0 until 100).foreach { i => - val groupKey = InternalRow(Decimal(i % 10)) - val row = map.getAggregationBuffer(groupKey) - row.update(0, Decimal(i)) - } - val seenKeys: Set[Int] = map.iterator().asScala.map { entry => - entry.key.getAs[Decimal](0).toInt - }.toSet - seenKeys.size should be (10) - seenKeys should be ((0 until 10).toSet) - - map.free() - } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index dff5faf9f6ec8..a0e1701339ea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,9 +24,8 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -34,28 +33,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) row.setInt(2, 2) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (3 * 8)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) + val unsafeRow: UnsafeRow = converter.apply(row) + assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) - // We can copy UnsafeRows as long as they don't reference ObjectPools val unsafeRowCopy = unsafeRow.copy() assert(unsafeRowCopy.getLong(0) === 0) assert(unsafeRowCopy.getLong(1) === 1) @@ -74,85 +64,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with primitive, string and binary types") { val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 3) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow( - row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - val pool = new ObjectPool(10) - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) - assert(unsafeRow.getLong(0) === 0) - assert(unsafeRow.getString(1) === "Hello") - assert(unsafeRow.get(2) === "World".getBytes) - - unsafeRow.update(1, UTF8String.fromString("World")) - assert(unsafeRow.getString(1) === "World") - assert(pool.size === 0) - unsafeRow.update(1, UTF8String.fromString("Hello World")) - assert(unsafeRow.getString(1) === "Hello World") - assert(pool.size === 1) - - unsafeRow.update(2, "World".getBytes) - assert(unsafeRow.get(2) === "World".getBytes) - assert(pool.size === 1) - unsafeRow.update(2, "Hello World".getBytes) - assert(unsafeRow.get(2) === "Hello World".getBytes) - assert(pool.size === 2) - - // We do not support copy() for UnsafeRows that reference ObjectPools - intercept[UnsupportedOperationException] { - unsafeRow.copy() - } - } - - test("basic conversion with primitive, decimal and array") { - val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType)) - val converter = new UnsafeRowConverter(fieldTypes) - val row = new SpecificMutableRow(fieldTypes) - row.setLong(0, 0) - row.update(1, Decimal(1)) - row.update(2, Array(2)) - - val pool = new ObjectPool(10) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 3)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool) - assert(numBytesWritten === sizeRequired) - assert(pool.size === 2) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) assert(unsafeRow.getLong(0) === 0) - assert(unsafeRow.get(1) === Decimal(1)) - assert(unsafeRow.get(2) === Array(2)) - - unsafeRow.update(1, Decimal(2)) - assert(unsafeRow.get(1) === Decimal(2)) - unsafeRow.update(2, Array(3, 4)) - assert(unsafeRow.get(2) === Array(3, 4)) - assert(pool.size === 2) + assert(unsafeRow.getString(1) === "Hello") + assert(unsafeRow.getBinary(2) === "World".getBytes) } test("basic conversion with primitive, string, date and timestamp types") { val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) @@ -160,30 +91,23 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 4) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) + assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01")) // Timestamp is represented as Long in unsafeRow DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be - (Timestamp.valueOf("2015-05-08 08:10:25")) + (Timestamp.valueOf("2015-05-08 08:10:25")) unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22")) unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25"))) DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be - (Timestamp.valueOf("2015-06-22 08:10:25")) + (Timestamp.valueOf("2015-06-22 08:10:25")) } test("null handling") { @@ -198,31 +122,22 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DoubleType, StringType, BinaryType, - DecimalType.Unlimited, - ArrayType(IntegerType) + DecimalType.USER_DEFAULT + // ArrayType(IntegerType) ) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { val r = new SpecificMutableRow(fieldTypes) - for (i <- 0 to fieldTypes.length - 1) { + for (i <- fieldTypes.indices) { r.setNullAt(i) } r } - val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) - val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired, null) - assert(numBytesWritten === sizeRequired) - - val createdFromNull = new UnsafeRow() - createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, - sizeRequired, null) - for (i <- 0 to fieldTypes.length - 1) { + val createdFromNull: UnsafeRow = converter.apply(rowWithAllNullColumns) + + for (i <- fieldTypes.indices) { assert(createdFromNull.isNullAt(i)) } assert(createdFromNull.getBoolean(1) === false) @@ -230,12 +145,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getShort(3) === 0) assert(createdFromNull.getInt(4) === 0) assert(createdFromNull.getLong(5) === 0) - assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) - assert(createdFromNull.getString(8) === null) - assert(createdFromNull.get(9) === null) - assert(createdFromNull.get(10) === null) - assert(createdFromNull.get(11) === null) + assert(createdFromNull.getFloat(6) === 0.0f) + assert(createdFromNull.getDouble(7) === 0.0d) + assert(createdFromNull.getUTF8String(8) === null) + assert(createdFromNull.getBinary(9) === null) + assert(createdFromNull.getDecimal(10, 10, 0) === null) + // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those // columns, then the serialized row representation should be identical to what we would get by @@ -252,20 +167,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) - r.update(10, Decimal(10)) - r.update(11, Array(11)) + r.setDecimal(10, Decimal(10), 10) + // r.update(11, Array(11)) r } - val pool = new ObjectPool(1) - val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) - converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired, pool) - val setToNullAfterCreation = new UnsafeRow() - setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, - sizeRequired, pool) + val setToNullAfterCreation = converter.apply(rowWithNoNullColumns) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) @@ -275,19 +182,16 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) - assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) - assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) + assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) + // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) - for (i <- 0 to fieldTypes.length - 1) { - if (i >= 8) { - setToNullAfterCreation.update(i, null) - } + for (i <- fieldTypes.indices) { setToNullAfterCreation.setNullAt(i) } // There are some garbage left in the var-length area - assert(Arrays.equals(createdFromNullBuffer, - java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8))) + assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes())) setToNullAfterCreation.setNullAt(0) setToNullAfterCreation.setBoolean(1, false) @@ -297,10 +201,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setLong(5, 500) setToNullAfterCreation.setFloat(6, 600) setToNullAfterCreation.setDouble(7, 700) - setToNullAfterCreation.update(8, UTF8String.fromString("hello")) - setToNullAfterCreation.update(9, "world".getBytes) - setToNullAfterCreation.update(10, Decimal(10)) - setToNullAfterCreation.update(11, Array(11)) + // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) + // setToNullAfterCreation.update(9, "world".getBytes) + setToNullAfterCreation.setDecimal(10, Decimal(10), 10) + // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -310,10 +214,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) - assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) - assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) - assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) + // assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) + // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } test("NaN canonicalization") { @@ -327,15 +232,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - val converter = new UnsafeRowConverter(fieldTypes) - val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1)) - val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2)) - converter.writeRow( - row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null) - converter.writeRow( - row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null) - - assert(row1Buffer.toSeq === row2Buffer.toSeq) + val converter = UnsafeProjection.create(fieldTypes) + assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala new file mode 100644 index 0000000000000..46daa3eb8bf80 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + + +class CodeFormatterSuite extends SparkFunSuite { + + def testCase(name: String)(input: String)(expected: String): Unit = { + test(name) { + assert(CodeFormatter.format(input).trim === expected.trim) + } + } + + testCase("basic example") { + """ + |class A { + |blahblah; + |} + """.stripMargin + }{ + """ + |class A { + | blahblah; + |} + """.stripMargin + } + + testCase("nested example") { + """ + |class A { + | if (c) { + |duh; + |} + |} + """.stripMargin + } { + """ + |class A { + | if (c) { + | duh; + | } + |} + """.stripMargin + } + + testCase("single line") { + """ + |class A { + | if (c) {duh;} + |} + """.stripMargin + }{ + """ + |class A { + | if (c) {duh;} + |} + """.stripMargin + } + + testCase("if else on the same line") { + """ + |class A { + | if (c) {duh;} else {boo;} + |} + """.stripMargin + }{ + """ + |class A { + | if (c) {duh;} else {boo;} + |} + """.stripMargin + } + + testCase("function calls") { + """ + |foo( + |a, + |b, + |c) + """.stripMargin + }{ + """ + |foo( + | a, + | b, + | c) + """.stripMargin + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala new file mode 100644 index 0000000000000..2d3f98dbbd3d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * A test suite that makes sure code generation handles expression internally states correctly. + */ +class CodegenExpressionCachingSuite extends SparkFunSuite { + + test("GenerateUnsafeProjection should initialize expressions") { + // Use an Add to wrap two of them together in case we only initialize the top level expressions. + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = UnsafeProjection.create(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateProjection.generate(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateMutableProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateMutableProjection.generate(Seq(expr))() + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GeneratePredicate should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GeneratePredicate.generate(expr) + assert(instance.apply(null) === false) + } + + test("GenerateUnsafeProjection should not share expression instances") { + val expr1 = MutableExpression() + val instance1 = UnsafeProjection.create(Seq(expr1)) + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = UnsafeProjection.create(Seq(expr2)) + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GenerateProjection should not share expression instances") { + val expr1 = MutableExpression() + val instance1 = GenerateProjection.generate(Seq(expr1)) + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GenerateProjection.generate(Seq(expr2)) + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GenerateMutableProjection should not share expression instances") { + val expr1 = MutableExpression() + val instance1 = GenerateMutableProjection.generate(Seq(expr1))() + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GenerateMutableProjection.generate(Seq(expr2))() + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GeneratePredicate should not share expression instances") { + val expr1 = MutableExpression() + val instance1 = GeneratePredicate.generate(expr1) + assert(instance1.apply(null) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GeneratePredicate.generate(expr2) + assert(instance1.apply(null) === false) + assert(instance2.apply(null) === true) + } + +} + +/** + * An expression that's non-deterministic and doesn't support codegen. + */ +case class NondeterministicExpression() + extends LeafExpression with Nondeterministic with CodegenFallback { + override protected def initInternal(): Unit = { } + override protected def evalInternal(input: InternalRow): Any = false + override def nullable: Boolean = false + override def dataType: DataType = BooleanType +} + + +/** + * An expression with mutable state so we can change it freely in our test suite. + */ +case class MutableExpression() extends LeafExpression with CodegenFallback { + var mutableState: Boolean = false + override def eval(input: InternalRow): Any = mutableState + + override def nullable: Boolean = false + override def dataType: DataType = BooleanType +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala similarity index 72% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index df29a62ff0e15..2d080b95b1292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -19,14 +19,17 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class ReplaceDistinctWithAggregateSuite extends PlanTest { +class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil + val batches = Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: Nil } test("replace distinct with aggregate") { @@ -39,4 +42,16 @@ class ReplaceDistinctWithAggregateSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("remove literals in grouping expression") { + val input = LocalRelation('a.int, 'b.int) + + val query = + input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b)) + val optimized = Optimize.execute(query) + + val correctAnswer = input.groupBy('a)(sum('b)) + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index fab9eb9cd4c9f..60d2bcfe13757 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,47 +19,48 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{TimeZone, Calendar} +import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ class DateTimeUtilsSuite extends SparkFunSuite { private[this] def getInUTCDays(timestamp: Long): Int = { val tz = TimeZone.getDefault - ((timestamp + tz.getOffset(timestamp)) / DateTimeUtils.MILLIS_PER_DAY).toInt + ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt } test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) - val ns = DateTimeUtils.fromJavaTimestamp(now) + val ns = fromJavaTimestamp(now) assert(ns % 1000000L === 1) - assert(DateTimeUtils.toJavaTimestamp(ns) === now) + assert(toJavaTimestamp(ns) === now) List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => - val ts = DateTimeUtils.toJavaTimestamp(t) - assert(DateTimeUtils.fromJavaTimestamp(ts) === t) - assert(DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJavaTimestamp(ts)) === ts) + val ts = toJavaTimestamp(t) + assert(fromJavaTimestamp(ts) === t) + assert(toJavaTimestamp(fromJavaTimestamp(ts)) === ts) } } test("us and julian day") { - val (d, ns) = DateTimeUtils.toJulianDay(0) - assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH) - assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND) - assert(DateTimeUtils.fromJulianDay(d, ns) == 0L) + val (d, ns) = toJulianDay(0) + assert(d === JULIAN_DAY_OF_EPOCH) + assert(ns === SECONDS_PER_DAY / 2 * NANOS_PER_SECOND) + assert(fromJulianDay(d, ns) == 0L) val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) - val (d1, ns1) = DateTimeUtils.toJulianDay(DateTimeUtils.fromJavaTimestamp(t)) - val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1)) + val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) + val t2 = toJavaTimestamp(fromJulianDay(d1, ns1)) assert(t.equals(t2)) } test("SPARK-6785: java date conversion before and after epoch") { def checkFromToJavaDate(d1: Date): Unit = { - val d2 = DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d1)) + val d2 = toJavaDate(fromJavaDate(d1)) assert(d2.toString === d1.toString) } @@ -95,157 +96,156 @@ class DateTimeUtilsSuite extends SparkFunSuite { } test("string to date") { - import DateTimeUtils.millisToDays var c = Calendar.getInstance() c.set(2015, 0, 28, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-01-28")).get === + assert(stringToDate(UTF8String.fromString("2015-01-28")).get === millisToDays(c.getTimeInMillis)) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015")).get === + assert(stringToDate(UTF8String.fromString("2015")).get === millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03")).get === + assert(stringToDate(UTF8String.fromString("2015-03")).get === millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 ")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18 ")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 123142")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18 123142")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T123123")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18T123123")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18T")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("20150318")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) + assert(stringToDate(UTF8String.fromString("20150318")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) } test("string to timestamp") { var c = Calendar.getInstance() c.set(1969, 11, 31, 16, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === + assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === c.getTimeInMillis * 1000) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015")).get === + assert(stringToTimestamp(UTF8String.fromString("2015")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17-13:53")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17-01:00")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17+07:03")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18 12:03:17.123")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 456) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.456Z")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18 12:03:17.456Z")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123-1:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123-01:00")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123121+7:30")).get === c.getTimeInMillis * 1000 + 121) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) @@ -254,7 +254,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("18:12:15")).get === c.getTimeInMillis * 1000) @@ -263,7 +263,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("T18:12:15.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) @@ -272,93 +272,130 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("18:12:15.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) c = Calendar.getInstance() c.set(2011, 4, 6, 7, 8, 9) c.set(Calendar.MILLISECOND, 100) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("238")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) } test("hours") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 13) + assert(getHours(c.getTimeInMillis * 1000) === 13) c.set(2015, 12, 8, 2, 7, 9) - assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 2) + assert(getHours(c.getTimeInMillis * 1000) === 2) } test("minutes") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 2) + assert(getMinutes(c.getTimeInMillis * 1000) === 2) c.set(2015, 2, 8, 2, 7, 9) - assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 7) + assert(getMinutes(c.getTimeInMillis * 1000) === 7) } test("seconds") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 11) + assert(getSeconds(c.getTimeInMillis * 1000) === 11) c.set(2015, 2, 8, 2, 7, 9) - assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 9) + assert(getSeconds(c.getTimeInMillis * 1000) === 9) } test("get day in year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) c.set(2012, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) } test("get year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2015) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2015) c.set(2012, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2012) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2012) } test("get quarter") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) + assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) c.set(2012, 11, 18, 0, 0, 0) - assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) + assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) } test("get month") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 3) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 3) c.set(2012, 11, 18, 0, 0, 0) - assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 12) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 12) } test("get day of month") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) + assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) c.set(2012, 11, 24, 0, 0, 0) - assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + } + + test("date add months") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + val days1 = millisToDays(c1.getTimeInMillis) + val c2 = Calendar.getInstance() + c2.set(2000, 1, 29) + assert(dateAddMonths(days1, 36) === millisToDays(c2.getTimeInMillis)) + c2.set(1996, 0, 31) + assert(dateAddMonths(days1, -13) === millisToDays(c2.getTimeInMillis)) + } + + test("timestamp add months") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + c1.set(Calendar.MILLISECOND, 0) + val ts1 = c1.getTimeInMillis * 1000L + val c2 = Calendar.getInstance() + c2.set(2000, 1, 29, 10, 30, 0) + c2.set(Calendar.MILLISECOND, 123) + val ts2 = c2.getTimeInMillis * 1000L + assert(timestampAddInterval(ts1, 36, 123000) === ts2) + } + + test("monthsBetween") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + val c2 = Calendar.getInstance() + c2.set(1996, 9, 30, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala deleted file mode 100644 index 94764df4b9cdb..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import org.scalatest.Matchers - -import org.apache.spark.SparkFunSuite - -class ObjectPoolSuite extends SparkFunSuite with Matchers { - - test("pool") { - val pool = new ObjectPool(1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - assert(pool.put(false) === 2) - - assert(pool.get(0) === 1) - assert(pool.get(1) === "hello") - assert(pool.get(2) === false) - assert(pool.size() === 3) - - pool.replace(1, "world") - assert(pool.get(1) === "world") - assert(pool.size() === 3) - } - - test("unique pool") { - val pool = new UniqueObjectPool(1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - - assert(pool.get(0) === 1) - assert(pool.get(1) === "hello") - assert(pool.size() === 2) - - intercept[UnsupportedOperationException] { - pool.replace(1, "world") - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index c6171b7b6916d..1ba290753ce48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -44,7 +44,7 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("float", FloatType) checkDataType("dOUBle", DoubleType) checkDataType("decimal(10, 5)", DecimalType(10, 5)) - checkDataType("decimal", DecimalType.Unlimited) + checkDataType("decimal", DecimalType.USER_DEFAULT) checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) @@ -87,7 +87,7 @@ class DataTypeParserSuite extends SparkFunSuite { StructType( StructField("struct", StructType( - StructField("deciMal", DecimalType.Unlimited, true) :: + StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 14e7b4a9561b6..88b221cd81d74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -185,7 +185,7 @@ class DataTypeSuite extends SparkFunSuite { checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) checkDataTypeJsonRepr(DecimalType(10, 5)) - checkDataTypeJsonRepr(DecimalType.Unlimited) + checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT) checkDataTypeJsonRepr(DateType) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) @@ -219,7 +219,7 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(FloatType, 4) checkDefaultSize(DoubleType, 8) checkDefaultSize(DecimalType(10, 5), 4096) - checkDefaultSize(DecimalType.Unlimited, 4096) + checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096) checkDefaultSize(DateType, 4) checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 4096) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 32632b5d6e342..0ee9ddac815b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -34,7 +34,7 @@ object DataTypeTestUtils { * decimal types. */ val fractionalTypes: Set[FractionalType] = Set( - DecimalType(precisionInfo = None), + DecimalType.SYSTEM_DEFAULT, DecimalType(2, 1), DoubleType, FloatType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6e2a6525bf17e..b25dcbca82b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -996,7 +996,7 @@ class ColumnName(name: String) extends Column(name) { * Creates a new [[StructField]] of type decimal. * @since 1.3.0 */ - def decimal: StructField = StructField(name, DecimalType.Unlimited) + def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT) /** * Creates a new [[StructField]] of type decimal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 323ff17357fda..3ea0f9ed3bddd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties +import org.apache.spark.unsafe.types.UTF8String + import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -38,8 +40,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} -import org.apache.spark.sql.execution.datasources.CreateTableUsingAsSelect -import org.apache.spark.sql.json.JacksonGenerator +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} +import org.apache.spark.sql.json.{JacksonGenerator, JSONRelation} +import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -137,8 +140,7 @@ class DataFrame private[sql]( // happen right away to let these side effects take place eagerly. case _: Command | _: InsertIntoTable | - _: CreateTableUsingAsSelect | - _: WriteToFile => + _: CreateTableUsingAsSelect => LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => queryExecution.analyzed @@ -1282,7 +1284,7 @@ class DataFrame private[sql]( val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList - val ret: Seq[InternalRow] = if (outputCols.nonEmpty) { + val ret: Seq[Row] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } @@ -1290,19 +1292,18 @@ class DataFrame private[sql]( val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq // Pivot the data so each summary is one row - row.grouped(outputCols.size).toSeq.zip(statistics).map { - case (aggregation, (statistic, _)) => - InternalRow(statistic :: aggregation.toList: _*) + row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) } } else { // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => InternalRow(name) } + statistics.map { case (name, _) => Row(name) } } // All columns are string type val schema = StructType( StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - LocalRelation(schema, ret) + LocalRelation.fromExternalRows(schema, ret) } /** @@ -1546,6 +1547,21 @@ class DataFrame private[sql]( } } + /** + * Returns a best-effort snapshot of the files that compose this DataFrame. This method simply + * asks each constituent BaseRelation for its respective files and takes the union of all results. + * Depending on the source relations, this may not find all input files. Duplicates are removed. + */ + def inputFiles: Array[String] = { + val files: Seq[String] = logicalPlan.collect { + case LogicalRelation(fsBasedRelation: HadoopFsRelation) => + fsBasedRelation.paths.toSeq + case LogicalRelation(jsonRelation: JSONRelation) => + jsonRelation.path.toSeq + }.flatten + files.toSet.toArray + } + //////////////////////////////////////////////////////////////////////////// // for Python API //////////////////////////////////////////////////////////////////////////// @@ -1614,11 +1630,7 @@ class DataFrame private[sql]( */ @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { - if (sqlContext.conf.parquetUseDataSourceApi) { - write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - } else { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } + write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e9d782cdcd667..eb09807f9d9c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,16 +21,16 @@ import java.util.Properties import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types.StructType +import org.apache.spark.{Logging, Partition} /** * :: Experimental :: @@ -259,7 +259,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { }.toArray sqlContext.baseRelationToDataFrame( - new ParquetRelation2( + new ParquetRelation( globbedPaths.map(_.toString), None, None, extraOptions.toMap)(sqlContext)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 587869e57f96e..2e68e358f2f1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.{util => ju, lang => jl} + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -77,7 +81,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * pair frequencies will be returned. * The first column of each row will be the distinct values of `col1` and the column names will * be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts - * will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. + * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts. * Null elements will be replaced by "null", and back ticks will be dropped from elements if they * exist. * @@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.{rand, udf} + val c = Column(col) + val r = rand(seed) + val f = udf { (stratum: Any, x: Double) => + x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) + } + df.filter(f(c, r)) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 05da05d7b8050..7e3318cefe62c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.Properties import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} @@ -159,15 +160,19 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - val partitions = - partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) - val overwrite = (mode == SaveMode.Overwrite) - df.sqlContext.executePlan(InsertIntoTable( - UnresolvedRelation(Seq(tableName)), - partitions.getOrElse(Map.empty[String, Option[String]]), - df.logicalPlan, - overwrite, - ifNotExists = false)).toRdd + insertInto(new SqlParser().parseTableIdentifier(tableName)) + } + + private def insertInto(tableIdent: TableIdentifier): Unit = { + val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val overwrite = mode == SaveMode.Overwrite + df.sqlContext.executePlan( + InsertIntoTable( + UnresolvedRelation(tableIdent.toSeq), + partitions.getOrElse(Map.empty[String, Option[String]]), + df.logicalPlan, + overwrite, + ifNotExists = false)).toRdd } /** @@ -183,35 +188,37 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - if (df.sqlContext.catalog.tableExists(tableName :: Nil) && mode != SaveMode.Overwrite) { - mode match { - case SaveMode.Ignore => - // Do nothing - - case SaveMode.ErrorIfExists => - throw new AnalysisException(s"Table $tableName already exists.") - - case SaveMode.Append => - // If it is Append, we just ask insertInto to handle it. We will not use insertInto - // to handle saveAsTable with Overwrite because saveAsTable can change the schema of - // the table. But, insertInto with Overwrite requires the schema of data be the same - // the schema of the table. - insertInto(tableName) - - case SaveMode.Overwrite => - throw new UnsupportedOperationException("overwrite mode unsupported.") - } - } else { - val cmd = - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), - mode, - extraOptions.toMap, - df.logicalPlan) - df.sqlContext.executePlan(cmd).toRdd + saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + } + + private def saveAsTable(tableIdent: TableIdentifier): Unit = { + val tableExists = df.sqlContext.catalog.tableExists(tableIdent.toSeq) + + (tableExists, mode) match { + case (true, SaveMode.Ignore) => + // Do nothing + + case (true, SaveMode.ErrorIfExists) => + throw new AnalysisException(s"Table $tableIdent already exists.") + + case (true, SaveMode.Append) => + // If it is Append, we just ask insertInto to handle it. We will not use insertInto + // to handle saveAsTable with Overwrite because saveAsTable can change the schema of + // the table. But, insertInto with Overwrite requires the schema of data be the same + // the schema of the table. + insertInto(tableIdent) + + case _ => + val cmd = + CreateTableUsingAsSelect( + tableIdent.unquotedString, + source, + temporary = false, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df.logicalPlan) + df.sqlContext.executePlan(cmd).toRdd } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 1474b170ba896..6644e85d4a037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -229,7 +229,7 @@ private[spark] object SQLConf { " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use the new optimized Tungsten physical execution backend.") val DIALECT = stringConf( @@ -247,6 +247,13 @@ private[spark] object SQLConf { "otherwise the schema is picked from the summary file or a random data file " + "if no summary file is available.") + val PARQUET_SCHEMA_RESPECT_SUMMARIES = booleanConf("spark.sql.parquet.respectSummaryFiles", + defaultValue = Some(false), + doc = "When true, we make assumption that all part-files of Parquet are consistent with " + + "summary files and we will ignore them when merging schema. Otherwise, if this is " + + "false, which is the default, we will merge all part-files. This should be considered " + + "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") + val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", defaultValue = Some(false), doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + @@ -273,16 +280,8 @@ private[spark] object SQLConf { "uncompressed, snappy, gzip, lzo.") val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", - defaultValue = Some(false), - doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default " + - "because of a known bug in Parquet 1.6.0rc3 " + - "(PARQUET-136, https://issues.apache.org/jira/browse/PARQUET-136). However, " + - "if your table doesn't contain any nullable string or binary columns, it's still safe to " + - "turn this feature on.") - - val PARQUET_USE_DATA_SOURCE_API = booleanConf("spark.sql.parquet.useDataSourceApi", defaultValue = Some(true), - doc = "") + doc = "Enables Parquet filter push-down optimization when set to true.") val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( key = "spark.sql.parquet.followParquetFormatSpec", @@ -309,6 +308,11 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") + val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", + defaultValue = Some(false), + doc = "When true, some predicates will be pushed down into the Hive metastore so that " + + "unmatching partitions can be eliminated earlier.") + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", defaultValue = Some("_corrupt_record"), doc = "") @@ -325,7 +329,7 @@ private[spark] object SQLConf { " memory.") val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") // This is only used for the thriftserver @@ -460,12 +464,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) - private[spark] def parquetUseDataSourceApi: Boolean = getConf(PARQUET_USE_DATA_SOURCE_API) - private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) + private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 49bfe74b680af..dbb2a09846548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -798,8 +798,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group ddl_ops * @since 1.3.0 */ - def table(tableName: String): DataFrame = - DataFrame(this, catalog.lookupRelation(Seq(tableName))) + def table(tableName: String): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + DataFrame(this, catalog.lookupRelation(tableIdent.toSeq)) + } /** * Returns a [[DataFrame]] containing names of existing tables in the current database. @@ -870,7 +872,6 @@ class SQLContext(@transient val sparkContext: SparkContext) LeftSemiJoin :: HashJoin :: InMemoryScans :: - ParquetOperations :: BasicOperators :: CartesianProduct :: BroadcastNestedLoopJoin :: Nil) @@ -1115,11 +1116,8 @@ class SQLContext(@transient val sparkContext: SparkContext) def parquetFile(paths: String*): DataFrame = { if (paths.isEmpty) { emptyDataFrame - } else if (conf.parquetUseDataSourceApi) { - read.parquet(paths : _*) } else { - DataFrame(this, parquet.ParquetRelation( - paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + read.parquet(paths : _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index e59fa6e162900..ea8fce6ca9cf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -21,7 +21,7 @@ import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions} import org.apache.spark.sql.execution._ import org.apache.spark.sql.types.StringType @@ -57,6 +57,10 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val AS = Keyword("AS") protected val CACHE = Keyword("CACHE") protected val CLEAR = Keyword("CLEAR") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") + protected val FUNCTION = Keyword("FUNCTION") + protected val FUNCTIONS = Keyword("FUNCTIONS") protected val IN = Keyword("IN") protected val LAZY = Keyword("LAZY") protected val SET = Keyword("SET") @@ -65,7 +69,8 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") - override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others + override protected lazy val start: Parser[LogicalPlan] = + cache | uncache | set | show | desc | others private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { @@ -85,9 +90,24 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case input => SetCommandParser(input) } + // It can be the following patterns: + // SHOW FUNCTIONS; + // SHOW FUNCTIONS mydb.func1; + // SHOW FUNCTIONS func1; + // SHOW FUNCTIONS `mydb.a`.`func1.aa`; private lazy val show: Parser[LogicalPlan] = - SHOW ~> TABLES ~ (IN ~> ident).? ^^ { - case _ ~ dbName => ShowTablesCommand(dbName) + ( SHOW ~> TABLES ~ (IN ~> ident).? ^^ { + case _ ~ dbName => ShowTablesCommand(dbName) + } + | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { + case Some(f) => ShowFunctions(f._1, Some(f._2)) + case None => ShowFunctions(None, None) + } + ) + + private lazy val desc: Parser[LogicalPlan] = + DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { + case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined) } private lazy val others: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala index 5b872f5e3eecd..0d4e30f29255e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Expression} -import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 931469bed634a..4c29a093218a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -41,9 +41,9 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( +private[sql] abstract class BasicColumnAccessor[JvmType]( protected val buffer: ByteBuffer, - protected val columnType: ColumnType[T, JvmType]) + protected val columnType: ColumnType[JvmType]) extends ColumnAccessor { protected def initialize() {} @@ -93,14 +93,14 @@ private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) + extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) -private[sql] class GenericColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) +private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType) + extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType)) with NullableColumnAccessor private[sql] class DateColumnAccessor(buffer: ByteBuffer) @@ -131,7 +131,7 @@ private[sql] object ColumnAccessor { case BinaryType => new BinaryColumnAccessor(dup) case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnAccessor(dup, precision, scale) - case _ => new GenericColumnAccessor(dup) + case other => new GenericColumnAccessor(dup, other) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 087c52239713d..1620fc401ba6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -46,9 +46,9 @@ private[sql] trait ColumnBuilder { def build(): ByteBuffer } -private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( +private[sql] class BasicColumnBuilder[JvmType]( val columnStats: ColumnStats, - val columnType: ColumnType[T, JvmType]) + val columnType: ColumnType[JvmType]) extends ColumnBuilder { protected var columnName: String = _ @@ -78,16 +78,16 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( } } -private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( +private[sql] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, - columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](columnStats, columnType) + columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](columnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#InternalType](columnStats, columnType) + extends BasicColumnBuilder[T#InternalType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] @@ -114,12 +114,12 @@ private[sql] class FixedDecimalColumnBuilder( precision: Int, scale: Int) extends NativeColumnBuilder( - new FixedDecimalColumnStats, + new FixedDecimalColumnStats(precision, scale), FIXED_DECIMAL(precision, scale)) // TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) +private[sql] class GenericColumnBuilder(dataType: DataType) + extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType)) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) @@ -164,7 +164,7 @@ private[sql] object ColumnBuilder { case BinaryType => new BinaryColumnBuilder case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnBuilder(precision, scale) - case _ => new GenericColumnBuilder + case other => new GenericColumnBuilder(other) } builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 00374d1fa3ef1..af1a8ecca9b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -211,7 +211,7 @@ private[sql] class StringColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[UTF8String] + val value = row.getUTF8String(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) @@ -234,14 +234,14 @@ private[sql] class BinaryColumnStats extends ColumnStats { InternalRow(null, null, nullCount, count, sizeInBytes) } -private[sql] class FixedDecimalColumnStats extends ColumnStats { +private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { protected var upper: Decimal = null protected var lower: Decimal = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Decimal] + val value = row.getDecimal(ordinal, precision, scale) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += FIXED_DECIMAL.defaultSize @@ -252,11 +252,13 @@ private[sql] class FixedDecimalColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class GenericColumnStats extends ColumnStats { +private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { + val columnType = GENERIC(dataType) + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += GENERIC.actualSize(row, ordinal) + sizeInBytes += columnType.actualSize(row, ordinal) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index fc72360c88fe1..30f8fe320db3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -31,14 +31,18 @@ import org.apache.spark.unsafe.types.UTF8String * An abstract class that represents type of a column. Used to append/extract Java objects into/from * the underlying [[ByteBuffer]] of a column. * - * @param typeId A unique ID representing the type. - * @param defaultSize Default size in bytes for one element of type T (e.g. 4 for `Int`). - * @tparam T Scala data type for the column. * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( - val typeId: Int, - val defaultSize: Int) { +private[sql] sealed abstract class ColumnType[JvmType] { + + // The catalyst data type of this column. + def dataType: DataType + + // A unique ID representing the type. + def typeId: Int + + // Default size in bytes for one element of type T (e.g. 4 for `Int`). + def defaultSize: Int /** * Extracts a value out of the buffer at the buffer's current position. @@ -90,7 +94,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * boxing/unboxing costs whenever possible. */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to(toOrdinal) = from(fromOrdinal) + to.update(toOrdinal, from.get(fromOrdinal, dataType)) } /** @@ -103,9 +107,9 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( private[sql] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, - typeId: Int, - defaultSize: Int) - extends ColumnType[T, T#InternalType](typeId, defaultSize) { + val typeId: Int, + val defaultSize: Int) + extends ColumnType[T#InternalType] { /** * Scala TypeTag. Can be used to create primitive arrays and hash tables. @@ -309,7 +313,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { override def actualSize(row: InternalRow, ordinal: Int): Int = { - row.getString(ordinal).getBytes("utf-8").length + 4 + row.getUTF8String(ordinal).numBytes() + 4 } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { @@ -329,11 +333,11 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } override def getField(row: InternalRow, ordinal: Int): UTF8String = { - row(ordinal).asInstanceOf[UTF8String] + row.getUTF8String(ordinal) } override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.update(toOrdinal, from(fromOrdinal)) + to.update(toOrdinal, from.getUTF8String(fromOrdinal)) } } @@ -347,7 +351,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { } override def getField(row: InternalRow, ordinal: Int): Int = { - row(ordinal).asInstanceOf[Int] + row.getInt(ordinal) } def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { @@ -365,7 +369,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { } override def getField(row: InternalRow, ordinal: Int): Long = { - row(ordinal).asInstanceOf[Long] + row.getLong(ordinal) } override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { @@ -375,7 +379,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) extends NativeColumnType( - DecimalType(Some(PrecisionInfo(precision, scale))), + DecimalType(precision, scale), 10, FIXED_DECIMAL.defaultSize) { @@ -388,7 +392,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) } override def getField(row: InternalRow, ordinal: Int): Decimal = { - row(ordinal).asInstanceOf[Decimal] + row.getDecimal(ordinal, precision, scale) } override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { @@ -400,10 +404,10 @@ private[sql] object FIXED_DECIMAL { val defaultSize = 8 } -private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( - typeId: Int, - defaultSize: Int) - extends ColumnType[T, Array[Byte]](typeId, defaultSize) { +private[sql] sealed abstract class ByteArrayColumnType( + val typeId: Int, + val defaultSize: Int) + extends ColumnType[Array[Byte]] { override def actualSize(row: InternalRow, ordinal: Int): Int = { getField(row, ordinal).length + 4 @@ -421,31 +425,34 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) { +private[sql] object BINARY extends ByteArrayColumnType(11, 16) { + + def dataType: DataType = BooleanType + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row(ordinal) = value + row.update(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - row(ordinal).asInstanceOf[Array[Byte]] + row.getBinary(ordinal) } } // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { +private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(12, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row(ordinal) = SparkSqlSerializer.deserialize[Any](value) + row.update(ordinal, SparkSqlSerializer.deserialize[Any](value)) } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - SparkSqlSerializer.serialize(row(ordinal)) + SparkSqlSerializer.serialize(row.get(ordinal, dataType)) } } private[sql] object ColumnType { - def apply(dataType: DataType): ColumnType[_, _] = { + def apply(dataType: DataType): ColumnType[_] = { dataType match { case BooleanType => BOOLEAN case ByteType => BYTE @@ -460,7 +467,7 @@ private[sql] object ColumnType { case BinaryType => BINARY case DecimalType.Fixed(precision, scale) if precision < 19 => FIXED_DECIMAL(precision, scale) - case _ => GENERIC + case other => GENERIC(other) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 38720968c1313..5d5b0697d7016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -134,13 +134,13 @@ private[sql] case class InMemoryRelation( // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat // hard to decipher. assert( - row.size == columnBuilders.size, - s"""Row column number mismatch, expected ${output.size} columns, but got ${row.size}. - |Row content: $row - """.stripMargin) + row.numFields == columnBuilders.size, + s"Row column number mismatch, expected ${output.size} columns, " + + s"but got ${row.numFields}." + + s"\nRow content: $row") var i = 0 - while (i < row.length) { + while (i < row.numFields) { columnBuilders(i).appendFrom(row, i) i += 1 } @@ -304,7 +304,7 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[InternalRow] { - private[this] val rowLen = nextRow.length + private[this] val rowLen = nextRow.numFields override def next(): InternalRow = { var i = 0 while (i < rowLen) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 4eaec6d853d4d..b1ef9b2ef7849 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -46,7 +46,7 @@ private[sql] trait Decoder[T <: AtomicType] { private[sql] trait CompressionScheme { def typeId: Int - def supports(columnType: ColumnType[_, _]): Boolean + def supports(columnType: ColumnType[_]): Boolean def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 5abc1259a19ab..c91d960a0932b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils private[sql] case object PassThrough extends CompressionScheme { override val typeId = 0 - override def supports(columnType: ColumnType[_, _]): Boolean = true + override def supports(columnType: ColumnType[_]): Boolean = true override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) @@ -78,7 +78,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { new this.Decoder(buffer, columnType) } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { + override def supports(columnType: ColumnType[_]): Boolean = columnType match { case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true case _ => false } @@ -128,7 +128,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { while (from.hasRemaining) { columnType.extract(from, value, 0) - if (value(0) == currentValue(0)) { + if (value.get(0, columnType.dataType) == currentValue.get(0, columnType.dataType)) { currentRun += 1 } else { // Writes current run @@ -189,7 +189,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { new this.Encoder[T](columnType) } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { + override def supports(columnType: ColumnType[_]): Boolean = columnType match { case INT | LONG | STRING => true case _ => false } @@ -304,7 +304,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { (new this.Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == BOOLEAN + override def supports(columnType: ColumnType[_]): Boolean = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 @@ -392,7 +392,7 @@ private[sql] case object IntDelta extends CompressionScheme { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == INT + override def supports(columnType: ColumnType[_]): Boolean = columnType == INT class Encoder extends compression.Encoder[IntegerType.type] { protected var _compressedSize: Int = 0 @@ -472,7 +472,7 @@ private[sql] case object LongDelta extends CompressionScheme { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == LONG + override def supports(columnType: ColumnType[_]): Boolean = columnType == LONG class Encoder extends compression.Encoder[LongType.type] { protected var _compressedSize: Int = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index c2c945321db95..e8c6a0f8f801d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -172,7 +172,7 @@ case class Aggregate( private[this] val resultProjection = new InterpretedMutableProjection( resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow4 + private[this] val joinedRow = new JoinedRow override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a4b38d364d54a..d3e5c378d037d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -84,7 +84,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { } /** - * Caches the data produced by the logical representation of the given schema rdd. Unlike + * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing * the in-memory columnar representation of the underlying table is expensive. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d31e265a293e9..70e5031fb63c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,7 +47,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = { + // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to + // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to + // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. + !newPartitioning.isInstanceOf[RangePartitioning] + } /** * Determines whether records must be defensively copied before being sent to the shuffle. @@ -224,13 +229,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // compatible. // TODO: ASSUMES TRANSITIVITY? def compatible: Boolean = - !operator.children + operator.children .map(_.outputPartitioning) .sliding(2) - .map { + .forall { case Seq(a) => true case Seq(a, b) => a.compatibleWith(b) - }.exists(!_) + } // Adds Exchange or Sort operators as required def addOperatorsIfNecessary( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 0e63f2fe29cb3..d851eae3fcc71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.types._ case class AggregateEvaluation( @@ -92,8 +92,8 @@ case class GeneratedAggregate( case s @ Sum(expr) => val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 10, s) case _ => expr.dataType } @@ -108,7 +108,7 @@ case class GeneratedAggregate( Add( Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType) - ) :: currentSum :: zero :: Nil) + ) :: currentSum :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -118,45 +118,6 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case cs @ CombineSum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val actualExpr = expr match { - case UnscaledValue(e) => e - case _ => expr - } - // partial sum result can be null only when no input rows present - val updateFunction = If( - IsNotNull(actualExpr), - Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType)) :: currentSum :: zero :: Nil), - currentSum) - - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, cs.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) @@ -239,6 +200,11 @@ case class GeneratedAggregate( StructType(fields) } + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupKeySchema) + } + child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -264,7 +230,7 @@ case class GeneratedAggregate( namedGroups.map(_._2) ++ computationSchema) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - val joinedRow = new JoinedRow3 + val joinedRow = new JoinedRow if (!iter.hasNext) { // This is an empty input, so return early so that we do not allocate data structures @@ -290,15 +256,18 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (unsafeEnabled) { + + } else if (unsafeEnabled && schemaSupportsUnsafe) { assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") + val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggregationBufferSchema), + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity + pageSizeBytes, false // disable tracking of performance metrics ) @@ -331,6 +300,9 @@ case class GeneratedAggregate( } } } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } val buffers = new java.util.HashMap[InternalRow, MutableRow]() var currentRow: InternalRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index cd341180b6100..34e926e4582be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -34,13 +34,11 @@ private[sql] case class LocalTableScan( protected override def doExecute(): RDD[InternalRow] = rdd - override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]).toArray } - override def executeTake(limit: Int): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2dee3542d6101..a2145b185ce90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -37,61 +35,15 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { - case StringType => PrefixComparators.STRING - case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType => PrefixComparators.FLOAT - case DoubleType => PrefixComparators.DOUBLE + case StringType if sortOrder.isAscending => PrefixComparators.STRING + case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC + case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending => + PrefixComparators.LONG + case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending => + PrefixComparators.LONG_DESC + case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE + case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC case _ => NoOpPrefixComparator } } - - def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { - sortOrder.dataType match { - case StringType => (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - } - case BooleanType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 - else 0 - } - case ByteType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Byte] - } - case ShortType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Short] - } - case IntegerType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Int] - } - case LongType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] - } - case FloatType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX - else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) - } - case DoubleType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX - else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) - } - case _ => (row: InternalRow) => 0L - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index c87e2064a8f33..e5bbd0aaed0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -25,7 +25,6 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.serializer._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ @@ -53,7 +52,7 @@ private[sql] class Serializer2SerializationStream( private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { - val kv = t.asInstanceOf[Product2[Row, Row]] + val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]] writeKey(kv._1) writeValue(kv._2) @@ -66,7 +65,7 @@ private[sql] class Serializer2SerializationStream( } override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeRowFunc(t.asInstanceOf[Row]) + writeRowFunc(t.asInstanceOf[InternalRow]) this } @@ -205,8 +204,9 @@ private[sql] object SparkSqlSerializer2 { /** * The util function to create the serialization function based on the given schema. */ - def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { - (row: Row) => + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream) + : InternalRow => Unit = { + (row: InternalRow) => // If the schema is null, the returned function does nothing when it get called. if (schema != null) { var i = 0 @@ -278,7 +278,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[UTF8String](i).getBytes + val bytes = row.getUTF8String(i).getBytes out.writeInt(bytes.length) out.write(bytes) } @@ -288,7 +288,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[Array[Byte]](i) + val bytes = row.getBinary(i) out.writeInt(bytes.length) out.write(bytes) } @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.apply(i).asInstanceOf[Decimal] + val value = row.getDecimal(i, decimal.precision, decimal.scale) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f54aa2027f6a6..03d24a88d4ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,28 +17,26 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} -import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => + case ExtractEquiJoinKeys( + LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastLeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys @@ -91,6 +89,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } + private[this] def isValidSort( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Boolean = { + leftKeys.zip(rightKeys).forall { keys => + (keys._1.dataType, keys._2.dataType) match { + case (l: AtomicType, r: AtomicType) => true + case (NullType, NullType) => true + case _ => false + } + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) @@ -101,7 +111,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // If the sort merge join option is set, we want to use sort merge join prior to hashjoin // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled => + if sqlContext.conf.sortMergeJoinEnabled && isValidSort(leftKeys, rightKeys) => val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil @@ -183,19 +193,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = { - aggregate.Utils.tryConvert( - plan, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled).isDefined + def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { + case a: logical.Aggregate => + if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { + a.newAggregation.isDefined + } else { + Utils.checkInvalidAggregateFunction2(a) + false + } + case _ => false } - def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists { - case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false + def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { + case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && - Seq(IntegerType, LongType).contains(exprs.head.dataType) => false - case _ => true + Seq(IntegerType, LongType).contains(exprs.head.dataType) => true + case _ => false } def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = @@ -207,12 +221,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate => - val converted = - aggregate.Utils.tryConvert( - p, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled) + case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && + sqlContext.conf.codegenEnabled => + val converted = p.newAggregation converted match { case None => Nil // Cannot convert to new aggregation code path. case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => @@ -306,57 +317,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object ParquetOperations extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // TODO: need to support writing to other types of files. Unify the below code paths. - case logical.WriteToFile(path, child) => - val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) - // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil - case logical.InsertIntoTable( - table: ParquetRelation, partition, child, overwrite, ifNotExists) => - InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil - case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => - val partitionColNames = relation.partitioningAttributes.map(_.name).toSet - val filtersToPush = filters.filter { pred => - val referencedColNames = pred.references.map(_.name).toSet - referencedColNames.intersect(partitionColNames).isEmpty - } - val prunePushedDownFilters = - if (sqlContext.conf.parquetFilterPushDown) { - (predicates: Seq[Expression]) => { - // Note: filters cannot be pushed down to Parquet if they contain more complex - // expressions than simple "Attribute cmp Literal" comparisons. Here we remove all - // filters that have been pushed down. Note that a predicate such as "(A AND B) OR C" - // can result in "A OR C" being pushed down. Here we are conservative in the sense - // that even if "A" was pushed and we check for "A AND B" we still want to keep - // "A AND B" in the higher-level filter, not just "B". - predicates.map(p => p -> ParquetFilters.createFilter(p)).collect { - case (predicate, None) => predicate - // Filter needs to be applied above when it contains partitioning - // columns - case (predicate, _) - if !predicate.references.map(_.name).toSet.intersect(partitionColNames).isEmpty => - predicate - } - } - } else { - identity[Seq[Expression]] _ - } - pruneFilterProject( - projectList, - filters, - prunePushedDownFilters, - ParquetTableScan( - _, - relation, - if (sqlContext.conf.parquetFilterPushDown) filtersToPush else Nil)) :: Nil - - case _ => Nil - } - } - object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => @@ -380,8 +340,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * if necessary. */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) { - execution.UnsafeExternalSort(sortExprs, global, child) + if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && + TungstenSort.supportsSchema(child.schema)) { + execution.TungstenSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { @@ -404,23 +365,27 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Sort(sortExprs, global, child) => getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => - execution.Project(projectList, planLater(child)) :: Nil + // If unsafe mode is enabled and we support these data types in Unsafe, use the + // Tungsten project. Otherwise, use the normal project. + if (sqlContext.conf.unsafeEnabled && + UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { + execution.TungstenProject(projectList, planLater(child)) :: Nil + } else { + execution.Project(projectList, planLater(child)) :: Nil + } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = - aggregate.Utils.tryConvert( - a, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled).isDefined - if (useNewAggregation) { + val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled + if (useNewAggregation && a.newAggregation.isDefined) { // If this logical.Aggregate can be planned to use new aggregation code path // (i.e. it can be planned by the Strategy Aggregation), we will not use the old // aggregation code path. Nil } else { + Utils.checkInvalidAggregateFunction2(a) execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } @@ -480,6 +445,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ExecutedCommand( RunnableDescribeCommand(resultPlan, describe.output, isExtended)) :: Nil + case logical.ShowFunctions(db, pattern) => ExecutedCommand(ShowFunctions(db, pattern)) :: Nil + + case logical.DescribeFunction(function, extended) => + ExecutedCommand(DescribeFunction(function, extended)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 318550e5ed899..16498da080c88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -37,9 +37,6 @@ import org.apache.spark.unsafe.PlatformDependent * Note that this serializer implements only the [[Serializer]] methods that are used during * shuffle, so certain [[SerializerInstance]] methods will throw UnsupportedOperationException. * - * This serializer does not support UnsafeRows that use - * [[org.apache.spark.sql.catalyst.util.ObjectPool]]. - * * @param numFields the number of fields in the row being serialized. */ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with Serializable { @@ -65,7 +62,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool") dOut.writeInt(row.getSizeInBytes) row.writeToStream(out, writeBuffer) this @@ -118,7 +114,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(in, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream val _rowTuple = rowTuple @@ -152,7 +148,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(in, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index de04132eb1104..91c8a02e2b5bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -298,7 +298,7 @@ case class Window( var rowsSize = 0 override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable - val join = new JoinedRow6 + val join = new JoinedRow val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) override final def next(): InternalRow = { // Load the next partition if we need to. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala index 0c9082897f390..98538c462bc89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -72,8 +72,10 @@ case class Aggregate2Sort( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => if (aggregateExpressions.length == 0) { - new GroupingIterator( + new FinalSortAggregationIterator( groupingExpressions, + Nil, + Nil, resultExpressions, newMutableProjection, child.output, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala index ce1cbdc9cb090..2ca0cb82c1aab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -41,7 +41,8 @@ private[sql] abstract class SortAggregationIterator( /////////////////////////////////////////////////////////////////////////// protected val aggregateFunctions: Array[AggregateFunction2] = { - var bufferOffset = initialBufferOffset + var mutableBufferOffset = 0 + var inputBufferOffset: Int = initialInputBufferOffset val functions = new Array[AggregateFunction2](aggregateExpressions.length) var i = 0 while (i < aggregateExpressions.length) { @@ -54,26 +55,24 @@ private[sql] abstract class SortAggregationIterator( // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. BindReferences.bindReference(func, inputAttributes) - case _ => func + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + func.inputBufferOffset = inputBufferOffset + inputBufferOffset += func.bufferSchema.length + func } - // Set bufferOffset for this function. It is important that setting bufferOffset - // happens after all potential bindReference operations because bindReference - // will create a new instance of the function. - funcWithBoundReferences.bufferOffset = bufferOffset - bufferOffset += funcWithBoundReferences.bufferSchema.length + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset + mutableBufferOffset += funcWithBoundReferences.bufferSchema.length functions(i) = funcWithBoundReferences i += 1 } functions } - // All non-algebraic aggregate functions. - protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { - aggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } - // Positions of those non-algebraic aggregate functions in aggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and // func2 and func3 are non-algebraic aggregate functions. @@ -91,6 +90,10 @@ private[sql] abstract class SortAggregationIterator( positions.toArray } + // All non-algebraic aggregate functions. + protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions) + // This is used to project expressions for the grouping expressions. protected val groupGenerator = newMutableProjection(groupingExpressions, inputAttributes)() @@ -100,25 +103,24 @@ private[sql] abstract class SortAggregationIterator( // The number of elements of the underlying buffer of this operator. // All aggregate functions are sharing this underlying buffer and they find their // buffer values through bufferOffset. - var size = initialBufferOffset - var i = 0 - while (i < aggregateFunctions.length) { - size += aggregateFunctions(i).bufferSchema.length - i += 1 - } - new GenericMutableRow(size) + // var size = 0 + // var i = 0 + // while (i < aggregateFunctions.length) { + // size += aggregateFunctions(i).bufferSchema.length + // i += 1 + // } + new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum) } - protected val joinedRow = new JoinedRow4 - - protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) + protected val joinedRow = new JoinedRow // This projection is used to initialize buffer values for all AlgebraicAggregates. protected val algebraicInitialProjection = { - val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val initExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.initialValues case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } + newMutableProjection(initExpressions, Nil)().target(buffer) } @@ -135,10 +137,6 @@ private[sql] abstract class SortAggregationIterator( // Indicates if we has new group of rows to process. protected var hasNewGroup: Boolean = true - /////////////////////////////////////////////////////////////////////////// - // Private methods - /////////////////////////////////////////////////////////////////////////// - /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(): Unit = { algebraicInitialProjection(EmptyRow) @@ -163,6 +161,10 @@ private[sql] abstract class SortAggregationIterator( } } + /////////////////////////////////////////////////////////////////////////// + // Private methods + /////////////////////////////////////////////////////////////////////////// + /** Processes rows in the current group. It will stop when it find a new group. */ private def processCurrentGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -179,8 +181,6 @@ private[sql] abstract class SortAggregationIterator( // For the below compare method, we do not need to make a copy of groupingKey. val groupingKey = groupGenerator(currentRow) // Check if the current row belongs the current input row. - currentGroupingKey.equals(groupingKey) - if (currentGroupingKey == groupingKey) { processRow(currentRow) } else { @@ -223,10 +223,13 @@ private[sql] abstract class SortAggregationIterator( // Methods that need to be implemented /////////////////////////////////////////////////////////////////////////// - protected def initialBufferOffset: Int + /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */ + protected def initialInputBufferOffset: Int + /** The function used to process an input row. */ protected def processRow(row: InternalRow): Unit + /** The function used to generate the result row. */ protected def generateOutput(): InternalRow /////////////////////////////////////////////////////////////////////////// @@ -236,37 +239,6 @@ private[sql] abstract class SortAggregationIterator( initialize() } -/** - * An iterator only used to group input rows according to values of `groupingExpressions`. - * It assumes that input rows are already grouped by values of `groupingExpressions`. - */ -class GroupingIterator( - groupingExpressions: Seq[NamedExpression], - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - Nil, - newMutableProjection, - inputAttributes, - inputIter) { - - private val resultProjection = - newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))() - - override protected def initialBufferOffset: Int = 0 - - override protected def processRow(row: InternalRow): Unit = { - // Since we only do grouping, there is nothing to do at here. - } - - override protected def generateOutput(): InternalRow = { - resultProjection(currentGroupingKey) - } -} - /** * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). * It assumes that input rows are already grouped by values of `groupingExpressions`. @@ -288,10 +260,7 @@ class PartialSortAggregationIterator( // This projection is used to update buffer values for all AlgebraicAggregates. private val algebraicUpdateProjection = { - val bufferSchema = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } + val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes) val updateExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) @@ -299,7 +268,7 @@ class PartialSortAggregationIterator( newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) } - override protected def initialBufferOffset: Int = 0 + override protected def initialInputBufferOffset: Int = 0 override protected def processRow(row: InternalRow): Unit = { // Process all algebraic aggregate functions. @@ -326,11 +295,7 @@ class PartialSortAggregationIterator( * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| * * The format of its internal buffer is: - * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN| - * Every placeholder is for a grouping expression. - * The actual buffers are stored after placeholderN. - * The reason that we have placeholders at here is to make our underlying buffer have the same - * length with a input row. + * |aggregationBuffer1|...|aggregationBufferN| * * The format of its output rows is: * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| @@ -348,38 +313,21 @@ class PartialMergeSortAggregationIterator( inputAttributes, inputIter) { - private val placeholderAttribtues = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { - val bufferSchemata = - placeholderAttribtues ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ placeholderAttribtues ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } - val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val mergeInputSchema = + aggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingExpressions.map(_.toAttribute) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } - // This projection is used to extract aggregation buffers from the underlying buffer. - // We need it because the underlying buffer has placeholders at its beginning. - private val extractsBufferValues = { - val expressions = aggregateFunctions.flatMap { - case agg => agg.bufferAttributes - } - - newMutableProjection(expressions, inputAttributes)() - } - - override protected def initialBufferOffset: Int = groupingExpressions.length + override protected def initialInputBufferOffset: Int = groupingExpressions.length override protected def processRow(row: InternalRow): Unit = { // Process all algebraic aggregate functions. @@ -394,7 +342,7 @@ class PartialMergeSortAggregationIterator( override protected def generateOutput(): InternalRow = { // We output grouping expressions and aggregation buffers. - joinedRow(currentGroupingKey, extractsBufferValues(buffer)) + joinedRow(currentGroupingKey, buffer).copy() } } @@ -406,11 +354,7 @@ class PartialMergeSortAggregationIterator( * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| * * The format of its internal buffer is: - * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN| - * Every placeholder is for a grouping expression. - * The actual buffers are stored after placeholderN. - * The reason that we have placeholders at here is to make our underlying buffer have the same - * length with a input row. + * |aggregationBuffer1|...|aggregationBufferN| * * The format of its output rows is represented by the schema of `resultExpressions`. */ @@ -438,37 +382,23 @@ class FinalSortAggregationIterator( newMutableProjection( resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() - private val offsetAttributes = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } - val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val mergeInputSchema = + aggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingExpressions.map(_.toAttribute) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp @@ -477,7 +407,7 @@ class FinalSortAggregationIterator( newMutableProjection(evalExpressions, bufferSchemata)() } - override protected def initialBufferOffset: Int = groupingExpressions.length + override protected def initialInputBufferOffset: Int = groupingExpressions.length override def initialize(): Unit = { if (inputIter.hasNext) { @@ -494,7 +424,10 @@ class FinalSortAggregationIterator( // Right now, the buffer only contains initial buffer values. Because // merging two buffers with initial values will generate a row that // still store initial values. We set the currentRow as the copy of the current buffer. - val currentRow = buffer.copy() + // Because input aggregation buffer has initialInputBufferOffset extra values at the + // beginning, we create a dummy row for this part. + val currentRow = + joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() nextGroupingKey = groupGenerator(currentRow).copy() firstRowInNextGroup = currentRow } else { @@ -541,18 +474,15 @@ class FinalSortAggregationIterator( * Final mode. * * The format of its internal buffer is: - * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)| - * The first N placeholders represent slots of grouping expressions. - * Then, next M placeholders represent slots of col1 to colM. + * |aggregationBuffer1|...|aggregationBuffer(N+M)| * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode - * Complete. The reason that we have placeholders at here is to make our underlying buffer - * have the same length with a input row. + * Complete. * * The format of its output rows is represented by the schema of `resultExpressions`. */ class FinalAndCompleteSortAggregationIterator( - override protected val initialBufferOffset: Int, + override protected val initialInputBufferOffset: Int, groupingExpressions: Seq[NamedExpression], finalAggregateExpressions: Seq[AggregateExpression2], finalAggregateAttributes: Seq[Attribute], @@ -584,9 +514,6 @@ class FinalAndCompleteSortAggregationIterator( newMutableProjection(resultExpressions, inputSchema)() } - private val offsetAttributes = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // All aggregate functions with mode Final. private val finalAggregateFunctions: Array[AggregateFunction2] = { val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) @@ -599,11 +526,10 @@ class FinalAndCompleteSortAggregationIterator( } // All non-algebraic aggregate functions with mode Final. - private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = finalAggregateFunctions.collect { case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } + } // All aggregate functions with mode Complete. private val completeAggregateFunctions: Array[AggregateFunction2] = { @@ -617,53 +543,46 @@ class FinalAndCompleteSortAggregationIterator( } // All non-algebraic aggregate functions with mode Complete. - private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = completeAggregateFunctions.collect { case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } + } // This projection is used to merge buffer values for all AlgebraicAggregates with mode // Final. private val finalAlgebraicMergeProjection = { - val numCompleteOffsetAttributes = - completeAggregateFunctions.map(_.bufferAttributes.length).sum - val completeOffsetAttributes = - Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)()) - val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) - - val bufferSchemata = - offsetAttributes ++ finalAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } ++ completeOffsetAttributes + // The first initialInputBufferOffset values of the input aggregation buffer is + // for grouping expressions and distinct columns. + val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + + val mergeInputSchema = + finalAggregateFunctions.flatMap(_.bufferAttributes) ++ + completeAggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingAttributesAndDistinctColumns ++ + finalAggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = - placeholderExpressions ++ finalAggregateFunctions.flatMap { + finalAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } ++ completeOffsetExpressions - - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } // This projection is used to update buffer values for all AlgebraicAggregates with mode // Complete. private val completeAlgebraicUpdateProjection = { - val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum - val finalOffsetAttributes = - Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)()) - val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) val bufferSchema = - offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } + finalAggregateFunctions.flatMap(_.bufferAttributes) ++ + completeAggregateFunctions.flatMap(_.bufferAttributes) val updateExpressions = - placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } @@ -672,14 +591,7 @@ class FinalAndCompleteSortAggregationIterator( // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp @@ -703,7 +615,10 @@ class FinalAndCompleteSortAggregationIterator( // Right now, the buffer only contains initial buffer values. Because // merging two buffers with initial values will generate a row that // still store initial values. We set the currentRow as the copy of the current buffer. - val currentRow = buffer.copy() + // Because input aggregation buffer has initialInputBufferOffset extra values at the + // beginning, we create a dummy row for this part. + val currentRow = + joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() nextGroupingKey = groupGenerator(currentRow).copy() firstRowInNextGroup = currentRow } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 6c49a906c848a..cc54319171bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -15,87 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.expressions.aggregate +package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType} /** - * The abstract class for implementing user-defined aggregate function. + * A Mutable [[Row]] representing an mutable aggregation buffer. */ -abstract class UserDefinedAggregateFunction extends Serializable { - - /** - * A [[StructType]] represents data types of input arguments of this aggregate function. - * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments - * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like - * - * ``` - * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) - * ``` - * - * The name of a field of this [[StructType]] is only used to identify the corresponding - * input argument. Users can choose names to identify the input arguments. - */ - def inputSchema: StructType - - /** - * A [[StructType]] represents data types of values in the aggregation buffer. - * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values - * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], - * the returned [[StructType]] will look like - * - * ``` - * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) - * ``` - * - * The name of a field of this [[StructType]] is only used to identify the corresponding - * buffer value. Users can choose names to identify the input arguments. - */ - def bufferSchema: StructType - - /** - * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. - */ - def returnDataType: DataType - - /** Indicates if this function is deterministic. */ - def deterministic: Boolean - - /** - * Initializes the given aggregation buffer. Initial values set by this method should satisfy - * the condition that when merging two buffers with initial values, the new buffer should - * still store initial values. - */ - def initialize(buffer: MutableAggregationBuffer): Unit - - /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ - def update(buffer: MutableAggregationBuffer, input: Row): Unit - - /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit - - /** - * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given - * aggregation buffer. - */ - def evaluate(buffer: Row): Any -} - -private[sql] abstract class AggregationBuffer( +private[sql] class MutableAggregationBufferImpl ( + schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], - bufferOffset: Int) - extends Row { - - override def length: Int = toCatalystConverters.length + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends MutableAggregationBuffer { - protected val offsets: Array[Int] = { + private[this] val offsets: Array[Int] = { val newOffsets = new Array[Int](length) var i = 0 while (i < newOffsets.length) { @@ -104,24 +46,15 @@ private[sql] abstract class AggregationBuffer( } newOffsets } -} -/** - * A Mutable [[Row]] representing an mutable aggregation buffer. - */ -class MutableAggregationBuffer private[sql] ( - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int, - var underlyingBuffer: MutableRow) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { if (i >= length || i < 0) { throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - toScalaConverters(i)(underlyingBuffer(offsets(i))) + toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType)) } def update(i: Int, value: Any): Unit = { @@ -132,8 +65,9 @@ class MutableAggregationBuffer private[sql] ( underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) } - override def copy(): MutableAggregationBuffer = { - new MutableAggregationBuffer( + override def copy(): MutableAggregationBufferImpl = { + new MutableAggregationBufferImpl( + schema, toCatalystConverters, toScalaConverters, bufferOffset, @@ -144,23 +78,38 @@ class MutableAggregationBuffer private[sql] ( /** * A [[Row]] representing an immutable aggregation buffer. */ -class InputAggregationBuffer private[sql] ( +private[sql] class InputAggregationBuffer private[sql] ( + schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, - var underlyingInputBuffer: Row) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + var underlyingInputBuffer: InternalRow) + extends Row { + + private[this] val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } + + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { if (i >= length || i < 0) { throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - toScalaConverters(i)(underlyingInputBuffer(offsets(i))) + // TODO: Use buffer schema to avoid using generic getter. + toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType)) } override def copy(): InputAggregationBuffer = { new InputAggregationBuffer( + schema, toCatalystConverters, toScalaConverters, bufferOffset, @@ -174,7 +123,7 @@ class InputAggregationBuffer private[sql] ( * @param children * @param udaf */ -case class ScalaUDAF( +private[sql] case class ScalaUDAF( children: Seq[Expression], udaf: UserDefinedAggregateFunction) extends AggregateFunction2 with Logging { @@ -232,18 +181,27 @@ case class ScalaUDAF( lazy val inputAggregateBuffer: InputAggregationBuffer = new InputAggregationBuffer( + bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - bufferOffset, + inputBufferOffset, null) - lazy val mutableAggregateBuffer: MutableAggregationBuffer = - new MutableAggregationBuffer( + lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = + new MutableAggregationBufferImpl( + bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - bufferOffset, + mutableBufferOffset, null) + lazy val evalAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableBufferOffset, + null) override def initialize(buffer: MutableRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer @@ -266,10 +224,10 @@ case class ScalaUDAF( udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) } - override def eval(buffer: InternalRow = null): Any = { - inputAggregateBuffer.underlyingInputBuffer = buffer + override def eval(buffer: InternalRow): Any = { + evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(inputAggregateBuffer) + udaf.evaluate(evalAggregateBuffer) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 1cb27710e0480..03635baae4a5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -29,150 +29,6 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType} * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - - !hasComplexTypes - } - - private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - } - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert( - plan: LogicalPlan, - useNewAggregation: Boolean, - codeGenEnabled: Boolean): Option[Aggregate] = plan match { - case p: Aggregate if useNewAggregation && codeGenEnabled => - val converted = tryConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case p: Aggregate => - checkInvalidAggregateFunction2(p) - None - case other => None - } - def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], @@ -191,10 +47,7 @@ object Utils { } val groupExpressionMap = namedGroupingExpressions.toMap val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map { - case AggregateExpression2(aggregateFunction, mode, isDistinct) => - AggregateExpression2(aggregateFunction, Partial, isDistinct) - } + val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => agg.aggregateFunction.bufferAttributes } @@ -208,10 +61,7 @@ object Utils { child) // 2. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map { - case AggregateExpression2(aggregateFunction, mode, isDistinct) => - AggregateExpression2(aggregateFunction, Final, isDistinct) - } + val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) @@ -298,8 +148,8 @@ object Utils { AggregateExpression2(aggregateFunction, PartialMerge, false) } val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + partialMergeAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes } val partialMergeAggregate = Aggregate2Sort( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fdd7ad59aba50..2294a670c735f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.types.StructType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.util.{CompletionIterator, MutablePair} @@ -49,6 +49,31 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override def outputOrdering: Seq[SortOrder] = child.outputOrdering } + +/** + * A variant of [[Project]] that returns [[UnsafeRow]]s. + */ +case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + this.transformAllExpressions { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + } + val project = UnsafeProjection.create(projectList, child.output) + iter.map(project) + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + + /** * :: DeveloperApi :: */ @@ -195,137 +220,6 @@ case class TakeOrderedAndProject( override def outputOrdering: Seq[SortOrder] = sortOrder } -/** - * :: DeveloperApi :: - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -@DeveloperApi -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * :: DeveloperApi :: - * Performs a sort, spilling to disk as needed. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -@DeveloperApi -case class ExternalSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r.copy, null))) - val baseIterator = sorter.iterator.map(_._1) - // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * :: DeveloperApi :: - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of - * Project Tungsten). - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will - * spill every `frequency` records. - */ -@DeveloperApi -case class UnsafeExternalSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends UnaryNode { - - private[this] val schema: StructType = child.schema - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") - def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { - val ordering = newOrdering(sortOrder, child.output) - val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - // Hack until we generate separate comparator implementations for ascending vs. descending - // (or choose to codegen them): - val prefixComparator = { - val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) - if (sortOrder.head.direction == Descending) { - new PrefixComparator { - override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) - } - } else { - comp - } - } - val prefixComputer = { - val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = prefixComputer(row) - } - } - val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } - sorter.sort(iterator) - } - child.execute().mapPartitions(doSort, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def outputsUnsafeRows: Boolean = true -} - -@DeveloperApi -object UnsafeExternalSort { - /** - * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. - */ - def supportsSchema(schema: StructType): Boolean = { - UnsafeExternalRowSorter.supportsSchema(schema) - } -} - /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index bace3f8a9c8d4..6b83025d5a153 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{ExpressionDescription, Expression, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -298,3 +298,78 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma rows } } + +/** + * A command for users to list all of the registered functions. + * The syntax of using this command in SQL is: + * {{{ + * SHOW FUNCTIONS + * }}} + * TODO currently we are simply ignore the db + */ +case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("function", StringType, nullable = false) :: Nil) + + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = pattern match { + case Some(p) => + try { + val regex = java.util.regex.Pattern.compile(p) + sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + } catch { + // probably will failed in the regex that user provided, then returns empty row. + case _: Throwable => Seq.empty[Row] + } + case None => + sqlContext.functionRegistry.listFunction().map(Row(_)) + } +} + +/** + * A command for users to get the usage of a registered function. + * The syntax of using this command in SQL is + * {{{ + * DESCRIBE FUNCTION [EXTENDED] upper; + * }}} + */ +case class DescribeFunction( + functionName: String, + isExtended: Boolean) extends RunnableCommand { + + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("function_desc", StringType, nullable = false) :: Nil) + + schema.toAttributes + } + + private def replaceFunctionName(usage: String, functionName: String): String = { + if (usage == null) { + "To be added." + } else { + usage.replaceAll("_FUNC_", functionName) + } + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.functionRegistry.lookupFunction(functionName) match { + case Some(info) => + val result = + Row(s"Function: ${info.getName}") :: + Row(s"Class: ${info.getClassName}") :: + Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil + + if (isExtended) { + result :+ Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") + } else { + result + } + + case None => Seq(Row(s"Function: $functionName is not found.")) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2b400926177fe..6b91e51ca52fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -170,6 +170,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows) } + // TODO: refactor this thing. It is very complicated because it does projection internally. + // We should just put a project on top of this. private def mergeWithPartitionValues( schema: StructType, requiredColumns: Array[String], @@ -187,13 +189,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (i != -1) { // If yes, gets column value from partition values. (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = partitionValues(i) + mutableRow(ordinal) = partitionValues.genericGet(i) } } else { // Otherwise, inherits the value from scanned data. val i = nonPartitionColumns.indexOf(name) (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = dataRow(i) + mutableRow(ordinal) = dataRow.genericGet(i) } } } @@ -206,7 +208,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val mutableRow = new SpecificMutableRow(dataTypes) iterator.map { dataRow => var i = 0 - while (i < mutableRow.length) { + while (i < mutableRow.numFields) { mergers(i)(mutableRow, dataRow, i) i += 1 } @@ -315,7 +317,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (relation.relation.needConversion) { execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 6b4a359db22d1..66dfcc308ceca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -25,6 +25,7 @@ import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -178,8 +179,7 @@ private[sql] object PartitioningUtils { * {{{ * NullType -> * IntegerType -> LongType -> - * DoubleType -> DecimalType.Unlimited -> - * StringType + * DoubleType -> StringType * }}} */ private[sql] def resolvePartitions( @@ -236,7 +236,7 @@ private[sql] object PartitioningUtils { /** * Converts a string to a [[Literal]] with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and * [[StringType]]. */ private[sql] def inferPartitionColumnValue( @@ -249,7 +249,7 @@ private[sql] object PartitioningUtils { .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + .orElse(Try(Literal(new JBigDecimal(raw)))) // Then falls back to string .getOrElse { if (raw == defaultPartitionName) { @@ -268,7 +268,7 @@ private[sql] object PartitioningUtils { } private val upCastingOrder: Seq[DataType] = - Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType) + Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) /** * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala index cd2aa7f7433c5..d551f386eee6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala @@ -174,14 +174,19 @@ private[sql] case class InsertIntoHadoopFsRelation( try { writerContainer.executorSideSetup(taskContext) - val converter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { + val internalRow = iterator.next() + writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) + } } else { - r: InternalRow => r.asInstanceOf[Row] - } - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) + while (iterator.hasNext) { + val internalRow = iterator.next() + writerContainer.outputWriterForRow(internalRow) + .asInstanceOf[OutputWriterInternal].writeInternal(internalRow) + } } writerContainer.commitTask() @@ -248,17 +253,23 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionProj = newProjection(codegenEnabled, partitionCasts, output) val dataProj = newProjection(codegenEnabled, dataOutput, output) - val dataConverter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { + val internalRow = iterator.next() + val partitionPart = partitionProj(internalRow) + val dataPart = converter(dataProj(internalRow)) + writerContainer.outputWriterForRow(partitionPart).write(dataPart) + } } else { - r: InternalRow => r.asInstanceOf[Row] - } - - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = dataConverter(dataProj(internalRow)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) + while (iterator.hasNext) { + val internalRow = iterator.next() + val partitionPart = partitionProj(internalRow) + val dataPart = dataProj(internalRow) + writerContainer.outputWriterForRow(partitionPart) + .asInstanceOf[OutputWriterInternal].writeInternal(dataPart) + } } writerContainer.commitTask() @@ -530,8 +541,12 @@ private[sql] class DynamicPartitionWriterContainer( while (i < partitionColumns.length) { val col = partitionColumns(i) val partitionValueString = { - val string = row.getString(i) - if (string.eq(null)) defaultPartitionName else PartitioningUtils.escapePathName(string) + val string = row.getUTF8String(i) + if (string.eq(null)) { + defaultPartitionName + } else { + PartitioningUtils.escapePathName(string.toString) + } } if (i > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index c8033d3c0470a..0cdb407ad57b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -21,16 +21,17 @@ import scala.language.{existentials, implicitConversions} import scala.util.matching.Regex import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, InternalRow} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -151,7 +152,7 @@ private[sql] class DDLParser( protected lazy val refreshTable: Parser[LogicalPlan] = REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { case maybeDatabaseName ~ tableName => - RefreshTable(maybeDatabaseName.getOrElse("default"), tableName) + RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) } protected lazy val options: Parser[Map[String, String]] = @@ -307,7 +308,7 @@ private[sql] object ResolvedDataSource { mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[IntervalType])) { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } val clazz: Class[_] = lookupDataSource(provider) @@ -415,12 +416,12 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext): Seq[InternalRow] = { + def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty + Seq.empty[Row] } } @@ -432,26 +433,26 @@ private[sql] case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty + Seq.empty[Row] } } -private[sql] case class RefreshTable(databaseName: String, tableName: String) +private[sql] case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. - sqlContext.catalog.refreshTable(databaseName, tableName) + sqlContext.catalog.refreshTable(tableIdent) // If this table is cached as a InMemoryColumnarRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.catalog.lookupRelation(Seq(databaseName, tableName)) + val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent.toSeq) // Use lookupCachedData directly since RefreshTable also takes databaseName. val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty if (isCached) { @@ -461,10 +462,10 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) // Uncache the logicalPlan. sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) // Cache it again. - sqlContext.cacheManager.cacheQuery(df, Some(tableName)) + sqlContext.cacheManager.cacheQuery(df, Some(tableIdent.table)) } - Seq.empty[InternalRow] + Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e6081cb05bc2d..f26f41fb75d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -136,7 +136,7 @@ package object debug { tupleCount += 1 var i = 0 while (i < numColumns) { - val value = currentRow(i) + val value = currentRow.get(i, output(i).dataType) if (value != null) { columnStats(i).elementTypes += HashSet(value.getClass.getName) } @@ -158,8 +158,8 @@ package object debug { case (row: InternalRow, StructType(fields)) => row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } - case (s: Seq[_], ArrayType(elemType, _)) => - s.foreach(typeCheck(_, elemType)) + case (a: ArrayData, ArrayType(elemType, _)) => + a.toArray().foreach(typeCheck(_, elemType)) case (m: Map[_, _], MapType(keyType, valueType, _)) => m.keys.foreach(typeCheck(_, keyType)) m.values.foreach(typeCheck(_, valueType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index abaa4a6ce86a2..624efc1b1d734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -62,7 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = buildHashRelation(input.iterator) + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index c9d1a880f4ef4..77e7fe71009b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = buildHashRelation(input.iterator) + val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index f71c0ce352904..a60593911f94f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash( condition: Option[Expression]) extends BinaryNode with HashSemiJoin { protected override def doExecute(): RDD[InternalRow] = { - val buildIter = right.execute().map(_.copy()).collect().toIterator + val input = right.execute().map(_.copy()).collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) + val hashSet = buildKeyHashSet(input.toIterator) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashRelation = buildHashRelation(buildIter) + val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 700636966f8be..83b726a8e2897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: Projection = { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin( var streamRowMatched = false while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => @@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin( val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - buf += resultProjection(new JoinedRow(leftNulls, rel(i))) - case (LeftOuter | FullOuter, BuildLeft) => - buf += resultProjection(new JoinedRow(rel(i), rightNulls)) - case _ => + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + val joinedRow = new JoinedRow + joinedRow.withLeft(leftNulls) + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + buf += resultProjection(joinedRow.withRight(rel(i))).copy() + } + i += 1 } - } - i += 1 + case (LeftOuter | FullOuter, BuildLeft) => + val joinedRow = new JoinedRow + joinedRow.withRight(rightNulls) + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + buf += resultProjection(joinedRow.withLeft(rel(i))).copy() + } + i += 1 + } + case _ => } buf.toSeq } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ae34409bcfcca..6b3d1652923fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.util.collection.CompactBuffer trait HashJoin { @@ -44,16 +43,24 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - protected[this] def supportUnsafe: Boolean = { + protected[this] def isUnsafeMode: Boolean = { (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + + @transient protected lazy val buildSideKeyGenerator: Projection = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildPlan.output) + } else { + newMutableProjection(buildKeys, buildPlan.output)() + } @transient protected lazy val streamSideKeyGenerator: Projection = - if (supportUnsafe) { + if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newMutableProjection(streamedKeys, streamedPlan.output)() @@ -65,18 +72,16 @@ trait HashJoin { { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: CompactBuffer[InternalRow] = _ + private[this] var currentHashMatches: Seq[InternalRow] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. - private[this] val joinRow = new JoinedRow2 - private[this] val resultProjection: Projection = { - if (supportUnsafe) { + private[this] val joinRow = new JoinedRow + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -122,12 +127,4 @@ trait HashJoin { } } } - - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, buildKeys, buildPlan) - } else { - HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6bf2f82954046..7e671e7914f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -75,30 +75,36 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - protected[this] def supportUnsafe: Boolean = { + protected[this] def isUnsafeMode: Boolean = { (self.codegenEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode - protected[this] def streamedKeyGenerator(): Projection = { - if (supportUnsafe) { + @transient protected lazy val buildKeyGenerator: Projection = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildPlan.output) + } else { + newMutableProjection(buildKeys, buildPlan.output)() + } + + @transient protected[this] lazy val streamedKeyGenerator: Projection = { + if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newProjection(streamedKeys, streamedPlan.output) } } - @transient private[this] lazy val resultProjection: Projection = { - if (supportUnsafe) { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -230,12 +236,4 @@ trait HashOuterJoin { hashTable } - - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, buildKeys, buildPlan) - } else { - HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 7f49264d40354..97fde8f975bfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -35,11 +35,13 @@ trait HashSemiJoin { protected[this] def supportUnsafe: Boolean = { (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(left.schema)) + && UnsafeProjection.canSupport(left.schema) + && UnsafeProjection.canSupport(right.schema)) } - override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def outputsUnsafeRows: Boolean = supportUnsafe override def canProcessUnsafeRows: Boolean = supportUnsafe + override def canProcessSafeRows: Boolean = !supportUnsafe @transient protected lazy val leftKeyGenerator: Projection = if (supportUnsafe) { @@ -87,14 +89,6 @@ trait HashSemiJoin { }) } - protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, rightKeys, right) - } else { - HashedRelation(buildIter, newProjection(rightKeys, right.output)) - } - } - protected def hashSemiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8d5731afd59b8..f88a45f48aee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,12 +18,16 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.collection.CompactBuffer @@ -32,7 +36,7 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[joins] sealed trait HashedRelation { - def get(key: InternalRow): CompactBuffer[InternalRow] + def get(key: InternalRow): Seq[InternalRow] // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation @@ -59,9 +63,9 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private def this() = this(null) // Needed for serialization - override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key) + override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -81,9 +85,9 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private def this() = this(null) // Needed for serialization - override def get(key: InternalRow): CompactBuffer[InternalRow] = { + override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } @@ -109,6 +113,10 @@ private[joins] object HashedRelation { keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { + if (keyGenerator.isInstanceOf[UnsafeProjection]) { + return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + } + // TODO: Use Spark's HashMap implementation. val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate) var currentRow: InternalRow = null @@ -149,31 +157,140 @@ private[joins] object HashedRelation { } } - /** - * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a - * sequence of values. + * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key + * into a sequence of values. + * + * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use + * BytesToBytesMap for better memory performance (multiple values for the same are stored as a + * continuous byte array. * - * TODO(davies): use BytesToBytesMap + * It's serialized in the following format: + * [number of keys] + * [size of key] [size of all values in bytes] [key bytes] [bytes for all values] + * ... + * + * All the values are serialized as following: + * [number of fields] [number of bytes] [underlying bytes of UnsafeRow] + * ... */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private[joins] def this() = this(null) // Needed for serialization + + // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + @transient private[this] var binaryMap: BytesToBytesMap = _ - override def get(key: InternalRow): CompactBuffer[InternalRow] = { + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] - // Thanks to type eraser - hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + + if (binaryMap != null) { + // Used in Broadcast join + val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes) + if (loc.isDefined) { + val buffer = CompactBuffer[UnsafeRow]() + + val base = loc.getValueAddress.getBaseObject + var offset = loc.getValueAddress.getBaseOffset + val last = loc.getValueAddress.getBaseOffset + loc.getValueLength + while (offset < last) { + val numFields = PlatformDependent.UNSAFE.getInt(base, offset) + val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + offset += 8 + + val row = new UnsafeRow + row.pointTo(base, offset, numFields, sizeInBytes) + buffer += row + offset += sizeInBytes + } + buffer + } else { + null + } + + } else { + // Use the JavaHashMap in Local mode or ShuffleHashJoin + hashTable.get(unsafeKey) + } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.length) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.length) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 + } + } } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + val nKeys = in.readInt() + // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory + val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + + val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + .getSizeAsBytes("spark.buffer.pageSize", "64m") + + binaryMap = new BytesToBytesMap( + memoryManager, + nKeys * 2, // reduce hash collision + pageSizeBytes) + + var i = 0 + var keyBuffer = new Array[Byte](1024) + var valuesBuffer = new Array[Byte](1024) + while (i < nKeys) { + val keySize = in.readInt() + val valuesSize = in.readInt() + if (keySize > keyBuffer.size) { + keyBuffer = new Array[Byte](keySize) + } + in.readFully(keyBuffer, 0, keySize) + if (valuesSize > valuesBuffer.size) { + valuesBuffer = new Array[Byte](valuesSize) + } + in.readFully(valuesBuffer, 0, valuesSize) + + // put it into binary map + val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + assert(!loc.isDefined, "Duplicated key found!") + loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + i += 1 + } } } @@ -181,33 +298,14 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - buildKeys: Seq[Expression], - buildPlan: SparkPlan, - sizeEstimate: Int = 64): HashedRelation = { - val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) - apply(input, boundedKeys, buildPlan.schema, sizeEstimate) - } - - // Used for tests - def apply( - input: Iterator[InternalRow], - buildKeys: Seq[Expression], - rowSchema: StructType, + keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { - // TODO: Use BytesToBytesMap. val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) - val toUnsafe = UnsafeProjection.create(rowSchema) - val keyGenerator = UnsafeProjection.create(buildKeys) // Create a mapping of buildKeys -> rows while (input.hasNext) { - val currentRow = input.next() - val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { - currentRow.asInstanceOf[UnsafeRow] - } else { - toUnsafe(currentRow) - } + val unsafeRow = input.next().asInstanceOf[UnsafeRow] val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 874712a4e739f..26a664104d6fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -46,7 +46,7 @@ case class LeftSemiJoinHash( val hashSet = buildKeyHashSet(buildIter) hashSemiJoin(streamIter, hashSet) } else { - val hashRelation = buildHashRelation(buildIter) + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) hashSemiJoin(streamIter, hashRelation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 948d0ccebceb0..5439e10a60b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,7 +45,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = buildHashRelation(buildIter) + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index f54f1edd38ec8..d29b593207c4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,8 +50,8 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val hashed = buildHashRelation(rightIter) - val keyGenerator = streamedKeyGenerator() + val hashed = HashedRelation(rightIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) @@ -59,8 +59,8 @@ case class ShuffledHashOuterJoin( }) case RightOuter => - val hashed = buildHashRelation(leftIter) - val keyGenerator = streamedKeyGenerator() + val hashed = HashedRelation(leftIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 981447eacad74..bb18b5403f8e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -66,7 +66,7 @@ case class SortMergeJoin( leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => new Iterator[InternalRow] { // Mutable per row objects. - private[this] val joinRow = new JoinedRow5 + private[this] val joinRow = new JoinedRow private[this] var leftElement: InternalRow = _ private[this] var rightElement: InternalRow = _ private[this] var leftKey: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index e6e27a87c7151..ef1c6e57dc08a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -126,16 +126,27 @@ object EvaluatePython { case (null, _) => null case (row: InternalRow, struct: StructType) => - val values = new Array[Any](row.size) + val values = new Array[Any](row.numFields) var i = 0 - while (i < row.size) { - values(i) = toJava(row(i), struct.fields(i).dataType) + while (i < row.numFields) { + values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) i += 1 } new GenericInternalRowWithSchema(values, struct) - case (seq: Seq[Any], array: ArrayType) => - seq.map(x => toJava(x, array.elementType)).asJava + case (a: ArrayData, array: ArrayType) => + val length = a.numElements() + val values = new java.util.ArrayList[Any](length) + var i = 0 + while (i < length) { + if (a.isNullAt(i)) { + values.add(null) + } else { + values.add(toJava(a.get(i), array.elementType)) + } + i += 1 + } + values case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) @@ -190,10 +201,10 @@ object EvaluatePython { case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}.toSeq + new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray) case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) @@ -267,7 +278,6 @@ object EvaluatePython { pickler.save(row.values(i)) i += 1 } - row.values.foreach(pickler.save) out.write(Opcodes.TUPLE) out.write(Opcodes.REDUCE) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala new file mode 100644 index 0000000000000..6d903ab23c57f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines various sort operators. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +/** + * Performs a sort on-heap. + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ +case class Sort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} + +/** + * Performs a sort, spilling to disk as needed. + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ +case class ExternalSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) + sorter.insertAll(iterator.map(r => (r.copy(), null))) + val baseIterator = sorter.iterator.map(_._1) + // TODO(marmbrus): The complex type signature below thwarts inference for no reason. + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} + +/** + * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Project Tungsten). + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will + * spill every `frequency` records. + */ +case class TungstenSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val schema = child.schema + val childOutput = child.output + child.execute().mapPartitions({ iter => + val ordering = newOrdering(sortOrder, childOutput) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + }, preservesPartitioning = true) + } + +} + +object TungstenSort { + /** + * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. + */ + def supportsSchema(schema: StructType): Boolean = { + UnsafeExternalRowSorter.supportsSchema(schema) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index ec5c6950f37ad..9329148aa233c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{ArrayType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -85,17 +85,17 @@ private[sql] object FrequentItems extends Logging { val sizeOfMap = (1 / support).toInt val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap)) val originalSchema = df.schema - val colInfo = cols.map { name => + val colInfo: Array[(String, DataType)] = cols.map { name => val index = originalSchema.fieldIndex(name) (name, originalSchema.fields(index).dataType) - } + }.toArray val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { val thisMap = counts(i) - val key = row.get(i) + val key = row.get(i, colInfo(i)._2) thisMap.add(key, 1L) i += 1 } @@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toSeq) + val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_)) val resultRow = InternalRow(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala new file mode 100644 index 0000000000000..278dd438fab4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * The abstract class for implementing user-defined aggregate functions. + */ +@Experimental +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer. Initial values set by this method should satisfy + * the condition that when merging two buffers with initial values, the new buffer + * still store initial values. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +/** + * :: Experimental :: + * A [[Row]] representing an mutable aggregation buffer. + */ +@Experimental +trait MutableAggregationBuffer extends Row { + + /** Update the ith value of this buffer. */ + def update(i: Int, value: Any): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10cd1796410f6..631cfa504833d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -634,7 +634,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() /** * Return an alternative value `r` if `l` is NaN. @@ -741,7 +741,16 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def sparkPartitionId(): Column = execution.expressions.SparkPartitionID + def sparkPartitionId(): Column = SparkPartitionID() + + /** + * The file name of the current Spark task + * + * Note that this is indeterministic becuase it depends on what is currently being read in. + * + * @group normal_funcs + */ + def inputFileName(): Column = InputFileName() /** * Computes the square root of the specified float value. @@ -792,6 +801,18 @@ object functions { */ def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + /** + * Parses the expression string into the column that it represents, similar to + * DataFrame.selectExpr + * {{{ + * // get the number of words of each length + * df.groupBy(expr("length(word)")).count() + * }}} + * + * @group normal_funcs + */ + def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1411,7 +1432,8 @@ object functions { def round(columnName: String): Column = round(Column(columnName), 0) /** - * Returns the value of `e` rounded to `scale` decimal places. + * Round the value of `e` to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. * * @group math_funcs * @since 1.5.0 @@ -1419,7 +1441,8 @@ object functions { def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) /** - * Returns the value of the given column rounded to `scale` decimal places. + * Round the value of the given column to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. * * @group math_funcs * @since 1.5.0 @@ -1912,6 +1935,14 @@ object functions { // DateTime functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns the date that is numMonths after startDate. + * @group datetime_funcs + * @since 1.5.0 + */ + def add_months(startDate: Column, numMonths: Int): Column = + AddMonths(startDate.expr, Literal(numMonths)) + /** * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. @@ -1944,6 +1975,20 @@ object functions { def date_format(dateColumnName: String, format: String): Column = date_format(Column(dateColumnName), format) + /** + * Returns the date that is `days` days after `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + + /** + * Returns the date that is `days` days before `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + /** * Extracts the year as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2028,6 +2073,16 @@ object functions { */ def hour(columnName: String): Column = hour(Column(columnName)) + /** + * Given a date column, returns the last day of the month which the given date belongs to. + * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the + * month in July 2015. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def last_day(e: Column): Column = LastDay(e.expr) + /** * Extracts the minutes as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2042,6 +2097,28 @@ object functions { */ def minute(columnName: String): Column = minute(Column(columnName)) + /* + * Returns number of months between dates `date1` and `date2`. + * @group datetime_funcs + * @since 1.5.0 + */ + def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + + /** + * Given a date column, returns the first date which is later than the value of the date column + * that is on the specified day of the week. + * + * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first + * Sunday after 2015-07-27. + * + * Day of the week parameter is case insensitive, and accepts: + * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". + * + * @group datetime_funcs + * @since 1.5.0 + */ + def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) + /** * Extracts the seconds as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2070,6 +2147,48 @@ object functions { */ def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + + /** + * Gets current Unix timestamp in seconds. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), + * using the default timezone and the default locale, return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Convert time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2459,5 +2578,4 @@ object functions { } UnresolvedFunction(udfName, exprs, isDistinct = false) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 7a27fba1780b9..3cf70db6b7b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -66,8 +66,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.DATALINK => null case java.sql.Types.DATE => DateType case java.sql.Types.DECIMAL - if precision != 0 || scale != 0 => DecimalType(precision, scale) - case java.sql.Types.DECIMAL => DecimalType.Unlimited + if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) + case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType case java.sql.Types.FLOAT => FloatType @@ -80,8 +80,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.NCLOB => StringType case java.sql.Types.NULL => null case java.sql.Types.NUMERIC - if precision != 0 || scale != 0 => DecimalType(precision, scale) - case java.sql.Types.NUMERIC => DecimalType.Unlimited + if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) + case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT case java.sql.Types.NVARCHAR => StringType case java.sql.Types.OTHER => null case java.sql.Types.REAL => DoubleType @@ -314,7 +314,7 @@ private[sql] class JDBCRDD( abstract class JDBCConversion case object BooleanConversion extends JDBCConversion case object DateConversion extends JDBCConversion - case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion + case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion case object DoubleConversion extends JDBCConversion case object FloatConversion extends JDBCConversion case object IntegerConversion extends JDBCConversion @@ -331,8 +331,7 @@ private[sql] class JDBCRDD( schema.fields.map(sf => sf.dataType match { case BooleanType => BooleanConversion case DateType => DateConversion - case DecimalType.Unlimited => DecimalConversion(None) - case DecimalType.Fixed(d) => DecimalConversion(Some(d)) + case DecimalType.Fixed(p, s) => DecimalConversion(p, s) case DoubleType => DoubleConversion case FloatType => FloatConversion case IntegerType => IntegerConversion @@ -399,20 +398,13 @@ private[sql] class JDBCRDD( // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then // retrieve it, you will get wrong result 199.99. // So it is needed to set precision and scale for Decimal based on JDBC metadata. - case DecimalConversion(Some((p, s))) => + case DecimalConversion(p, s) => val decimalVal = rs.getBigDecimal(pos) if (decimalVal == null) { mutableRow.update(i, null) } else { mutableRow.update(i, Decimal(decimalVal, p, s)) } - case DecimalConversion(None) => - val decimalVal = rs.getBigDecimal(pos) - if (decimalVal == null) { - mutableRow.update(i, null) - } else { - mutableRow.update(i, Decimal(decimalVal)) - } case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4d3aac464c538..41d0ecb4bbfbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -128,6 +128,7 @@ private[sql] case class JDBCRelation( override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverRegistry.getDriverClassName(url) + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sqlContext.sparkContext, schema, @@ -137,7 +138,7 @@ private[sql] case class JDBCRelation( table, requiredColumns, filters, - parts).map(_.asInstanceOf[Row]) + parts).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index f7ea852fe7f58..035e0510080ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -89,8 +89,7 @@ package object jdbc { case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case DecimalType.Unlimited => stmt.setBigDecimal(i + 1, - row.getAs[java.math.BigDecimal](i)) + case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } @@ -145,7 +144,7 @@ package object jdbc { case BinaryType => "BLOB" case TimestampType => "TIMESTAMP" case DateType => "DATE" - case DecimalType.Unlimited => "DECIMAL(40,20)" + case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") }) val nullable = if (field.nullable) "" else "NOT NULL" @@ -177,7 +176,7 @@ package object jdbc { case BinaryType => java.sql.Types.BLOB case TimestampType => java.sql.Types.TIMESTAMP case DateType => java.sql.Types.DATE - case DecimalType.Unlimited => java.sql.Types.DECIMAL + case t: DecimalType => java.sql.Types.DECIMAL case _ => throw new IllegalArgumentException( s"Can't translate null value for field $field") }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index afe2c6c11ac69..04ab5e2217882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -113,7 +113,7 @@ private[sql] object InferSchema { case INT | LONG => LongType // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited + case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT case FLOAT | DOUBLE => DoubleType } @@ -125,7 +125,7 @@ private[sql] object InferSchema { * Convert NullType to StringType and remove StructTypes with no fields */ private def canonicalizeType: DataType => Option[DataType] = { - case at@ArrayType(elementType, _) => + case at @ ArrayType(elementType, _) => for { canonicalType <- canonicalizeType(elementType) } yield { @@ -168,8 +168,13 @@ private[sql] object InferSchema { HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { - case (other: DataType, NullType) => other - case (NullType, other: DataType) => other + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, t: DecimalType) => + if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + case (t: DecimalType, DoubleType) => + if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + case (StructType(fields1), StructType(fields2)) => val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { case (name, fieldTypes) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 922794ac9aac5..562b058414d07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -154,17 +154,19 @@ private[sql] class JSONRelation( } override def buildScan(): RDD[Row] = { + // Rely on type erasure hack to pass RDD[InternalRow] back as RDD[Row] JacksonParser( baseRDD(), schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) + sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] } override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JacksonParser( baseRDD(), StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) + sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 381e7ed54428f..1c309f8794ef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -110,8 +110,13 @@ private[sql] object JacksonParser { case (START_OBJECT, st: StructType) => convertObject(factory, parser, st) + case (START_ARRAY, st: StructType) => + // SPARK-3308: support reading top level JSON arrays and take every element + // in such an array as a row + convertArray(factory, parser, st) + case (START_ARRAY, ArrayType(st, _)) => - convertList(factory, parser, st) + convertArray(factory, parser, st) case (START_OBJECT, ArrayType(st, _)) => // the business end of SPARK-3308: @@ -165,16 +170,16 @@ private[sql] object JacksonParser { builder.result() } - private def convertList( + private def convertArray( factory: JsonFactory, parser: JsonParser, - schema: DataType): Seq[Any] = { - val builder = Seq.newBuilder[Any] + elementType: DataType): ArrayData = { + val values = scala.collection.mutable.ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_ARRAY)) { - builder += convertField(factory, parser, schema) + values += convertField(factory, parser, elementType) } - builder.result() + new GenericArrayData(values.toArray) } private def parseJson( @@ -201,12 +206,15 @@ private[sql] object JacksonParser { val parser = factory.createParser(record) parser.nextToken() - // to support both object and arrays (see SPARK-3308) we'll start - // by converting the StructType schema to an ArrayType and let - // convertField wrap an object into a single value array when necessary. - convertField(factory, parser, ArrayType(schema)) match { + convertField(factory, parser, schema) match { case null => failedRecord(record) - case list: Seq[InternalRow @unchecked] => list + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray().map(_.asInstanceOf[InternalRow]) + } case _ => sys.error( s"Failed to parse record $record. Please make sure that each line of the file " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala new file mode 100644 index 0000000000000..975fec101d9c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.util.{Map => JMap} + +import scala.collection.JavaConversions.{iterableAsScalaIterable, mapAsJavaMap, mapAsScalaMap} + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.io.api.RecordMaterializer +import org.apache.parquet.schema.MessageType + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with Logging { + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + log.debug(s"Preparing for read Parquet file with message type: $fileSchema") + + val toCatalyst = new CatalystSchemaConverter(conf) + val parquetRequestedSchema = readContext.getRequestedSchema + + val catalystRequestedSchema = + Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => + metadata + // First tries to read requested schema, which may result from projections + .get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + // If not available, tries to read Catalyst schema from file metadata. It's only + // available if the target file is written by Spark SQL. + .orElse(metadata.get(CatalystReadSupport.SPARK_METADATA_KEY)) + }.map(StructType.fromString).getOrElse { + logDebug("Catalyst schema not available, falling back to Parquet schema") + toCatalyst.convert(parquetRequestedSchema) + } + + logDebug(s"Catalyst schema used to read Parquet files: $catalystRequestedSchema") + new CatalystRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) + } + + override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration + + // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst + // schema of this file from its the metadata. + val maybeRowSchema = Option(conf.get(RowWriteSupport.SPARK_ROW_SCHEMA)) + + // Optional schema of requested columns, in the form of a string serialized from a Catalyst + // `StructType` containing all requested columns. + val maybeRequestedSchema = Option(conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) + + // Below we construct a Parquet schema containing all requested columns. This schema tells + // Parquet which columns to read. + // + // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, + // we have to fallback to the full file schema which contains all columns in the file. + // Obviously this may waste IO bandwidth since it may read more columns than requested. + // + // Two things to note: + // + // 1. It's possible that some requested columns don't exist in the target Parquet file. For + // example, in the case of schema merging, the globally merged schema may contain extra + // columns gathered from other Parquet files. These columns will be simply filled with nulls + // when actually reading the target Parquet file. + // + // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to + // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to + // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file + // containing a single integer array field `f1` may have the following legacy 2-level + // structure: + // + // message root { + // optional group f1 (LIST) { + // required INT32 element; + // } + // } + // + // while `CatalystSchemaConverter` may generate a standard 3-level structure: + // + // message root { + // optional group f1 (LIST) { + // repeated group list { + // required INT32 element; + // } + // } + // } + // + // Apparently, we can't use the 2nd schema to read the target Parquet file as they have + // different physical structures. + val parquetRequestedSchema = + maybeRequestedSchema.fold(context.getFileSchema) { schemaString => + val toParquet = new CatalystSchemaConverter(conf) + val fileSchema = context.getFileSchema.asGroupType() + val fileFieldNames = fileSchema.getFields.map(_.getName).toSet + + StructType + // Deserializes the Catalyst schema of requested columns + .fromString(schemaString) + .map { field => + if (fileFieldNames.contains(field.name)) { + // If the field exists in the target Parquet file, extracts the field type from the + // full file schema and makes a single-field Parquet schema + new MessageType("root", fileSchema.getType(field.name)) + } else { + // Otherwise, just resorts to `CatalystSchemaConverter` + toParquet.convert(StructType(Array(field))) + } + } + // Merges all single-field Parquet schemas to form a complete schema for all requested + // columns. Note that it's possible that no columns are requested at all (e.g., count + // some partition column of a partitioned Parquet table). That's why `fold` is used here + // and always fallback to an empty Parquet schema. + .fold(new MessageType("root")) { + _ union _ + } + } + + val metadata = + Map.empty[String, String] ++ + maybeRequestedSchema.map(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ + maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) + + logInfo(s"Going to read Parquet file with these requested columns: $parquetRequestedSchema") + new ReadContext(parquetRequestedSchema, metadata) + } +} + +private[parquet] object CatalystReadSupport { + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala new file mode 100644 index 0000000000000..84f1dccfeb788 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} +import org.apache.parquet.schema.MessageType + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +/** + * A [[RecordMaterializer]] for Catalyst rows. + * + * @param parquetSchema Parquet schema of the records to be read + * @param catalystSchema Catalyst schema of the rows to be constructed + */ +private[parquet] class CatalystRecordMaterializer( + parquetSchema: MessageType, catalystSchema: StructType) + extends RecordMaterializer[InternalRow] { + + private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) + + override def getCurrentRecord: InternalRow = rootConverter.currentRow + + override def getRootConverter: GroupConverter = rootConverter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index 0c3d8fdab6bd2..172db8362afb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder import scala.collection.JavaConversions._ @@ -28,7 +29,7 @@ import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveCo import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -55,8 +56,8 @@ private[parquet] trait ParentContainerUpdater { private[parquet] object NoopUpdater extends ParentContainerUpdater /** - * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[Row]]s. Since - * any Parquet record is also a struct, this converter can also be used as root converter. + * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. + * Since any Parquet record is also a struct, this converter can also be used as root converter. * * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. @@ -108,7 +109,7 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = { var i = 0 - while (i < currentRow.length) { + while (i < currentRow.numFields) { currentRow.setNullAt(i) i += 1 } @@ -178,7 +179,7 @@ private[parquet] class CatalystRowConverter( case t: StructType => new CatalystRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater { - override def set(value: Any): Unit = updater.set(value.asInstanceOf[Row].copy()) + override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) case t: UserDefinedType[_] => @@ -263,17 +264,23 @@ private[parquet] class CatalystRowConverter( val scale = decimalType.scale val bytes = value.getBytes - var unscaled = 0L - var i = 0 + if (precision <= 8) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + var unscaled = 0L + var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } - val bits = 8 * bytes.length - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - Decimal(unscaled, precision, scale) + val bits = 8 * bytes.length + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + Decimal(unscaled, precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale) + } } } @@ -318,7 +325,7 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = elementConverter - override def end(): Unit = updater.set(currentArray) + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) // NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the // next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 1ea6926af6d5b..d43ca95b4eea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -387,24 +387,18 @@ private[parquet] class CatalystSchemaConverter( // ===================================== // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and - // always store decimals in fixed-length byte arrays. - case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => + // always store decimals in fixed-length byte arrays. To keep compatibility with these older + // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated + // by `DECIMAL`. + case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) .as(DECIMAL) .precision(precision) .scale(scale) - .length(minBytesForPrecision(precision)) + .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType() if !followParquetFormatSpec => - throw new AnalysisException( - s"Data type $dec is not supported. " + - s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," + - "decimal precision and scale must be specified, " + - "and precision must be less than or equal to 18.") - // ===================================== // Decimals (follow Parquet format spec) // ===================================== @@ -436,13 +430,9 @@ private[parquet] class CatalystSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(minBytesForPrecision(precision)) + .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType.Unlimited if followParquetFormatSpec => - throw new AnalysisException( - s"Data type $dec is not supported. Decimal precision and scale must be specified.") - // =================================================== // ArrayType and MapType (for Spark versions <= 1.4.x) // =================================================== @@ -552,15 +542,6 @@ private[parquet] class CatalystSchemaConverter( Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes .asInstanceOf[Int] } - - // Min byte counts needed to store decimals with various precisions - private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => - var numBytes = 1 - while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { - numBytes += 1 - } - numBytes - } } @@ -574,9 +555,33 @@ private[parquet] object CatalystSchemaConverter { """.stripMargin.split("\n").mkString(" ")) } + def checkFieldNames(schema: StructType): StructType = { + schema.fieldNames.foreach(checkFieldName) + schema + } + def analysisRequire(f: => Boolean, message: String): Unit = { if (!f) { throw new AnalysisException(message) } } + + private def computeMinBytesForPrecision(precision : Int) : Int = { + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } + + private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision) + + // Returns the minimum number of bytes needed to store a decimal with a given `precision`. + def minBytesForPrecision(precision : Int) : Int = { + if (precision < MIN_BYTES_FOR_PRECISION.length) { + MIN_BYTES_FOR_PRECISION(precision) + } else { + computeMinBytesForPrecision(precision) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index be0a2029d233b..2332a36468dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.parquet import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.ArrayData +// TODO Removes this while fixing SPARK-8848 private[sql] object CatalystConverter { // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). // Note that "array" for the array elements is chosen by ParquetAvro. @@ -31,7 +33,7 @@ private[sql] object CatalystConverter { val MAP_SCHEMA_NAME = "map" // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType[T] = Seq[T] + type ArrayScalaType[T] = ArrayData type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 086559e9f7658..b4337a48dbd80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -17,81 +17,736 @@ package org.apache.spark.sql.parquet -import java.io.IOException +import java.net.URI import java.util.logging.{Level, Logger => JLogger} +import java.util.{List => JList} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.permission.FsAction +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.{Failure, Try} + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetRecordReader, _} import org.apache.parquet.schema.MessageType import org.apache.parquet.{Log => ParquetLog} -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} +import org.apache.spark.rdd.RDD._ +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} + + +private[sql] class DefaultSource extends HadoopFsRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext) + } +} + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) + extends OutputWriterInternal { + + private val recordWriter: RecordWriter[Void, InternalRow] = { + val outputFormat = { + new ParquetOutputFormat[InternalRow]() { + // Here we override `getDefaultWorkFile` for two reasons: + // + // 1. To allow appending. We need to generate unique output file names to avoid + // overwriting existing files (either exist before the write job, or are just written + // by other tasks within the same write job). + // + // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses + // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all + // partitions in the case of dynamic partitioning. + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val split = context.getTaskAttemptID.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + } + } + + outputFormat.getRecordWriter(context) + } + + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + + override def close(): Unit = recordWriter.close(context) +} + +private[sql] class ParquetRelation( + override val paths: Array[String], + private val maybeDataSchema: Option[StructType], + // This is for metastore conversion. + private val maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec) + with Logging { + + private[sql] def this( + paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + sqlContext: SQLContext) = { + this( + paths, + maybeDataSchema, + maybePartitionSpec, + maybePartitionSpec.map(_.partitionColumns), + parameters)(sqlContext) + } + + // Should we merge schemas from all Parquet part-files? + private val shouldMergeSchemas = + parameters + .get(ParquetRelation.MERGE_SCHEMA) + .map(_.toBoolean) + .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + + private val mergeRespectSummaries = + sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) + + private val maybeMetastoreSchema = parameters + .get(ParquetRelation.METASTORE_SCHEMA) + .map(DataType.fromJson(_).asInstanceOf[StructType]) + + private lazy val metadataCache: MetadataCache = { + val meta = new MetadataCache + meta.refresh() + meta + } -/** - * Relation that consists of data stored in a Parquet columnar format. - * - * Users should interact with parquet files though a [[DataFrame]], created by a [[SQLContext]] - * instead of using this class directly. - * - * {{{ - * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file") - * }}} - * - * @param path The path to the Parquet file. - */ -private[sql] case class ParquetRelation( - path: String, - @transient conf: Option[Configuration], - @transient sqlContext: SQLContext, - partitioningAttributes: Seq[Attribute] = Nil) - extends LeafNode with MultiInstanceRelation { - - /** Schema derived from ParquetFile */ - def parquetSchema: MessageType = - ParquetTypesConverter - .readMetaData(new Path(path), conf) - .getFileMetaData - .getSchema - - /** Attributes */ - override val output = - partitioningAttributes ++ - ParquetTypesConverter.readSchemaFromFile( - new Path(path.split(",").head), - conf, - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp) - lazy val attributeMap = AttributeMap(output.map(o => o -> o)) - - override def newInstance(): this.type = { - ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] - } - - // Equals must also take into account the output attributes so that we can distinguish between - // different instances of the same relation, override def equals(other: Any): Boolean = other match { - case p: ParquetRelation => - p.path == path && p.output == output + case that: ParquetRelation => + val schemaEquality = if (shouldMergeSchemas) { + this.shouldMergeSchemas == that.shouldMergeSchemas + } else { + this.dataSchema == that.dataSchema && + this.schema == that.schema + } + + this.paths.toSet == that.paths.toSet && + schemaEquality && + this.maybeDataSchema == that.maybeDataSchema && + this.partitionColumns == that.partitionColumns + case _ => false } - override def hashCode: Int = { - com.google.common.base.Objects.hashCode(path, output) + override def hashCode(): Int = { + if (shouldMergeSchemas) { + Objects.hashCode( + Boolean.box(shouldMergeSchemas), + paths.toSet, + maybeDataSchema, + partitionColumns) + } else { + Objects.hashCode( + Boolean.box(shouldMergeSchemas), + paths.toSet, + dataSchema, + schema, + maybeDataSchema, + partitionColumns) + } + } + + /** Constraints on schema of dataframe to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to parquet format") + } + } + + override def dataSchema: StructType = { + val schema = maybeDataSchema.getOrElse(metadataCache.dataSchema) + // check if schema satisfies the constraints + // before moving forward + checkConstraints(schema) + schema + } + + override private[sql] def refresh(): Unit = { + super.refresh() + metadataCache.refresh() + } + + // Parquet data source always uses Catalyst internal representations. + override val needConversion: Boolean = false + + override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[ParquetOutputCommitter], + classOf[ParquetOutputCommitter]) + + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo("Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + + conf.setClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + committerClass, + classOf[ParquetOutputCommitter]) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + // TODO There's no need to use two kinds of WriteSupport + // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and + // complex types. + val writeSupportClass = + if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) + RowWriteSupport.setSchema(dataSchema.toAttributes, conf) + + // Sets compression scheme + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + + new OutputWriterFactory { + override def newInstance( + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } + } + } + + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) + val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + + // Create the function to set variable Parquet confs at both driver and executor side. + val initLocalJobFuncOpt = + ParquetRelation.initializeLocalJobFunc( + requiredColumns, + filters, + dataSchema, + useMetadataCache, + parquetFilterPushDown, + assumeBinaryIsString, + assumeInt96IsTimestamp, + followParquetFormatSpec) _ + + // Create the function to set input paths at the driver side. + val setInputPaths = ParquetRelation.initializeDriverSideJobFunc(inputFiles) _ + + Utils.withDummyCallSite(sqlContext.sparkContext) { + new SqlNewHadoopRDD( + sc = sqlContext.sparkContext, + broadcastedConf = broadcastedConf, + initDriverSideJobFuncOpt = Some(setInputPaths), + initLocalJobFuncOpt = Some(initLocalJobFuncOpt), + inputFormatClass = classOf[ParquetInputFormat[InternalRow]], + keyClass = classOf[Void], + valueClass = classOf[InternalRow]) { + + val cacheMetadata = useMetadataCache + + @transient val cachedStatuses = inputFiles.map { f => + // In order to encode the authority of a Path containing special characters such as '/' + // (which does happen in some S3N credentials), we need to use the string returned by the + // URI of the path to create a new Path. + val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) + new FileStatus( + f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) + }.toSeq + + private def escapePathUserInfo(path: Path): Path = { + val uri = path.toUri + new Path(new URI( + uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, + uri.getQuery, uri.getFragment)) + } + + // Overridden so we can inject our own cached files statuses. + override def getPartitions: Array[SparkPartition] = { + val inputFormat = new ParquetInputFormat[InternalRow] { + override def listStatus(jobContext: JobContext): JList[FileStatus] = { + if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) + } + } + + val jobContext = newJobContext(getConf(isDriverSide = true), jobId) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + } + }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] + } } - // TODO: Use data from the footers. - override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes) + private class MetadataCache { + // `FileStatus` objects of all "_metadata" files. + private var metadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all "_common_metadata" files. + private var commonMetadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all data files (Parquet part-files). + var dataStatuses: Array[FileStatus] = _ + + // Schema of the actual Parquet files, without partition columns discovered from partition + // directory paths. + var dataSchema: StructType = null + + // Schema of the whole table, including partition columns. + var schema: StructType = _ + + // Cached leaves + var cachedLeaves: Set[FileStatus] = null + + /** + * Refreshes `FileStatus`es, footers, partition spec, and table schema. + */ + def refresh(): Unit = { + val currentLeafStatuses = cachedLeafStatuses() + + // Check if cachedLeafStatuses is changed or not + val leafStatusesChanged = (cachedLeaves == null) || + !cachedLeaves.equals(currentLeafStatuses) + + if (leafStatusesChanged) { + cachedLeaves = currentLeafStatuses.toIterator.toSet + + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. + val leaves = currentLeafStatuses.filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + }.toArray + + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + + dataSchema = { + val dataSchema0 = maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(throw new AnalysisException( + s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + + paths.mkString("\n\t"))) + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // case insensitivity issue and possible schema mismatch (probably caused by schema + // evolution). + maybeMetastoreSchema + .map(ParquetRelation.mergeMetastoreParquetSchema(_, dataSchema0)) + .getOrElse(dataSchema0) + } + } + } + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + private def readSchema(): Option[StructType] = { + // Sees which file(s) we need to touch in order to figure out the schema. + // + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. If no summary file is available, falls back to some random part-file. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + val filesToTouch = + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + + // If mergeRespectSummaries config is true, we assume that all part-files are the same for + // their schema with summary files, so we ignore them when merging schema. + // If the config is disabled, which is the default setting, we merge all part-files. + // In this mode, we only need to merge schemas contained in all those summary files. + // You should enable this configuration only if you are very sure that for the parquet + // part-files to read there are corresponding summary files containing correct schema. + + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq() + } else { + dataStatuses + } + (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + commonMetadataStatuses.headOption + // Falls back to "_metadata" + .orElse(metadataStatuses.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(dataStatuses.headOption) + .toSeq + } + + assert( + filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, + "No predefined schema found, " + + s"and no Parquet data files or summary files found under ${paths.mkString(", ")}.") + + ParquetRelation.mergeSchemasInParallel(filesToTouch, sqlContext) + } + } } -private[sql] object ParquetRelation { +private[sql] object ParquetRelation extends Logging { + // Whether we should merge schemas collected from all Parquet part-files. + private[sql] val MERGE_SCHEMA = "mergeSchema" + + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" + + /** This closure sets various Parquet configurations at both driver side and executor side. */ + private[parquet] def initializeLocalJobFunc( + requiredColumns: Array[String], + filters: Array[Filter], + dataSchema: StructType, + useMetadataCache: Boolean, + parquetFilterPushDown: Boolean, + assumeBinaryIsString: Boolean, + assumeInt96IsTimestamp: Boolean, + followParquetFormatSpec: Boolean)(job: Job): Unit = { + val conf = job.getConfiguration + conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) + + // Try to push down filters when filter push-down is enabled. + if (parquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(dataSchema, _)) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) + } + + conf.set(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { + val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) + CatalystSchemaConverter.checkFieldNames(requestedSchema).json + }) + + conf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + CatalystSchemaConverter.checkFieldNames(dataSchema).json) + + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) + + // Sets flags for Parquet schema conversion + conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) + conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) + conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) + } + + /** This closure sets input paths at the driver side. */ + private[parquet] def initializeDriverSideJobFunc( + inputFiles: Array[FileStatus])(job: Job): Unit = { + // We side the input paths at the driver side. + logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") + if (inputFiles.nonEmpty) { + FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) + } + } + + private[parquet] def readSchema( + footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { + + def parseParquetSchema(schema: MessageType): StructType = { + val converter = new CatalystSchemaConverter( + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.followParquetFormatSpec) + + converter.convert(schema) + } + + val seen = mutable.HashSet[String]() + val finalSchemas: Seq[StructType] = footers.flatMap { footer => + val metadata = footer.getParquetMetadata.getFileMetaData + val serializedSchema = metadata + .getKeyValueMetaData + .toMap + .get(CatalystReadSupport.SPARK_METADATA_KEY) + if (serializedSchema.isEmpty) { + // Falls back to Parquet schema if no Spark SQL schema found. + Some(parseParquetSchema(metadata.getSchema)) + } else if (!seen.contains(serializedSchema.get)) { + seen += serializedSchema.get + + // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to + // whatever is available. + Some(Try(DataType.fromJson(serializedSchema.get)) + .recover { case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(serializedSchema.get) + } + .recover { case cause: Throwable => + logWarning( + s"""Failed to parse serialized Spark schema in Parquet key-value metadata: + |\t$serializedSchema + """.stripMargin, + cause) + } + .map(_.asInstanceOf[StructType]) + .getOrElse { + // Falls back to Parquet schema if Spark SQL schema can't be parsed. + parseParquetSchema(metadata.getSchema) + }) + } else { + None + } + } + + finalSchemas.reduceOption { (left, right) => + try left.merge(right) catch { case e: Throwable => + throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + } + } + } + + /** + * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore + * schema and Parquet schema. + * + * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the + * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't + * distinguish binary and string). This method generates a correct schema by merging Metastore + * schema data types and Parquet schema field names. + */ + private[parquet] def mergeMetastoreParquetSchema( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + def schemaConflictMessage: String = + s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Parquet schema: + |${parquetSchema.prettyJson} + """.stripMargin + + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) + + val ordinalMap = metastoreSchema.zipWithIndex.map { + case (field, index) => field.name.toLowerCase -> index + }.toMap + + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) + + StructType(metastoreSchema.zip(reorderedParquetSchema).map { + // Uses Parquet field names but retains Metastore data types. + case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => + mSchema.copy(name = pSchema.name) + case _ => + throw new SparkException(schemaConflictMessage) + }) + } + + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } + + /** + * Figures out a merged Parquet schema with a distributed Spark job. + * + * Note that locality is not taken into consideration here because: + * + * 1. For a single Parquet part-file, in most cases the footer only resides in the last block of + * that file. Thus we only need to retrieve the location of the last block. However, Hadoop + * `FileSystem` only provides API to retrieve locations of all blocks, which can be + * potentially expensive. + * + * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty + * slow. And basically locality is not available when using S3 (you can't run computation on + * S3 nodes). + */ + def mergeSchemasInParallel( + filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) + + // HACK ALERT: + // + // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es + // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` + // but only `Writable`. What makes it worth, for some reason, `FileStatus` doesn't play well + // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These + // facts virtually prevents us to serialize `FileStatus`es. + // + // Since Parquet only relies on path and length information of those `FileStatus`es to read + // footers, here we just extract them (which can be easily serialized), send them to executor + // side, and resemble fake `FileStatus`es there. + val partialFileStatusInfo = filesToTouch.map(f => (f.getPath.toString, f.getLen)) + + // Issues a Spark job to read Parquet schema in parallel. + val partiallyMergedSchemas = + sqlContext + .sparkContext + .parallelize(partialFileStatusInfo) + .mapPartitions { iterator => + // Resembles fake `FileStatus`es with serialized path and length information. + val fakeFileStatuses = iterator.map { case (path, length) => + new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) + }.toSeq + + // Skips row group information since we only need the schema + val skipRowGroups = true + + // Reads footers in multi-threaded manner within each task + val footers = + ParquetFileReader.readAllFootersInParallel( + serializedConf.value, fakeFileStatuses, skipRowGroups) + + // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` + val converter = + new CatalystSchemaConverter( + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + footers.map { footer => + ParquetRelation.readSchemaFromFooter(footer, converter) + }.reduceOption(_ merge _).iterator + }.collect() + + partiallyMergedSchemas.reduceOption(_ merge _) + } + + /** + * Reads Spark SQL schema from a Parquet footer. If a valid serialized Spark SQL schema string + * can be found in the file metadata, returns the deserialized [[StructType]], otherwise, returns + * a [[StructType]] converted from the [[MessageType]] stored in this footer. + */ + def readSchemaFromFooter( + footer: Footer, converter: CatalystSchemaConverter): StructType = { + val fileMetaData = footer.getParquetMetadata.getFileMetaData + fileMetaData + .getKeyValueMetaData + .toMap + .get(CatalystReadSupport.SPARK_METADATA_KEY) + .flatMap(deserializeSchemaString) + .getOrElse(converter.convert(fileMetaData.getSchema)) + } + + private def deserializeSchemaString(schemaString: String): Option[StructType] = { + // Tries to deserialize the schema string as JSON first, then falls back to the case class + // string parser (data generated by older versions of Spark SQL uses this format). + Try(DataType.fromJson(schemaString).asInstanceOf[StructType]).recover { + case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] + }.recoverWith { + case cause: Throwable => + logWarning( + "Failed to parse and ignored serialized Spark schema in " + + s"Parquet key-value metadata:\n\t$schemaString", cause) + Failure(cause) + }.toOption + } def enableLogForwarding() { // Note: the org.apache.parquet.Log class has a static initializer that @@ -127,12 +782,6 @@ private[sql] object ParquetRelation { JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) } - // The element type for the RDDs that this relation maps to. - type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow - - // The compression type - type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName - // The parquet compression short names val shortParquetCompressionCodecNames = Map( "NONE" -> CompressionCodecName.UNCOMPRESSED, @@ -140,82 +789,4 @@ private[sql] object ParquetRelation { "SNAPPY" -> CompressionCodecName.SNAPPY, "GZIP" -> CompressionCodecName.GZIP, "LZO" -> CompressionCodecName.LZO) - - /** - * Creates a new ParquetRelation and underlying Parquetfile for the given LogicalPlan. Note that - * this is used inside [[org.apache.spark.sql.execution.SparkStrategies SparkStrategies]] to - * create a resolved relation as a data sink for writing to a Parquetfile. The relation is empty - * but is initialized with ParquetMetadata and can be inserted into. - * - * @param pathString The directory the Parquetfile will be stored in. - * @param child The child node that will be used for extracting the schema. - * @param conf A configuration to be used. - * @return An empty ParquetRelation with inferred metadata. - */ - def create(pathString: String, - child: LogicalPlan, - conf: Configuration, - sqlContext: SQLContext): ParquetRelation = { - if (!child.resolved) { - throw new UnresolvedException[LogicalPlan]( - child, - "Attempt to create Parquet table from unresolved child (when schema is not available)") - } - createEmpty(pathString, child.output, false, conf, sqlContext) - } - - /** - * Creates an empty ParquetRelation and underlying Parquetfile that only - * consists of the Metadata for the given schema. - * - * @param pathString The directory the Parquetfile will be stored in. - * @param attributes The schema of the relation. - * @param conf A configuration to be used. - * @return An empty ParquetRelation. - */ - def createEmpty(pathString: String, - attributes: Seq[Attribute], - allowExisting: Boolean, - conf: Configuration, - sqlContext: SQLContext): ParquetRelation = { - val path = checkPath(pathString, allowExisting, conf) - conf.set(ParquetOutputFormat.COMPRESSION, shortParquetCompressionCodecNames.getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED) - .name()) - ParquetRelation.enableLogForwarding() - // This is a hack. We always set nullable/containsNull/valueContainsNull to true - // for the schema of a parquet data. - val schema = StructType.fromAttributes(attributes).asNullable - val newAttributes = schema.toAttributes - ParquetTypesConverter.writeMetaData(newAttributes, path, conf) - new ParquetRelation(path.toString, Some(conf), sqlContext) { - override val output = newAttributes - } - } - - private def checkPath(pathStr: String, allowExisting: Boolean, conf: Configuration): Path = { - if (pathStr == null) { - throw new IllegalArgumentException("Unable to create ParquetRelation: path is null") - } - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to create ParquetRelation: incorrectly formatted path $pathStr") - } - val path = origPath.makeQualified(fs) - if (!allowExisting && fs.exists(path)) { - sys.error(s"File $pathStr already exists.") - } - - if (fs.exists(path) && - !fs.getFileStatus(path) - .getPermission - .getUserAction - .implies(FsAction.READ_WRITE)) { - throw new IOException( - s"Unable to create ParquetRelation: path $path not read-writable") - } - path - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala deleted file mode 100644 index 28cba5e54d69e..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ /dev/null @@ -1,492 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException -import java.text.{NumberFormat, SimpleDateFormat} -import java.util.concurrent.TimeUnit -import java.util.Date - -import scala.collection.JavaConversions._ -import scala.util.Try - -import com.google.common.cache.CacheBuilder -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.api.ReadSupport -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.schema.MessageType - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, _} -import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.util.SerializableConfiguration - -/** - * :: DeveloperApi :: - * Parquet table scan operator. Imports the file that backs the given - * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[InternalRow]``. - */ -private[sql] case class ParquetTableScan( - attributes: Seq[Attribute], - relation: ParquetRelation, - columnPruningPred: Seq[Expression]) - extends LeafNode { - - // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes - // by exprId. note: output cannot be transient, see - // https://issues.apache.org/jira/browse/SPARK-1367 - val output = attributes.map(relation.attributeMap) - - // A mapping of ordinals partitionRow -> finalOutput. - val requestedPartitionOrdinals = { - val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex) - - attributes.zipWithIndex.flatMap { - case (attribute, finalOrdinal) => - partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal) - } - }.toArray - - protected override def doExecute(): RDD[InternalRow] = { - import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat - - val sc = sqlContext.sparkContext - val job = new Job(sc.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - - val conf: Configuration = ContextUtil.getConfiguration(job) - - relation.path.split(",").foreach { curPath => - val qualifiedPath = { - val path = new Path(curPath) - path.getFileSystem(conf).makeQualified(path) - } - NewFileInputFormat.addInputPath(job, qualifiedPath) - } - - // Store both requested and original schema in `Configuration` - conf.set( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(output)) - conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(relation.output)) - - // Store record filtering predicate in `Configuration` - // Note 1: the input format ignores all predicates that cannot be expressed - // as simple column predicate filters in Parquet. Here we just record - // the whole pruning predicate. - ParquetFilters - .createRecordFilter(columnPruningPred) - .map(_.asInstanceOf[FilterPredicateCompat].getFilterPredicate) - // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering - .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.setBoolean( - SQLConf.PARQUET_CACHE_METADATA.key, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, true)) - - // Use task side metadata in parquet - conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) - - val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( - sc, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[InternalRow], - conf) - - if (requestedPartitionOrdinals.nonEmpty) { - // This check is based on CatalystConverter.createRootConverter. - val primitiveRow = output.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) - - // Uses temporary variable to avoid the whole `ParquetTableScan` object being captured into - // the `mapPartitionsWithInputSplit` closure below. - val outputSize = output.size - - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - // Convert the partitioning attributes into the correct types - val partitionRowValues = - relation.partitioningAttributes - .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) - - if (primitiveRow) { - new Iterator[InternalRow] { - def hasNext: Boolean = iter.hasNext - def next(): InternalRow = { - // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. - val row = iter.next()._2.asInstanceOf[SpecificMutableRow] - - // Parquet will leave partitioning columns empty, so we fill them in here. - var i = 0 - while (i < requestedPartitionOrdinals.size) { - row(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 - } - row - } - } - } else { - // Create a mutable row since we need to fill in values from partition columns. - val mutableRow = new GenericMutableRow(outputSize) - new Iterator[InternalRow] { - def hasNext: Boolean = iter.hasNext - def next(): InternalRow = { - // We are using CatalystGroupConverter and it returns a GenericRow. - // Since GenericRow is not mutable, we just cast it to a Row. - val row = iter.next()._2.asInstanceOf[InternalRow] - - var i = 0 - while (i < row.size) { - mutableRow(i) = row(i) - i += 1 - } - // Parquet will leave partitioning columns empty, so we fill them in here. - i = 0 - while (i < requestedPartitionOrdinals.size) { - mutableRow(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 - } - mutableRow - } - } - } - } - } else { - baseRDD.map(_._2) - } - } - - /** - * Applies a (candidate) projection. - * - * @param prunedAttributes The list of attributes to be used in the projection. - * @return Pruned TableScan. - */ - def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { - val success = validateProjection(prunedAttributes) - if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred) - } else { - sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") - } - } - - /** - * Evaluates a candidate projection by checking whether the candidate is a subtype - * of the original type. - * - * @param projection The candidate projection. - * @return True if the projection is valid, false otherwise. - */ - private def validateProjection(projection: Seq[Attribute]): Boolean = { - val original: MessageType = relation.parquetSchema - val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) - Try(original.checkContains(candidate)).isSuccess - } -} - -/** - * :: DeveloperApi :: - * Operator that acts as a sink for queries on RDDs and can be used to - * store the output inside a directory of Parquet files. This operator - * is similar to Hive's INSERT INTO TABLE operation in the sense that - * one can choose to either overwrite or append to a directory. Note - * that consecutive insertions to the same table must have compatible - * (source) schemas. - * - * WARNING: EXPERIMENTAL! InsertIntoParquetTable with overwrite=false may - * cause data corruption in the case that multiple users try to append to - * the same table simultaneously. Inserting into a table that was - * previously generated by other means (e.g., by creating an HDFS - * directory and importing Parquet files generated by other tools) may - * cause unpredicted behaviour and therefore results in a RuntimeException - * (only detected via filename pattern so will not catch all cases). - */ -@DeveloperApi -private[sql] case class InsertIntoParquetTable( - relation: ParquetRelation, - child: SparkPlan, - overwrite: Boolean = false) - extends UnaryNode with SparkHadoopMapReduceUtil { - - /** - * Inserts all rows into the Parquet file. - */ - protected override def doExecute(): RDD[InternalRow] = { - // TODO: currently we do not check whether the "schema"s are compatible - // That means if one first creates a table and then INSERTs data with - // and incompatible schema the execution will fail. It would be nice - // to catch this early one, maybe having the planner validate the schema - // before calling execute(). - - val childRdd = child.execute() - assert(childRdd != null) - - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - - val writeSupport = - if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - log.debug("Initializing MutableRowWriteSupport") - classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] - } else { - classOf[org.apache.spark.sql.parquet.RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - - val conf = ContextUtil.getConfiguration(job) - // This is a hack. We always set nullable/containsNull/valueContainsNull to true - // for the schema of a parquet data. - val schema = StructType.fromAttributes(relation.output).asNullable - RowWriteSupport.setSchema(schema.toAttributes, conf) - - val fspath = new Path(relation.path) - val fs = fspath.getFileSystem(conf) - - if (overwrite) { - try { - fs.delete(fspath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${fspath.toString} prior" - + s" to InsertIntoParquetTable:\n${e.toString}") - } - } - saveAsHadoopFile(childRdd, relation.path.toString, conf) - - // We return the child RDD to allow chaining (alternatively, one could return nothing). - childRdd - } - - override def output: Seq[Attribute] = child.output - - /** - * Stores the given Row RDD as a Hadoop file. - * - * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]] - * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses - * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing - * directory and need to determine which was the largest written file index before starting to - * write. - * - * @param rdd The [[org.apache.spark.rdd.RDD]] to writer - * @param path The directory to write to. - * @param conf A [[org.apache.hadoop.conf.Configuration]]. - */ - private def saveAsHadoopFile( - rdd: RDD[InternalRow], - path: String, - conf: Configuration) { - val job = new Job(conf) - val keyType = classOf[Void] - job.setOutputKeyClass(keyType) - job.setOutputValueClass(classOf[InternalRow]) - NewFileOutputFormat.setOutputPath(job, new Path(path)) - val wrappedConf = new SerializableConfiguration(job.getConfiguration) - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - val stageId = sqlContext.sparkContext.newRddId() - - val taskIdOffset = - if (overwrite) { - 1 - } else { - FileSystemHelper - .findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 - } - - def writeShard(context: TaskContext, iter: Iterator[InternalRow]): Int = { - /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, - context.attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = new AppendingParquetOutputFormat(taskIdOffset) - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext) - try { - while (iter.hasNext) { - val row = iter.next() - writer.write(null, row) - } - } finally { - writer.close(hadoopContext) - } - SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) - 1 - } - val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - jobCommitter.setupJob(jobTaskContext) - sqlContext.sparkContext.runJob(rdd, writeShard _) - jobCommitter.commitJob(jobTaskContext) - } -} - -/** - * TODO: this will be able to append to directories it created itself, not necessarily - * to imported ones. - */ -private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends org.apache.parquet.hadoop.ParquetOutputFormat[InternalRow] { - // override to accept existing directories as valid output directory - override def checkOutputSpecs(job: JobContext): Unit = {} - var committer: OutputCommitter = null - - // override to choose output filename so not overwrite existing ones - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val taskId: TaskID = getTaskAttemptID(context).getTaskID - val partition: Int = taskId.getId - val filename = "part-r-" + numfmt.format(partition + offset) + ".parquet" - val committer: FileOutputCommitter = - getOutputCommitter(context).asInstanceOf[FileOutputCommitter] - new Path(committer.getWorkPath, filename) - } - - // The TaskAttemptContext is a class in hadoop-1 but is an interface in hadoop-2. - // The signatures of the method TaskAttemptContext.getTaskAttemptID for the both versions - // are the same, so the method calls are source-compatible but NOT binary-compatible because - // the opcode of method call for class is INVOKEVIRTUAL and for interface is INVOKEINTERFACE. - private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { - context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] - } - - // override to create output committer from configuration - override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - if (committer == null) { - val output = getOutputPath(context) - val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", - classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) - val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] - } - committer - } - - // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 - private def getOutputPath(context: TaskAttemptContext): Path = { - context.getConfiguration().get("mapred.output.dir") match { - case null => null - case name => new Path(name) - } - } -} - -// TODO Removes this class after removing old Parquet support code -/** - * We extend ParquetInputFormat in order to have more control over which - * RecordFilter we want to use. - */ -private[parquet] class FilteringParquetRowInputFormat - extends org.apache.parquet.hadoop.ParquetInputFormat[InternalRow] with Logging { - - override def createRecordReader( - inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = { - - import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter - - val readSupport: ReadSupport[InternalRow] = new RowReadSupport() - - val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) - if (!filter.isInstanceOf[NoOpFilter]) { - new ParquetRecordReader[InternalRow]( - readSupport, - filter) - } else { - new ParquetRecordReader[InternalRow](readSupport) - } - } - -} - -private[parquet] object FileSystemHelper { - def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"ParquetTableOperations: Path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (!fs.exists(path) || !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException( - s"ParquetTableOperations: path $path does not exist or is not a directory") - } - fs.globStatus(path) - .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } - .map(_.getPath) - } - - /** - * Finds the maximum taskid in the output file names at the given path. - */ - def findMaxTaskId(pathStr: String, conf: Configuration): Int = { - val files = FileSystemHelper.listFiles(pathStr, conf) - // filename pattern is part-r-.parquet - val nameP = new scala.util.matching.Regex("""part-.-(\d{1,}).*""", "taskid") - val hiddenFileP = new scala.util.matching.Regex("_.*") - files.map(_.getName).map { - case nameP(taskid) => taskid.toInt - case hiddenFileP() => 0 - case other: String => - sys.error("ERROR: attempting to append to set of Parquet files and found file" + - s"that does not match name pattern: $other") - case _ => 0 - }.reduceOption(_ max _).getOrElse(0) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index e8851ddb68026..ec8da38a3d427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,19 +17,15 @@ package org.apache.spark.sql.parquet +import java.math.BigInteger import java.nio.{ByteBuffer, ByteOrder} -import java.util import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ - import org.apache.hadoop.conf.Configuration import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext -import org.apache.parquet.hadoop.api.{InitContext, ReadSupport, WriteSupport} +import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.io.api._ -import org.apache.parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -38,147 +34,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -/** - * A [[RecordMaterializer]] for Catalyst rows. - * - * @param parquetSchema Parquet schema of the records to be read - * @param catalystSchema Catalyst schema of the rows to be constructed - */ -private[parquet] class RowRecordMaterializer(parquetSchema: MessageType, catalystSchema: StructType) - extends RecordMaterializer[InternalRow] { - - private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) - - override def getCurrentRecord: InternalRow = rootConverter.currentRow - - override def getRootConverter: GroupConverter = rootConverter -} - -private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logging { - override def prepareForRead( - conf: Configuration, - keyValueMetaData: util.Map[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[InternalRow] = { - log.debug(s"Preparing for read Parquet file with message type: $fileSchema") - - val toCatalyst = new CatalystSchemaConverter(conf) - val parquetRequestedSchema = readContext.getRequestedSchema - - val catalystRequestedSchema = - Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => - metadata - // First tries to read requested schema, which may result from projections - .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) - // If not available, tries to read Catalyst schema from file metadata. It's only - // available if the target file is written by Spark SQL. - .orElse(metadata.get(RowReadSupport.SPARK_METADATA_KEY)) - }.map(StructType.fromString).getOrElse { - logDebug("Catalyst schema not available, falling back to Parquet schema") - toCatalyst.convert(parquetRequestedSchema) - } - - logDebug(s"Catalyst schema used to read Parquet files: $catalystRequestedSchema") - new RowRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) - } - - override def init(context: InitContext): ReadContext = { - val conf = context.getConfiguration - - // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst - // schema of this file from its the metadata. - val maybeRowSchema = Option(conf.get(RowWriteSupport.SPARK_ROW_SCHEMA)) - - // Optional schema of requested columns, in the form of a string serialized from a Catalyst - // `StructType` containing all requested columns. - val maybeRequestedSchema = Option(conf.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) - - // Below we construct a Parquet schema containing all requested columns. This schema tells - // Parquet which columns to read. - // - // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, - // we have to fallback to the full file schema which contains all columns in the file. - // Obviously this may waste IO bandwidth since it may read more columns than requested. - // - // Two things to note: - // - // 1. It's possible that some requested columns don't exist in the target Parquet file. For - // example, in the case of schema merging, the globally merged schema may contain extra - // columns gathered from other Parquet files. These columns will be simply filled with nulls - // when actually reading the target Parquet file. - // - // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to - // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to - // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file - // containing a single integer array field `f1` may have the following legacy 2-level - // structure: - // - // message root { - // optional group f1 (LIST) { - // required INT32 element; - // } - // } - // - // while `CatalystSchemaConverter` may generate a standard 3-level structure: - // - // message root { - // optional group f1 (LIST) { - // repeated group list { - // required INT32 element; - // } - // } - // } - // - // Apparently, we can't use the 2nd schema to read the target Parquet file as they have - // different physical structures. - val parquetRequestedSchema = - maybeRequestedSchema.fold(context.getFileSchema) { schemaString => - val toParquet = new CatalystSchemaConverter(conf) - val fileSchema = context.getFileSchema.asGroupType() - val fileFieldNames = fileSchema.getFields.map(_.getName).toSet - - StructType - // Deserializes the Catalyst schema of requested columns - .fromString(schemaString) - .map { field => - if (fileFieldNames.contains(field.name)) { - // If the field exists in the target Parquet file, extracts the field type from the - // full file schema and makes a single-field Parquet schema - new MessageType("root", fileSchema.getType(field.name)) - } else { - // Otherwise, just resorts to `CatalystSchemaConverter` - toParquet.convert(StructType(Array(field))) - } - } - // Merges all single-field Parquet schemas to form a complete schema for all requested - // columns. Note that it's possible that no columns are requested at all (e.g., count - // some partition column of a partitioned Parquet table). That's why `fold` is used here - // and always fallback to an empty Parquet schema. - .fold(new MessageType("root")) { - _ union _ - } - } - - val metadata = - Map.empty[String, String] ++ - maybeRequestedSchema.map(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ - maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) - - logInfo(s"Going to read Parquet file with these requested columns: $parquetRequestedSchema") - new ReadContext(parquetRequestedSchema, metadata) - } -} - -private[parquet] object RowReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" - - private def getRequestedSchema(configuration: Configuration): Seq[Attribute] = { - val schemaString = configuration.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) - if (schemaString == null) null else ParquetTypesConverter.convertFromString(schemaString) - } -} - /** * A `parquet.hadoop.api.WriteSupport` for Row objects. */ @@ -190,7 +45,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo override def init(configuration: Configuration): WriteSupport.WriteContext = { val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) val metadata = new JHashMap[String, String]() - metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) + metadata.put(CatalystReadSupport.SPARK_METADATA_KEY, origAttributesStr) if (attributes == null) { attributes = ParquetTypesConverter.convertFromString(origAttributesStr).toArray @@ -208,18 +63,18 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo override def write(record: InternalRow): Unit = { val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") } var index = 0 writer.startMessage() while(index < attributesSize) { // null values indicate optional fields but we do not check currently - if (record(index) != null) { + if (!record.isNullAt(index)) { writer.startField(attributes(index).name, index) - writeValue(attributes(index).dataType, record(index)) + writeValue(attributes(index).dataType, record.get(index, attributes(index).dataType)) writer.endField(attributes(index).name, index) } index = index + 1 @@ -260,11 +115,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) + case DecimalType.Fixed(precision, _) => + writeDecimal(value.asInstanceOf[Decimal], precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -277,10 +129,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo val fields = schema.fields.toArray writer.startGroup() var i = 0 - while(i < fields.size) { - if (struct(i) != null) { + while(i < fields.length) { + if (!struct.isNullAt(i)) { writer.startField(fields(i).name, i) - writeValue(fields(i).dataType, struct(i)) + writeValue(fields(i).dataType, struct.get(i, fields(i).dataType)) writer.endField(fields(i).name, i) } i = i + 1 @@ -294,15 +146,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo array: CatalystConverter.ArrayScalaType[_]): Unit = { val elementType = schema.elementType writer.startGroup() - if (array.size > 0) { + if (array.numElements() > 0) { if (schema.containsNull) { writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { + while (i < array.numElements()) { writer.startGroup() - if (array(i) != null) { + if (!array.isNullAt(i)) { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array(i)) + writeValue(elementType, array.get(i)) writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } writer.endGroup() @@ -312,8 +164,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo } else { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { - writeValue(elementType, array(i)) + while (i < array.numElements()) { + writeValue(elementType, array.get(i)) i = i + 1 } writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) @@ -345,20 +197,47 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.endGroup() } - // Scratch array used to write decimals as fixed-length binary - private[this] val scratchBytes = new Array[Byte](8) + // Scratch array used to write decimals as fixed-length byte array + private[this] var reusableDecimalBytes = new Array[Byte](16) private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { - val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) - val unscaledLong = decimal.toUnscaledLong - var i = 0 - var shift = 8 * (numBytes - 1) - while (i < numBytes) { - scratchBytes(i) = (unscaledLong >> shift).toByte - i += 1 - shift -= 8 + val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision) + + def longToBinary(unscaled: Long): Binary = { + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + reusableDecimalBytes(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) } - writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) + + def bigIntegerToBinary(unscaled: BigInteger): Binary = { + unscaled.toByteArray match { + case bytes if bytes.length == numBytes => + Binary.fromByteArray(bytes) + + case bytes if bytes.length <= reusableDecimalBytes.length => + val signedByte = (if (bytes.head < 0) -1 else 0).toByte + java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) + System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length) + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) + + case bytes => + reusableDecimalBytes = new Array[Byte](bytes.length) + bigIntegerToBinary(unscaled) + } + } + + val binary = if (numBytes <= 8) { + longToBinary(decimal.toUnscaledLong) + } else { + bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue()) + } + + writer.addBinary(binary) } // array used to write Timestamp as Int96 (fixed-length binary) @@ -378,16 +257,16 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] class MutableRowWriteSupport extends RowWriteSupport { override def write(record: InternalRow): Unit = { val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") } var index = 0 writer.startMessage() while(index < attributesSize) { // null values indicate optional fields but we do not check currently - if (record(index) != null && record(index) != Nil) { + if (!record.isNullAt(index) && !record.isNullAt(index)) { writer.startField(attributes(index).name, index) consumeType(attributes(index).dataType, record, index) writer.endField(attributes(index).name, index) @@ -410,15 +289,12 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case TimestampType => writeTimestamp(record.getLong(index)) case FloatType => writer.addFloat(record.getFloat(index)) case DoubleType => writer.addDouble(record.getDouble(index)) - case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) - case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) + case StringType => + writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) + case BinaryType => + writer.addBinary(Binary.fromByteArray(record.getBinary(index))) + case DecimalType.Fixed(precision, scale) => + writeDecimal(record.getDecimal(index, precision, scale), precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } @@ -443,4 +319,3 @@ private[parquet] object RowWriteSupport { ParquetProperties.WriterVersion.PARQUET_1_0.toString) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e748bd7857bd8..3854f5bd39fb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -53,15 +53,6 @@ private[parquet] object ParquetTypesConverter extends Logging { length } - def convertToAttributes( - parquetSchema: MessageType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - val converter = new CatalystSchemaConverter( - isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec = false) - converter.convert(parquetSchema).toAttributes - } - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { val converter = new CatalystSchemaConverter() converter.convert(StructType.fromAttributes(attributes)) @@ -103,7 +94,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } val extraMetadata = new java.util.HashMap[String, String]() extraMetadata.put( - RowReadSupport.SPARK_METADATA_KEY, + CatalystReadSupport.SPARK_METADATA_KEY, ParquetTypesConverter.convertToString(attributes)) // TODO: add extra data, e.g., table name, date, etc.? @@ -165,35 +156,4 @@ private[parquet] object ParquetTypesConverter extends Logging { .getOrElse( throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) } - - /** - * Reads in Parquet Metadata from the given path and tries to extract the schema - * (Catalyst attributes) from the application-specific key-value map. If this - * is empty it falls back to converting from the Parquet file schema which - * may lead to an upcast of types (e.g., {byte, short} to int). - * - * @param origPath The path at which we expect one (or more) Parquet files. - * @param conf The Hadoop configuration to use. - * @return A list of attributes that make up the schema. - */ - def readSchemaFromFile( - origPath: Path, - conf: Option[Configuration], - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - val keyValueMetadata: java.util.Map[String, String] = - readMetaData(origPath, conf) - .getFileMetaData - .getKeyValueMetaData - if (keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { - convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) - } else { - val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema, - isBinaryAsString, - isInt96AsTimestamp) - log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") - attributes - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala deleted file mode 100644 index 2f9f880c70690..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ /dev/null @@ -1,722 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.net.URI -import java.util.{List => JList} - -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.util.{Failure, Try} - -import com.google.common.base.Objects -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.schema.MessageType - -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDD._ -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} -import org.apache.spark.sql.execution.datasources.PartitionSpec -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} - - -private[sql] class DefaultSource extends HadoopFsRelationProvider { - override def createRelation( - sqlContext: SQLContext, - paths: Array[String], - schema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation2(paths, schema, None, partitionColumns, parameters)(sqlContext) - } -} - -// NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriter { - - private val recordWriter: RecordWriter[Void, InternalRow] = { - val outputFormat = { - new ParquetOutputFormat[InternalRow]() { - // Here we override `getDefaultWorkFile` for two reasons: - // - // 1. To allow appending. We need to generate unique output file names to avoid - // overwriting existing files (either exist before the write job, or are just written - // by other tasks within the same write job). - // - // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses - // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all - // partitions in the case of dynamic partitioning. - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") - } - } - } - - outputFormat.getRecordWriter(context) - } - - override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow]) - - override def close(): Unit = recordWriter.close(context) -} - -private[sql] class ParquetRelation2( - override val paths: Array[String], - private val maybeDataSchema: Option[StructType], - // This is for metastore conversion. - private val maybePartitionSpec: Option[PartitionSpec], - override val userDefinedPartitionColumns: Option[StructType], - parameters: Map[String, String])( - val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) - with Logging { - - private[sql] def this( - paths: Array[String], - maybeDataSchema: Option[StructType], - maybePartitionSpec: Option[PartitionSpec], - parameters: Map[String, String])( - sqlContext: SQLContext) = { - this( - paths, - maybeDataSchema, - maybePartitionSpec, - maybePartitionSpec.map(_.partitionColumns), - parameters)(sqlContext) - } - - // Should we merge schemas from all Parquet part-files? - private val shouldMergeSchemas = - parameters - .get(ParquetRelation2.MERGE_SCHEMA) - .map(_.toBoolean) - .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) - - private val maybeMetastoreSchema = parameters - .get(ParquetRelation2.METASTORE_SCHEMA) - .map(DataType.fromJson(_).asInstanceOf[StructType]) - - private lazy val metadataCache: MetadataCache = { - val meta = new MetadataCache - meta.refresh() - meta - } - - override def equals(other: Any): Boolean = other match { - case that: ParquetRelation2 => - val schemaEquality = if (shouldMergeSchemas) { - this.shouldMergeSchemas == that.shouldMergeSchemas - } else { - this.dataSchema == that.dataSchema && - this.schema == that.schema - } - - this.paths.toSet == that.paths.toSet && - schemaEquality && - this.maybeDataSchema == that.maybeDataSchema && - this.partitionColumns == that.partitionColumns - - case _ => false - } - - override def hashCode(): Int = { - if (shouldMergeSchemas) { - Objects.hashCode( - Boolean.box(shouldMergeSchemas), - paths.toSet, - maybeDataSchema, - partitionColumns) - } else { - Objects.hashCode( - Boolean.box(shouldMergeSchemas), - paths.toSet, - dataSchema, - schema, - maybeDataSchema, - partitionColumns) - } - } - - /** Constraints on schema of dataframe to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to parquet format") - } - } - - override def dataSchema: StructType = { - val schema = maybeDataSchema.getOrElse(metadataCache.dataSchema) - // check if schema satisfies the constraints - // before moving forward - checkConstraints(schema) - schema - } - - override private[sql] def refresh(): Unit = { - super.refresh() - metadataCache.refresh() - } - - // Parquet data source always uses Catalyst internal representations. - override val needConversion: Boolean = false - - override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum - - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = ContextUtil.getConfiguration(job) - - val committerClass = - conf.getClass( - SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, - classOf[ParquetOutputCommitter], - classOf[ParquetOutputCommitter]) - - if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { - logInfo("Using default output committer for Parquet: " + - classOf[ParquetOutputCommitter].getCanonicalName) - } else { - logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) - } - - conf.setClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, - committerClass, - classOf[ParquetOutputCommitter]) - - // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override - // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why - // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is - // bundled with `ParquetOutputFormat[Row]`. - job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - - // TODO There's no need to use two kinds of WriteSupport - // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and - // complex types. - val writeSupportClass = - if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) - RowWriteSupport.setSchema(dataSchema.toAttributes, conf) - - // Sets compression scheme - conf.set( - ParquetOutputFormat.COMPRESSION, - ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, - CompressionCodecName.UNCOMPRESSED).name()) - - new OutputWriterFactory { - override def newInstance( - path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) - } - } - } - - override def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) - val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown - val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString - val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec - - // Create the function to set variable Parquet confs at both driver and executor side. - val initLocalJobFuncOpt = - ParquetRelation2.initializeLocalJobFunc( - requiredColumns, - filters, - dataSchema, - useMetadataCache, - parquetFilterPushDown, - assumeBinaryIsString, - assumeInt96IsTimestamp, - followParquetFormatSpec) _ - - // Create the function to set input paths at the driver side. - val setInputPaths = ParquetRelation2.initializeDriverSideJobFunc(inputFiles) _ - - Utils.withDummyCallSite(sqlContext.sparkContext) { - new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, - broadcastedConf = broadcastedConf, - initDriverSideJobFuncOpt = Some(setInputPaths), - initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = classOf[ParquetInputFormat[InternalRow]], - keyClass = classOf[Void], - valueClass = classOf[InternalRow]) { - - val cacheMetadata = useMetadataCache - - @transient val cachedStatuses = inputFiles.map { f => - // In order to encode the authority of a Path containing special characters such as '/' - // (which does happen in some S3N credentials), we need to use the string returned by the - // URI of the path to create a new Path. - val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) - new FileStatus( - f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) - }.toSeq - - private def escapePathUserInfo(path: Path): Path = { - val uri = path.toUri - new Path(new URI( - uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, - uri.getQuery, uri.getFragment)) - } - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = new ParquetInputFormat[InternalRow] { - override def listStatus(jobContext: JobContext): JList[FileStatus] = { - if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) - } - } - - val jobContext = newJobContext(getConf(isDriverSide = true), jobId) - val rawSplits = inputFormat.getSplits(jobContext) - - Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - } - }.values.map(_.asInstanceOf[Row]) - } - } - - private class MetadataCache { - // `FileStatus` objects of all "_metadata" files. - private var metadataStatuses: Array[FileStatus] = _ - - // `FileStatus` objects of all "_common_metadata" files. - private var commonMetadataStatuses: Array[FileStatus] = _ - - // `FileStatus` objects of all data files (Parquet part-files). - var dataStatuses: Array[FileStatus] = _ - - // Schema of the actual Parquet files, without partition columns discovered from partition - // directory paths. - var dataSchema: StructType = null - - // Schema of the whole table, including partition columns. - var schema: StructType = _ - - /** - * Refreshes `FileStatus`es, footers, partition spec, and table schema. - */ - def refresh(): Unit = { - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = cachedLeafStatuses().filter { f => - isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - }.toArray - - dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) - metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) - commonMetadataStatuses = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - - // If we already get the schema, don't need to re-compute it since the schema merging is - // time-consuming. - if (dataSchema == null) { - dataSchema = { - val dataSchema0 = maybeDataSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(throw new AnalysisException( - s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + - paths.mkString("\n\t"))) - - // If this Parquet relation is converted from a Hive Metastore table, must reconcile case - // case insensitivity issue and possible schema mismatch (probably caused by schema - // evolution). - maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) - .getOrElse(dataSchema0) - } - } - } - - private def isSummaryFile(file: Path): Boolean = { - file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || - file.getName == ParquetFileWriter.PARQUET_METADATA_FILE - } - - private def readSchema(): Option[StructType] = { - // Sees which file(s) we need to touch in order to figure out the schema. - // - // Always tries the summary files first if users don't require a merged schema. In this case, - // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row - // groups information, and could be much smaller for large Parquet files with lots of row - // groups. If no summary file is available, falls back to some random part-file. - // - // NOTE: Metadata stored in the summary files are merged from all part-files. However, for - // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know - // how to merge them correctly if some key is associated with different values in different - // part-files. When this happens, Parquet simply gives up generating the summary file. This - // implies that if a summary file presents, then: - // - // 1. Either all part-files have exactly the same Spark SQL schema, or - // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus - // their schemas may differ from each other). - // - // Here we tend to be pessimistic and take the second case into account. Basically this means - // we can't trust the summary files if users require a merged schema, and must touch all part- - // files to do the merge. - val filesToTouch = - if (shouldMergeSchemas) { - // Also includes summary files, 'cause there might be empty partition directories. - (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq - } else { - // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet - // don't have this. - commonMetadataStatuses.headOption - // Falls back to "_metadata" - .orElse(metadataStatuses.headOption) - // Summary file(s) not found, the Parquet file is either corrupted, or different part- - // files contain conflicting user defined metadata (two or more values are associated - // with a same key in different files). In either case, we fall back to any of the - // first part-file, and just assume all schemas are consistent. - .orElse(dataStatuses.headOption) - .toSeq - } - - assert( - filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, - "No predefined schema found, " + - s"and no Parquet data files or summary files found under ${paths.mkString(", ")}.") - - ParquetRelation2.mergeSchemasInParallel(filesToTouch, sqlContext) - } - } -} - -private[sql] object ParquetRelation2 extends Logging { - // Whether we should merge schemas collected from all Parquet part-files. - private[sql] val MERGE_SCHEMA = "mergeSchema" - - // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used - // internally. - private[sql] val METASTORE_SCHEMA = "metastoreSchema" - - /** This closure sets various Parquet configurations at both driver side and executor side. */ - private[parquet] def initializeLocalJobFunc( - requiredColumns: Array[String], - filters: Array[Filter], - dataSchema: StructType, - useMetadataCache: Boolean, - parquetFilterPushDown: Boolean, - assumeBinaryIsString: Boolean, - assumeInt96IsTimestamp: Boolean, - followParquetFormatSpec: Boolean)(job: Job): Unit = { - val conf = job.getConfiguration - conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName) - - // Try to push down filters when filter push-down is enabled. - if (parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(dataSchema, _)) - .reduceOption(FilterApi.and) - .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) - } - - conf.set(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { - val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) - ParquetTypesConverter.convertToString(requestedSchema.toAttributes) - }) - - conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(dataSchema.toAttributes)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) - - // Sets flags for Parquet schema conversion - conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) - conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) - conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) - } - - /** This closure sets input paths at the driver side. */ - private[parquet] def initializeDriverSideJobFunc( - inputFiles: Array[FileStatus])(job: Job): Unit = { - // We side the input paths at the driver side. - logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") - if (inputFiles.nonEmpty) { - FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) - } - } - - private[parquet] def readSchema( - footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { - - def parseParquetSchema(schema: MessageType): StructType = { - StructType.fromAttributes( - // TODO Really no need to use `Attribute` here, we only need to know the data type. - ParquetTypesConverter.convertToAttributes( - schema, - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp)) - } - - val seen = mutable.HashSet[String]() - val finalSchemas: Seq[StructType] = footers.flatMap { footer => - val metadata = footer.getParquetMetadata.getFileMetaData - val serializedSchema = metadata - .getKeyValueMetaData - .toMap - .get(RowReadSupport.SPARK_METADATA_KEY) - if (serializedSchema.isEmpty) { - // Falls back to Parquet schema if no Spark SQL schema found. - Some(parseParquetSchema(metadata.getSchema)) - } else if (!seen.contains(serializedSchema.get)) { - seen += serializedSchema.get - - // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to - // whatever is available. - Some(Try(DataType.fromJson(serializedSchema.get)) - .recover { case _: Throwable => - logInfo( - s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + - "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(serializedSchema.get) - } - .recover { case cause: Throwable => - logWarning( - s"""Failed to parse serialized Spark schema in Parquet key-value metadata: - |\t$serializedSchema - """.stripMargin, - cause) - } - .map(_.asInstanceOf[StructType]) - .getOrElse { - // Falls back to Parquet schema if Spark SQL schema can't be parsed. - parseParquetSchema(metadata.getSchema) - }) - } else { - None - } - } - - finalSchemas.reduceOption { (left, right) => - try left.merge(right) catch { case e: Throwable => - throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) - } - } - } - - /** - * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore - * schema and Parquet schema. - * - * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the - * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't - * distinguish binary and string). This method generates a correct schema by merging Metastore - * schema data types and Parquet schema field names. - */ - private[parquet] def mergeMetastoreParquetSchema( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - def schemaConflictMessage: String = - s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: - |${metastoreSchema.prettyJson} - | - |Parquet schema: - |${parquetSchema.prettyJson} - """.stripMargin - - val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) - - assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) - - val ordinalMap = metastoreSchema.zipWithIndex.map { - case (field, index) => field.name.toLowerCase -> index - }.toMap - - val reorderedParquetSchema = mergedParquetSchema.sortBy(f => - ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) - - StructType(metastoreSchema.zip(reorderedParquetSchema).map { - // Uses Parquet field names but retains Metastore data types. - case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => - mSchema.copy(name = pSchema.name) - case _ => - throw new SparkException(schemaConflictMessage) - }) - } - - /** - * Returns the original schema from the Parquet file with any missing nullable fields from the - * Hive Metastore schema merged in. - * - * When constructing a DataFrame from a collection of structured data, the resulting object has - * a schema corresponding to the union of the fields present in each element of the collection. - * Spark SQL simply assigns a null value to any field that isn't present for a particular row. - * In some cases, it is possible that a given table partition stored as a Parquet file doesn't - * contain a particular nullable field in its schema despite that field being present in the - * table schema obtained from the Hive Metastore. This method returns a schema representing the - * Parquet file schema along with any additional nullable fields from the Metastore schema - * merged in. - */ - private[parquet] def mergeMissingNullableFields( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap - val missingFields = metastoreSchema - .map(_.name.toLowerCase) - .diff(parquetSchema.map(_.name.toLowerCase)) - .map(fieldMap(_)) - .filter(_.nullable) - StructType(parquetSchema ++ missingFields) - } - - /** - * Figures out a merged Parquet schema with a distributed Spark job. - * - * Note that locality is not taken into consideration here because: - * - * 1. For a single Parquet part-file, in most cases the footer only resides in the last block of - * that file. Thus we only need to retrieve the location of the last block. However, Hadoop - * `FileSystem` only provides API to retrieve locations of all blocks, which can be - * potentially expensive. - * - * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty - * slow. And basically locality is not available when using S3 (you can't run computation on - * S3 nodes). - */ - def mergeSchemasInParallel( - filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { - val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString - val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec - val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) - - // HACK ALERT: - // - // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es - // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` - // but only `Writable`. What makes it worth, for some reason, `FileStatus` doesn't play well - // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These - // facts virtually prevents us to serialize `FileStatus`es. - // - // Since Parquet only relies on path and length information of those `FileStatus`es to read - // footers, here we just extract them (which can be easily serialized), send them to executor - // side, and resemble fake `FileStatus`es there. - val partialFileStatusInfo = filesToTouch.map(f => (f.getPath.toString, f.getLen)) - - // Issues a Spark job to read Parquet schema in parallel. - val partiallyMergedSchemas = - sqlContext - .sparkContext - .parallelize(partialFileStatusInfo) - .mapPartitions { iterator => - // Resembles fake `FileStatus`es with serialized path and length information. - val fakeFileStatuses = iterator.map { case (path, length) => - new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) - }.toSeq - - // Skips row group information since we only need the schema - val skipRowGroups = true - - // Reads footers in multi-threaded manner within each task - val footers = - ParquetFileReader.readAllFootersInParallel( - serializedConf.value, fakeFileStatuses, skipRowGroups) - - // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` - val converter = - new CatalystSchemaConverter( - assumeBinaryIsString = assumeBinaryIsString, - assumeInt96IsTimestamp = assumeInt96IsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) - - footers.map { footer => - ParquetRelation2.readSchemaFromFooter(footer, converter) - }.reduceOption(_ merge _).iterator - }.collect() - - partiallyMergedSchemas.reduceOption(_ merge _) - } - - /** - * Reads Spark SQL schema from a Parquet footer. If a valid serialized Spark SQL schema string - * can be found in the file metadata, returns the deserialized [[StructType]], otherwise, returns - * a [[StructType]] converted from the [[MessageType]] stored in this footer. - */ - def readSchemaFromFooter( - footer: Footer, converter: CatalystSchemaConverter): StructType = { - val fileMetaData = footer.getParquetMetadata.getFileMetaData - fileMetaData - .getKeyValueMetaData - .toMap - .get(RowReadSupport.SPARK_METADATA_KEY) - .flatMap(deserializeSchemaString) - .getOrElse(converter.convert(fileMetaData.getSchema)) - } - - private def deserializeSchemaString(schemaString: String): Option[StructType] = { - // Tries to deserialize the schema string as JSON first, then falls back to the case class - // string parser (data generated by older versions of Spark SQL uses this format). - Try(DataType.fromJson(schemaString).asInstanceOf[StructType]).recover { - case _: Throwable => - logInfo( - s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + - "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] - }.recoverWith { - case cause: Throwable => - logWarning( - "Failed to parse and ignored serialized Spark schema in " + - s"Parquet key-value metadata:\n\t$schemaString", cause) - Failure(cause) - }.toOption - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7cd005b959488..7126145ddc010 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -28,7 +28,7 @@ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.RDDConversions @@ -344,6 +344,18 @@ abstract class OutputWriter { def close(): Unit } +/** + * This is an internal, private version of [[OutputWriter]] with an writeInternal method that + * accepts an [[InternalRow]] rather than an [[Row]]. Data sources that return this must have + * the conversion flag set to false. + */ +private[sql] abstract class OutputWriterInternal extends OutputWriter { + + override def write(row: Row): Unit = throw new UnsupportedOperationException + + def writeInternal(row: InternalRow): Unit +} + /** * ::Experimental:: * A [[BaseRelation]] that provides much of the common code required for formats that store their @@ -581,6 +593,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio * * @since 1.4.0 */ + // TODO Tries to eliminate the extra Catalyst-to-Scala conversion when `needConversion` is true + // + // PR #7626 separated `Row` and `InternalRow` completely. One of the consequences is that we can + // no longer treat an `InternalRow` containing Catalyst values as a `Row`. Thus we have to + // introduce another row value conversion for data sources whose `needConversion` is true. def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { // Yeah, to workaround serialization... val dataSchema = this.dataSchema @@ -592,22 +609,34 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) }.toSeq - val rdd = buildScan(inputFiles) - val converted = + val rdd: RDD[Row] = buildScan(inputFiles) + val converted: RDD[InternalRow] = if (needConversion) { RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } + converted.mapPartitions { rows => val buildProjection = if (codegenEnabled) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) } - val mutableProjection = buildProjection() - rows.map(r => mutableProjection(r).asInstanceOf[Row]) - } + + val projectedRows = { + val mutableProjection = buildProjection() + rows.map(r => mutableProjection(r)) + } + + if (needConversion) { + val requiredSchema = StructType(requiredColumns.map(dataSchema(_))) + val toScala = CatalystTypeConverters.createToScalaConverter(requiredSchema) + projectedRows.map(toScala(_).asInstanceOf[Row]) + } else { + projectedRows + } + }.asInstanceOf[RDD[Row]] } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index fcb8f5499cf84..cb84e78d628ca 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -22,7 +22,6 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.sql.test.TestSQLContext$; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -31,8 +30,14 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.*; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -159,7 +164,8 @@ public void applySchemaToJSON() { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List fields = new ArrayList(7); - fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18), + true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 72c42f4fe376b..2c669bb59a0b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -30,7 +30,6 @@ import scala.collection.JavaConversions; import scala.collection.Seq; -import scala.collection.mutable.Buffer; import java.io.Serializable; import java.util.Arrays; @@ -168,10 +167,10 @@ public void testCreateDataFrameFromJavaBeans() { for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } - Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); + Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer))); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { @@ -227,4 +226,13 @@ public void testCovariance() { Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1e-6); } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Assert.assertArrayEquals(expected, actual); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1f9f7118c3f04..eb64684ae0fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,16 +19,19 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest { +class ColumnExpressionSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -489,6 +492,18 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("InputFileName") { + withTempPath { dir => + val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + .head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(inputFileName()).limit(1), Row("")) + } + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, @@ -523,6 +538,7 @@ class ColumnExpressionSuite extends QueryTest { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { case project: Project => project + case tungstenProject: TungstenProject => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b26d3ab253a1d..228ece8065151 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{BinaryType, DecimalType} class DataFrameAggregateSuite extends QueryTest { @@ -191,4 +191,13 @@ class DataFrameAggregateSuite extends QueryTest { Row(null)) } + test("aggregation can't work on binary type") { + val df = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + intercept[AnalysisException] { + df.groupBy("c").agg(count("*")) + } + intercept[AnalysisException] { + df.distinct + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7ba4ba73e0cc9..07a675e64f527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -21,9 +21,9 @@ import java.util.Random import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 5), Row(1, 8))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f67f2c60c0e16..97beae2f85c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -23,40 +23,38 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ +import org.apache.spark.sql.json.JSONRelation +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} class DataFrameSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - def sqlContext: SQLContext = ctx + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ test("analysis error should be eagerly reported") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) - } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) - } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + } + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) + } } // No more eager analysis once the flag is turned off - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - testData.select('nonExistentName) - - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { + testData.select('nonExistentName) + } } test("dataframe toString") { @@ -74,21 +72,18 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("invalid plan toString, debug mode") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - ctx.debug() - val badPlan = testData.select('badColumn) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + sqlContext.debug() - assert(badPlan.toString contains badPlan.queryExecution.toString, - "toString on bad query plans should include the query execution but was:\n" + - badPlan.toString) + val badPlan = testData.select('badColumn) - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + assert(badPlan.toString contains badPlan.queryExecution.toString, + "toString on bad query plans should include the query execution but was:\n" + + badPlan.toString) + } } test("access complex data") { @@ -104,8 +99,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("empty data frame") { - assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(ctx.emptyDataFrame.count() === 0) + assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(sqlContext.emptyDataFrame.count() === 0) } test("head and take") { @@ -341,7 +336,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("replace column using withColumn") { - val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -422,7 +417,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("randomSplit") { val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -491,6 +486,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) } + test("inputFiles") { + val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"), + Some(testData.schema), None, Map.empty)(sqlContext) + val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1)) + assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet) + + val fakeRelation2 = new JSONRelation("/json/path", 1, Some(testData.schema), sqlContext) + val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2)) + assert(df2.inputFiles.toSet == fakeRelation2.path.toSet) + + val unionDF = df1.unionAll(df2) + assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path) + + val filtered = df1.filter("false").unionAll(df2.intersect(df2)) + assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path) + } + ignore("show") { // This test case is intended ignored, but to make sure it compiles correctly testData.select($"*").show() @@ -499,7 +511,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("showString: truncate = [true, false]") { val longString = Array.fill(21)("1").mkString - val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ ||_1 | |+---------------------+ @@ -589,21 +601,17 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() } - test("SPARK-6899") { - val originalValue = ctx.conf.codegenEnabled - ctx.setConf(SQLConf.CODEGEN_ENABLED, true) - try{ + test("SPARK-6899: type should match when using codegen") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) - } finally { - ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -615,14 +623,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = ctx.read.json(ctx.sparkContext.makeRDD( + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = ctx.read.json(ctx.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -642,7 +650,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7324 dropDuplicates") { - val testData = ctx.sparkContext.parallelize( + val testData = sqlContext.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -690,49 +698,49 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = ctx.range(0, 10, 1, 15).select("id") + val res1 = sqlContext.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = ctx.range(3, 15, 3, 2).select("id") + val res2 = sqlContext.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = ctx.range(1, -2).select("id") + val res3 = sqlContext.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = ctx.range(1, -2, -2, 6).select("id") + val res4 = sqlContext.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = ctx.range(-3, -8, -2, 1).select("id") + val res5 = sqlContext.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = ctx.range(-8, -4, 2, 1).select("id") + val res6 = sqlContext.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = ctx.range(-10, -9, -20, 1).select("id") + val res7 = sqlContext.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = ctx.range(10).select("id") + val res10 = sqlContext.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = ctx.range(-1).select("id") + val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) } @@ -799,13 +807,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { // pass case: parquet table (HadoopFsRelation) df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = ctx.read.parquet(tempParquetFile.getCanonicalPath) + val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) pdf.registerTempTable("parquet_base") insertion.write.insertInto("parquet_base") // pass case: json table (InsertableRelation) df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = ctx.read.json(tempJsonFile.getCanonicalPath) + val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) jdf.registerTempTable("json_base") insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") @@ -825,11 +833,54 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - new DataFrame(ctx, OneRowRelation).registerTempTable("one_row") + new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) } } + + test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + test("SPARK-8609: local DataFrame with random columns should return same value after sort") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF() + checkAnswer(df.sort(rand(33)), df.sort(rand(33))) + } + + test("SPARK-9083: sort with non-deterministic expressions") { + import org.apache.spark.util.random.XORShiftRandom + + val seed = 33 + val df = (1 to 100).map(Tuple1.apply).toDF("i") + val random = new XORShiftRandom(seed) + val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1) + val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) + assert(expected === actual) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala new file mode 100644 index 0000000000000..bf8ef9a97bc60 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +/** + * An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode. + * + * This is here for now so I can make sure Tungsten project is tested without refactoring existing + * end-to-end test infra. In the long run this should just go away. + */ +class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { + + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + + test("test simple types") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + } + } + + test("test struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) + } + } + + test("test nested struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 9e80ae86920d9..b7267c413165a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -20,13 +20,36 @@ package org.apache.spark.sql import java.sql.{Timestamp, Date} import java.text.SimpleDateFormat +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.unsafe.types.CalendarInterval class DateFunctionsSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + test("function current_date") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) + val d2 = DateTimeUtils.fromJavaDate( + ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) + } + + // This is a bad test. SPARK-9196 will fix it and re-enable it. + ignore("function current_timestamp") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value + checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + Row(true)) + assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + 0).getTime - System.currentTimeMillis()) < 5000) + } + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val sdfDate = new SimpleDateFormat("yyyy-MM-dd") val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) @@ -184,4 +207,198 @@ class DateFunctionsSuite extends QueryTest { Row(15, 15, 15)) } + test("function date_add") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_add(col("d"), 1)), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + checkAnswer( + df.select(date_add(col("t"), 3)), + Seq(Row(Date.valueOf("2015-06-04")), Row(Date.valueOf("2015-06-05")))) + checkAnswer( + df.select(date_add(col("s"), 5)), + Seq(Row(Date.valueOf("2015-06-06")), Row(Date.valueOf("2015-06-07")))) + checkAnswer( + df.select(date_add(col("ss"), 7)), + Seq(Row(Date.valueOf("2015-06-08")), Row(Date.valueOf("2015-06-09")))) + + checkAnswer(df.selectExpr("DATE_ADD(null, 1)"), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_ADD(d, 1)"""), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + } + + test("function date_sub") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_sub(col("d"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("t"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("s"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("ss"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(lit(null), 1)).limit(1), Row(null)) + + checkAnswer(df.selectExpr("""DATE_SUB(d, null)"""), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_SUB(d, 1)"""), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + } + + test("time_add") { + val t1 = Timestamp.valueOf("2015-07-31 23:59:59") + val t2 = Timestamp.valueOf("2015-12-31 00:00:00") + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-12-31") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d + $i"), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2016-02-29")))) + checkAnswer( + df.selectExpr(s"t + $i"), + Seq(Row(Timestamp.valueOf("2015-10-01 00:00:01")), + Row(Timestamp.valueOf("2016-02-29 00:00:02")))) + } + + test("time_sub") { + val t1 = Timestamp.valueOf("2015-10-01 00:00:01") + val t2 = Timestamp.valueOf("2016-02-29 00:00:02") + val d1 = Date.valueOf("2015-09-30") + val d2 = Date.valueOf("2016-02-29") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d - $i"), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30")))) + checkAnswer( + df.selectExpr(s"t - $i"), + Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")), + Row(Timestamp.valueOf("2015-12-31 00:00:00")))) + } + + test("function add_months") { + val d1 = Date.valueOf("2015-08-31") + val d2 = Date.valueOf("2015-02-28") + val df = Seq((1, d1), (2, d2)).toDF("n", "d") + checkAnswer( + df.select(add_months(col("d"), 1)), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31")))) + checkAnswer( + df.selectExpr("add_months(d, -1)"), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31")))) + } + + test("function months_between") { + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-02-16") + val t1 = Timestamp.valueOf("2014-09-30 23:30:00") + val t2 = Timestamp.valueOf("2015-09-16 12:00:00") + val s1 = "2014-09-15 11:30:00" + val s2 = "2015-10-01 00:00:00" + val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") + checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + } + + test("function last_day") { + val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d") + val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t") + checkAnswer( + df1.select(last_day(col("d"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + checkAnswer( + df2.select(last_day(col("t"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + } + + test("function next_day") { + val df1 = Seq(("mon", "2015-07-23"), ("tuesday", "2015-07-20")).toDF("dow", "d") + val df2 = Seq(("th", "2015-07-23 00:11:22"), ("xx", "2015-07-24 11:22:33")).toDF("dow", "t") + checkAnswer( + df1.select(next_day(col("d"), "MONDAY")), + Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-07-27")))) + checkAnswer( + df2.select(next_day(col("t"), "th")), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) + } + + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd HH-mm-ss" + val sdf3 = new SimpleDateFormat(fmt3) + val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") + checkAnswer( + df.select(from_unixtime(col("a"))), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt2)), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt3)), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr("from_unixtime(a)"), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt2')"), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt3')"), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + } + + test("unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala deleted file mode 100644 index 44b915304533c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions._ - -class DatetimeExpressionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - import ctx.implicits._ - - lazy val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") - - test("function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) - val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) - val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) - } - - test("function current_timestamp") { - checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) - // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), - Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( - 0).getTime - System.currentTimeMillis()) < 5000) - } - -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8953889d1fae9..27c08f64649ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.types.BinaryType class JoinSuite extends QueryTest with BeforeAndAfterEach { @@ -79,9 +80,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashOuterJoin]), @@ -108,6 +109,18 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } } + test("SortMergeJoin shouldn't work on unsortable columns") { + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + try { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + Seq( + ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + } + } + test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") @@ -477,4 +490,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(3, 2) :: Nil) } + + test("Join can't work on binary type") { + val left = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + val right = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("d").select($"d" cast BinaryType) + intercept[AnalysisException] { + left.join(right, ($"left.N" === $"right.N"), "full") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 21256704a5b16..8cf2ef5957d8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest { checkAnswer( ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), - Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 7cc6ffd7548d0..01b7c21e84159 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -30,23 +30,24 @@ class RowSuite extends SparkFunSuite { test("create row") { val expected = new GenericMutableRow(4) - expected.update(0, 2147483647) + expected.setInt(0, 2147483647) expected.setString(1, "this is a string") - expected.update(2, false) - expected.update(3, null) + expected.setBoolean(2, false) + expected.setNullAt(3) + val actual1 = Row(2147483647, "this is a string", false, null) - assert(expected.size === actual1.size) + assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) - assert(expected(3) === actual1(3)) + assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) - assert(expected.size === actual2.size) + assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) - assert(expected(3) === actual2(3)) + assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ab8dce603c117..535011fe3db5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.scalatest.BeforeAndAfterAll import java.sql.Timestamp @@ -58,6 +59,31 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(queryCoalesce, Row("1") :: Nil) } + test("show functions") { + checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + } + + test("describe functions") { + checkExistence(sql("describe function extended upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase", + "Extended Usage:", + "> SELECT upper('SparkSql');", + "'SPARKSQL'") + + checkExistence(sql("describe functioN Upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase") + + checkExistence(sql("describe functioN Upper"), false, + "Extended Usage") + + checkExistence(sql("describe functioN abcadf"), true, + "Function: abcadf is not found.") + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38), @@ -112,6 +138,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } + test("SPARK-8668 expr function") { + checkAnswer(Seq((1, "Bobby G.")) + .toDF("id", "name") + .select(expr("length(name)"), expr("abs(id)")), Row(8, 1)) + + checkAnswer(Seq((1, "building burrito tunnels"), (1, "major projects")) + .toDF("id", "saying") + .groupBy(expr("length(saying)")) + .count(), Row(24, 1) :: Row(14, 1) :: Nil) + } + test("SQL Dialect Switching to a new SQL parser") { val newContext = new SQLContext(sqlContext.sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) @@ -190,6 +227,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Seq(Row("1"), Row("2"))) } + test("SPARK-8828 sum should return null if all input values are null") { + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -300,7 +368,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(1)) checkAnswer( sql("SELECT COALESCE(null, 1, 1.5)"), - Row(1.toDouble)) + Row(BigDecimal(1))) checkAnswer( sql("SELECT COALESCE(null, null, null)"), Row(null)) @@ -426,12 +494,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("literal in agg grouping expressions") { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + def literalInAggTest(): Unit = { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) + } + + literalInAggTest() + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + literalInAggTest() + } } test("aggregates with nulls") { @@ -1149,19 +1234,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(0.3) + sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) ) checkAnswer( - sql("SELECT -0.8"), Row(-0.8) + sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) ) checkAnswer( - sql("SELECT .5"), Row(0.5) + sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - sql("SELECT -.18"), Row(-0.18) + sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } @@ -1194,11 +1279,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) checkAnswer( - sql("SELECT -5.2"), Row(-5.2) + sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(6.8) + sql("SELECT +6.8"), Row(BigDecimal(6.8)) ) checkAnswer( @@ -1492,10 +1577,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-8753: add interval type") { - import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.CalendarInterval val df = sql("select interval 3 years -3 month 7 week 123 microseconds") - checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) withTempPath(f => { // Currently we don't yet support saving out values of interval data type. val e = intercept[AnalysisException] { @@ -1517,19 +1602,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-8945: add and subtract expressions for interval type") { - import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.CalendarInterval + import org.apache.spark.unsafe.types.CalendarInterval.MICROS_PER_WEEK val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") - checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123))) + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) - checkAnswer(df.select(df("i") + new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123))) + checkAnswer(df.select(df("i") + new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) - checkAnswer(df.select(df("i") - new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123))) + checkAnswer(df.select(df("i") - new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) // unary minus checkAnswer(df.select(-df("i")), - Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))) + Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 5ea9f97609e36..b7f073cccb6ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -57,19 +57,27 @@ class StringFunctionsSuite extends QueryTest { } test("string regex_replace / regex_extract") { - val df = Seq(("100-200", "")).toDF("a", "b") + val df = Seq( + ("100-200", "(\\d+)-(\\d+)", "300"), + ("100-200", "(\\d+)-(\\d+)", "400"), + ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") checkAnswer( df.select( regexp_replace($"a", "(\\d+)", "num"), regexp_extract($"a", "(\\d+)-(\\d+)", 1)), - Row("num-num", "100")) - - checkAnswer( - df.selectExpr( - "regexp_replace(a, '(\\d+)', 'num')", - "regexp_extract(a, '(\\d+)-(\\d+)', 2)"), - Row("num-num", "200")) + Row("num-num", "100") :: Row("num-num", "100") :: Row("num-num", "100") :: Nil) + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection followed by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + checkAnswer( + df.filter("isnotnull(a)").selectExpr( + "regexp_replace(a, b, c)", + "regexp_extract(a, b, 1)"), + Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) } test("string ascii function") { @@ -137,7 +145,7 @@ class StringFunctionsSuite extends QueryTest { test("soundex function") { val df = Seq(("MARY", "SU")).toDF("l", "r") checkAnswer( - df.select(soundex("l"), soundex($"r")), Row("M600", "S000")) + df.select(soundex($"l"), soundex($"r")), Row("M600", "S000")) checkAnswer( df.selectExpr("SoundEx(l)", "SoundEx(r)"), Row("M600", "S000")) @@ -299,5 +307,15 @@ class StringFunctionsSuite extends QueryTest { df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable Row("5.0000")) } + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection follows by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b") + checkAnswer( + df2.filter("b>0").selectExpr("format_number(a, b)"), + Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 207d7a352c7b3..e340f54850bcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.sql.Timestamp - import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index c1516b450cbd4..183dc3407b3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { +class UDFSuite extends QueryTest with SQLTestUtils { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") @@ -51,6 +54,25 @@ class UDFSuite extends QueryTest { df.selectExpr("count(distinct a)") } + test("SPARK-8003 spark_partition_id") { + val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") + df.registerTempTable("tmp_table") + checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) + ctx.dropTempTable("tmp_table") + } + + test("SPARK-8005 input_file_name") { + withTempPath { dir => + val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + data.write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + assert(answer.contains(dir.getCanonicalPath)) + assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + ctx.dropTempTable("test_table") + } + } + test("error reporting for incorrect number of arguments") { val df = ctx.emptyDataFrame val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index d36e2639376e7..e72a1bc6c4e20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -53,8 +53,7 @@ class UnsafeRowSuite extends SparkFunSuite { offheapRowPage.getBaseObject, offheapRowPage.getBaseOffset, 3, // num fields - arrayBackedUnsafeRow.getSizeInBytes, - null // object pool + arrayBackedUnsafeRow.getSizeInBytes ) assert(offheapUnsafeRow.getBaseObject === null) val baos = new ByteArrayOutputStream() @@ -68,4 +67,19 @@ class UnsafeRowSuite extends SparkFunSuite { assert(bytesFromArrayBackedRow === bytesFromOffheapRow) } + + test("calling getDouble() and getFloat() on null columns") { + val row = InternalRow.apply(null, null) + val unsafeRow = UnsafeProjection.create(Array[DataType](FloatType, DoubleType)).apply(row) + assert(unsafeRow.getFloat(0) === row.getFloat(0)) + assert(unsafeRow.getDouble(1) === row.getDouble(1)) + } + + test("calling get(ordinal, datatype) on null columns") { + val row = InternalRow.apply(null) + val unsafeRow = UnsafeProjection.create(Array[DataType](NullType)).apply(row) + for (dataType <- DataTypeTestUtils.atomicTypes) { + assert(unsafeRow.get(0, dataType) === null) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 45c9f06941c10..77ed4a9c0d5ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -47,17 +47,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): ArrayData = { obj match { case features: MyDenseVector => - features.data.toSeq + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } } override def deserialize(datum: Any): MyDenseVector = { datum match { - case data: Seq[_] => - new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + case data: ArrayData => + new MyDenseVector(data.toArray.map(_.asInstanceOf[Double])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 3333fee6711c0..66014ddca0596 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -34,8 +34,7 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[DoubleColumnStats], DOUBLE, InternalRow(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[FixedDecimalColumnStats], - FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) + testDecimalColumnStats(InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -52,21 +51,56 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_(0).asInstanceOf[T#InternalType]) + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) - assertResult(10, "Wrong null count")(stats(2)) - assertResult(20, "Wrong row count")(stats(3)) - assertResult(stats(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) + assertResult(10, "Wrong null count")(stats.genericGet(2)) + assertResult(20, "Wrong row count")(stats.genericGet(3)) + assertResult(stats.genericGet(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } + + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + + val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName + val columnType = FIXED_DECIMAL(15, 10) + + test(s"$columnStatsName: empty") { + val columnStats = new FixedDecimalColumnStats(15, 10) + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"$columnStatsName: non-empty") { + import org.apache.spark.sql.columnar.ColumnarTestUtils._ + + val columnStats = new FixedDecimalColumnStats(15, 10) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) + assertResult(10, "Wrong null count")(stats.genericGet(2)) + assertResult(20, "Wrong row count")(stats.genericGet(3)) + assertResult(stats.genericGet(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 4d46a657056e0..8f024690efd0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -32,13 +32,15 @@ import org.apache.spark.unsafe.types.UTF8String class ColumnTypeSuite extends SparkFunSuite with Logging { - val DEFAULT_BUFFER_SIZE = 512 + private val DEFAULT_BUFFER_SIZE = 512 + private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType)) test("defaultSize") { val checks = Map( BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, - STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16) + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, + MAP_GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -48,8 +50,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } test("actualSize") { - def checkActualSize[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def checkActualSize[JvmType]( + columnType: ColumnType[JvmType], value: JvmType, expected: Int): Unit = { @@ -74,7 +76,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) + checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } testNativeColumnType(BOOLEAN)( @@ -123,7 +125,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { UTF8String.fromBytes(bytes) }) - testColumnType[BinaryType.type, Array[Byte]]( + testColumnType[Array[Byte]]( BINARY, (buffer: ByteBuffer, bytes: Array[Byte]) => { buffer.putInt(bytes.length).put(bytes) @@ -140,7 +142,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val obj = Map(1 -> "spark", 2 -> "sql") val serializedObj = SparkSqlSerializer.serialize(obj) - GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) + MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) buffer.rewind() val length = buffer.getInt() @@ -157,7 +159,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(obj, "Deserialized object didn't equal to the original object") { buffer.rewind() - SparkSqlSerializer.deserialize(GENERIC.extract(buffer)) + SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer)) } } @@ -170,7 +172,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val obj = CustomClass(Int.MaxValue, Long.MaxValue) val serializedObj = serializer.serialize(obj).array() - GENERIC.append(serializer.serialize(obj).array(), buffer) + MAP_GENERIC.append(serializer.serialize(obj).array(), buffer) buffer.rewind() val length = buffer.getInt @@ -192,7 +194,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(obj, "Custom deserialized object didn't equal the original object") { buffer.rewind() - serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer))) } } @@ -201,11 +203,11 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { - testColumnType[T, T#InternalType](columnType, putter, getter) + testColumnType[T#InternalType](columnType, putter, getter) } - def testColumnType[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def testColumnType[JvmType]( + columnType: ColumnType[JvmType], putter: (ByteBuffer, JvmType) => Unit, getter: (ByteBuffer) => JvmType): Unit = { @@ -262,7 +264,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - assertResult(GENERIC) { + assertResult(GENERIC(DecimalType(19, 0))) { ColumnType(DecimalType(19, 0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index d9861339739c9..79bb7d072feb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -31,7 +31,7 @@ object ColumnarTestUtils { row } - def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = { + def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) @@ -58,15 +58,15 @@ object ColumnarTestUtils { } def makeRandomValues( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) - def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = { + def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } - def makeUniqueRandomValues[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def makeUniqueRandomValues[JvmType]( + columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => @@ -75,10 +75,10 @@ object ColumnarTestUtils { } def makeRandomRow( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): InternalRow = makeRandomRow(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) - def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): InternalRow = { + def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 01bc23277fa88..037e2048a8631 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -148,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 9eaa769846088..f4f6c7649bfa8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -21,17 +21,17 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} -class TestNullableColumnAccessor[T <: DataType, JvmType]( +class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, - columnType: ColumnType[T, JvmType]) + columnType: ColumnType[JvmType]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) - : TestNullableColumnAccessor[T, JvmType] = { + def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) + : TestNullableColumnAccessor[JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) @@ -43,13 +43,13 @@ class NullableColumnAccessorSuite extends SparkFunSuite { Seq( BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) .foreach { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnAccessor[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) @@ -75,7 +75,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite { (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row(0) === randomRow(0)) + assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 17e9ae464bcc0..241d09ea205e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) +class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) - : TestNullableColumnBuilder[T, JvmType] = { + def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder @@ -39,13 +39,13 @@ class NullableColumnBuilderSuite extends SparkFunSuite { Seq( BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) .foreach { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnBuilder[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -92,13 +92,14 @@ class NullableColumnBuilderSuite extends SparkFunSuite { // For non-null values (0 until 4).foreach { _ => - val actual = if (columnType == GENERIC) { - SparkSqlSerializer.deserialize[Any](GENERIC.extract(buffer)) + val actual = if (columnType.isInstanceOf[GENERIC]) { + SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) } else { columnType.extract(buffer) } - assert(actual === randomRow(0), "Extracted value didn't equal to the original one") + assert(actual === randomRow.get(0, columnType.dataType), + "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index f606e2133bedc..9a2948c59ba42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -33,7 +33,7 @@ class BooleanBitSetSuite extends SparkFunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_(0)) + val values = rows.map(_.getBoolean(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3d71deb13e884..845ce669f0b33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -109,7 +109,7 @@ class PlannerSuite extends SparkFunSuite { FloatType :: DoubleType :: DecimalType(10, 5) :: - DecimalType.Unlimited :: + DecimalType.SYSTEM_DEFAULT :: DateType :: TimestampType :: StringType :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 7b75f755918c1..707cd9c6d939b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext class RowFormatConvertersSuite extends SparkPlanTest { @@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { @@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest { } test("filter can process unsafe rows") { - val plan = Filter(IsNull(null), outputsUnsafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) + assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { - val plan = Filter(IsNull(null), outputsSafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 6a8f394545816..f46855edfe0de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row} +import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row} import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -33,11 +33,13 @@ import scala.util.control.NonFatal */ class SparkPlanTest extends SparkFunSuite { + protected def sqlContext: SQLContext = TestSQLContext + /** * Creates a DataFrame from a local Seq of Product. */ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - TestSQLContext.implicits.localSeqToDataFrameHolder(data) + sqlContext.implicits.localSeqToDataFrameHolder(data) } /** @@ -98,7 +100,7 @@ class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -121,7 +123,8 @@ class SparkPlanTest extends SparkFunSuite { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { + SparkPlanTest.checkAnswer( + input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -147,13 +150,14 @@ object SparkPlanTest { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, - sortAnswers: Boolean): Option[String] = { + sortAnswers: Boolean, + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan) + executePlan(expectedOutputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -168,7 +172,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -207,12 +211,13 @@ object SparkPlanTest { input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { + sortAnswers: Boolean, + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -275,10 +280,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = TestSQLContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 4a53fadd7e099..54f82f89ed18a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -54,7 +54,7 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { checkSupported(StringType, isSupported = true) checkSupported(BinaryType, isSupported = true) checkSupported(DecimalType(10, 5), isSupported = true) - checkSupported(DecimalType.Unlimited, isSupported = true) + checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true) // If NullType is the only data type in the schema, we do not support it. checkSupported(NullType, isSupported = false) @@ -86,7 +86,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), DateType, TimestampType) val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala similarity index 70% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 5fe73f7e0b072..450963547c798 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -36,39 +36,21 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } - ignore("sort followed by limit should not leak memory") { - // TODO: this test is going to fail until we implement a proper iterator interface - // with a close() method. - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) } - test("sort followed by limit") { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - try { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } finally { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") - - } - } - test("sorting does not crash for large inputs") { val sortOrder = 'a.asc :: Nil val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -88,10 +70,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + assert(TungstenSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23), + plan => ConvertToSafe( + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index bd788ec8c14b1..40b47ae18d648 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -22,33 +22,22 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter} -import org.apache.spark.sql.catalyst.util.ObjectPool +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent class UnsafeRowSerializerSuite extends SparkFunSuite { - private def toUnsafeRow( - row: Row, - schema: Array[DataType], - objPool: ObjectPool = null): UnsafeRow = { + private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] - val rowConverter = new UnsafeRowConverter(schema) - val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow) - val byteArray = new Array[Byte](rowSizeInBytes) - rowConverter.writeRow( - internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes, objPool) - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes, objPool) - unsafeRow + val converter = UnsafeProjection.create(schema) + converter.apply(internalRow) } test("toUnsafeRow() test helper method") { + // This currently doesnt work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) - assert(row.getString(0) === unsafeRow.get(0).toString) + assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) assert(row.getInt(1) === unsafeRow.getInt(1)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9dd2220f0967e..8b1a9b21a96b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.joins +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.types.{StructField, StructType, IntegerType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -64,27 +65,34 @@ class HashedRelationSuite extends SparkFunSuite { } test("UnsafeHashedRelation") { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val toUnsafe = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafe(_).copy()).toArray + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val schema = StructType(StructField("a", IntegerType, true) :: Nil) - val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) + val keyGenerator = UnsafeProjection.create(buildKey) + val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - val toUnsafeKey = UnsafeProjection.create(schema) - val unsafeData = data.map(toUnsafeKey(_).copy()).toArray assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) - assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed.get(toUnsafe(InternalRow(10))) === null) val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) data2 += unsafeData(2).copy() assert(hashed.get(unsafeData(2)) === data2) - val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) - .asInstanceOf[UnsafeHashedRelation] + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) - assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0f82f13088d39..42f2449afb0f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -134,7 +134,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))" + conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() conn.prepareStatement("insert into test.flttypes values (" + "1.0000000000000002220446049250313080847263336181640625, " @@ -152,7 +152,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), |f VARCHAR_IGNORECASE(20), g CHAR(20), h BLOB, i CLOB, j TIME, k DATE, l TIMESTAMP, - |m DOUBLE, n REAL, o DECIMAL(40, 20)) + |m DOUBLE, n REAL, o DECIMAL(38, 18)) """.stripMargin.replaceAll("\n", " ")).executeUpdate() conn.prepareStatement("insert into test.nulltypes values (" + "null, null, null, null, null, null, null, null, null, " @@ -357,14 +357,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() - assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. - assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. - assert(rows(0).getAs[BigDecimal](2) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) - assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20)) - val compareDecimal = sql("SELECT C FROM flttypes where C > C - 1").collect() - assert(compareDecimal(0).getAs[BigDecimal](0) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) + assert(rows(0).getDouble(0) === 1.00000000000000022) + assert(rows(0).getDouble(1) === 1.00000011920928955) + assert(rows(0).getAs[BigDecimal](2) === + new BigDecimal("123456789012345.543215432154321000")) + assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18)) + val result = sql("SELECT C FROM flttypes where C > C - 1").collect() + assert(result(0).getAs[BigDecimal](0) === + new BigDecimal("123456789012345.543215432154321000")) } test("SQL query as table name") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 1d04513a44672..f19f22fca7d54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -63,18 +63,18 @@ class JsonSuite extends QueryTest with TestJsonData { checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) checkTypePromotion( - Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.SYSTEM_DEFAULT)) val longNumber: Long = 9223372036854775807L checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) checkTypePromotion( - Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.SYSTEM_DEFAULT)) val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) checkTypePromotion( - Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) + Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT)) checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), enforceCorrectType(intNumber, TimestampType)) @@ -115,7 +115,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(NullType, IntegerType, IntegerType) checkDataType(NullType, LongType, LongType) checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(NullType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(NullType, StringType, StringType) checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) checkDataType(NullType, StructType(Nil), StructType(Nil)) @@ -126,7 +126,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(BooleanType, IntegerType, StringType) checkDataType(BooleanType, LongType, StringType) checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType.Unlimited, StringType) + checkDataType(BooleanType, DecimalType.SYSTEM_DEFAULT, StringType) checkDataType(BooleanType, StringType, StringType) checkDataType(BooleanType, ArrayType(IntegerType), StringType) checkDataType(BooleanType, StructType(Nil), StringType) @@ -135,7 +135,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(IntegerType, IntegerType, IntegerType) checkDataType(IntegerType, LongType, LongType) checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(IntegerType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(IntegerType, StringType, StringType) checkDataType(IntegerType, ArrayType(IntegerType), StringType) checkDataType(IntegerType, StructType(Nil), StringType) @@ -143,23 +143,24 @@ class JsonSuite extends QueryTest with TestJsonData { // LongType checkDataType(LongType, LongType, LongType) checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(LongType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(LongType, StringType, StringType) checkDataType(LongType, ArrayType(IntegerType), StringType) checkDataType(LongType, StructType(Nil), StringType) // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) - // DoubleType - checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(DecimalType.Unlimited, StringType, StringType) - checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) - checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) + // DecimalType + checkDataType(DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT, + DecimalType.SYSTEM_DEFAULT) + checkDataType(DecimalType.SYSTEM_DEFAULT, StringType, StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, StructType(Nil), StringType) // StringType checkDataType(StringType, StringType, StringType) @@ -213,7 +214,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType( StructType( StructField("f1", IntegerType, true) :: Nil), - DecimalType.Unlimited, + DecimalType.SYSTEM_DEFAULT, StringType) } @@ -240,7 +241,7 @@ class JsonSuite extends QueryTest with TestJsonData { val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -270,7 +271,7 @@ class JsonSuite extends QueryTest with TestJsonData { val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: StructField("arrayOfInteger", ArrayType(LongType, true), true) :: @@ -284,7 +285,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("field3", StringType, true) :: Nil), true), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: + StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(LongType, true), true) :: StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) @@ -385,7 +386,7 @@ class JsonSuite extends QueryTest with TestJsonData { val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType.Unlimited, true) :: + StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -423,12 +424,12 @@ class JsonSuite extends QueryTest with TestJsonData { // Widening to DecimalType checkAnswer( - sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: - Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), + Row(BigDecimal("21474836472.2")) :: + Row(BigDecimal("92233720368547758071.3")) :: Nil ) - // Widening to DoubleType + // Widening to Double checkAnswer( sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), Row(101.2) :: Row(21474836471.2) :: Nil @@ -437,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(92233720368547758071.2) + Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue) + sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + Row(new java.math.BigDecimal("92233720368547758071.2")) ) // String and Boolean conflict: resolve the type as string. @@ -489,9 +490,9 @@ class JsonSuite extends QueryTest with TestJsonData { // in the Project. checkAnswer( jsonDF. - where('num_str > BigDecimal("92233720368547758060")). + where('num_str >= BigDecimal("92233720368547758060")). select(('num_str + 1.2).as("num")), - Row(new java.math.BigDecimal("92233720368547758061.2")) + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue()) ) // The following test will fail. The type of num_str is StringType. @@ -502,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Row(14.3) :: Row(92233720368547758071.2) :: Nil + Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil ) } @@ -610,7 +611,7 @@ class JsonSuite extends QueryTest with TestJsonData { val jsonDF = ctx.read.json(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -668,7 +669,7 @@ class JsonSuite extends QueryTest with TestJsonData { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 23df102cd951d..b6a7c4fbddbdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.parquet -import org.scalatest.BeforeAndAfterAll import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} @@ -40,7 +39,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuiteBase extends QueryTest with ParquetTest { +class ParquetFilterSuite extends QueryTest with ParquetTest { lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext private def checkFilterPredicate( @@ -56,17 +55,9 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val maybeAnalyzedPredicate = { - val forParquetTableScan = query.queryExecution.executedPlan.collect { - case plan: ParquetTableScan => plan.columnPruningPred - }.flatten.reduceOption(_ && _) - - val forParquetDataSource = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters - }.flatten.reduceOption(_ && _) - - forParquetTableScan.orElse(forParquetDataSource) - } + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation)) => filters + }.flatten.reduceOption(_ && _) assert(maybeAnalyzedPredicate.isDefined) maybeAnalyzedPredicate.foreach { pred => @@ -98,7 +89,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) (implicit df: DataFrame): Unit = { def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } @@ -308,18 +299,6 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } -} - -class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } test("SPARK-6554: don't push down predicates which reference partition columns") { import sqlContext.implicits._ @@ -338,37 +317,3 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA } } } - -class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } - - test("SPARK-6742: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ - - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - - // If the "part = 1" filter gets pushed down, this query will throw an exception since - // "part" is not a valid column in the actual Parquet file - val df = DataFrame(sqlContext, org.apache.spark.sql.parquet.ParquetRelation( - path, - Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext, - Seq(AttributeReference("part", IntegerType, false)()) )) - - checkAnswer( - df.filter("a = 1 or part = 1"), - (1 to 3).map(i => Row(1, i, i.toString))) - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 7b16eba00d6fb..b415da5b8c136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -32,7 +32,6 @@ import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, P import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} -import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql._ @@ -63,7 +62,7 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuiteBase extends QueryTest with ParquetTest { +class ParquetIOSuite extends QueryTest with ParquetTest { lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ @@ -107,29 +106,13 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { // Parquet doesn't allow column names with spaces, have to add an alias here .select($"_1" cast decimal as "dec") - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } - - // Decimals with precision above 18 are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } - - // Unlimited-length decimals are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } } test("date type") { @@ -365,7 +348,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { """.stripMargin) withTempPath { location => - val extraMetadata = Map(RowReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") val path = new Path(location.getCanonicalPath) @@ -430,26 +413,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } -} - -class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) - extends ParquetOutputCommitter(outputPath, context) { - - override def commitJob(jobContext: JobContext): Unit = { - sys.error("Intentional exception for testing purposes") - } -} - -class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key, originalConf.toString) - } test("SPARK-6330 regression test") { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: @@ -464,14 +427,10 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA } } -class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } +class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) + override def commitJob(jobContext: JobContext): Unit = { + sys.error("Intentional exception for testing purposes") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 4f98776b91160..2eef10189f11c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -467,7 +467,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: ParquetRelation2) => + case LogicalRelation(relation: ParquetRelation) => assert(relation.partitionSpec === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") @@ -509,7 +509,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { FloatType, DoubleType, DecimalType(10, 5), - DecimalType.Unlimited, + DecimalType.SYSTEM_DEFAULT, DateType, TimestampType, StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 21007d95ed752..a95f70f2bba69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuiteBase extends QueryTest with ParquetTest { +class ParquetQuerySuite extends QueryTest with ParquetTest { lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.sql @@ -124,6 +126,30 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { } } + test("Enabling/disabling merging partfiles when merging parquet schema") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + // delete summary files, so if we don't merge part-files, one column will not be included. + Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) + Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { + testSchemaMerging(2) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { + testSchemaMerging(3) + } + } + test("Enabling/disabling schema merging") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => @@ -164,27 +190,3 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { } } } - -class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} - -class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index fa629392674bd..4a0b3b60f419d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -378,7 +378,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructField("lowerCase", StringType), StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("lowercase", StringType), StructField("uppercase", DoubleType, nullable = false))), @@ -393,7 +393,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructType(Seq( StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -404,7 +404,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Metastore schema contains additional non-nullable fields. assert(intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false), StructField("lowerCase", BinaryType, nullable = false))), @@ -415,7 +415,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Conflicting non-nullable field names intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } @@ -429,7 +429,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructField("firstField", StringType, nullable = true), StructField("secondField", StringType, nullable = true), StructField("thirdfield", StringType, nullable = true)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), @@ -442,7 +442,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Merge should fail if the Metastore contains any additional fields that are not // nullable. assert(intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index eb15a1609f1d0..64e94056f209a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -22,6 +22,7 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{DataFrame, SaveMode} @@ -32,8 +33,7 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SQLTestUtils { - +private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 54e1efb6e36e7..84855ce45e918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -44,7 +44,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("doubleType", DoubleType, nullable = false), StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), - StructField("decimalType", DecimalType.Unlimited, nullable = false), + StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), @@ -61,9 +61,10 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sqlContext.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row - } + InternalRow(UTF8String.fromString(s"people$e"), e * 2) + }.asInstanceOf[RDD[Row]] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 257526feab945..0d5183444af78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -131,7 +131,7 @@ class PrunedScanSuite extends DataSourceTest { queryExecution) } - if (rawOutput.size != expectedColumns.size) { + if (rawOutput.numFields != expectedColumns.size) { fail(s"Wrong output row. Got $rawOutput\n$queryExecution") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 2c916f3322b6d..cfb03ff485b7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -67,12 +67,12 @@ case class AllDataTypesScan( override def schema: StructType = userSpecifiedSchema - override def needConversion: Boolean = false + override def needConversion: Boolean = true override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - InternalRow( - UTF8String.fromString(s"str_$i"), + Row( + s"str_$i", s"str_$i".getBytes(), i % 2 == 0, i.toByte, @@ -81,18 +81,18 @@ case class AllDataTypesScan( i.toLong, i.toFloat, i.toDouble, - Decimal(new java.math.BigDecimal(i)), - Decimal(new java.math.BigDecimal(i)), - DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)), - DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)), - UTF8String.fromString(s"varchar_$i"), + new java.math.BigDecimal(i), + new java.math.BigDecimal(i), + new Date(1970, 1, 1), + new Timestamp(20000 + i), + s"varchar_$i", Seq(i, i + 1), - Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), - Map(i -> UTF8String.fromString(i.toString)), - Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - InternalRow(i, UTF8String.fromString(i.toString)), - InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), - InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), + Row(Seq(new Date(1970, 1, i + 1))))) } } } @@ -202,7 +202,7 @@ class TableScanSuite extends DataSourceTest { StructField("longField_:,<>=+/~^", LongType, true) :: StructField("floatField", FloatType, true) :: StructField("doubleField", DoubleType, true) :: - StructField("decimalField1", DecimalType.Unlimited, true) :: + StructField("decimalField1", DecimalType.USER_DEFAULT, true) :: StructField("decimalField2", DecimalType(9, 2), true) :: StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index fa01823e9417c..4c11acdab9ec0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.test import java.io.File +import java.util.UUID import scala.util.Try +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils -trait SQLTestUtils { +trait SQLTestUtils { this: SparkFunSuite => def sqlContext: SQLContext protected def configuration = sqlContext.sparkContext.hadoopConfiguration @@ -87,4 +89,29 @@ trait SQLTestUtils { } } } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + sqlContext.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + sqlContext.sql(s"USE $db") + try f finally sqlContext.sql(s"USE default") + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala similarity index 97% rename from sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala rename to sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala index 1fe4fe9629c02..1a5ba20404c4e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala @@ -23,16 +23,16 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive /** - * Runs the test cases that are included in the hive distribution with sort merge join is true. + * Runs the test cases that are included in the hive distribution with hash joins. */ -class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { +class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) + TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) } override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) + TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) super.afterAll() } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b12b3838e615c..ec959cb2194b0 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -822,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", - "udaf_number_format", "udf2", "udf5", "udf6", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4cdb83c5116f9..110f51a305861 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -40,7 +40,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} @@ -267,7 +267,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - catalog.refreshTable(catalog.client.currentDatabase, tableName) + val tableIdent = TableIdentifier(tableName).withDatabase(catalog.client.currentDatabase) + catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { @@ -444,9 +445,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { HiveDDLStrategy, DDLStrategy, TakeOrderedAndProject, - ParquetOperations, InMemoryScans, - ParquetConversion, // Must be before HiveTableScans HiveTableScans, DataSinks, Scripts, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index a8f2ee37cb8ed..5926ef9aa388b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -46,15 +46,14 @@ import scala.collection.JavaConversions._ * long / scala.Long * short / scala.Short * byte / scala.Byte - * org.apache.spark.sql.types.Decimal + * [[org.apache.spark.sql.types.Decimal]] * Array[Byte] * java.sql.Date * java.sql.Timestamp * Complex Types => * Map: scala.collection.immutable.Map - * List: scala.collection.immutable.Seq - * Struct: - * org.apache.spark.sql.catalyst.expression.Row + * List: [[org.apache.spark.sql.types.ArrayData]] + * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -179,7 +178,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -195,8 +194,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.SYSTEM_DEFAULT + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -297,7 +296,10 @@ private[hive] trait HiveInspectors { }.toMap case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + val values = li.getWritableConstantValue + .map(unwrap(_, li.getListElementObjectInspector)) + .toArray + new GenericArrayData(values) // if the value is null, we don't care about the object inspector type case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector @@ -339,7 +341,10 @@ private[hive] trait HiveInspectors { } case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) + .map { l => + val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + new GenericArrayData(values) + } .orNull case mi: MapObjectInspector => Option(mi.getMap(data)).map( @@ -391,7 +396,13 @@ private[hive] trait HiveInspectors { case loi: ListObjectInspector => val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null + (o: Any) => { + if (o != null) { + seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper)) + } else { + null + } + } case moi: MapObjectInspector => // The Predef.Map is scala.collection.immutable.Map. @@ -454,7 +465,7 @@ private[hive] trait HiveInspectors { * * NOTICE: the complex data type requires recursive wrapping. */ - def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match { + def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match { case x: ConstantObjectInspector => x.getWritableConstantValue case _ if a == null => null case x: PrimitiveObjectInspector => x match { @@ -488,43 +499,50 @@ private[hive] trait HiveInspectors { } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs + val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] // 1. create the pojo (most likely) object val result = x.create() var i = 0 while (i < fieldRefs.length) { // 2. set the property for the pojo + val tpe = structType(i).dataType x.setStructFieldData( result, fieldRefs.get(i), - wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs + val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { - result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + val tpe = structType(i).dataType + result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: ListObjectInspector => val list = new java.util.ArrayList[Object] - a.asInstanceOf[Seq[_]].foreach { - v => list.add(wrap(v, x.getListElementObjectInspector)) + val tpe = dataType.asInstanceOf[ArrayType].elementType + a.asInstanceOf[ArrayData].toArray().foreach { + v => list.add(wrap(v, x.getListElementObjectInspector, tpe)) } list case x: MapObjectInspector => + val keyType = dataType.asInstanceOf[MapType].keyType + val valueType = dataType.asInstanceOf[MapType].valueType // Some UDFs seem to assume we pass in a HashMap. val hashMap = new java.util.HashMap[AnyRef, AnyRef]() - hashMap.putAll(a.asInstanceOf[Map[_, _]].map { - case (k, v) => - wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector) + hashMap.putAll(a.asInstanceOf[Map[_, _]].map { case (k, v) => + wrap(k, x.getMapKeyObjectInspector, keyType) -> + wrap(v, x.getMapValueObjectInspector, valueType) }) hashMap @@ -533,22 +551,24 @@ private[hive] trait HiveInspectors { def wrap( row: InternalRow, inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) i += 1 } cache } def wrap( - row: Seq[Any], - inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) i += 1 } cache @@ -625,7 +645,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector))) + value.asInstanceOf[ArrayData].toArray() + .foreach(v => list.add(wrap(v, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -636,7 +657,7 @@ private[hive] trait HiveInspectors { } else { val map = new java.util.HashMap[Object, Object]() value.asInstanceOf[Map[_, _]].foreach (entry => { - map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)) + map.put(wrap(entry._1, keyOI, keyType), wrap(entry._2, valueOI, valueType)) }) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) } @@ -813,9 +834,6 @@ private[hive] trait HiveInspectors { private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo( - HiveShim.UNLIMITED_DECIMAL_PRECISION, - HiveShim.UNLIMITED_DECIMAL_SCALE) } def toTypeInfo: TypeInfo = dt match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0a2121c955871..a8c9b4fa71b99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConversions._ import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.metastore.Warehouse @@ -30,20 +29,19 @@ import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} import org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, PartitionSpec, CreateTableUsingAsSelect, ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ - +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -116,7 +114,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. // Since we also cache ParquetRelations converted from Hive Parquet tables and @@ -125,7 +123,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for // data source table.). - invalidateTable(databaseName, tableName) + invalidateTable(tableIdent.database.getOrElse(client.currentDatabase), tableIdent.table) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -145,7 +143,27 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = processDatabaseAndTableName(client.currentDatabase, tableName) + createDataSourceTable( + new SqlParser().parseTableIdentifier(tableName), + userSpecifiedSchema, + partitionColumns, + provider, + options, + isExternal) + } + + private def createDataSourceTable( + tableIdent: TableIdentifier, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String], + isExternal: Boolean): Unit = { + val (dbName, tblName) = { + val database = tableIdent.database.getOrElse(client.currentDatabase) + processDatabaseAndTableName(database, tableIdent.table) + } + val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) @@ -178,7 +196,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // partitions when we load the table. However, if there are specified partition columns, // we simplily ignore them and provide a warning message.. logWarning( - s"The schema and partitions of table $tableName will be inferred when it is loaded. " + + s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") } Seq.empty[HiveColumn] @@ -256,12 +274,12 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging - // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the - // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + // evil case insensitivity issue, which is reconciled within `ParquetRelation`. val parquetOptions = Map( - ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) @@ -272,7 +290,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = @@ -317,7 +335,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) val parquetRelation = cached.getOrElse { val created = LogicalRelation( - new ParquetRelation2( + new ParquetRelation( paths.toArray, None, Some(partitionSpec), parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created @@ -330,7 +348,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, None) val parquetRelation = cached.getOrElse { val created = LogicalRelation( - new ParquetRelation2(paths.toArray, None, None, parquetOptions)(hive)) + new ParquetRelation(paths.toArray, None, None, parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created } @@ -370,8 +388,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive /** * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet * data source relations for better performance. - * - * This rule can be considered as [[HiveStrategies.ParquetConversion]] done right. */ object ParquetConversions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { @@ -386,7 +402,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) @@ -397,7 +412,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) @@ -406,7 +420,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Read path case p @ PhysicalOperation(_, _, relation: MetastoreRelation) if hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) @@ -665,8 +678,18 @@ private[hive] case class MetastoreRelation } ) + // When metastore partition pruning is turned off, we cache the list of all partitions to + // mimic the behavior of Spark < 1.5 + lazy val allPartitions = table.getAllPartitions + def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { - table.getPartitions(predicates).map { p => + val rawPartitions = if (sqlContext.conf.metastorePartitionPruning) { + table.getPartitions(predicates) + } else { + allPartitions + } + + rawPartitions.map { p => val tPartition = new org.apache.hadoop.hive.metastore.api.Partition tPartition.setDbName(databaseName) tPartition.setTableName(tableName) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 8518e333e8058..e6df64d2642bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -377,7 +377,7 @@ private[hive] object HiveQl extends Logging { DecimalType(precision.getText.toInt, scale.getText.toInt) case Token("TOK_DECIMAL", precision :: Nil) => DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited + case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -874,15 +874,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { + : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, "", Nil) + (rowFormat, None, Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, serdeClass, Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", @@ -891,9 +891,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => (name, value) } - (Nil, serdeClass, serdeProps) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, "", Nil) + case Nil => (Nil, None, Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) @@ -1321,11 +1321,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Attribute References */ case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => - UnresolvedAttribute(cleanIdentifier(name)) + UnresolvedAttribute.quoted(cleanIdentifier(name)) case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { - case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) case other => UnresolvedExtractValue(other, Literal(attr)) } @@ -1369,7 +1369,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.Unlimited) + Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a22c3292eff94..cd6cd322c94ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,23 +17,14 @@ package org.apache.spark.sql.hive -import scala.collection.JavaConversions._ - -import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.types.StringType private[hive] trait HiveStrategies { @@ -42,136 +33,6 @@ private[hive] trait HiveStrategies { val hiveContext: HiveContext - /** - * :: Experimental :: - * Finds table scans that would use the Hive SerDe and replaces them with our own native parquet - * table scan operator. - * - * TODO: Much of this logic is duplicated in HiveTableScan. Ideally we would do some refactoring - * but since this is after the code freeze for 1.1 all logic is here to minimize disruption. - * - * Other issues: - * - Much of this logic assumes case insensitive resolution. - */ - @Experimental - object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: DataFrame) { - def lowerCase: DataFrame = DataFrame(s.sqlContext, s.logicalPlan) - - def addPartitioningAttributes(attrs: Seq[Attribute]): DataFrame = { - // Don't add the partitioning key if its already present in the data. - if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { - s - } else { - DataFrame( - s.sqlContext, - s.logicalPlan transform { - case p: ParquetRelation => p.copy(partitioningAttributes = attrs) - }) - } - } - } - - implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { - def fakeOutput(newOutput: Seq[Attribute]): OutputFaker = - OutputFaker( - originalPlan.output.map(a => - newOutput.find(a.name.toLowerCase == _.name.toLowerCase) - .getOrElse( - sys.error(s"Can't find attribute $a to fake in set ${newOutput.mkString(",")}"))), - originalPlan) - } - - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) - if relation.tableDesc.getSerdeClassName.contains("Parquet") && - hiveContext.convertMetastoreParquet && - !hiveContext.conf.parquetUseDataSourceApi => - - // Filter out all predicates that only deal with partition keys - val partitionsKeys = AttributeSet(relation.partitionKeys) - val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.subsetOf(partitionsKeys) - } - - // We are going to throw the predicates and projection back at the whole optimization - // sequence so lets unresolve all the attributes, allowing them to be rebound to the - // matching parquet attributes. - val unresolvedOtherPredicates = Column(otherPredicates.map(_ transform { - case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true))) - - val unresolvedProjection: Seq[Column] = projectList.map(_ transform { - case a: AttributeReference => UnresolvedAttribute(a.name) - }).map(Column(_)) - - try { - if (relation.hiveQlTable.isPartitioned) { - val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true)) - // Translate the predicate so that it automatically casts the input values to the - // correct data types during evaluation. - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = relation.partitionKeys.indexWhere(a.exprId == _.exprId) - val key = relation.partitionKeys(idx) - Cast(BoundReference(idx, StringType, nullable = true), key.dataType) - } - - val inputData = new GenericMutableRow(relation.partitionKeys.size) - val pruningCondition = - if (codegenEnabled) { - GeneratePredicate.generate(castedPredicate) - } else { - InterpretedPredicate.create(castedPredicate) - } - - val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part => - val partitionValues = part.getValues - var i = 0 - while (i < partitionValues.size()) { - inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i)) - i += 1 - } - pruningCondition(inputData) - } - - val partitionLocations = partitions.map(_.getLocation) - - if (partitionLocations.isEmpty) { - PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil - } else { - hiveContext - .read.parquet(partitionLocations: _*) - .addPartitioningAttributes(relation.partitionKeys) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil - } - - } else { - hiveContext - .read.parquet(relation.hiveQlTable.getDataLocation.toString) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil - } - } catch { - // parquetFile will throw an exception when there is no data. - // TODO: Remove this hack for Spark 1.3. - case iae: java.lang.IllegalArgumentException - if iae.getMessage.contains("Can not create a Path from an empty string") => - PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil - } - case _ => Nil - } - } - object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 1656587d14835..d834b4e83e043 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -72,12 +72,10 @@ private[hive] case class HiveTable( def isPartitioned: Boolean = partitionColumns.nonEmpty - def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = { - predicates match { - case Nil => client.getAllPartitions(this) - case _ => client.getPartitionsByFilter(this, predicates) - } - } + def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) + + def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = + client.getPartitionsByFilter(this, predicates) // Hive does not support backticks when passing names to the client. def qualifiedName: String = s"$database.$name" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8202e553afbfe..e4944caeff924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ +import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkException, TaskContext} import scala.collection.JavaConversions._ @@ -96,13 +97,14 @@ case class InsertIntoHiveTable( val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) + val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) i += 1 } @@ -122,7 +124,7 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc @@ -252,13 +254,12 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[InternalRow] + Seq.empty[Row] } - override def executeCollect(): Array[Row] = - sideEffectResult.toArray + override def executeCollect(): Array[Row] = sideEffectResult.toArray protected override def doExecute(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(sideEffectResult, 1) + sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 205e622195f09..7e3342cc84c0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,15 +17,18 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} +import java.io._ import java.util.Properties +import javax.annotation.Nullable import scala.collection.JavaConversions._ +import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.spark.{TaskContext, Logging} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -56,21 +59,53 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) - // We need to start threads connected to the process pipeline: - // 1) The error msg generated by the script process would be hidden. - // 2) If the error msg is too big to chock up the buffer, the input logic would be hung + val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream - val reader = new BufferedReader(new InputStreamReader(inputStream)) - val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() + + val outputProjection = new InterpretedProjection(input, child.output) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new ScriptTransformationWriterThread( + inputIterator, + outputProjection, + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get() + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + ioschema.initOutputSerDe(output).getOrElse((null, null)) + } - val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + val reader = new BufferedReader(new InputStreamReader(inputStream)) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var cacheRow: InternalRow = null var curLine: String = null var eof: Boolean = false @@ -79,12 +114,26 @@ case class ScriptTransformation( if (outputSerde == null) { if (curLine == null) { curLine = reader.readLine() - curLine != null + if (curLine == null) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { + true + } } else { true } } else { - !eof + if (eof) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { + true + } } } @@ -110,11 +159,11 @@ case class ScriptTransformation( } i += 1 }) - return mutableRow + mutableRow } catch { case e: EOFException => eof = true - return null + null } } @@ -127,13 +176,13 @@ case class ScriptTransformation( val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) } else { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) } } else { val ret = deserialize() @@ -146,49 +195,83 @@ case class ScriptTransformation( } } - val (inputSerde, inputSoi) = ioschema.initInputSerDe(input) - val dataOutputStream = new DataOutputStream(outputStream) - val outputProjection = new InterpretedProjection(input, child.output) + writerThread.start() - // TODO make the 2048 configurable? - val stderrBuffer = new CircularBuffer(2048) - // Consume the error stream from the pipeline, otherwise it will be blocked if - // the pipeline is full. - new RedirectThread(errorStream, // input stream from the pipeline - stderrBuffer, // output to a circular buffer - "Thread-ScriptTransformation-STDERR-Consumer").start() + outputIterator + } - // Put the write(output to the pipeline) into a single thread - // and keep the collector as remain in the main thread. - // otherwise it will causes deadlock if the data size greater than - // the pipeline / buffer capacity. - new Thread(new Runnable() { - override def run(): Unit = { - Utils.tryWithSafeFinally { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) - } - } - outputStream.close() - } { - if (proc.waitFor() != 0) { - logError(stderrBuffer.toString) // log the stderr circular buffer - } - } - } - }, "Thread-ScriptTransformation-Feed").start() + child.execute().mapPartitions { iter => + if (iter.hasNext) { + processIterator(iter) + } else { + // If the input iterator has no rows then do not launch the external script. + Iterator.empty + } + } + } +} - iterator +private class ScriptTransformationWriterThread( + iter: Iterator[InternalRow], + outputProjection: Projection, + @Nullable inputSerde: AbstractSerDe, + @Nullable inputSoi: ObjectInspector, + ioschema: HiveScriptIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext + ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { + + setDaemon(true) + + @volatile private var _exception: Throwable = null + + /** Contains the exception thrown while writing the parent iterator to the external process. */ + def exception: Option[Throwable] = Option(_exception) + + override def run(): Unit = Utils.logUncaughtExceptions { + TaskContext.setTaskContext(taskContext) + + val dataOutputStream = new DataOutputStream(outputStream) + + // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so + // let's use a variable to record whether the `finally` block was hit due to an exception + var threwException: Boolean = true + try { + iter.map(outputProjection).foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + outputStream.write(data) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + threwException = false + } catch { + case NonFatal(e) => + // An error occurred while writing input, so kill the child process. According to the + // Javadoc this call will not throw an exception: + _exception = e + proc.destroy() + throw e + } finally { + try { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + } + } catch { + case NonFatal(exceptionFromFinallyBlock) => + if (!threwException) { + throw exceptionFromFinallyBlock + } else { + log.error("Exception in finally block", exceptionFromFinallyBlock) + } + } } } } @@ -200,33 +283,43 @@ private[hive] case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], outputRowFormat: Seq[(String, String)], - inputSerdeClass: String, - outputSerdeClass: String, + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { - val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n")) + private val defaultFormat = Map( + ("TOK_TABLEROWFORMATFIELD", "\t"), + ("TOK_TABLEROWFORMATLINES", "\n") + ) val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { - val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) - (serde, initInputSoi(serde, columns, columnTypes)) + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + inputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) + val fieldObjectInspectors = columnTypes.map(toInspector) + val objectInspector = ObjectInspectorFactory + .getStandardStructObjectInspector(columns, fieldObjectInspectors) + .asInstanceOf[ObjectInspector] + (serde, objectInspector) + } } - def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = { - val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps) - (serde, initOutputputSoi(serde)) + def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + outputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) + val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] + (serde, structObjectInspector) + } } - def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - + private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { val columns = attrs.map { case aref: AttributeReference => aref.name case e: NamedExpression => e.name @@ -242,52 +335,25 @@ case class HiveScriptIOSchema ( (columns, columnTypes) } - def initSerDe(serdeClassName: String, columns: Seq[String], - columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { + private def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { - val serde: AbstractSerDe = if (serdeClassName != "") { - val trimed_class = serdeClassName.split("'")(1) - Utils.classForName(trimed_class) - .newInstance.asInstanceOf[AbstractSerDe] - } else { - null - } + val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] - if (serde != null) { - val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - var propsMap = serdeProps.map(kv => { - (kv._1.split("'")(1), kv._2.split("'")(1)) - }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) - propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + var propsMap = serdeProps.map(kv => { + (kv._1.split("'")(1), kv._2.split("'")(1)) + }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - val properties = new Properties() - properties.putAll(propsMap) - serde.initialize(null, properties) - } + val properties = new Properties() + properties.putAll(propsMap) + serde.initialize(null, properties) serde } - - def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType]) - : ObjectInspector = { - - if (inputSerde != null) { - val fieldObjectInspectors = columnTypes.map(toInspector(_)) - ObjectInspectorFactory - .getStandardStructObjectInspector(columns, fieldObjectInspectors) - .asInstanceOf[ObjectInspector] - } else { - null - } - } - - def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { - if (outputSerde != null) { - outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] - } else { - null - } - } } - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 3259b50acc765..4a13022eddf60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -76,31 +76,53 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) } } - override def registerFunction(name: String, builder: FunctionBuilder): Unit = - underlying.registerFunction(name, builder) + override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) + : Unit = underlying.registerFunction(name, info, builder) + + /* List all of the registered function names. */ + override def listFunction(): Seq[String] = { + val a = FunctionRegistry.getFunctionNames ++ underlying.listFunction() + a.toList.sorted + } + + /* Get the class of the registered function by specified name. */ + override def lookupFunction(name: String): Option[ExpressionInfo] = { + underlying.lookupFunction(name).orElse( + Try { + val info = FunctionRegistry.getFunctionInfo(name) + val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) + if (annotation != null) { + Some(new ExpressionInfo( + info.getFunctionClass.getCanonicalName, + annotation.name(), + annotation.value(), + annotation.extended())) + } else { + None + } + }.getOrElse(None)) + } } private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - type UDFType = UDF - override def deterministic: Boolean = isUDFDeterministic override def nullable: Boolean = true @transient - lazy val function = funcWrapper.createFunction[UDFType]() + lazy val function = funcWrapper.createFunction[UDF]() @transient - protected lazy val method = + private lazy val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) @transient - protected lazy val arguments = children.map(toInspector).toArray + private lazy val arguments = children.map(toInspector).toArray @transient - protected lazy val isUDFDeterministic = { + private lazy val isUDFDeterministic = { val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } @@ -109,7 +131,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre // Create parameter converters @transient - protected lazy val conversionHelper = new ConversionHelper(method, arguments) + private lazy val conversionHelper = new ConversionHelper(method, arguments) @transient lazy val dataType = javaClassToDataType(method.getReturnType) @@ -119,14 +141,19 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre method.getGenericReturnType(), ObjectInspectorOptions.JAVA) @transient - protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - unwrap( - FunctionRegistry.invoke(method, function, conversionHelper - .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), - returnInspector) + val inputs = wrap(children.map(c => c.eval(input)), arguments, cached, inputDataTypes) + val ret = FunctionRegistry.invoke( + method, + function, + conversionHelper.convertIfNecessary(inputs : _*): _*) + unwrap(ret, returnInspector) } override def toString: String = { @@ -135,47 +162,48 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre } // Adapter from Catalyst ExpressionResult to Hive DeferredObject -private[hive] class DeferredObjectAdapter(oi: ObjectInspector) +private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) extends DeferredObject with HiveInspectors { + private var func: () => Any = _ def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrap(func(), oi) + override def get(): AnyRef = wrap(func(), oi, dataType) } private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - type UDFType = GenericUDF + + override def nullable: Boolean = true override def deterministic: Boolean = isUDFDeterministic - override def nullable: Boolean = true + override def foldable: Boolean = + isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient - lazy val function = funcWrapper.createFunction[UDFType]() + lazy val function = funcWrapper.createFunction[GenericUDF]() @transient - protected lazy val argumentInspectors = children.map(toInspector) + private lazy val argumentInspectors = children.map(toInspector) @transient - protected lazy val returnInspector = { + private lazy val returnInspector = { function.initializeAndFoldConstants(argumentInspectors.toArray) } @transient - protected lazy val isUDFDeterministic = { + private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } - override def foldable: Boolean = - isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] - @transient - protected lazy val deferedObjects = - argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] + private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + new DeferredObjectAdapter(inspect, child.dataType) + }.toArray[DeferredObject] lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -354,6 +382,9 @@ private[hive] case class HiveWindowFunction( // Output buffer. private var outputBuffer: Any = _ + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + override def init(): Unit = { evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } @@ -368,8 +399,13 @@ private[hive] case class HiveWindowFunction( } override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length)) + wrap( + inputProjection(input), + inputInspectors, + new Array[AnyRef](children.length), + inputDataTypes) } + // Add input parameters for a single row. override def update(input: AnyRef): Unit = { evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) @@ -395,7 +431,7 @@ private[hive] case class HiveWindowFunction( // if pivotResult is true, we will get a Seq having the same size with the size // of the window frame. At here, we will return the result at the position of // index in the output buffer. - outputBuffer.asInstanceOf[Seq[Any]].get(index) + outputBuffer.asInstanceOf[ArrayData].get(index) } } @@ -510,12 +546,15 @@ private[hive] case class HiveGenericUDTF( field => (inspectorToDataType(field.getFieldObjectInspector), true) } + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) - function.process(wrap(inputProjection(input), inputInspectors, udtInput)) + function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes)) collector.collectRows() } @@ -584,9 +623,12 @@ private[hive] case class HiveUDAFFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) + @transient + private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray + def update(input: InternalRow): Unit = { val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached)) + function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ecc78a5f8d321..8850e060d2a73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ @@ -94,7 +95,9 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { + writer + } def close() { // Seems the boolean value passed into close does not matter. @@ -197,7 +200,8 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: InternalRow, schema: StructType) + : FileSinkOperator.RecordWriter = { def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index de63ee56dd8e6..924f4d37ce21f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -66,7 +66,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriterInternal with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -119,10 +119,10 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = { + override def writeInternal(row: InternalRow): Unit = { var i = 0 - while (i < row.length) { - reusableOutputBuffer(i) = wrappers(i)(row(i)) + while (i < row.numFields) { + reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) i += 1 } @@ -192,7 +192,7 @@ private[sql] class OrcRelation( filters: Array[Filter], inputPaths: Array[FileStatus]): RDD[Row] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row]) + OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]] } override def prepareJobForWrite(job: Job): OutputWriterFactory = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 3662a4352f55d..7bbdef90cd6b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -56,6 +56,7 @@ object TestHive .set("spark.sql.test", "") .set("spark.sql.hive.metastore.barrierPrefixes", "org.apache.spark.sql.hive.execution.PairSerDe") + .set("spark.buffer.pageSize", "4m") // SPARK-8910 .set("spark.ui.enabled", "false"))) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index 5c9d0e97a99c6..a2247e3da1554 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -21,13 +21,18 @@ import java.util.List; import org.apache.spark.sql.Row; -import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +/** + * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a + * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum + * of the average value of input values and 100.0. + */ public class MyDoubleAvg extends UserDefinedAggregateFunction { private StructType _inputDataType; @@ -37,10 +42,13 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleAvg() { - List inputfields = new ArrayList(); - inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputfields); + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + // The buffer has two values, bufferSum for storing the current sum and + // bufferCount for storing the number of non-null input values that have been contribuetd + // to the current sum. List bufferFields = new ArrayList(); bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); @@ -66,16 +74,23 @@ public MyDoubleAvg() { } @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. buffer.update(0, null); + // The initial value of the count is 0. buffer.update(1, 0L); } @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. if (!input.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer and set the bufferCount to 1. if (buffer.isNullAt(0)) { buffer.update(0, input.getDouble(0)); buffer.update(1, 1L); } else { + // Otherwise, update the bufferSum and increment bufferCount. Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); buffer.update(1, buffer.getLong(1) + 1L); @@ -84,11 +99,16 @@ public MyDoubleAvg() { } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's sum value is not null. if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. buffer1.update(0, buffer2.getDouble(0)); buffer1.update(1, buffer2.getLong(1)); } else { + // Otherwise, we update the bufferSum and bufferCount. Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); @@ -98,10 +118,12 @@ public MyDoubleAvg() { @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { + // If the bufferSum is still null, we return null because this function has not got + // any input row. return null; } else { + // Otherwise, we calculate the special average value. return buffer.getDouble(0) / buffer.getLong(1) + 100.0; } } } - diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index 1d4587a27c787..da29e24d267dd 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -20,14 +20,18 @@ import java.util.ArrayList; import java.util.List; -import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.Row; +/** + * An example {@link UserDefinedAggregateFunction} to calculate the sum of a + * {@link org.apache.spark.sql.types.DoubleType} column. + */ public class MyDoubleSum extends UserDefinedAggregateFunction { private StructType _inputDataType; @@ -37,9 +41,9 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleSum() { - List inputfields = new ArrayList(); - inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputfields); + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); List bufferFields = new ArrayList(); bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); @@ -65,14 +69,20 @@ public MyDoubleSum() { } @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. buffer.update(0, null); } @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. if (!input.isNullAt(0)) { if (buffer.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer. buffer.update(0, input.getDouble(0)); } else { + // Otherwise, we add the input value to the buffer value. Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); } @@ -80,10 +90,16 @@ public MyDoubleSum() { } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's value is not null. if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. buffer1.update(0, buffer2.getDouble(0)); } else { + // Otherwise, we add the input buffer's value (buffer1) to the mutable + // buffer's value (buffer2). Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); } @@ -92,8 +108,10 @@ public MyDoubleSum() { @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { + // If the buffer value is still null, we return null. return null; } else { + // Otherwise, the intermediate sum is the final result. return buffer.getDouble(0); } } diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 deleted file mode 100644 index c6f275a0db131..0000000000000 --- a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 +++ /dev/null @@ -1 +0,0 @@ -0.0 NULL NULL NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 8bb498a06fc9e..f719f2e06ab63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -48,7 +48,11 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector] val a = unwrap(state, soi).asInstanceOf[InternalRow] - val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State] + + val dt = new StructType() + .add("counts", MapType(LongType, LongType)) + .add("percentiles", ArrayType(DoubleType)) + val b = wrap(a, soi, dt).asInstanceOf[UDAFPercentile.State] val sfCounts = soi.getStructFieldRef("counts") val sfPercentiles = soi.getStructFieldRef("percentiles") @@ -158,44 +162,45 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val writableOIs = dataTypes.map(toWritableInspector) val nullRow = data.map(d => null) - checkValues(nullRow, nullRow.zip(writableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(nullRow, nullRow.zip(writableOIs).zip(dataTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) // struct couldn't be constant, sweep it out val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType]) + val constantTypes = constantExprs.map(_.dataType) val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) - checkValues(constantData, constantData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantData, constantData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantData.zip(constantNullWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantNullData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) } test("wrap / unwrap primitive writable object inspector") { val writableOIs = dataTypes.map(toWritableInspector) - checkValues(row, row.zip(writableOIs).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(writableOIs).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } test("wrap / unwrap primitive java object inspector") { val ois = dataTypes.map(toInspector) - checkValues(row, row.zip(ois).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(ois).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } @@ -205,31 +210,33 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { }) val inspector = toInspector(dt) checkValues(row, - unwrap(wrap(InternalRow.fromSeq(row), inspector), inspector).asInstanceOf[InternalRow]) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) - val d = row(0) :: row(0) :: Nil - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + val d = new GenericArrayData(Array(row(0), row(0))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, - unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) checkValue(d, - unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { val dt = MapType(dataTypes(0), dataTypes(1)) val d = Map(row(0) -> row(1)) - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, - unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) checkValue(d, - unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index af68615e8e9d6..a45c2d957278f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) @@ -28,64 +28,54 @@ class HiveParquetSuite extends QueryTest with ParquetTest { import sqlContext._ - def run(prefix: String): Unit = { - test(s"$prefix: Case insensitive attribute names") { - withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { - val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(sql("SELECT upper FROM cases"), expected) - checkAnswer(sql("SELECT LOWER FROM cases"), expected) - } + test("Case insensitive attribute names") { + withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { + val expected = (1 to 4).map(i => Row(i.toString)) + checkAnswer(sql("SELECT upper FROM cases"), expected) + checkAnswer(sql("SELECT LOWER FROM cases"), expected) } + } - test(s"$prefix: SELECT on Parquet table") { - val data = (1 to 4).map(i => (i, s"val_$i")) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) - } + test("SELECT on Parquet table") { + val data = (1 to 4).map(i => (i, s"val_$i")) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) } + } - test(s"$prefix: Simple column projection + filter on Parquet table") { - withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { - checkAnswer( - sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), - Seq(Row(true, "val_2"), Row(true, "val_4"))) - } + test("Simple column projection + filter on Parquet table") { + withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { + checkAnswer( + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + Seq(Row(true, "val_2"), Row(true, "val_4"))) } + } - test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { - withTempPath { dir => - sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - checkAnswer( - sql("SELECT * FROM src ORDER BY key"), - sql("SELECT * from p ORDER BY key").collect().toSeq) - } + test("Converting Hive to Parquet Table via saveAsParquetFile") { + withTempPath { dir => + sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + checkAnswer( + sql("SELECT * FROM src ORDER BY key"), + sql("SELECT * from p ORDER BY key").collect().toSeq) } } + } - test(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { - withTempPath { file => - sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) - } + test("INSERT OVERWRITE TABLE Parquet table") { + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + withTempPath { file => + sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + read.parquet(file.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + // let's do three overwrites for good measure + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) } } } } - - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { - run("Parquet data source enabled") - } - - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "false") { - run("Parquet data source disabled") - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e403f32efaf91..4fdf774ead75e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -21,10 +21,9 @@ import java.io.File import scala.collection.mutable.ArrayBuffer -import org.scalatest.BeforeAndAfterAll - import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException +import org.scalatest.BeforeAndAfterAll import org.apache.spark.Logging import org.apache.spark.sql._ @@ -33,7 +32,7 @@ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -564,10 +563,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } test("scan a parquet table created through a CTAS statement") { - withSQLConf( - HiveContext.CONVERT_METASTORE_PARQUET.key -> "true", - SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { - + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "true") { withTempTable("jt") { (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") @@ -582,9 +578,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA Row(3) :: Row(4) :: Nil) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK + case LogicalRelation(p: ParquetRelation) => // OK case _ => - fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala new file mode 100644 index 0000000000000..73852f13ad20d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} + +class MultiDatabaseSuite extends QueryTest with SQLTestUtils { + override val sqlContext: SQLContext = TestHive + + import sqlContext.sql + + private val df = sqlContext.range(10).coalesce(1) + + test(s"saveAsTable() to non-default database - with USE - Overwrite") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + } + } + + test(s"saveAsTable() to non-default database - without USE - Overwrite") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + } + } + + test(s"saveAsTable() to non-default database - with USE - Append") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + df.write.mode(SaveMode.Append).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df.unionAll(df)) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test(s"saveAsTable() to non-default database - without USE - Append") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test(s"insertInto() non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + } + + test(s"insertInto() non-default database - without USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + } + + assert(sqlContext.tableNames(db).contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test("Looks up tables in non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql("CREATE TABLE t (key INT)") + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + } + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + } + } + + test("Drops a table in a non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql(s"CREATE TABLE t (key INT)") + assert(sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(sqlContext.tableNames(db).contains("t")) + + activateDatabase(db) { + sql(s"DROP TABLE t") + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames(db).contains("t")) + } + } + + test("Refreshes a table in a non-default database") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + activateDatabase(db) { + sql( + s"""CREATE EXTERNAL TABLE t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql("ALTER TABLE t ADD PARTITION (p=1)") + sql("REFRESH TABLE t") + checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f067ea0d4fc75..bc72b0172a467 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -172,7 +172,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = df.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j } assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 4056dee777574..9b3ede43ee2d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{Row, QueryTest} case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index efb04bf3d5097..638b9c810372a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -370,7 +370,11 @@ abstract class HiveComparisonTest // Check that the results match unless its an EXPLAIN query. val preparedHive = prepareAnswer(hiveQuery, hive) - if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { + // We will ignore the ExplainCommand, ShowFunctions, DescribeFunction + if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && + (!hiveQuery.logical.isInstanceOf[ShowFunctions]) && + (!hiveQuery.logical.isInstanceOf[DescribeFunction]) && + preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 03428265422e6..c4923d83e48f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,16 +19,19 @@ package org.apache.spark.sql.hive.execution import java.sql.{Date, Timestamp} +import scala.collection.JavaConversions._ + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ case class Nested1(f1: Nested2) @@ -61,7 +64,9 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest { +class SQLQuerySuite extends QueryTest with SQLTestUtils { + override def sqlContext: SQLContext = TestHive + test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") @@ -135,6 +140,50 @@ class SQLQuerySuite extends QueryTest { (1 to 6).map(_ => Row("CA", 20151))) } + test("show functions") { + val allFunctions = + (FunctionRegistry.builtin.listFunction().toSet[String] ++ + org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames).toList.sorted + checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) + checkAnswer(sql("SHOW functions abs"), Row("abs")) + checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) + checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `~`"), Row("~")) + checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) + checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + // this probably will failed if we add more function with `sha` prefixing. + checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + } + + test("describe functions") { + // The Spark SQL built-in functions + checkExistence(sql("describe function extended upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase", + "Extended Usage:", + "> SELECT upper('SparkSql')", + "'SPARKSQL'") + + checkExistence(sql("describe functioN Upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase") + + checkExistence(sql("describe functioN Upper"), false, + "Extended Usage") + + checkExistence(sql("describe functioN abcadf"), true, + "Function: abcadf is not found.") + + checkExistence(sql("describe functioN `~`"), true, + "Function: ~", + "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", + "Usage: ~ n - Bitwise not") + } + test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") @@ -195,17 +244,17 @@ class SQLQuerySuite extends QueryTest { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { - case LogicalRelation(r: ParquetRelation2) => + case LogicalRelation(r: ParquetRelation) => if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + - s"${ParquetRelation2.getClass.getCanonicalName}.") + s"${ParquetRelation.getClass.getCanonicalName}.") } case r: MetastoreRelation => if (isDataSourceParquet) { fail( - s"${ParquetRelation2.getClass.getCanonicalName} is expected, but found " + + s"${ParquetRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") } } @@ -350,16 +399,14 @@ class SQLQuerySuite extends QueryTest { "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - val origUseParquetDataSource = conf.parquetUseDataSourceApi - try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect() + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { checkExistence(sql("DESC EXTENDED ctas5"), true, "name:key", "type:string", "name:value", "ctas5", "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", @@ -367,16 +414,13 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", "MANAGED_TABLE" ) + } - val default = convertMetastoreParquet - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") + // use the Hive SerDe for parquet tables + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { checkAnswer( sql("SELECT key, value FROM ctas5 ORDER BY key, value"), sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") - } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource) } } @@ -1025,4 +1069,12 @@ class SQLQuerySuite extends QueryTest { ) TestHive.dropTempTable("test_SPARK8588") } + + test("SPARK-9371: fix the support for special chars in column names for hive context") { + TestHive.read.json(TestHive.sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala new file mode 100644 index 0000000000000..0875232aede3e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.types.StringType + +class ScriptTransformationSuite extends SparkPlanTest { + + override def sqlContext: SQLContext = TestHive + + private val noSerdeIOSchema = HiveScriptIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + schemaLess = false + ) + + private val serdeIOSchema = noSerdeIOSchema.copy( + inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), + outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) + ) + + test("cat without SerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = noSerdeIOSchema + )(TestHive), + rowsDf.collect()) + } + + test("cat with LazySimpleSerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = serdeIOSchema + )(TestHive), + rowsDf.collect()) + } + + test("script transformation should not swallow errors from upstream operators (no serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = noSerdeIOSchema + )(TestHive), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } + + test("script transformation should not swallow errors from upstream operators (with serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = serdeIOSchema + )(TestHive), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } +} + +private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + override def output: Seq[Attribute] = child.output +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 9d76d6503a3e6..145965388da01 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,14 +22,15 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.SQLTestUtils -private[sql] trait OrcTest extends SQLTestUtils { +private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - import sqlContext.sparkContext import sqlContext.implicits._ + import sqlContext.sparkContext /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 82a8daf8b4b09..f56fb96c52d37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -22,13 +22,13 @@ import java.io.File import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ -import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -57,7 +57,7 @@ case class ParquetDataWithKeyAndComplexTypes( * A suite to test the automatic conversion of metastore tables with parquet data to use the * built in parquet support. */ -class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { +class ParquetMetastoreSuite extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() @@ -134,6 +134,19 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' """) + sql( + """ + |create table test_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + (1 to 10).foreach { p => sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") } @@ -166,6 +179,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { sql("DROP TABLE normal_parquet") sql("DROP TABLE IF EXISTS jt") sql("DROP TABLE IF EXISTS jt_array") + sql("DROP TABLE IF EXISTS test_parquet") setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } @@ -176,40 +190,9 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { }.isEmpty) assert( sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: ParquetTableScan => true case _: PhysicalRDD => true }.nonEmpty) } -} - -class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - - sql( - """ - |create table test_parquet - |( - | intField INT, - | stringField STRING - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override def afterAll(): Unit = { - super.afterAll() - sql("DROP TABLE IF EXISTS test_parquet") - - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } test("scan an empty parquet table") { checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0)) @@ -292,10 +275,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: ParquetRelation2) => // OK + case LogicalRelation(_: ParquetRelation) => // OK case _ => fail( "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation2].getCanonicalName}") + s"${classOf[ParquetRelation].getCanonicalName}") } sql("DROP TABLE IF EXISTS test_parquet_ctas") @@ -316,9 +299,9 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[ParquetRelation].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + s"However, found a ${o.toString} ") } @@ -346,9 +329,9 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[ParquetRelation].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + s"However, found a ${o.toString} ") } @@ -379,17 +362,17 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { assertResult(2) { analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation2) => r + case r @ LogicalRelation(_: ParquetRelation) => r }.size } sql("DROP TABLE ms_convert") } - def collectParquetRelation(df: DataFrame): ParquetRelation2 = { + def collectParquetRelation(df: DataFrame): ParquetRelation = { val plan = df.queryExecution.analyzed plan.collectFirst { - case LogicalRelation(r: ParquetRelation2) => r + case LogicalRelation(r: ParquetRelation) => r }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$plan") } @@ -439,7 +422,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { // Converted test_parquet should be cached. catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + @@ -543,81 +526,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { } } -class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } - - test("MetastoreRelation in InsertIntoTable will not be converted") { - sql( - """ - |create table test_insert_parquet - |( - | intField INT - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { - case insert: execution.InsertIntoHiveTable => // OK - case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + - s"However, found ${o.toString}.") - } - - checkAnswer( - sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - sql("SELECT a FROM jt WHERE jt.a > 5").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") - } - - // TODO: enable it after the fix of SPARK-5950. - ignore("MetastoreRelation in InsertIntoHiveTable will not be converted") { - sql( - """ - |create table test_insert_parquet - |( - | int_array array - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { - case insert: execution.InsertIntoHiveTable => // OK - case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + - s"However, found ${o.toString}.") - } - - checkAnswer( - sql("SELECT int_array FROM test_insert_parquet"), - sql("SELECT a FROM jt_array").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") - } -} - /** * A suite of tests for the Parquet support through the data sources API. */ -class ParquetSourceSuiteBase extends ParquetPartitioningTest { +class ParquetSourceSuite extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() @@ -712,20 +624,6 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest { } } } -} - -class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } test("values in arrays and maps stored in parquet are always nullable") { val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") @@ -734,7 +632,7 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val expectedSchema1 = StructType( StructField("m", mapType1, nullable = true) :: - StructField("a", arrayType1, nullable = true) :: Nil) + StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) df.write.format("parquet").saveAsTable("alwaysNullable") @@ -772,20 +670,6 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { } } -class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} - /** * A collection of tests for parquet data with various forms of partitioning. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala new file mode 100644 index 0000000000000..e976125b3706d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils + + +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { + override val sqlContext = TestHive + + // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + // Here we coalesce partition number to 1 to ensure that only a single task is issued. This + // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` + // directory while committing/aborting the job. See SPARK-8513 for more details. + val df = sqlContext.range(0, 10).coalesce(1) + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000..d280543a071d9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import com.google.common.io.Files +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{AnalysisException, SaveMode, parquet} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + + +class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write.parquet(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() + + df.write + .format("parquet") + .save(dir.getCanonicalPath) + + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[AnalysisException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } + } + } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(configuration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000..e8975e5f5cd08 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2a8748d913569..dd274023a1cf5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,18 +17,14 @@ package org.apache.spark.sql.sources -import java.io.File - import scala.collection.JavaConversions._ -import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.parquet.hadoop.ParquetOutputCommitter -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -581,165 +577,3 @@ class AlwaysFailParquetOutputCommitter( sys.error("Intentional job commitment failure for testing purpose.") } } - -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName - - import sqlContext._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } -} - -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive - - // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. - val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName - - test("SPARK-7684: commitTask() failure should fallback to abortTask()") { - withTempPath { file => - // Here we coalesce partition number to 1 to ensure that only a single task is issued. This - // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` - // directory while committing/aborting the job. See SPARK-8513 for more details. - val df = sqlContext.range(0, 10).coalesce(1) - intercept[SparkException] { - df.write.format(dataSourceName).save(file.getCanonicalPath) - } - - val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) - assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) - } - } -} - -class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName - - import sqlContext._ - import sqlContext.implicits._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) - .toDF("a", "b", "p1") - .write.parquet(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } - - test("SPARK-7868: _temporary directories should be ignored") { - withTempPath { dir => - val df = Seq("a", "b", "c").zipWithIndex.toDF() - - df.write - .format("parquet") - .save(dir.getCanonicalPath) - - df.write - .format("parquet") - .save(s"${dir.getCanonicalPath}/_temporary") - - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) - } - } - - test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { - withTempDir { dir => - val path = dir.getCanonicalPath - val df = Seq(1 -> "a").toDF() - - // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw - // since it's not a valid Parquet file. - val emptyFile = new File(path, "empty") - Files.createParentDirs(emptyFile) - Files.touch(emptyFile) - - // This shouldn't throw anything. - df.write.format("parquet").mode(SaveMode.Ignore).save(path) - - // This should only complain that the destination directory already exists, rather than file - // "empty" is not a Parquet file. - assert { - intercept[AnalysisException] { - df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - }.getMessage.contains("already exists") - } - - // This shouldn't throw anything. - df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(read.format("parquet").load(path), df) - } - } - - test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { - withTempPath { dir => - intercept[AnalysisException] { - // Parquet doesn't allow field names with spaces. Here we are intentionally making an - // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger - // the bug. Please refer to spark-8079 for more details. - range(1, 10) - .withColumnRenamed("id", "a b") - .write - .format("parquet") - .save(dir.getCanonicalPath) - } - } - } - - test("SPARK-8604: Parquet data source should write summary file while doing appending") { - withTempPath { dir => - val path = dir.getCanonicalPath - val df = sqlContext.range(0, 5) - df.write.mode(SaveMode.Overwrite).parquet(path) - - val summaryPath = new Path(path, "_metadata") - val commonSummaryPath = new Path(path, "_common_metadata") - - val fs = summaryPath.getFileSystem(configuration) - fs.delete(summaryPath, true) - fs.delete(commonSummaryPath, true) - - df.write.mode(SaveMode.Append).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) - - assert(fs.exists(summaryPath)) - assert(fs.exists(commonSummaryPath)) - } - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index d58c99a8ff321..a6c4cd220e42f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -21,7 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.scheduler.RateController +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils /** @@ -47,6 +49,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) /** This is an unique identifier for the input stream. */ val id = ssc.getNewInputStreamId() + // Keep track of the freshest rate for this stream using the rateEstimator + protected[streaming] val rateController: Option[RateController] = None + /** A human-readable name of this InputDStream */ private[streaming] def name: String = { // e.g. FlumePollingDStream -> "Flume polling stream" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index a50f0efc030ce..646a8c3530a62 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,10 +21,11 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId -import org.apache.spark.streaming._ +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.streaming.util.WriteAheadLogUtils /** @@ -40,6 +41,17 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) } + } else { + None + } + } + /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation @@ -110,4 +122,14 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } + + /** + * A RateController that sends the new rate to receivers, via the receiver tracker. + */ + private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = + ssc.scheduler.receiverTracker.sendRateUpdate(id, rate) + } } + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index a7c220f426ecf..e98017a63756e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.util.control.NonFatal -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkEnv, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{Utils, ThreadUtils} /** * Abstract class that is responsible for supervising a Receiver in the worker. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 2f6841ee8879c..0d802f83549af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.RpcUtils import org.apache.spark.{Logging, SparkEnv, SparkException} /** @@ -46,6 +46,8 @@ private[streaming] class ReceiverSupervisorImpl( checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { + private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort + private val receivedBlockHandler: ReceivedBlockHandler = { if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { if (checkpointDirOption.isEmpty) { @@ -170,7 +172,7 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) + streamId, receiver.getClass.getSimpleName, hostPort, endpoint) trackerEndpoint.askWithRetry[Boolean](msg) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 4af9b6d3b56ab..58bdda7794bf2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -66,6 +66,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } eventLoop.start() + // attach rate controllers of input streams to receive batch completion updates + for { + inputDStream <- ssc.graph.getInputStreams + rateController <- inputDStream.rateController + } ssc.addStreamingListener(rateController) + listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala new file mode 100644 index 0000000000000..882ca0676b6ad --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.io.ObjectInputStream +import java.util.concurrent.atomic.AtomicLong + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * A StreamingListener that receives batch completion updates, and maintains + * an estimate of the speed at which this stream should ingest messages, + * given an estimate computation from a `RateEstimator` + */ +private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) + extends StreamingListener with Serializable { + + init() + + protected def publish(rate: Long): Unit + + @transient + implicit private var executionContext: ExecutionContext = _ + + @transient + private var rateLimit: AtomicLong = _ + + /** + * An initialization method called both from the constructor and Serialization code. + */ + private def init() { + executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update")) + rateLimit = new AtomicLong(-1L) + } + + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + init() + } + + /** + * Compute the new rate limit and publish it asynchronously. + */ + private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit = + Future[Unit] { + val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay) + newRate.foreach { s => + rateLimit.set(s.toLong) + publish(getLatestRate()) + } + } + + def getLatestRate(): Long = rateLimit.get() + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + val elements = batchCompleted.batchInfo.streamIdToInputInfo + + for { + processingEnd <- batchCompleted.batchInfo.processingEndTime; + workDelay <- batchCompleted.batchInfo.processingDelay; + waitDelay <- batchCompleted.batchInfo.schedulingDelay; + elems <- elements.get(streamUID).map(_.numRecords) + } computeAndPublish(processingEnd, elems, workDelay, waitDelay) + } +} + +object RateController { + def isBackPressureEnabled(conf: SparkConf): Boolean = + conf.getBoolean("spark.streaming.backpressure.enable", false) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index de85f24dd988d..59df892397fe0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -28,7 +28,6 @@ import org.apache.spark.rpc.RpcEndpointRef case class ReceiverInfo( streamId: Int, name: String, - private[streaming] val endpoint: RpcEndpointRef, active: Boolean, location: String, lastErrorMessage: String = "", diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala new file mode 100644 index 0000000000000..ef5b687b5831a --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.Map +import scala.collection.mutable + +import org.apache.spark.streaming.receiver.Receiver + +private[streaming] class ReceiverSchedulingPolicy { + + /** + * Try our best to schedule receivers with evenly distributed. However, if the + * `preferredLocation`s of receivers are not even, we may not be able to schedule them evenly + * because we have to respect them. + * + * Here is the approach to schedule executors: + *
    + *
  1. First, schedule all the receivers with preferred locations (hosts), evenly among the + * executors running on those host.
  2. + *
  3. Then, schedule all other receivers evenly among all the executors such that overall + * distribution over all the receivers is even.
  4. + *
+ * + * This method is called when we start to launch receivers at the first time. + */ + def scheduleReceivers( + receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = { + if (receivers.isEmpty) { + return Map.empty + } + + if (executors.isEmpty) { + return receivers.map(_.streamId -> Seq.empty).toMap + } + + val hostToExecutors = executors.groupBy(_.split(":")(0)) + val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // Set the initial value to 0 + executors.foreach(e => numReceiversOnExecutor(e) = 0) + + // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation", + // we need to make sure the "preferredLocation" is in the candidate scheduled executor list. + for (i <- 0 until receivers.length) { + // Note: preferredLocation is host but executors are host:port + receivers(i).preferredLocation.foreach { host => + hostToExecutors.get(host) match { + case Some(executorsOnHost) => + // preferredLocation is a known host. Select an executor that has the least receivers in + // this host + val leastScheduledExecutor = + executorsOnHost.minBy(executor => numReceiversOnExecutor(executor)) + scheduledExecutors(i) += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = + numReceiversOnExecutor(leastScheduledExecutor) + 1 + case None => + // preferredLocation is an unknown host. + // Note: There are two cases: + // 1. This executor is not up. But it may be up later. + // 2. This executor is dead, or it's not a host in the cluster. + // Currently, simply add host to the scheduled executors. + scheduledExecutors(i) += host + } + } + } + + // For those receivers that don't have preferredLocation, make sure we assign at least one + // executor to them. + for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) { + // Select the executor that has the least receivers + val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2) + scheduledExecutorsForOneReceiver += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1 + } + + // Assign idle executors to receivers that have less executors + val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1) + for (executor <- idleExecutors) { + // Assign an idle executor to the receiver that has least candidate executors. + val leastScheduledExecutors = scheduledExecutors.minBy(_.size) + leastScheduledExecutors += executor + } + + receivers.map(_.streamId).zip(scheduledExecutors).toMap + } + + /** + * Return a list of candidate executors to run the receiver. If the list is empty, the caller can + * run this receiver in arbitrary executor. The caller can use `preferredNumExecutors` to require + * returning `preferredNumExecutors` executors if possible. + * + * This method tries to balance executors' load. Here is the approach to schedule executors + * for a receiver. + *
    + *
  1. + * If preferredLocation is set, preferredLocation should be one of the candidate executors. + *
  2. + *
  3. + * Every executor will be assigned to a weight according to the receivers running or + * scheduling on it. + *
      + *
    • + * If a receiver is running on an executor, it contributes 1.0 to the executor's weight. + *
    • + *
    • + * If a receiver is scheduled to an executor but has not yet run, it contributes + * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight.
    • + *
    + * At last, if there are more than `preferredNumExecutors` idle executors (weight = 0), + * returns all idle executors. Otherwise, we only return `preferredNumExecutors` best options + * according to the weights. + *
  4. + *
+ * + * This method is called when a receiver is registering with ReceiverTracker or is restarting. + */ + def rescheduleReceiver( + receiverId: Int, + preferredLocation: Option[String], + receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], + executors: Seq[String], + preferredNumExecutors: Int = 3): Seq[String] = { + if (executors.isEmpty) { + return Seq.empty + } + + // Always try to schedule to the preferred locations + val scheduledExecutors = mutable.Set[String]() + scheduledExecutors ++= preferredLocation + + val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo => + receiverTrackingInfo.state match { + case ReceiverState.INACTIVE => Nil + case ReceiverState.SCHEDULED => + val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get + // The probability that a scheduled receiver will run in an executor is + // 1.0 / scheduledLocations.size + scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size)) + case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) + } + }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + + val idleExecutors = (executors.toSet -- executorWeights.keys).toSeq + if (idleExecutors.size >= preferredNumExecutors) { + // If there are more than `preferredNumExecutors` idle executors, return all of them + scheduledExecutors ++= idleExecutors + } else { + // If there are less than `preferredNumExecutors` idle executors, return 3 best options + scheduledExecutors ++= idleExecutors + val sortedExecutors = executorWeights.toSeq.sortBy(_._2).map(_._1) + scheduledExecutors ++= (idleExecutors ++ sortedExecutors).take(preferredNumExecutors) + } + scheduledExecutors.toSeq + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 9cc6ffcd12f61..e076fb5ea174b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,17 +17,27 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} +import java.util.concurrent.{TimeUnit, CountDownLatch} + +import scala.collection.mutable.HashMap +import scala.concurrent.ExecutionContext import scala.language.existentials -import scala.math.max +import scala.util.{Failure, Success} import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, - StopReceiver, UpdateRateLimit} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.{ThreadUtils, SerializableConfiguration} + + +/** Enumeration to identify current state of a Receiver */ +private[streaming] object ReceiverState extends Enumeration { + type ReceiverState = Value + val INACTIVE, SCHEDULED, ACTIVE = Value +} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -37,7 +47,7 @@ private[streaming] sealed trait ReceiverTrackerMessage private[streaming] case class RegisterReceiver( streamId: Int, typ: String, - host: String, + hostPort: String, receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) @@ -46,7 +56,38 @@ private[streaming] case class ReportError(streamId: Int, message: String, error: private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String) extends ReceiverTrackerMessage -private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage +/** + * Messages used by the driver and ReceiverTrackerEndpoint to communicate locally. + */ +private[streaming] sealed trait ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to restart a Spark job for the receiver. + */ +private[streaming] case class RestartReceiver(receiver: Receiver[_]) + extends ReceiverTrackerLocalMessage + +/** + * This message is sent to ReceiverTrackerEndpoint when we start to launch Spark jobs for receivers + * at the first time. + */ +private[streaming] case class StartAllReceivers(receiver: Seq[Receiver[_]]) + extends ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to send stop signals to all registered + * receivers. + */ +private[streaming] case object StopAllReceivers extends ReceiverTrackerLocalMessage + +/** + * A message used by ReceiverTracker to ask all receiver's ids still stored in + * ReceiverTrackerEndpoint. + */ +private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessage + +private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long) + extends ReceiverTrackerLocalMessage /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of @@ -60,8 +101,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private val receiverInputStreams = ssc.graph.getReceiverInputStreams() private val receiverInputStreamIds = receiverInputStreams.map { _.id } - private val receiverExecutor = new ReceiverLauncher() - private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] private val receivedBlockTracker = new ReceivedBlockTracker( ssc.sparkContext.conf, ssc.sparkContext.hadoopConfiguration, @@ -86,6 +125,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // This not being null means the tracker has been started and not stopped private var endpoint: RpcEndpointRef = null + private val schedulingPolicy = new ReceiverSchedulingPolicy() + + // Track the active receiver job number. When a receiver job exits ultimately, countDown will + // be called. + private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.size) + + /** + * Track all receivers' information. The key is the receiver id, the value is the receiver info. + * It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverTrackingInfos = new HashMap[Int, ReceiverTrackingInfo] + + /** + * Store all preferred locations for all receivers. We need this information to schedule + * receivers. It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverPreferredLocations = new HashMap[Int, Option[String]] + /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { if (isTrackerStarted) { @@ -95,7 +152,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (!receiverInputStreams.isEmpty) { endpoint = ssc.env.rpcEnv.setupEndpoint( "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) - if (!skipReceiverLaunch) receiverExecutor.start() + if (!skipReceiverLaunch) launchReceivers() logInfo("ReceiverTracker started") trackerState = Started } @@ -112,20 +169,18 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Wait for the Spark job that runs the receivers to be over // That is, for the receivers to quit gracefully. - receiverExecutor.awaitTermination(10000) + receiverJobExitLatch.await(10, TimeUnit.SECONDS) if (graceful) { - val pollTime = 100 logInfo("Waiting for receiver job to terminate gracefully") - while (receiverInfo.nonEmpty || receiverExecutor.running) { - Thread.sleep(pollTime) - } + receiverJobExitLatch.await() logInfo("Waited for receiver job to terminate gracefully") } // Check if all the receivers have been deregistered or not - if (receiverInfo.nonEmpty) { - logWarning("Not all of the receivers have deregistered, " + receiverInfo) + val receivers = endpoint.askWithRetry[Seq[Int]](AllReceiverIds) + if (receivers.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receivers) } else { logInfo("All of the receivers have deregistered successfully") } @@ -154,9 +209,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Get the blocks allocated to the given batch and stream. */ def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { - synchronized { - receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) - } + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) } /** @@ -170,8 +223,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - receiverInfo.values.flatMap { info => Option(info.endpoint) } - .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } + synchronized { + if (isTrackerStarted) { + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + } + } } } @@ -179,7 +235,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private def registerReceiver( streamId: Int, typ: String, - host: String, + hostPort: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress ): Boolean = { @@ -189,13 +245,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (isTrackerStopping || isTrackerStopped) { false + } else if (!scheduleReceiver(streamId).contains(hostPort)) { + // Refuse it since it's scheduled to a wrong executor + false } else { - // "stopReceivers" won't happen at the same time because both "registerReceiver" and are - // called in the event loop. So here we can assume "stopReceivers" has not yet been called. If - // "stopReceivers" is called later, it should be able to see this receiver. - receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + val name = s"${typ}-${streamId}" + val receiverTrackingInfo = ReceiverTrackingInfo( + streamId, + ReceiverState.ACTIVE, + scheduledExecutors = None, + runningExecutor = Some(hostPort), + name = Some(name), + endpoint = Some(receiverEndpoint)) + receiverTrackingInfos.put(streamId, receiverTrackingInfo) + listenerBus.post(StreamingListenerReceiverStarted(receiverTrackingInfo.toReceiverInfo)) logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) true } @@ -203,21 +266,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val lastErrorTime = + if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() + val errorInfo = ReceiverErrorInfo( + lastErrorMessage = message, lastError = error, lastErrorTime = lastErrorTime) + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - val lastErrorTime = - if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() - oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message, - lastError = error, lastErrorTime = lastErrorTime) + oldInfo.copy(state = ReceiverState.INACTIVE, errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - val lastErrorTime = - if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, - lastError = error, lastErrorTime = lastErrorTime) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo -= streamId - listenerBus.post(StreamingListenerReceiverStopped(newReceiverInfo)) + receiverTrackingInfos -= streamId + listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -227,9 +289,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Update a receiver's maximum ingestion rate */ - def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { - for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint)) { - eP.send(UpdateRateLimit(newRate)) + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized { + if (isTrackerStarted) { + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) } } @@ -240,16 +302,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Report error sent by a receiver */ private def reportError(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(lastErrorMessage = message, lastError = error) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = oldInfo.errorInfo.map(_.lastErrorTime).getOrElse(-1L)) + oldInfo.copy(errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, - lastError = error, lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo(streamId) = newReceiverInfo - listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + + receiverTrackingInfos(streamId) = newReceiverTrackingInfo + listenerBus.post(StreamingListenerReceiverError(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -258,171 +325,242 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logWarning(s"Error reported by receiver for stream $streamId: $messageWithError") } + private def scheduleReceiver(receiverId: Int): Seq[String] = { + val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None) + val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + receiverId, preferredLocation, receiverTrackingInfos, getExecutors) + updateReceiverScheduledExecutors(receiverId, scheduledExecutors) + scheduledExecutors + } + + private def updateReceiverScheduledExecutors( + receiverId: Int, scheduledExecutors: Seq[String]): Unit = { + val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match { + case Some(oldInfo) => + oldInfo.copy(state = ReceiverState.SCHEDULED, + scheduledExecutors = Some(scheduledExecutors)) + case None => + ReceiverTrackingInfo( + receiverId, + ReceiverState.SCHEDULED, + Some(scheduledExecutors), + runningExecutor = None) + } + receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo) + } + /** Check if any blocks are left to be processed */ def hasUnallocatedBlocks: Boolean = { receivedBlockTracker.hasUnallocatedReceivedBlocks } + /** + * Get the list of executors excluding driver + */ + private def getExecutors: Seq[String] = { + if (ssc.sc.isLocal) { + Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort) + } else { + ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) => + blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location + }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq + } + } + + /** + * Run the dummy Spark job to ensure that all slaves have registered. This avoids all the + * receivers to be scheduled on the same node. + * + * TODO Should poll the executor number and wait for executors according to + * "spark.scheduler.minRegisteredResourcesRatio" and + * "spark.scheduler.maxRegisteredResourcesWaitingTime" rather than running a dummy job. + */ + private def runDummySparkJob(): Unit = { + if (!ssc.sparkContext.isLocal) { + ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() + } + assert(getExecutors.nonEmpty) + } + + /** + * Get the receivers from the ReceiverInputDStreams, distributes them to the + * worker nodes as a parallel collection, and runs them. + */ + private def launchReceivers(): Unit = { + val receivers = receiverInputStreams.map(nis => { + val rcvr = nis.getReceiver() + rcvr.setReceiverId(nis.id) + rcvr + }) + + runDummySparkJob() + + logInfo("Starting " + receivers.length + " receivers") + endpoint.send(StartAllReceivers(receivers)) + } + + /** Check if tracker has been marked for starting */ + private def isTrackerStarted: Boolean = trackerState == Started + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping: Boolean = trackerState == Stopping + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped: Boolean = trackerState == Stopped + /** RpcEndpoint to receive messages from the receivers. */ private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged + private val submitJobThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + override def receive: PartialFunction[Any, Unit] = { + // Local messages + case StartAllReceivers(receivers) => + val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors) + for (receiver <- receivers) { + val executors = scheduledExecutors(receiver.streamId) + updateReceiverScheduledExecutors(receiver.streamId, executors) + receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation + startReceiver(receiver, executors) + } + case RestartReceiver(receiver) => + val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + receiver.streamId, + receiver.preferredLocation, + receiverTrackingInfos, + getExecutors) + updateReceiverScheduledExecutors(receiver.streamId, scheduledExecutors) + startReceiver(receiver, scheduledExecutors) + case c: CleanupOldBlocks => + receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) + case UpdateReceiverRateLimit(streamUID, newRate) => + for (info <- receiverTrackingInfos.get(streamUID); eP <- info.endpoint) { + eP.send(UpdateRateLimit(newRate)) + } + // Remote messages case ReportError(streamId, message, error) => reportError(streamId, message, error) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterReceiver(streamId, typ, host, receiverEndpoint) => + // Remote messages + case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => val successful = - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + // Local messages + case AllReceiverIds => + context.reply(receiverTrackingInfos.keys.toSeq) case StopAllReceivers => assert(isTrackerStopping || isTrackerStopped) stopReceivers() context.reply(true) } - /** Send stop signal to the receivers. */ - private def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.endpoint)} - .foreach { _.send(StopReceiver) } - logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") - } - } - - /** This thread class runs all the receivers on the cluster. */ - class ReceiverLauncher { - @transient val env = ssc.env - @volatile @transient var running = false - @transient val thread = new Thread() { - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") - } - } - } - - def start() { - thread.start() - } - /** - * Get the list of executors excluding driver - */ - private def getExecutors(ssc: StreamingContext): List[String] = { - val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList - val driver = ssc.sparkContext.getConf.get("spark.driver.host") - executors.diff(List(driver)) - } - - /** Set host location(s) for each receiver so as to distribute them over - * executors in a round-robin fashion taking into account preferredLocation if set + * Start a receiver along with its scheduled executors */ - private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]], - executors: List[String]): Array[ArrayBuffer[String]] = { - val locations = new Array[ArrayBuffer[String]](receivers.length) - var i = 0 - for (i <- 0 until receivers.length) { - locations(i) = new ArrayBuffer[String]() - if (receivers(i).preferredLocation.isDefined) { - locations(i) += receivers(i).preferredLocation.get - } + private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = { + val receiverId = receiver.streamId + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + return } - var count = 0 - for (i <- 0 until max(receivers.length, executors.length)) { - if (!receivers(i % receivers.length).preferredLocation.isDefined) { - locations(i % receivers.length) += executors(count) - count += 1 - if (count == executors.length) { - count = 0 - } - } - } - locations - } - - /** - * Get the receivers from the ReceiverInputDStreams, distributes them to the - * worker nodes as a parallel collection, and runs them. - */ - private def startReceivers() { - val receivers = receiverInputStreams.map(nis => { - val rcvr = nis.getReceiver() - rcvr.setReceiverId(nis.id) - rcvr - }) val checkpointDirOption = Option(ssc.checkpointDir) val serializableHadoopConf = new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node - val startReceiver = (iterator: Iterator[Receiver[_]]) => { - if (!iterator.hasNext) { - throw new SparkException( - "Could not start receiver as object not found.") - } - val receiver = iterator.next() - val supervisor = new ReceiverSupervisorImpl( - receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) - supervisor.start() - supervisor.awaitTermination() - } - - // Run the dummy Spark job to ensure that all slaves have registered. - // This avoids all the receivers to be scheduled on the same node. - if (!ssc.sparkContext.isLocal) { - ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() - } + val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, serializableHadoopConf) - // Get the list of executors and schedule receivers - val executors = getExecutors(ssc) - val tempRDD = - if (!executors.isEmpty) { - val locations = scheduleReceivers(receivers, executors) - val roundRobinReceivers = (0 until receivers.length).map(i => - (receivers(i), locations(i))) - ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers) + // Create the RDD using the scheduledExecutors to run the receiver in a Spark job + val receiverRDD: RDD[Receiver[_]] = + if (scheduledExecutors.isEmpty) { + ssc.sc.makeRDD(Seq(receiver), 1) } else { - ssc.sc.makeRDD(receivers, receivers.size) + ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors)) } + receiverRDD.setName(s"Receiver $receiverId") + val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit]( + receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ()) + // We will keep restarting the receiver job until ReceiverTracker is stopped + future.onComplete { + case Success(_) => + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + } else { + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + case Failure(e) => + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + } else { + logError("Receiver has been stopped. Try to restart it.", e) + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + }(submitJobThreadPool) + logInfo(s"Receiver ${receiver.streamId} started") + } - // Distribute the receivers and start them - logInfo("Starting " + receivers.length + " receivers") - running = true - try { - ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) - logInfo("All of the receivers have been terminated") - } finally { - running = false - } + override def onStop(): Unit = { + submitJobThreadPool.shutdownNow() } /** - * Wait until the Spark job that runs the receivers is terminated, or return when - * `milliseconds` elapses + * Call when a receiver is terminated. It means we won't restart its Spark job. */ - def awaitTermination(milliseconds: Long): Unit = { - thread.join(milliseconds) + private def onReceiverJobFinish(receiverId: Int): Unit = { + receiverJobExitLatch.countDown() + receiverTrackingInfos.remove(receiverId).foreach { receiverTrackingInfo => + if (receiverTrackingInfo.state == ReceiverState.ACTIVE) { + logWarning(s"Receiver $receiverId exited but didn't deregister") + } + } } - } - /** Check if tracker has been marked for starting */ - private def isTrackerStarted(): Boolean = trackerState == Started + /** Send stop signal to the receivers. */ + private def stopReceivers() { + receiverTrackingInfos.values.flatMap(_.endpoint).foreach { _.send(StopReceiver) } + logInfo("Sent stop signal to all " + receiverTrackingInfos.size + " receivers") + } + } - /** Check if tracker has been marked for stopping */ - private def isTrackerStopping(): Boolean = trackerState == Stopping +} - /** Check if tracker has been marked for stopped */ - private def isTrackerStopped(): Boolean = trackerState == Stopped +/** + * Function to start the receiver on the worker node. Use a class instead of closure to avoid + * the serialization issue. + */ +private class StartReceiverFunc( + checkpointDirOption: Option[String], + serializableHadoopConf: SerializableConfiguration) + extends (Iterator[Receiver[_]] => Unit) with Serializable { + + override def apply(iterator: Iterator[Receiver[_]]): Unit = { + if (!iterator.hasNext) { + throw new SparkException( + "Could not start receiver as object not found.") + } + if (TaskContext.get().attemptNumber() == 0) { + val receiver = iterator.next() + assert(iterator.hasNext == false) + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() + } else { + // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala new file mode 100644 index 0000000000000..043ff4d0ff054 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.streaming.scheduler.ReceiverState._ + +private[streaming] case class ReceiverErrorInfo( + lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L) + +/** + * Class having information about a receiver. + * + * @param receiverId the unique receiver id + * @param state the current Receiver state + * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy + * @param runningExecutor the running executor if the receiver is active + * @param name the receiver name + * @param endpoint the receiver endpoint. It can be used to send messages to the receiver + * @param errorInfo the receiver error information if it fails + */ +private[streaming] case class ReceiverTrackingInfo( + receiverId: Int, + state: ReceiverState, + scheduledExecutors: Option[Seq[String]], + runningExecutor: Option[String], + name: Option[String] = None, + endpoint: Option[RpcEndpointRef] = None, + errorInfo: Option[ReceiverErrorInfo] = None) { + + def toReceiverInfo: ReceiverInfo = ReceiverInfo( + receiverId, + name.getOrElse(""), + state == ReceiverState.ACTIVE, + location = runningExecutor.getOrElse(""), + lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), + lastError = errorInfo.map(_.lastError).getOrElse(""), + lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) + ) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala new file mode 100644 index 0000000000000..a08685119e5d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.SparkConf +import org.apache.spark.SparkException + +/** + * A component that estimates the rate at wich an InputDStream should ingest + * elements, based on updates at every batch completion. + */ +private[streaming] trait RateEstimator extends Serializable { + + /** + * Computes the number of elements the stream attached to this `RateEstimator` + * should ingest per second, given an update on the size and completion + * times of the latest batch. + * + * @param time The timetamp of the current batch interval that just finished + * @param elements The number of elements that were processed in this batch + * @param processingDelay The time in ms that took for the job to complete + * @param schedulingDelay The time in ms that the job spent in the scheduling queue + */ + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] +} + +object RateEstimator { + + /** + * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * + * @return None if there is no configured estimator, otherwise an instance of RateEstimator + * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any + * known estimators. + */ + def create(conf: SparkConf): Option[RateEstimator] = + conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index e0c0f57212f55..bc53f2a31f6d1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,11 +17,9 @@ package org.apache.spark.streaming.ui -import org.eclipse.jetty.servlet.ServletContextHandler - import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.{JettyUtils, SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} import StreamingTab._ @@ -42,18 +40,14 @@ private[spark] class StreamingTab(val ssc: StreamingContext) attachPage(new StreamingPage(this)) attachPage(new BatchPage(this)) - var staticHandler: ServletContextHandler = null - def attach() { getSparkUI(ssc).attachTab(this) - staticHandler = JettyUtils.createStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") - getSparkUI(ssc).attachHandler(staticHandler) + getSparkUI(ssc).addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") } def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).detachHandler(staticHandler) - staticHandler = null + getSparkUI(ssc).removeStaticHandler("/static/streaming") } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index a34f23475804a..e0718f73aa13f 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1735,6 +1735,7 @@ public Integer call(String s) throws Exception { @SuppressWarnings("unchecked") @Test public void testContextGetOrCreate() throws InterruptedException { + ssc.stop(); final SparkConf conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 08faeaa58f419..255376807c957 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -81,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase { test("repartition (more partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(5) - val ssc = setupStreams(input, operation, 2) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 5) - assert(second.size === 5) - assert(third.size === 5) - - assert(first.flatten.toSet.equals((1 to 100).toSet) ) - assert(second.flatten.toSet.equals((101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 2)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("repartition (fewer partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(2) - val ssc = setupStreams(input, operation, 5) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 2) - assert(second.size === 2) - assert(third.size === 2) - - assert(first.flatten.toSet.equals((1 to 100).toSet)) - assert(second.flatten.toSet.equals( (101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 5)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("groupByKey") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index d308ac05a54fe..67c2d900940ab 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -30,8 +30,10 @@ import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} +import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver} import org.apache.spark.util.{Clock, ManualClock, Utils} /** @@ -391,6 +393,32 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } + test("recovery maintains rate controller") { + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir) + + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + } + SingletonTestRateReceiver.reset() + + val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) + output.register() + runStreams(ssc, 5, 5) + + SingletonTestRateReceiver.reset() + ssc = new StreamingContext(checkpointDir) + ssc.start() + val outputNew = advanceTimeWithRealDelay(ssc, 2) + + eventually(timeout(5.seconds)) { + assert(dstream.getCurrentRateLimit === Some(200)) + } + ssc.stop() + ssc = null + } + // This tests whether file input stream remembers what files were seen before // the master failure and uses them again to process a large window operation. // It also tests whether batches, whose processing was incomplete due to the diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index b74d67c63a788..ec2852d9a0206 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -325,27 +325,31 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("test track the number of input stream") { - val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - class TestInputDStream extends InputDStream[String](ssc) { - def start() { } - def stop() { } - def compute(validTime: Time): Option[RDD[String]] = None - } + class TestInputDStream extends InputDStream[String](ssc) { + def start() {} - class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { - def getReceiver: Receiver[String] = null - } + def stop() {} + + def compute(validTime: Time): Option[RDD[String]] = None + } + + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } - // Register input streams - val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) - val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) - assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length) - assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) - assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) - assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) - assert(receiverInputStreams.map(_.id) === Array(0, 1)) + assert(ssc.graph.getInputStreams().length == + receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } } def testFileStream(newFilesOnly: Boolean) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 6e9d4431090a2..0e64b57e0ffd8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -244,7 +244,13 @@ object MasterFailureTest extends Logging { } catch { case e: Exception => logError("Error running streaming context", e) } - if (killingThread.isAlive) killingThread.interrupt() + if (killingThread.isAlive) { + killingThread.interrupt() + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is + // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env + // to null after the next test creates the new SparkContext and fail the test. + killingThread.join() + } ssc.stop() logInfo("Has been killed = " + killed) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4bba9691f8aa5..84a5fbb3d95eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -120,7 +120,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) - val ssc = new StreamingContext(myConf, batchDuration) + ssc = new StreamingContext(myConf, batchDuration) assert(ssc.checkpointDir != null) } @@ -369,16 +369,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + var t: Thread = null // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() ssc.awaitTermination() } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("awaitTermination after stop") { @@ -430,16 +436,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.awaitTerminationOrTimeout(500) === false) } + var t: Thread = null // test whether awaitTerminationOrTimeout() return true if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() assert(ssc.awaitTerminationOrTimeout(10000) === true) } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("getOrCreate") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 4bc1dd4a30fc4..d840c349bbbc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -36,13 +36,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // To make sure that the processing start and end times in collected // information are different for successive batches override def batchDuration: Duration = Milliseconds(100) override def actuallyWait: Boolean = true test("batch info reporting") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) @@ -107,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } test("receiver info reporting") { - val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) inputStream.foreachRDD(_.count) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala new file mode 100644 index 0000000000000..921da773f6c11 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +class RateControllerSuite extends TestSuiteBase { + + override def useManualClock: Boolean = false + + test("rate controller publishes updates") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateLimitInputDStream(ssc) + dstream.register() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.publishCalls > 0) + } + } + } + + test("publish rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + } + dstream.register() + SingletonTestRateReceiver.reset() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.getCurrentRateLimit === Some(200)) + } + } + } + + test("multiple publish rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val rates = Seq(100L, 200L, 300L) + + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*))) + } + SingletonTestRateReceiver.reset() + dstream.register() + + val observedRates = mutable.HashSet.empty[Long] + ssc.start() + + eventually(timeout(20.seconds)) { + dstream.getCurrentRateLimit.foreach(observedRates += _) + // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver + observedRates should contain theSameElementsAs (rates :+ Long.MaxValue) + } + } + } +} + +private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator { + private var idx: Int = 0 + + private def nextRate(): Double = { + val rate = rates(idx) + idx = (idx + 1) % rates.size + rate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(nextRate()) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala new file mode 100644 index 0000000000000..0418d776ecc9a --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite + +class ReceiverSchedulingPolicySuite extends SparkFunSuite { + + val receiverSchedulingPolicy = new ReceiverSchedulingPolicy + + test("rescheduleReceiver: empty executors") { + val scheduledExecutors = + receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty) + assert(scheduledExecutors === Seq.empty) + } + + test("rescheduleReceiver: receiver preferredLocation") { + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2")) + assert(scheduledExecutors.toSet === Set("host1", "host2")) + } + + test("rescheduleReceiver: return all idle executors if more than 3 idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // host3 is idle + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1"))) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 1, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) + } + + test("rescheduleReceiver: return 3 best options if less than 3 idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0 + // host4 and host5 are idle + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), + 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), + 2 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 3, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) + } + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more receivers than executors") { + val receivers = (0 until 6).map(new RateTestReceiver(_)) + val executors = (10000 until 10003).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 2 receivers running on each executor and each receiver has one executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1 + } + assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap) + } + + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more executors than receivers") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) + val executors = (10000 until 10006).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has two executors + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 2) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + } + + test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) ++ + (3 until 6).map(new RateTestReceiver(_, Some("localhost"))) + val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ + (10003 until 10006).map(port => s"localhost2:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has 1 executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + // Make sure we schedule the receivers to their preferredLocations + val executorsForReceiversWithPreferredLocation = + scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) + // We can simply check the executor set because we only know each receiver only has 1 executor + assert(executorsForReceiversWithPreferredLocation.toSet === + (10000 until 10003).map(port => s"localhost:${port}").toSet) + } + + test("scheduleReceivers: return empty if no receiver") { + assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty) + } + + test("scheduleReceivers: return empty scheduled executors if no executors") { + val receivers = (0 until 3).map(new RateTestReceiver(_)) + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.isEmpty) + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index aadb7231757b8..afad5f16dbc71 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -18,135 +18,118 @@ package org.apache.spark.streaming.scheduler import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.streaming._ + import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ import org.apache.spark.streaming.receiver._ -import org.apache.spark.util.Utils -import org.apache.spark.streaming.dstream.InputDStream -import scala.reflect.ClassTag import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.storage.StorageLevel /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") - val ssc = new StreamingContext(sparkConf, Milliseconds(100)) - val tracker = new ReceiverTracker(ssc) - val launcher = new tracker.ReceiverLauncher() - val executors: List[String] = List("0", "1", "2", "3") - - test("receiver scheduling - all or none have preferred location") { - - def parse(s: String): Array[Array[String]] = { - val outerSplit = s.split("\\|") - val loc = new Array[Array[String]](outerSplit.length) - var i = 0 - for (i <- 0 until outerSplit.length) { - loc(i) = outerSplit(i).split("\\,") - } - loc - } - - def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) { - val receivers = - if (preferredLocation) { - Array.tabulate(numReceivers)(i => new DummyReceiver(host = - Some(((i + 1) % executors.length).toString))) - } else { - Array.tabulate(numReceivers)(_ => new DummyReceiver) - } - val locations = launcher.scheduleReceivers(receivers, executors) - val expectedLocations = parse(allocation) - assert(locations.deep === expectedLocations.deep) - } - - testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0") - testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2") - testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0") - } - - test("receiver scheduling - some have preferred location") { - val numReceivers = 4; - val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")), - new DummyReceiver, new DummyReceiver, new DummyReceiver) - val locations = launcher.scheduleReceivers(receivers, executors) - assert(locations(0)(0) === "1") - assert(locations(1)(0) === "0") - assert(locations(2)(0) === "1") - assert(locations(0).length === 1) - assert(locations(3).length === 1) - } test("Receiver tracker - propagates rate limit") { - object ReceiverStartedWaiter extends StreamingListener { - @volatile - var started = false + withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc => + object ReceiverStartedWaiter extends StreamingListener { + @volatile + var started = false - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { - started = true + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + started = true + } } - } - - ssc.addStreamingListener(ReceiverStartedWaiter) - ssc.scheduler.listenerBus.start(ssc.sc) - - val newRateLimit = 100L - val inputDStream = new RateLimitInputDStream(ssc) - val tracker = new ReceiverTracker(ssc) - tracker.start() - // we wait until the Receiver has registered with the tracker, - // otherwise our rate update is lost - eventually(timeout(5 seconds)) { - assert(ReceiverStartedWaiter.started) - } - tracker.sendRateUpdate(inputDStream.id, newRateLimit) - // this is an async message, we need to wait a bit for it to be processed - eventually(timeout(3 seconds)) { - assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + ssc.addStreamingListener(ReceiverStartedWaiter) + ssc.scheduler.listenerBus.start(ssc.sc) + SingletonTestRateReceiver.reset() + + val newRateLimit = 100L + val inputDStream = new RateLimitInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + try { + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(ReceiverStartedWaiter.started) + } + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + // this is an async message, we need to wait a bit for it to be processed + eventually(timeout(3 seconds)) { + assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + } + } finally { + tracker.stop(false) + } } } } -/** An input DStream with a hard-coded receiver that gives access to internals for testing. */ -private class RateLimitInputDStream(@transient ssc_ : StreamingContext) +/** + * An input DStream with a hard-coded receiver that gives access to internals for testing. + * + * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test, + * or otherwise you may get {{{NotSerializableException}}} when trying to serialize + * the receiver. + * @see [[[SingletonDummyReceiver]]]. + */ +private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext) extends ReceiverInputDStream[Int](ssc_) { - override def getReceiver(): DummyReceiver = SingletonDummyReceiver + override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver def getCurrentRateLimit: Option[Long] = { invokeExecutorMethod.getCurrentRateLimit } + @volatile + var publishCalls = 0 + + override val rateController: Option[RateController] = { + Some(new RateController(id, new ConstantEstimator(100.0)) { + override def publish(rate: Long): Unit = { + publishCalls += 1 + } + }) + } + private def invokeExecutorMethod: ReceiverSupervisor = { val c = classOf[Receiver[_]] val ex = c.getDeclaredMethod("executor") ex.setAccessible(true) - ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor] + ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor] } } /** - * A Receiver as an object so we can read its rate limit. + * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when + * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being + * serialized when receivers are installed on executors. * * @note It's necessary to be a top-level object, or else serialization would create another * one on the executor side and we won't be able to read its rate limit. */ -private object SingletonDummyReceiver extends DummyReceiver +private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) { + + /** Reset the object to be usable in another test. */ + def reset(): Unit = { + executor_ = null + } +} /** * Dummy receiver implementation */ -private class DummyReceiver(host: Option[String] = None) +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - def onStart() { - } + setReceiverId(receiverId) - def onStop() { - } + override def onStart(): Unit = {} + + override def onStop(): Unit = {} override def preferredLocation: Option[String] = host } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 40dc1fb601bd0..995f1197ccdfd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -22,15 +22,24 @@ import java.util.Properties import org.scalatest.Matchers import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + private def createJobStart( batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { val properties = new Properties() @@ -46,7 +55,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + "onReceiverStarted, onReceiverError, onReceiverStopped") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val streamIdToInputInfo = Map( @@ -119,20 +128,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) @@ -141,7 +150,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("Remove the old completed batches when exceeding the limit") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -158,7 +167,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("out-of-order onJobStart and onBatchXXX") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -209,7 +218,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("detect memory leak") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d0bde69cc1068..198e0684f32f8 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -74,12 +74,6 @@ public final class BytesToBytesMap { */ private long pageCursor = 0; - /** - * The size of the data pages that hold key and value data. Map entries cannot span multiple - * pages, so this limits the maximum entry size. - */ - private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes - /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since @@ -117,6 +111,12 @@ public final class BytesToBytesMap { private final double loadFactor; + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private final long pageSizeBytes; + /** * Number of keys defined in the map. */ @@ -153,10 +153,12 @@ public BytesToBytesMap( TaskMemoryManager memoryManager, int initialCapacity, double loadFactor, + long pageSizeBytes, boolean enablePerfMetrics) { this.memoryManager = memoryManager; this.loadFactor = loadFactor; this.loc = new Location(); + this.pageSizeBytes = pageSizeBytes; this.enablePerfMetrics = enablePerfMetrics; if (initialCapacity <= 0) { throw new IllegalArgumentException("Initial capacity must be greater than 0"); @@ -165,18 +167,26 @@ public BytesToBytesMap( throw new IllegalArgumentException( "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY); } + if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) { + throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); + } allocate(initialCapacity); } - public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { - this(memoryManager, initialCapacity, 0.70, false); + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + long pageSizeBytes) { + this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false); } public BytesToBytesMap( TaskMemoryManager memoryManager, int initialCapacity, + long pageSizeBytes, boolean enablePerfMetrics) { - this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics); } /** @@ -443,20 +453,20 @@ public void putNewKey( // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; - assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. + assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker. size++; bitset.set(pos); // If there's not enough space in the current page, allocate a new page (8 bytes are reserved // for the end-of-page marker). - if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) { + if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) { if (currentDataPage != null) { // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); + MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes); dataPages.add(newPage); pageCursor = 0; currentDataPage = newPage; @@ -538,10 +548,11 @@ public void free() { /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ public long getTotalMemoryConsumption() { - return ( - dataPages.size() * PAGE_SIZE_BYTES + - bitset.memoryBlock().size() + - longArray.memoryBlock().size()); + long totalDataPagesSize = 0L; + for (MemoryBlock dataPage : dataPages) { + totalDataPagesSize += dataPage.size(); + } + return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size(); } /** diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 10881969dbc78..dd70df3b1f791 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -58,8 +58,13 @@ public class TaskMemoryManager { /** The number of entries in the page table. */ private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; - /** Maximum supported data page size */ - private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); + /** + * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page + * size is limited by the maximum amount of data that can be stored in a long[] array, which is + * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. + */ + public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -110,9 +115,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (size > MAXIMUM_PAGE_SIZE) { + if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( - "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } final int pageNumber; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java similarity index 52% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala rename to unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index 1618c24871c60..69b0e206cef18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -15,23 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions +package org.apache.spark.unsafe.types; -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.PlatformDependent; -class DatetimeFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("datetime function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] - val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) - } +public class ByteArray { - test("datetime function current_timestamp") { - val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long]) - val t1 = System.currentTimeMillis() - assert(math.abs(t1 - ct.getTime) < 5000) + /** + * Writes the content of a byte array into a memory address, identified by an object and an + * offset. The target memory address must already been allocated, and have enough space to + * hold all the bytes in this string. + */ + public static void writeToMemory(byte[] src, Object target, long targetOffset) { + PlatformDependent.copyMemory( + src, + PlatformDependent.BYTE_ARRAY_OFFSET, + target, + targetOffset, + src.length + ); } - } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java similarity index 87% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java rename to unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 71b1a85a818ea..92a5e4f86f234 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -24,7 +24,7 @@ /** * The internal representation of interval type. */ -public final class Interval implements Serializable { +public final class CalendarInterval implements Serializable { public static final long MICROS_PER_MILLI = 1000L; public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; @@ -58,7 +58,7 @@ private static long toLong(String s) { } } - public static Interval fromString(String s) { + public static CalendarInterval fromString(String s) { if (s == null) { return null; } @@ -75,40 +75,40 @@ public static Interval fromString(String s) { microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; microseconds += toLong(m.group(9)); - return new Interval((int) months, microseconds); + return new CalendarInterval((int) months, microseconds); } } public final int months; public final long microseconds; - public Interval(int months, long microseconds) { + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; } - public Interval add(Interval that) { + public CalendarInterval add(CalendarInterval that) { int months = this.months + that.months; long microseconds = this.microseconds + that.microseconds; - return new Interval(months, microseconds); + return new CalendarInterval(months, microseconds); } - public Interval subtract(Interval that) { + public CalendarInterval subtract(CalendarInterval that) { int months = this.months - that.months; long microseconds = this.microseconds - that.microseconds; - return new Interval(months, microseconds); + return new CalendarInterval(months, microseconds); } - public Interval negate() { - return new Interval(-this.months, -this.microseconds); + public CalendarInterval negate() { + return new CalendarInterval(-this.months, -this.microseconds); } @Override public boolean equals(Object other) { if (this == other) return true; - if (other == null || !(other instanceof Interval)) return false; + if (other == null || !(other instanceof CalendarInterval)) return false; - Interval o = (Interval) other; + CalendarInterval o = (CalendarInterval) other; return this.months == o.months && this.microseconds == o.microseconds; } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 87b4b6f90a2c0..9d4998fd48a38 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -48,7 +48,7 @@ public final class UTF8String implements Comparable, Serializable { 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, - 6, 6, 6, 6}; + 6, 6}; public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); @@ -65,6 +65,19 @@ public static UTF8String fromBytes(byte[] bytes) { } } + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ + public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + } else { + return null; + } + } + /** * Creates an UTF8String from String. */ @@ -89,10 +102,25 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int size) { + protected UTF8String(Object base, long offset, int numBytes) { this.base = base; this.offset = offset; - this.numBytes = size; + this.numBytes = numBytes; + } + + /** + * Writes the content of this string into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. + */ + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + base, + offset, + target, + targetOffset, + numBytes + ); } /** @@ -122,6 +150,32 @@ public int numChars() { return len; } + /** + * Returns a 64-bit integer that can be used as the prefix used in sorting. + */ + public long getPrefix() { + // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string. + // If size is 0, just return 0. + // If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and + // use a getInt to fetch the prefix. + // If size is greater than 4, assume we have at least 8 bytes of data to fetch. + // After getting the data, we use a mask to mask out data that is not part of the string. + long p; + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + p = p & ((1L << numBytes * 8) - 1); + } else if (numBytes > 0) { + p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + p = p & ((1L << numBytes * 8) - 1); + } else { + p = 0; + } + p = java.lang.Long.reverseBytes(p); + return p; + } + /** * Returns the underline bytes, will be a copy of it if it's part of another array. */ @@ -285,13 +339,13 @@ public UTF8String trimRight() { } public UTF8String reverse() { - byte[] bytes = getBytes(); - byte[] result = new byte[bytes.length]; + byte[] result = new byte[this.numBytes]; int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - System.arraycopy(bytes, i, result, result.length - i - len, len); + copyMemory(this.base, this.offset + i, result, + BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -301,11 +355,11 @@ public UTF8String reverse() { public UTF8String repeat(int times) { if (times <=0) { - return fromBytes(new byte[0]); + return EMPTY_UTF8; } byte[] newBytes = new byte[numBytes * times]; - System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -370,16 +424,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; - byte[] padBytes = pad.getBytes(); while (idx < count) { - System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++idx; offset += pad.numBytes; } - System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -406,15 +459,14 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; - byte[] padBytes = pad.getBytes(); while (idx < count) { - System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++idx; offset += pad.numBytes; } - System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - System.arraycopy(getBytes(), 0, data, offset, numBytes()); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -439,9 +491,9 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - PlatformDependent.copyMemory( + copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; } @@ -479,7 +531,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - PlatformDependent.copyMemory( + copyMemory( inputs[i].base, inputs[i].offset, result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, len); @@ -488,7 +540,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - PlatformDependent.copyMemory( + copyMemory( separator.base, separator.offset, result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, separator.numBytes); diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index dae47e4bab0cb..0be94ad371255 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -43,6 +43,7 @@ public abstract class AbstractBytesToBytesMapSuite { private TaskMemoryManager memoryManager; private TaskMemoryManager sizeLimitedMemoryManager; + private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes @Before public void setup() { @@ -110,7 +111,7 @@ private static boolean arrayEquals( @Test public void emptyMap() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES); try { Assert.assertEquals(0, map.size()); final int keyLengthInWords = 10; @@ -125,7 +126,7 @@ public void emptyMap() { @Test public void setAndRetrieveAKey() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES); final int recordLengthWords = 10; final int recordLengthBytes = recordLengthWords * 8; final byte[] keyData = getRandomByteArray(recordLengthWords); @@ -177,7 +178,7 @@ public void setAndRetrieveAKey() { @Test public void iteratorTest() throws Exception { final int size = 4096; - BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2, PAGE_SIZE_BYTES); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; @@ -235,7 +236,7 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; final int KEY_LENGTH = 16; final int VALUE_LENGTH = 40; - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte // pages won't be evenly-divisible by records of this size, which will cause us to waste some // space at the end of the page. This is necessary in order for us to take the end-of-record @@ -304,7 +305,7 @@ public void randomizedStressTest() { // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size, PAGE_SIZE_BYTES); try { // Fill the map to 90% full so that we can trigger probing @@ -353,14 +354,15 @@ public void randomizedStressTest() { @Test public void initialCapacityBoundsChecking() { try { - new BytesToBytesMap(sizeLimitedMemoryManager, 0); + new BytesToBytesMap(sizeLimitedMemoryManager, 0, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception } try { - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1); + new BytesToBytesMap( + sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception @@ -368,15 +370,15 @@ public void initialCapacityBoundsChecking() { // Can allocate _at_ the max capacity BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY); + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY, PAGE_SIZE_BYTES); map.free(); } @Test public void resizingLargeMap() { // As long as a map's capacity is below the max, we should be able to resize up to the max - BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64); + BytesToBytesMap map = new BytesToBytesMap( + sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64, PAGE_SIZE_BYTES); map.growAndRehash(); map.free(); } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index d29517cda66a3..e6733a7aae6f5 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -20,16 +20,16 @@ import org.junit.Test; import static junit.framework.Assert.*; -import static org.apache.spark.unsafe.types.Interval.*; +import static org.apache.spark.unsafe.types.CalendarInterval.*; public class IntervalSuite { @Test public void equalsTest() { - Interval i1 = new Interval(3, 123); - Interval i2 = new Interval(3, 321); - Interval i3 = new Interval(1, 123); - Interval i4 = new Interval(3, 123); + CalendarInterval i1 = new CalendarInterval(3, 123); + CalendarInterval i2 = new CalendarInterval(3, 321); + CalendarInterval i3 = new CalendarInterval(1, 123); + CalendarInterval i4 = new CalendarInterval(3, 123); assertNotSame(i1, i2); assertNotSame(i1, i3); @@ -39,21 +39,21 @@ public void equalsTest() { @Test public void toStringTest() { - Interval i; + CalendarInterval i; - i = new Interval(34, 0); + i = new CalendarInterval(34, 0); assertEquals(i.toString(), "interval 2 years 10 months"); - i = new Interval(-34, 0); + i = new CalendarInterval(-34, 0); assertEquals(i.toString(), "interval -2 years -10 months"); - i = new Interval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); - i = new Interval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); - i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); } @@ -72,33 +72,33 @@ public void fromStringTest() { String input; input = "interval -5 years 23 month"; - Interval result = new Interval(-5 * 12 + 23, 0); - assertEquals(Interval.fromString(input), result); + CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); + assertEquals(CalendarInterval.fromString(input), result); input = "interval -5 years 23 month "; - assertEquals(Interval.fromString(input), result); + assertEquals(CalendarInterval.fromString(input), result); input = " interval -5 years 23 month "; - assertEquals(Interval.fromString(input), result); + assertEquals(CalendarInterval.fromString(input), result); // Error cases input = "interval 3month 1 hour"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "interval 3 moth 1 hour"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "interval"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "int"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = ""; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = null; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); } @Test @@ -106,18 +106,18 @@ public void addTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - Interval interval = Interval.fromString(input); - Interval interval2 = Interval.fromString(input2); + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR)); + assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = Interval.fromString(input); - interval2 = Interval.fromString(input2); + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR)); + assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); } @Test @@ -125,25 +125,25 @@ public void subtractTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - Interval interval = Interval.fromString(input); - Interval interval2 = Interval.fromString(input2); + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR)); + assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = Interval.fromString(input); - interval2 = Interval.fromString(input2); + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR)); + assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); } private void testSingleUnit(String unit, int number, int months, long microseconds) { String input1 = "interval " + number + " " + unit; String input2 = "interval " + number + " " + unit + "s"; - Interval result = new Interval(months, microseconds); - assertEquals(Interval.fromString(input1), result); - assertEquals(Interval.fromString(input2), result); + CalendarInterval result = new CalendarInterval(months, microseconds); + assertEquals(CalendarInterval.fromString(input1), result); + assertEquals(CalendarInterval.fromString(input2), result); } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 5601048aec344..c565210872322 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -63,8 +63,27 @@ public void emptyStringTest() { assertEquals(0, EMPTY_UTF8.numBytes()); } + @Test + public void prefix() { + assertTrue(fromString("a").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue(fromString("ab").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue( + fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0); + assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0); + assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0); + + byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + byte[] buf2 = {1, 2, 3}; + UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3); + UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8); + UTF8String str3 = UTF8String.fromBytes(buf2); + assertTrue(str1.getPrefix() - str2.getPrefix() < 0); + assertEquals(str1.getPrefix(), str3.getPrefix()); + } + @Test public void compareTo() { + assertTrue(fromString("").compareTo(fromString("a")) < 0); assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0); assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0); diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 83dafa4a125d2..1d67b3ebb51b7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -229,7 +229,11 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM( + _rpcEnv: RpcEnv, + driverRef: RpcEndpointRef, + uiAddress: String, + securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -246,6 +250,7 @@ private[spark] class ApplicationMaster( RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) allocator = client.register(driverUrl, + driverRef, yarnConf, _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), @@ -262,17 +267,20 @@ private[spark] class ApplicationMaster( * * In cluster mode, the AM and the driver belong to same process * so the AMEndpoint need not monitor lifecycle of the driver. + * + * @return A reference to the driver's RPC endpoint. */ private def runAMEndpoint( host: String, port: String, - isClusterMode: Boolean): Unit = { + isClusterMode: Boolean): RpcEndpointRef = { val driverEndpoint = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode)) + driverEndpoint } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -290,11 +298,11 @@ private[spark] class ApplicationMaster( "Timed out waiting for SparkContext.") } else { rpcEnv = sc.env.rpcEnv - runAMEndpoint( + val driverRef = runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -302,9 +310,9 @@ private[spark] class ApplicationMaster( private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { val port = sparkConf.getInt("spark.yarn.am.port", 0) rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) - waitForSparkDriver() + val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -428,7 +436,7 @@ private[spark] class ApplicationMaster( } } - private def waitForSparkDriver(): Unit = { + private def waitForSparkDriver(): RpcEndpointRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) @@ -555,11 +563,12 @@ private[spark] class ApplicationMaster( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestExecutors(requestedTotal) => + case RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) => Option(allocator) match { case Some(a) => allocatorLock.synchronized { - if (a.requestTotalExecutors(requestedTotal)) { + if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, + localityAwareTasks, hostToLocalTaskCount)) { allocatorLock.notifyAll() } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bc28ce5eeae72..4ac3397f1ad28 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -767,7 +767,7 @@ private[spark] class Client( amContainer.setCommands(printableCommands) logDebug("===============================================================================") - logDebug("Yarn AM launch context:") + logDebug("YARN AM launch context:") logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") logDebug(" env:") launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 78e27fb7f3337..52580deb372c2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -86,10 +86,17 @@ class ExecutorRunnable( val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, appId, localResources) - logInfo(s"Setting up executor with environment: $env") - logInfo("Setting up executor with commands: " + commands) - ctx.setCommands(commands) + logInfo(s""" + |=============================================================================== + |YARN executor launch context: + | env: + |${env.map { case (k, v) => s" $k -> $v\n" }.mkString} + | command: + | ${commands.mkString(" ")} + |=============================================================================== + """.stripMargin) + ctx.setCommands(commands) ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) // If external shuffle service is enabled, register with the Yarn shuffle service already diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala new file mode 100644 index 0000000000000..081780204e424 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import scala.collection.mutable.{ArrayBuffer, HashMap, Set} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} +import org.apache.hadoop.yarn.util.RackResolver + +import org.apache.spark.SparkConf + +private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], racks: Array[String]) + +/** + * This strategy is calculating the optimal locality preferences of YARN containers by considering + * the node ratio of pending tasks, number of required cores/containers and and locality of current + * existing containers. The target of this algorithm is to maximize the number of tasks that + * would run locally. + * + * Consider a situation in which we have 20 tasks that require (host1, host2, host3) + * and 10 tasks that require (host1, host2, host4), besides each container has 2 cores + * and cpus per task is 1, so the required container number is 15, + * and host ratio is (host1: 30, host2: 30, host3: 20, host4: 10). + * + * 1. If requested container number (18) is more than the required container number (15): + * + * requests for 5 containers with nodes: (host1, host2, host3, host4) + * requests for 5 containers with nodes: (host1, host2, host3) + * requests for 5 containers with nodes: (host1, host2) + * requests for 3 containers with no locality preferences. + * + * The placement ratio is 3 : 3 : 2 : 1, and set the additional containers with no locality + * preferences. + * + * 2. If requested container number (10) is less than or equal to the required container number + * (15): + * + * requests for 4 containers with nodes: (host1, host2, host3, host4) + * requests for 3 containers with nodes: (host1, host2, host3) + * requests for 3 containers with nodes: (host1, host2) + * + * The placement ratio is 10 : 10 : 7 : 4, close to expected ratio (3 : 3 : 2 : 1) + * + * 3. If containers exist but none of them can match the requested localities, + * follow the method of 1 and 2. + * + * 4. If containers exist and some of them can match the requested localities. + * For example if we have 1 containers on each node (host1: 1, host2: 1: host3: 1, host4: 1), + * and the expected containers on each node would be (host1: 5, host2: 5, host3: 4, host4: 2), + * so the newly requested containers on each node would be updated to (host1: 4, host2: 4, + * host3: 3, host4: 1), 12 containers by total. + * + * 4.1 If requested container number (18) is more than newly required containers (12). Follow + * method 1 with updated ratio 4 : 4 : 3 : 1. + * + * 4.2 If request container number (10) is more than newly required containers (12). Follow + * method 2 with updated ratio 4 : 4 : 3 : 1. + * + * 5. If containers exist and existing localities can fully cover the requested localities. + * For example if we have 5 containers on each node (host1: 5, host2: 5, host3: 5, host4: 5), + * which could cover the current requested localities. This algorithm will allocate all the + * requested containers with no localities. + */ +private[yarn] class LocalityPreferredContainerPlacementStrategy( + val sparkConf: SparkConf, + val yarnConf: Configuration, + val resource: Resource) { + + // Number of CPUs per task + private val CPUS_PER_TASK = sparkConf.getInt("spark.task.cpus", 1) + + /** + * Calculate each container's node locality and rack locality + * @param numContainer number of containers to calculate + * @param numLocalityAwareTasks number of locality required tasks + * @param hostToLocalTaskCount a map to store the preferred hostname and possible task + * numbers running on it, used as hints for container allocation + * @return node localities and rack localities, each locality is an array of string, + * the length of localities is the same as number of containers + */ + def localityOfRequestedContainers( + numContainer: Int, + numLocalityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int], + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + ): Array[ContainerLocalityPreferences] = { + val updatedHostToContainerCount = expectedHostToContainerCount( + numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap) + val updatedLocalityAwareContainerNum = updatedHostToContainerCount.values.sum + + // The number of containers to allocate, divided into two groups, one with preferred locality, + // and the other without locality preference. + val requiredLocalityFreeContainerNum = + math.max(0, numContainer - updatedLocalityAwareContainerNum) + val requiredLocalityAwareContainerNum = numContainer - requiredLocalityFreeContainerNum + + val containerLocalityPreferences = ArrayBuffer[ContainerLocalityPreferences]() + if (requiredLocalityFreeContainerNum > 0) { + for (i <- 0 until requiredLocalityFreeContainerNum) { + containerLocalityPreferences += ContainerLocalityPreferences( + null.asInstanceOf[Array[String]], null.asInstanceOf[Array[String]]) + } + } + + if (requiredLocalityAwareContainerNum > 0) { + val largestRatio = updatedHostToContainerCount.values.max + // Round the ratio of preferred locality to the number of locality required container + // number, which is used for locality preferred host calculating. + var preferredLocalityRatio = updatedHostToContainerCount.mapValues { ratio => + val adjustedRatio = ratio.toDouble * requiredLocalityAwareContainerNum / largestRatio + adjustedRatio.ceil.toInt + } + + for (i <- 0 until requiredLocalityAwareContainerNum) { + // Only filter out the ratio which is larger than 0, which means the current host can + // still be allocated with new container request. + val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray + val racks = hosts.map { h => + RackResolver.resolve(yarnConf, h).getNetworkLocation + }.toSet + containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) + + // Minus 1 each time when the host is used. When the current ratio is 0, + // which means all the required ratio is satisfied, this host will not be allocated again. + preferredLocalityRatio = preferredLocalityRatio.mapValues(_ - 1) + } + } + + containerLocalityPreferences.toArray + } + + /** + * Calculate the number of executors need to satisfy the given number of pending tasks. + */ + private def numExecutorsPending(numTasksPending: Int): Int = { + val coresPerExecutor = resource.getVirtualCores + (numTasksPending * CPUS_PER_TASK + coresPerExecutor - 1) / coresPerExecutor + } + + /** + * Calculate the expected host to number of containers by considering with allocated containers. + * @param localityAwareTasks number of locality aware tasks + * @param hostToLocalTaskCount a map to store the preferred hostname and possible task + * numbers running on it, used as hints for container allocation + * @return a map with hostname as key and required number of containers on this host as value + */ + private def expectedHostToContainerCount( + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int], + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + ): Map[String, Int] = { + val totalLocalTaskNum = hostToLocalTaskCount.values.sum + hostToLocalTaskCount.map { case (host, count) => + val expectedCount = + count.toDouble * numExecutorsPending(localityAwareTasks) / totalLocalTaskNum + val existedCount = allocatedHostToContainersMap.get(host) + .map(_.size) + .getOrElse(0) + + // If existing container can not fully satisfy the expected number of container, + // the required container number is expected count minus existed count. Otherwise the + // required container number is 0. + (host, math.max(0, (expectedCount - existedCount).ceil.toInt)) + } + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 940873fbd046c..59caa787b6e20 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -36,6 +36,9 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -52,6 +55,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ */ private[yarn] class YarnAllocator( driverUrl: String, + driverRef: RpcEndpointRef, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -88,6 +92,9 @@ private[yarn] class YarnAllocator( // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] + private var numUnexpectedContainerRelease = 0L + private val containerIdToExecutorId = new HashMap[ContainerId, String] + // Executor memory in MB. protected val executorMemory = args.executorMemory // Additional memory overhead. @@ -96,7 +103,7 @@ private[yarn] class YarnAllocator( // Number of cores per executor. protected val executorCores = args.executorCores // Resource capability requested for each executors - private val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue @@ -127,6 +134,16 @@ private[yarn] class YarnAllocator( } } + // A map to store preferred hostname and possible task numbers running on it. + private var hostToLocalTaskCounts: Map[String, Int] = Map.empty + + // Number of tasks that have locality preferences in active stages + private var numLocalityAwareTasks: Int = 0 + + // A container placement strategy based on pending tasks' locality preference + private[yarn] val containerPlacementStrategy = + new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource) + def getNumExecutorsRunning: Int = numExecutorsRunning def getNumExecutorsFailed: Int = numExecutorsFailed @@ -146,10 +163,19 @@ private[yarn] class YarnAllocator( * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. - * + * @param requestedTotal total number of containers requested + * @param localityAwareTasks number of locality aware tasks to be used as container placement hint + * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as + * container placement hint. * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { + def requestTotalExecutorsWithPreferredLocalities( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean = synchronized { + this.numLocalityAwareTasks = localityAwareTasks + this.hostToLocalTaskCounts = hostToLocalTaskCount + if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal @@ -165,6 +191,7 @@ private[yarn] class YarnAllocator( def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get + containerIdToExecutorId.remove(container.getId) internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -221,12 +248,20 @@ private[yarn] class YarnAllocator( val numPendingAllocate = getNumPendingAllocate val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning + // TODO. Consider locality preferences of pending container requests. + // Since the last time we made container requests, stages have completed and been submitted, + // and that the localities at which we requested our pending executors + // no longer apply to our current needs. We should consider to remove all outstanding + // container requests and add requests anew each time to avoid this. if (missing > 0) { logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") - for (i <- 0 until missing) { - val request = createContainerRequest(resource) + val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers( + missing, numLocalityAwareTasks, hostToLocalTaskCounts, allocatedHostToContainersMap) + + for (locality <- containerLocalityPreferences) { + val request = createContainerRequest(resource, locality.nodes, locality.racks) amClient.addContainerRequest(request) val nodes = request.getNodes val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last @@ -249,11 +284,14 @@ private[yarn] class YarnAllocator( * Creates a container request, handling the reflection required to use YARN features that were * added in recent versions. */ - private def createContainerRequest(resource: Resource): ContainerRequest = { + protected def createContainerRequest( + resource: Resource, + nodes: Array[String], + racks: Array[String]): ContainerRequest = { nodeLabelConstructor.map { constructor => - constructor.newInstance(resource, null, null, RM_REQUEST_PRIORITY, true: java.lang.Boolean, + constructor.newInstance(resource, nodes, racks, RM_REQUEST_PRIORITY, true: java.lang.Boolean, labelExpression.orNull) - }.getOrElse(new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY)) + }.getOrElse(new ContainerRequest(resource, nodes, racks, RM_REQUEST_PRIORITY)) } /** @@ -353,6 +391,7 @@ private[yarn] class YarnAllocator( logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) executorIdToContainer(executorId) = container + containerIdToExecutorId(container.getId) = executorId val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]) @@ -383,12 +422,8 @@ private[yarn] class YarnAllocator( private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId - - if (releasedContainers.contains(containerId)) { - // Already marked the container for release, so remove it from - // `releasedContainers`. - releasedContainers.remove(containerId) - } else { + val alreadyReleased = releasedContainers.remove(containerId) + if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 @@ -430,6 +465,18 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.remove(containerId) } + + containerIdToExecutorId.remove(containerId).foreach { eid => + executorIdToContainer.remove(eid) + + if (!alreadyReleased) { + // The executor could have gone away (like no route to host, node failure, etc) + // Notify backend about the failure of the executor + numUnexpectedContainerRelease += 1 + driverRef.send(RemoveExecutor(eid, + s"Yarn deallocated the executor $eid (container $containerId)")) + } + } } } @@ -438,6 +485,8 @@ private[yarn] class YarnAllocator( amClient.releaseAssignedContainer(container.getId()) } + private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + } private object YarnAllocator { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 7f533ee55e8bb..4999f9c06210a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils @@ -56,6 +57,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg */ def register( driverUrl: String, + driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -73,7 +75,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args, + securityMgr) } /** diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 33f580aaebdc0..1aed5a1675075 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler.cluster import java.net.NetworkInterface +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + import scala.collection.JavaConverters._ import org.apache.hadoop.yarn.api.records.NodeState @@ -64,68 +66,29 @@ private[spark] class YarnClusterSchedulerBackend( } override def getDriverLogUrls: Option[Map[String, String]] = { - var yarnClientOpt: Option[YarnClient] = None var driverLogs: Option[Map[String, String]] = None try { val yarnConf = new YarnConfiguration(sc.hadoopConfiguration) val containerId = YarnSparkHadoopUtil.get.getContainerId - yarnClientOpt = Some(YarnClient.createYarnClient()) - yarnClientOpt.foreach { yarnClient => - yarnClient.init(yarnConf) - yarnClient.start() - - // For newer versions of YARN, we can find the HTTP address for a given node by getting a - // container report for a given container. But container reports came only in Hadoop 2.4, - // so we basically have to get the node reports for all nodes and find the one which runs - // this container. For that we have to compare the node's host against the current host. - // Since the host can have multiple addresses, we need to compare against all of them to - // find out if one matches. - - // Get all the addresses of this node. - val addresses = - NetworkInterface.getNetworkInterfaces.asScala - .flatMap(_.getInetAddresses.asScala) - .toSeq - - // Find a node report that matches one of the addresses - val nodeReport = - yarnClient.getNodeReports(NodeState.RUNNING).asScala.find { x => - val host = x.getNodeId.getHost - addresses.exists { address => - address.getHostAddress == host || - address.getHostName == host || - address.getCanonicalHostName == host - } - } - // Now that we have found the report for the Node Manager that the AM is running on, we - // can get the base HTTP address for the Node manager from the report. - // The format used for the logs for each container is well-known and can be constructed - // using the NM's HTTP address and the container ID. - // The NM may be running several containers, but we can build the URL for the AM using - // the AM's container ID, which we already know. - nodeReport.foreach { report => - val httpAddress = report.getHttpAddress - // lookup appropriate http scheme for container log urls - val yarnHttpPolicy = yarnConf.get( - YarnConfiguration.YARN_HTTP_POLICY_KEY, - YarnConfiguration.YARN_HTTP_POLICY_DEFAULT - ) - val user = Utils.getCurrentUserName() - val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" - val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" - logDebug(s"Base URL for logs: $baseUrl") - driverLogs = Some(Map( - "stderr" -> s"$baseUrl/stderr?start=-4096", - "stdout" -> s"$baseUrl/stdout?start=-4096")) - } - } + val httpAddress = System.getenv(Environment.NM_HOST.name()) + + ":" + System.getenv(Environment.NM_HTTP_PORT.name()) + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val user = Utils.getCurrentUserName() + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" + logDebug(s"Base URL for logs: $baseUrl") + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) } catch { case e: Exception => - logInfo("Node Report API is not available in the version of YARN being used, so AM" + + logInfo("Error while building AM log links, so AM" + " logs link will not appear in application UI", e) - } finally { - yarnClientOpt.foreach(_.close()) } driverLogs } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala new file mode 100644 index 0000000000000..b7fe4ccc67a38 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite + +class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + + private val yarnAllocatorSuite = new YarnAllocatorSuite + import yarnAllocatorSuite._ + + override def beforeEach() { + yarnAllocatorSuite.beforeEach() + } + + override def afterEach() { + yarnAllocatorSuite.afterEach() + } + + test("allocate locality preferred containers with enough resource and no matched existed " + + "containers") { + // 1. All the locations of current containers cannot satisfy the new requirements + // 2. Current requested container number can fully satisfy the pending tasks. + + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array( + Array("host3", "host4", "host5"), + Array("host3", "host4", "host5"), + Array("host3", "host4"))) + } + + test("allocate locality preferred containers with enough resource and partially matched " + + "containers") { + // 1. Parts of current containers' locations can satisfy the new requirements + // 2. Current requested container number can fully satisfy the pending tasks. + + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === + Array(null, Array("host2", "host3"), Array("host2", "host3"))) + } + + test("allocate locality preferred containers with limited resource and partially matched " + + "containers") { + // 1. Parts of current containers' locations can satisfy the new requirements + // 2. Current requested container number cannot fully satisfy the pending tasks. + + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(Array("host2", "host3"))) + } + + test("allocate locality preferred containers with fully matched containers") { + // Current containers' locations can fully satisfy the new requirements + + val handler = createAllocator(5) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2"), + createContainer("host2"), + createContainer("host3") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(null, null, null)) + } + + test("allocate containers with no locality preference") { + // Request new container without locality preference + + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 0, Map.empty, handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(null)) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 7509000771d94..58318bf9bcc08 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -25,15 +25,18 @@ import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.mockito.Mockito._ import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, Matchers} - class MockResolver extends DNSToSwitchMapping { override def resolve(names: JList[String]): JList[String] = { @@ -91,6 +94,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter "--class", "SomeClass") new YarnAllocator( "not used", + mock(classOf[RpcEndpointRef]), conf, sparkConf, rmClient, @@ -171,7 +175,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) - handler.requestTotalExecutors(3) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (3) @@ -182,7 +186,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - handler.requestTotalExecutors(2) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (1) } @@ -193,7 +197,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) - handler.requestTotalExecutors(3) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (3) @@ -203,7 +207,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (2) - handler.requestTotalExecutors(1) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (0) handler.getNumExecutorsRunning should be (2) @@ -219,7 +223,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val container2 = createContainer("host2") handler.handleAllocatedContainers(Array(container1, container2)) - handler.requestTotalExecutors(1) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id ) } val statuses = Seq(container1, container2).map { c => @@ -231,6 +235,30 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumPendingAllocate should be (1) } + test("lost executor removed from backend") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map()) + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (2) + handler.getNumExecutorsFailed should be (2) + handler.getNumUnexpectedContainerRelease should be (2) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + @@ -241,5 +269,4 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) } - }